In [None]:
import torch
import pickle 
import os
from torchvision import transforms 
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from PIL import Image
import torch.utils.data as data
import json
import nltk
from pycocotools.coco import COCO

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
encoder_path = 'resnet152+lstm+hidden_size512+lr_1e-3/models/encoder-5-3000.pkl'
decoder_path = 'resnet152+lstm+hidden_size512+lr_1e-3/models/decoder-5-3000.pkl'
#image_dir = '/datasets/COCO-2015/val2014'
image_dir = '/datasets/COCO-2015/test2015'
#caption_path = '/datasets/ee285f-public/COCO-Annotations/annotations_trainval2014/captions_val2014.json'
caption_path = '/datasets/ee285f-public/COCO-Annotations/image_info_test2015/image_info_test2015.json'
vocab_path = './vocab.pkl'
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)
batch_size = 1
embed_size = 256
hidden_size = 512
num_layers = 1

In [None]:
transform = transforms.Compose([ 
        transforms.Resize((240, 240)), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

In [None]:
class CocoDataset(data.Dataset):
    def __init__(self, root, json, vocab, transform=None):
        
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.imgs.keys())
        #self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        #img_id = coco.anns[ann_id]['image_id']
        img_id = self.ids[index]
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
            image = image.unsqueeze(0)
        return image, img_id

    def __len__(self):
        return len(self.ids)

In [None]:
def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)
    
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers)
    return data_loader

In [None]:
data_loader = get_loader(image_dir,caption_path,
                         vocab,transform,batch_size,
                        shuffle=False,num_workers=5)
print(len(data_loader))
encoder = EncoderCNN(embed_size).eval().to(device)
decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers).to(device)
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

In [None]:
total_step = len(data_loader)
output_list = []
f=open('captions_test2014_cnnrnn_results.json','w')
#image_id_dict = dict()
index = 0
for i,(images,image_id) in enumerate(data_loader):
    index += 1
    if(index%100==0):
        print(index)
    image_id = image_id.cpu().numpy().item()
    #try:
    #    if image_id_dict[image_id]=="yes":
    #        continue
    #except:
    #    image_id_dict[image_id]="yes"
    #    pass
    images = images.view([1,3,240,240]).to(device)
    #print(images.size())
    features = encoder(images)
    sampled_ids = decoder.sample(features)
    sampled_ids = sampled_ids[0].cpu().numpy()
    sampled_caption = []
    result_dict = dict()
    for word_id in sampled_ids:
        word = vocab.idx2word[word_id]
        if word != '<end>' and word != '<start>':
            sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)
    #print(sentence)
    result_dict["image_id"] = image_id
    result_dict["caption"] = sentence
    output_list.append(result_dict)

f.write(json.dumps(output_list))
f.close()