In [18]:
import torch
import resnet.resnet as resnet

from torch import nn
from torchvision import transforms
from model import Net
from PIL import Image

import warnings
warnings.filterwarnings("ignore", category=UserWarning) # Cleaner demos : Don't do this normally...

In [19]:
device = torch.device('cuda')
state_path = '../2017-08-04_00.55.19.pth'
saved_state = torch.load(state_path)
print(saved_state.keys())
qtoken_to_index = saved_state['vocab']['question']
answer_words = ['UNDEF'] * len(saved_state['vocab']['answer'])
print(len(answer_words))
for w,idx in (saved_state['vocab']['answer']).items():
    answer_words[idx]=w

dict_keys(['name', 'tracker', 'config', 'weights', 'eval', 'vocab'])
3000


In [20]:
class WrappedModel(nn.Module):
    
    def __init__(self, embedding_tokens):
        super().__init__()
        self.module = Net(embedding_tokens)
    
    def forward(self, v, q, q_len):
        return self.module.forward(v, q, q_len)

In [21]:
class ResNet152(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.model = resnet.resnet152(pretrained=True)

        image_size = 448
        # output_features = 2048
        central_fraction = 0.875

        self.transform = get_transform(image_size, central_fraction)

        def save_output(module, input, output):
            self.buffer = output
        self.model.layer4.register_forward_hook(save_output)

    def forward(self, x):
        self.model(x)
        return self.buffer

    def image_to_features(self, image_file):
        img = Image.open(image_file).convert('RGB')
        img_transformed = self.transform(img)
        img_batch = img_transformed.unsqueeze(dim=0).to(device)
        return self.forward(img_batch)


In [22]:
def get_transform(target_size, central_fraction=1.0):
    return transforms.Compose([
        transforms.Scale(int(target_size / central_fraction)),
        transforms.CenterCrop(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


In [23]:
def encode_question(question_str):
    tokens = question_str.lower().split(' ')
    vec = torch.zeros(len(tokens)).long()
    for i, token in enumerate(tokens):
        vec[i] = qtoken_to_index.get(token, 0)
    return vec.to(device), torch.tensor(len(tokens)).to(device)


In [24]:
# model = WrappedModel(len(qtoken_to_index) + 1)
model = torch.nn.DataParallel(Net(len(qtoken_to_index) + 1))
model.load_state_dict(saved_state['weights'])
model.to(device)
model.eval()

resnet152 = ResNet152().to(device)

In [25]:
vqa_examples = [
    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/banana.jpg',
        'question': 'what is this fruit?',
        'answer': 'banana'
    },
    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/apple.jpg',
        'question': 'what is this fruit?',
        'answer': 'apple'
    },
    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/gray.jpeg',
        'question': 'what is this fruit?',
        'answer': 'apple'
    },
    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/how many person in the image 1.jpg',
        'question': 'how many person in the image ?',
        'answer': 'one'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/how many person in the image 2.jpg',
        'question': 'how many person in the image ?',
        'answer': 'three'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/how many person in the image 3.jpg',
        'question': 'how many person in the image ?',
        'answer': 'four'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/how many person in the image 4.jpg',
        'question': 'how many person in the image ?',
        'answer': 'six'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/how many person in the image 5.jpg',
        'question': 'how many person in the image ?',
        'answer': 'seven'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/how many person in the image.jpg',
        'question': 'how many person in the image ?',
        'answer': 'five'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is the child happy 1.jpg',
        'question': 'is the child happy ?',
        'answer': 'no'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is the child happy.jpg',
        'question': 'is the child happy ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (1).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (1).png',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (1).webp',
        'question': 'is there a sea in the image ?',
        'answer': 'no'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (10).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (11).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'no'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (2).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (3).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'no'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (4).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (5).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (6).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (7).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (8).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'no'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image (9).jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'no'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/is there a sea in the image.jpg',
        'question': 'is there a sea in the image ?',
        'answer': 'yes'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what dose the child do.jpg',
        'question': 'what dose the child do ?',
        'answer': 'sing'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of child eyes.jpg',
        'question': 'what is the color of child eyes ?',
        'answer': 'blue'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of child hair 1.jpg',
        'question': 'what is the color of child hair ?',
        'answer': 'black'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of child hair.jpg',
        'question': 'what is the color of child hair ?',
        'answer': 'blonde'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of the cat  (1).jpg',
        'question': 'what is the color of the cat ?',
        'answer': 'brown'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of the cat  (2).jpg',
        'question': 'what is the color of the cat ?',
        'answer': 'black'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of the cat  (3).jpg',
        'question': 'what is the color of the cat ?',
        'answer': 'gray'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of the cat  (4).jpg',
        'question': 'what is the color of the cat ?',
        'answer': 'white'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of the cat  (5).jpg',
        'question': 'what is the color of the cat ?',
        'answer': 'brown'
    },

    {
        'img_path': 'C:/Users/mhala/Desktop/Projects/img/what is the color of the cat  (6).jpg',
        'question': 'what is the color of the cat ?',
        'answer': 'black'
    }
]


In [26]:
%%time
print('-'*50)
for d in vqa_examples:
    v = resnet152.image_to_features(d['img_path'])
    q, q_len = encode_question(d['question'])
    ans = model(v, q.unsqueeze(0), q_len.unsqueeze(0))
    _, answer_idx = ans.data.cpu().max(dim=1)
    print('question:', d['question'] , ', vqa answer:', answer_words[ answer_idx ], ', actuel answer:', d['answer'])
print('-'*50)

--------------------------------------------------


RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu