In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
import numpy as np
import os
from torch.utils.data import Dataset
from torchvision import transforms

class VQADataset(Dataset):

    def __init__(self, input_dir, input_file, max_qu_len = 30, transform = None):

        self.input_data = np.load(os.path.join(input_dir, input_file), allow_pickle=True)
        self.qu_vocab = Vocab(input_dir+'/question_vocabs.txt')
        self.ans_vocab = Vocab(input_dir+'/annotation_vocabs.txt')
        self.max_qu_len = max_qu_len
        self.labeled = True if not "test" in input_file else False
        self.transform = transform

    def __getitem__(self, idx):

        path = self.input_data[idx]['img_path']
        img = np.array(Image.open(path).convert('RGB'))
        qu_id = self.input_data[idx]['qu_id']
        qu_tokens = self.input_data[idx]['qu_tokens']
        qu2idx = np.array([self.qu_vocab.word2idx('<pad>')] * self.max_qu_len)
        qu2idx[:len(qu_tokens)] = [self.qu_vocab.word2idx(token) for token in qu_tokens]
        sample = {'image': img, 'question': qu2idx, 'question_id': qu_id}

        if self.labeled:
            ans2idx = [self.ans_vocab.word2idx(ans) for ans in self.input_data[idx]['valid_ans']]
            ans2idx = np.random.choice(ans2idx)
            sample['answer'] = ans2idx

        if self.transform:
            sample['image'] = self.transform(sample['image'])

        return sample

    def __len__(self):

        return len(self.input_data)


class Vocab:

    def __init__(self, vocab_file):

        self.vocab = self.load_vocab(vocab_file)
        self.vocab2idx = {vocab: idx for idx, vocab in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)

    def load_vocab(self, vocab_file):

        with open(vocab_file) as f:
            vocab = [v.strip() for v in f]

        return vocab

    def word2idx(self, vocab):

        if vocab in self.vocab2idx:
            return self.vocab2idx[vocab]
        else:
            return self.vocab2idx['<unk>']

    def idx2word(self, idx):

        return self.vocab[idx]


In [6]:
data_dir = '/home/cyz/Project/nlp_lab/Final project/2025_Spring_NLP_Final_Project/demo_model'
vqa_dataset = VQADataset(data_dir, 'val.npy', max_qu_len=30)
qu_vocab = vqa_dataset.qu_vocab
ans_vocab = vqa_dataset.ans_vocab
qu_vocab_size = qu_vocab.vocab_size

In [7]:


class ImgEncoder(nn.Module):

    def __init__(self, embed_dim):

        super(ImgEncoder, self).__init__()
        self.model = models.vgg19(pretrained=True)
        in_features = self.model.classifier[-1].in_features
        self.model.classifier = nn.Sequential(*list(self.model.classifier.children())[:-1]) # remove vgg19 last layer
        self.fc = nn.Linear(in_features, embed_dim)

    def forward(self, image):

        with torch.no_grad():
            img_feature = self.model(image) # (batch, channel, height, width)
        img_feature = self.fc(img_feature)

        l2_norm = F.normalize(img_feature, p=2, dim=1).detach()
        return l2_norm

class QuEncoder(nn.Module):

    def __init__(self, qu_vocab_size, word_embed, hidden_size, num_hidden, qu_feature_size):

        super(QuEncoder, self).__init__()
        self.word_embedding = nn.Embedding(qu_vocab_size, word_embed)
        self.tanh = nn.Tanh()
        self.lstm = nn.LSTM(word_embed, hidden_size, num_hidden) # input_feature, hidden_feature, num_layer
        self.fc = nn.Linear(2*num_hidden*hidden_size, qu_feature_size)

    def forward(self, question):

        qu_embedding = self.word_embedding(question)                # (batchsize, qu_length=30, word_embed=300)
        qu_embedding = self.tanh(qu_embedding)
        qu_embedding = qu_embedding.transpose(0, 1)                 # (qu_length=30, batchsize, word_embed=300)
        _, (hidden, cell) = self.lstm(qu_embedding)                 # (num_layer=2, batchsize, hidden_size=1024)
        qu_feature = torch.cat((hidden, cell), dim=2)               # (num_layer=2, batchsize, 2*hidden_size=1024)
        qu_feature = qu_feature.transpose(0, 1)                     # (batchsize, num_layer=2, 2*hidden_size=1024)
        qu_feature = qu_feature.reshape(qu_feature.size()[0], -1)   # (batchsize, 2*num_layer*hidden_size=2048)
        qu_feature = self.tanh(qu_feature)
        qu_feature = self.fc(qu_feature)                            # (batchsize, qu_feature_size=1024)

        return qu_feature

class VQAModel(nn.Module):

    def __init__(self, feature_size, qu_vocab_size, ans_vocab_size, word_embed, hidden_size, num_hidden):

        super(VQAModel, self).__init__()
        self.img_encoder = ImgEncoder(feature_size)
        self.qu_encoder = QuEncoder(qu_vocab_size, word_embed, hidden_size, num_hidden, feature_size)
        self.dropout = nn.Dropout(0.5)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(feature_size, ans_vocab_size)
        self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)

    def forward(self, image, question):

        img_feature = self.img_encoder(image)               # (batchsize, feature_size=1024)
        qst_feature = self.qu_encoder(question)
        combined_feature = img_feature * qst_feature
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.tanh(combined_feature)
        combined_feature = self.fc1(combined_feature)       
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.tanh(combined_feature)
        logits = self.fc2(combined_feature)                 

        return logits


In [9]:
FEATURE_SIZE, WORD_EMBED = 1024, 300
MAX_QU_LEN, NUM_HIDDEN, HIDDEN_SIZE = 30, 2, 512
ckpt_dir = '/home/cyz/Project/nlp_lab/Final project/2025_Spring_NLP_Final_Project/CNN_LSTM_Demo.pth'

model = VQAModel(
    feature_size=FEATURE_SIZE,
    qu_vocab_size=17856,
    ans_vocab_size=30000,
    word_embed=WORD_EMBED,
    hidden_size=HIDDEN_SIZE,
    num_hidden=NUM_HIDDEN
).to(device)
model.load_state_dict(torch.load(ckpt_dir, map_location=device))
model.eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/cyz/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [08:09<00:00, 1.17MB/s] 


VQAModel(
  (img_encoder): ImgEncoder(
    (model): VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1,

In [10]:
def get_answer(image_path, question_str):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Adjust size if needed
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    tokens = question_str.split()  # Simple tokenization
    question_idx = []
    for token in tokens:
        try:
            idx = qu_vocab.word2idx(token)
        except KeyError:  # Adjust based on word2idx behavior
            idx = qu_vocab.word2idx('<unk>')
        question_idx.append(idx)
    
    # Pad or truncate to MAX_QU_LEN
    question_idx = question_idx[:MAX_QU_LEN] + [0] * (MAX_QU_LEN - len(question_idx))
    question = torch.tensor([question_idx], dtype=torch.long).to(device)

    with torch.no_grad():
        logits = model(image, question)
        predict = torch.log_softmax(logits, dim=1)
        predicted_idx = torch.argmax(predict, dim=1).item()

    answer = ans_vocab.idx2word(predicted_idx)
    return answer

In [14]:
image_path = "/home/cyz/Project/nlp_lab/Final project/2025_Spring_NLP_Final_Project/demo_model/001.jpg"
# question = "Where is he looking?"
# question = "What are the people in the background doing?"
question =  "What is he on top of?"
answer = get_answer(image_path, question)
print(f'Answer: {answer}')

Answer: skateboard
