<a href="https://colab.research.google.com/github/hy-e/2025-ai-expert/blob/main/%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98/%EC%84%9C%ED%99%8D%EC%84%9D%20%EA%B5%90%EC%88%98%EB%8B%98/VQA_with_quiz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Code walkthrough of the pytorch implementation of the paper VQA: Visual Question Answering(ICCV 2015)

![](assets/transformer1.png)

##Import Libraries

In [None]:
import os
import argparse
from PIL import Image
import numpy as np
import json
import re
from collections import defaultdict
import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.optim as optim
import tqdm
import easydict
from torch.optim import lr_scheduler
import gc
import matplotlib.pyplot as plt

## Prepare Data

In [None]:
#Make Folders
######################################################################
DATASETS_DIR = "./datasets"

ANNOTATIONS_DIR = os.path.join(DATASETS_DIR,'Annotations')
QUESTIONS_DIR = os.path.join(DATASETS_DIR,'Questions')
IMAGES_DIR = os.path.join(DATASETS_DIR,'Images')

!mkdir $DATASETS_DIR
!mkdir $ANNOTATIONS_DIR
!mkdir $QUESTIONS_DIR
!mkdir $IMAGES_DIR
######################################################################

# Download datasets from VQA official url: https://visualqa.org/download.html

#VQA Annotations
!wget -O $ANNOTATIONS_DIR/v2_Annotations_Train_mscoco.zip "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip"
!wget -O $ANNOTATIONS_DIR/v2_Annotations_Val_mscoco.zip "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip"

# VQA Input Questions
!wget -O $QUESTIONS_DIR/v2_Questions_Train_mscoco.zip "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip"
!wget -O $QUESTIONS_DIR/v2_Questions_Val_mscoco.zip "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip"
!wget -O $QUESTIONS_DIR/v2_Questions_Test_mscoco.zip "https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip"

# VQA Input Images (COCO)
!gdown https://drive.google.com/uc?id=1-7wzyJhs5ZaHLmHwqjR2-OLS4EeGyVih

######################################################################

!unzip $ANNOTATIONS_DIR/v2_Annotations_Train_mscoco.zip -d $ANNOTATIONS_DIR
!unzip $ANNOTATIONS_DIR/v2_Annotations_Val_mscoco.zip -d $ANNOTATIONS_DIR

!rm $ANNOTATIONS_DIR/v2_Annotations_Train_mscoco.zip
!rm $ANNOTATIONS_DIR/v2_Annotations_Val_mscoco.zip

!unzip $QUESTIONS_DIR/v2_Questions_Train_mscoco.zip -d $QUESTIONS_DIR
!unzip $QUESTIONS_DIR/v2_Questions_Val_mscoco.zip -d $QUESTIONS_DIR
!unzip $QUESTIONS_DIR/v2_Questions_Test_mscoco.zip -d $QUESTIONS_DIR

!rm $QUESTIONS_DIR/v2_Questions_Train_mscoco.zip
!rm $QUESTIONS_DIR/v2_Questions_Val_mscoco.zip
!rm $QUESTIONS_DIR/v2_Questions_Test_mscoco.zip

!unzip total.zip
!mv /content/content/datasets/ResizedImages /content/datasets/Images/Resized_Images
import shutil
shutil.rmtree('/content/content')

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000446909.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000409331.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000077473.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000156370.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000045976.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000005184.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000261888.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000326308.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000472399.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000459265.jpg  
  inflating: content/datasets/ResizedImages/val2014/COCO_val2014_000000020177.jpg  
  inflating: content/datas

##Preprocess Input Data
Preprocess input images, questions and answers.<br>
  1. Resize Images
  2. Make vocabs for questions and answers
  3. Build VQA Inputs

### Resize Images
Resize the image from 640x480 -> 224x224.

In [None]:
'''Use the code below if you are resizing the images directly downladed from the official dataset.
def resize_image(image, size):
    """Resize an image to the given size."""
    return image.resize((size,size), Image.ANTIALIAS) # Option to prevent aliasing after resizing image.

def resize_images(input_dir, output_dir, size):
    """Resize the images in 'input_dir' and save into 'output_dir'."""
    for idir in os.scandir(input_dir):
        import pdb; pdb.set_trace()
        if not idir.is_dir():
            continue
        if not os.path.exists(output_dir+'/'+idir.name):
            os.makedirs(output_dir+'/'+idir.name)
        images = os.listdir(idir.path)
        n_images = len(images)
        for iimage, image in tqdm(enumerate(images)):
            if iimage >= 1000:
              break
            try:
                with open(os.path.join(idir.path, image), 'r+b') as f:
                    with Image.open(f) as img:
                        img = resize_image(img, size)
                        img.save(os.path.join(output_dir+'/'+idir.name, image), img.format)
                        original_img_path = os.path.join(idir.path,image)
            except(IOError, SyntaxError) as e:
                pass
            if (iimage+1) % 1000 == 0:
                print("[{}/{}] Resized the images and saved into '{}'."
                      .format(iimage+1, n_images, output_dir+'/'+idir.name))

input_dir = "./datasets/Images"
output_dir = "./datasets/ResizedImages"
img_size = 224

resize_images(input_dir,output_dir,img_size) # Resize Images'''

### Make Vocabs for Questions and answers.
Make a dictionary for questions and answers and save them into a txt file.

## 1. Build the dictionary for questions.
After running the code, the words in the dictionary can be observed in the vocab_questions.txt.

In [None]:
def make_vocab_questions(input_dir):
    """Make dictionary for questions and save them into text file."""
    vocab_set = set()
    SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)')
    question_length = []
    datasets = os.listdir(input_dir)
    print(datasets)
    for dataset in datasets:
        if 'json' in dataset:
          with open(input_dir+'/'+dataset) as f:
              questions = json.load(f)['questions']
          set_question_length = [None]*len(questions)
          for iquestion, question in enumerate(questions):
              words = SENTENCE_SPLIT_REGEX.split(question['question'].lower())
              words = [w.strip() for w in words if len(w.strip()) > 0]
              vocab_set.update(words)
              set_question_length[iquestion] = len(words)
          question_length += set_question_length

    vocab_list = list(vocab_set)
    vocab_list.sort()
    vocab_list.insert(0, '<pad>')
    vocab_list.insert(1, '<unk>')

    with open('./datasets/vocab_questions.txt', 'w') as f:
        f.writelines([w+'\n' for w in vocab_list])

    print('Make vocabulary for questions')
    print('The number of total words of questions: %d' % len(vocab_set))
    print('Maximum length of question: %d' % np.max(question_length))

input_dir = "./datasets/"
make_vocab_questions(input_dir+'Questions')

['v2_OpenEnded_mscoco_test2015_questions.json', 'v2_OpenEnded_mscoco_test-dev2015_questions.json', 'v2_OpenEnded_mscoco_train2014_questions.json', 'v2_OpenEnded_mscoco_val2014_questions.json']
Make vocabulary for questions
The number of total words of questions: 17854
Maximum length of question: 26


## 2. Build the dictionary for answers.
Note that the number of words for the answers are cut to n_answers.<br>
After running the code, the words in the dictionary can be observed in the vocab_answers.txt.

In [None]:
def make_vocab_answers(input_dir, n_answers):
    """Make dictionary for top n answers and save them into text file."""
    answers = defaultdict(lambda: 0)
    datasets = os.listdir(input_dir)
    for dataset in datasets:
      if 'json' in dataset:
          with open(input_dir+'/'+dataset) as f:
              annotations = json.load(f)['annotations']
          for annotation in annotations:
              for answer in annotation['answers']:
                  word = answer['answer']
                  if re.search(r"[^\w\s]", word):
                      continue
                  answers[word] += 1

    answers = sorted(answers, key=answers.get, reverse=True)
    assert('<unk>' not in answers)
    top_answers = ['<unk>'] + answers[:n_answers-1] # '-1' is due to '<unk>'. We restrict the number of words to n_answers.

    with open('./datasets/vocab_answers.txt', 'w') as f:
        f.writelines([w+'\n' for w in top_answers])

    print('Make vocabulary for answers')
    print('The number of total words of answers: %d' % len(answers))
    print('Keep top %d answers into vocab' % n_answers)

input_dir = "./datasets/"
n_answers = 1000

make_vocab_answers(input_dir+'Annotations',n_answers)

Make vocabulary for answers
The number of total words of answers: 181102
Keep top 1000 answers into vocab


### Build VQA Inputs.

Define functions and classes that helps building inputs easier.

In [None]:
SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)')


def tokenize(sentence):
    '''Question 1'''
    '''Use the SENTENCE_SPLIT_REGEX to tokenize the input sentence'''
    tokens =
    tokens =
    return tokens


def load_str_list(fname):
    with open(fname) as f:
        lines = f.readlines()
    lines = [l.strip() for l in lines]
    return lines

SyntaxError: invalid syntax (ipython-input-1405578578.py, line 7)

In [None]:
input = 'The world is changing.'
print(input)
print(tokenize(input))

## Question 2 & 3
### Q2. Define the __init__ function of the class VocabDict that takes a vocab_file as input and then translates the word to index and index to word.

### The vocab_file is the .txt file that was the output of the functions make_vocab_questions, make_vocab_answers. Try giving them as input to fill in the blanks.

### Q3. Define the tokenize_and_index function.

In [None]:
class VocabDict:

    def __init__(self, vocab_file):
        self.word_list = load_str_list(vocab_file)
        '''Question 2'''
        '''Fill in the blanks below'''

        self.word2idx_dict =
        self.vocab_size =
        self.unk2idx =

    def idx2word(self, n_w):

        return self.word_list[n_w]

    def word2idx(self, w):
        if w in self.word2idx_dict:
            return self.word2idx_dict[w]
        elif self.unk2idx is not None:
            return self.unk2idx
        else:
            raise ValueError('word %s not in dictionary (while dictionary does not contain <unk>)' % w)

    def tokenize_and_index(self, sentence):
      '''Question 3'''
      '''This function takes a sentence as input and returns the indices of the tokenized words.
         Use the tokenize function you implemented above.'''

      inds =

      return inds

In [None]:
answer_vocab = VocabDict('./datasets/vocab_answers.txt')

In [None]:
answer_vocab.word2idx_dict['corn']
answer_vocab.unk2idx

## Understanding the VocabDict class.
### Try playing with the VocabDict class with the text files from make_vocab_questions, make_vocab_answers.

### Try finding a sentence with all known vocabulary in the dictionary.

In [None]:
question_dict = VocabDict('./datasets/vocab_questions.txt')
answer_dict = VocabDict('./datasets/vocab_answers.txt')

answer_dict.tokenize_and_index('Yes No')

In [None]:
def extract_answers(q_answers, valid_answer_set):
    all_answers = [answer["answer"] for answer in q_answers]
    valid_answers = [a for a in all_answers if a in valid_answer_set]
    return all_answers, valid_answers


def vqa_processing(image_dir, annotation_file, question_file, valid_answer_set, image_set):
    print('building vqa %s dataset' % image_set)
    if image_set in ['train2014', 'val2014']:
        load_answer = True
        with open(annotation_file % image_set) as f:
            annotations = json.load(f)['annotations']
            qid2ann_dict = {ann['question_id']: ann for ann in annotations}
    else:
        load_answer = False
    with open(question_file % image_set) as f:
        questions = json.load(f)['questions']
    coco_set_name = image_set.replace('-dev', '')
    abs_image_dir = os.path.abspath(image_dir % coco_set_name)
    image_name_template = 'COCO_'+coco_set_name+'_%012d'
    dataset = [None]*len(questions)

    unk_ans_count = 0
    for n_q, q in enumerate(questions):
        if (n_q+1) % 10000 == 0:
            print('processing %d / %d' % (n_q+1, len(questions)))
        image_id = q['image_id']
        question_id = q['question_id']
        image_name = image_name_template % image_id
        image_path = os.path.join(abs_image_dir, image_name+'.jpg')
        question_str = q['question']
        question_tokens = tokenize(question_str)

        iminfo = dict(image_name=image_name,
                      image_path=image_path,
                      question_id=question_id,
                      question_str=question_str,
                      question_tokens=question_tokens)

        if load_answer:
            ann = qid2ann_dict[question_id]
            all_answers, valid_answers = extract_answers(ann['answers'], valid_answer_set)
            if len(valid_answers) == 0:
                valid_answers = ['<unk>']
                unk_ans_count += 1
            iminfo['all_answers'] = all_answers
            iminfo['valid_answers'] = valid_answers

        dataset[n_q] = iminfo
    print('total %d out of %d answers are <unk>' % (unk_ans_count, len(questions)))
    return dataset

#############################################################################################################
input_dir = "./datasets/"
output_dir = "./datasets/"

image_dir = input_dir+'/Images/Resized_Images/%s/'
annotation_file = input_dir+'/Annotations/v2_mscoco_%s_annotations.json'
question_file = input_dir+'/Questions/v2_OpenEnded_mscoco_%s_questions.json'

vocab_answer_file = output_dir+'/vocab_answers.txt'
answer_dict = VocabDict(vocab_answer_file)
valid_answer_set = set(answer_dict.word_list)

train = vqa_processing(image_dir, annotation_file, question_file, valid_answer_set, 'train2014')
valid = vqa_processing(image_dir, annotation_file, question_file, valid_answer_set, 'val2014')
test = vqa_processing(image_dir, annotation_file, question_file, valid_answer_set, 'test2015')
test_dev = vqa_processing(image_dir, annotation_file, question_file, valid_answer_set, 'test-dev2015')

np.save(output_dir+'/train.npy', np.array(train))
np.save(output_dir+'/valid.npy', np.array(valid))
np.save(output_dir+'/train_valid.npy', np.array(train+valid))
np.save(output_dir+'/test.npy', np.array(test))
np.save(output_dir+'/test-dev.npy', np.array(test_dev))


## Try testing the train, validation and test dataset

In [None]:
train[0]

#DataLoader

In [None]:
class VqaDataset(data.Dataset):

    def __init__(self, input_dir, input_vqa, max_qst_length=30, max_num_ans=10, transform=None):
        self.input_dir = input_dir
        self.vqa = np.load(input_dir+'/'+input_vqa,allow_pickle=True)
        self.qst_vocab = VocabDict(input_dir+'/vocab_questions.txt')
        self.ans_vocab = VocabDict(input_dir+'/vocab_answers.txt')
        self.max_qst_length = max_qst_length
        self.max_num_ans = max_num_ans
        self.load_ans = ('valid_answers' in self.vqa[0]) and (self.vqa[0]['valid_answers'] is not None)
        self.transform = transform

    def __getitem__(self, idx):

        vqa = self.vqa
        qst_vocab = self.qst_vocab
        ans_vocab = self.ans_vocab
        max_qst_length = self.max_qst_length
        max_num_ans = self.max_num_ans
        transform = self.transform
        load_ans = self.load_ans

        image = vqa[idx]['image_path']
        image = Image.open(image).convert('RGB')
        qst2idc = np.array([qst_vocab.word2idx('<pad>')] * max_qst_length)  # padded with '<pad>' in 'ans_vocab'
        qst2idc[:len(vqa[idx]['question_tokens'])] = [qst_vocab.word2idx(w) for w in vqa[idx]['question_tokens']]
        sample = {'image': image, 'question': qst2idc}

        if load_ans:
            ans2idc = [ans_vocab.word2idx(w) for w in vqa[idx]['valid_answers']]
            ans2idx = np.random.choice(ans2idc)
            sample['answer_label'] = ans2idx         # for training

            mul2idc = list([-1] * max_num_ans)       # padded with -1 (no meaning) not used in 'ans_vocab'
            mul2idc[:len(ans2idc)] = ans2idc         # our model should not predict -1
            sample['answer_multi_choice'] = mul2idc  # for evaluation metric of 'multiple choice'

        if transform:
            sample['image'] = transform(sample['image'])

        return sample

    def __len__(self):

        return len(self.vqa)


def get_loader(input_dir, input_vqa_train, input_vqa_valid, max_qst_length, max_num_ans, batch_size, num_workers):

    transform = {
        phase: transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize((0.485, 0.456, 0.406),
                                                        (0.229, 0.224, 0.225))])
        for phase in ['train', 'valid']}

    vqa_dataset = {
        'train': VqaDataset(
            input_dir=input_dir,
            input_vqa=input_vqa_train,
            max_qst_length=max_qst_length,
            max_num_ans=max_num_ans,
            transform=transform['train']),
        'valid': VqaDataset(
            input_dir=input_dir,
            input_vqa=input_vqa_valid,
            max_qst_length=max_qst_length,
            max_num_ans=max_num_ans,
            transform=transform['valid'])}

    data_loader = {
        phase: torch.utils.data.DataLoader(
            dataset=vqa_dataset[phase],
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers)
        for phase in ['train', 'valid']}

    return data_loader

## Model
Our model is consisted of 3 big parts.
1. Image Encoder : Returns the image features.
2. Question Encoder : Encodes the questions with the defined dictionary.
3. FC Layer: Estimate the class of the combined features from both encoders.

### 1. Image Encoder

In [None]:
model = models.vgg19(pretrained=True)
print(model.classifier[-1].in_features)

In [None]:
print(*list(model.classifier.children()))

In [None]:
class ImgEncoder(nn.Module):

    def __init__(self, embed_size):
        """(1) Load the pretrained model as you want.
               cf) one needs to check structure of model using 'print(model)'
                   to remove last fc layer from the model.
           (2) Replace final fc layer (score values from the ImageNet)
               with new fc layer (image feature).
           (3) Normalize feature vector.
        """
        super(ImgEncoder, self).__init__()
        model = models.vgg19(pretrained=True)
        '''Question 4'''
        '''Fill in the blanks so that it matches (1) and (2) in the function description.'''
        in_features =  # input size of feature vector
        model.classifier =  # remove last fc layer

        self.model =                             # loaded model without last fc layer
        self.fc =     # feature vector of image

    def forward(self, image):
        """Extract feature vector from image vector.
        """
        with torch.no_grad():
            img_feature = self.model(image)                  # [batch_size, vgg16(19)_fc=4096]
        img_feature = self.fc(img_feature)                   # [batch_size, embed_size]

        '''Question 5'''
        '''L2 Normalize the img_features from self.fc()'''
        l2_norm =
        img_feature =               # l2-normalized feature vector

        return img_feature

In [None]:
x = torch.rand(20,3,224,224)
test = ImgEncoder(1024)
out = test(x)
print(out.shape)

### 2. Question Encoder

In [None]:
class QstEncoder(nn.Module):

    def __init__(self, qst_vocab_size, word_embed_size, embed_size, num_layers, hidden_size):

        super(QstEncoder, self).__init__()
        '''Question 6'''
        '''Define the following functions using the nn library.'''
        '''Please look at the documents of the functions nn.Embedding, nn.Tanh, nn.LSTM, nn.Linear to match the dimensions.'''
        '''word2vec : A learnable layer that takes the vocabulary id and return a embedding vector '''

        self.word2vec = nn.Embedding(qst_vocab_size)
        self.tanh = nn.Tanh()
        self.lstm = nn.LSTM()
        self.fc = nn.Linear()     # 2 for hidden and cell states

    def forward(self, question):

        qst_vec = self.word2vec(question)                             # [batch_size, max_qst_length=30, word_embed_size=300]
        qst_vec = self.tanh(qst_vec)
        qst_vec = qst_vec.transpose(0, 1)                             # [max_qst_length=30, batch_size, word_embed_size=300]
        _, (hidden, cell) = self.lstm(qst_vec)                        # [num_layers=2, batch_size, hidden_size=512]
        qst_feature = torch.cat((hidden, cell), 2)                    # [num_layers=2, batch_size, 2*hidden_size=1024]
        qst_feature = qst_feature.transpose(0, 1)                     # [batch_size, num_layers=2, 2*hidden_size=1024]
        qst_feature = qst_feature.reshape(qst_feature.size()[0], -1)  # [batch_size, 2*num_layers*hidden_size=2048]
        qst_feature = self.tanh(qst_feature)
        qst_feature = self.fc(qst_feature)                            # [batch_size, embed_size]

        return qst_feature

## VQA Model

In [None]:
class VqaModel(nn.Module):

    def __init__(self, embed_size, qst_vocab_size, ans_vocab_size, word_embed_size, num_layers, hidden_size):

        super(VqaModel, self).__init__()
        self.img_encoder = ImgEncoder(embed_size)
        self.qst_encoder = QstEncoder(qst_vocab_size, word_embed_size, embed_size, num_layers, hidden_size)
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(embed_size, ans_vocab_size)
        self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)

    def forward(self, img, qst):

        img_feature = self.img_encoder(img)                     # [batch_size, embed_size]
        qst_feature = self.qst_encoder(qst)                     # [batch_size, embed_size]
        combined_feature = torch.mul(img_feature, qst_feature)  # [batch_size, embed_size]
        combined_feature = self.tanh(combined_feature)
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.fc1(combined_feature)           # [batch_size, ans_vocab_size=1000]
        combined_feature = self.tanh(combined_feature)
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.fc2(combined_feature)           # [batch_size, ans_vocab_size=1000]

        return combined_feature

#Start Training

In [None]:
args = easydict.EasyDict()

## Parameters for training
args.input_dir = './datasets' # Input Directory for VQA
args.log_dir = './logs' # Directory for logs
args.model_dir = './models' # Directory for saved models
args.max_qst_length = 30 # Maximum length of quesiton.
args.max_num_ans = 10 # Maximum number of answers.
args.embed_size = 1024 # Embedding size of feature vectors for both image and question.
args.word_embed_size = 300 # Embedding size of word used for the input in the LSTM.
args.num_layers = 2 # Number os Layers of the RNN(LSTM).
args.hidden_size = 512 # Hidden size in the LSTM.
args.learning_rate = 0.001 # Learning rate for training.
args.step_size = 10 # Period of Learning Rate Decay.
args.gamma = 0.1 # Multiplicative factor of learning rate decay.
args.num_epochs = 1 # Number of epochs.
args.batch_size = 256 # Batch size.
args.num_workers = 0 # Number of processes working on cpu.
args.save_step = 1 # Save step of model.

print(args)

In [None]:
'''Question 7'''
'''Define the loss function'''
def custom_ce(pred,label):
  return

In [None]:
import torch
import torch.nn as nn
x = torch.rand((20,10))
y = torch.randint(low=0,high=9,size=(20,)).to(torch.int64)
real = nn.CrossEntropyLoss()
print(custom_ce(x,y))
print(real(x,y))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(args):

    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    data_loader = get_loader(
        input_dir=args.input_dir,
        input_vqa_train='train.npy',
        input_vqa_valid='valid.npy',
        max_qst_length=args.max_qst_length,
        max_num_ans=args.max_num_ans,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
    ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
    ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx

    model = VqaModel(
        embed_size=args.embed_size,
        qst_vocab_size=qst_vocab_size,
        ans_vocab_size=ans_vocab_size,
        word_embed_size=args.word_embed_size,
        num_layers=args.num_layers,
        hidden_size=args.hidden_size).to(device)

    criterion = custom_ce

    params = list(model.img_encoder.fc.parameters()) \
        + list(model.qst_encoder.parameters()) \
        + list(model.fc1.parameters()) \
        + list(model.fc2.parameters())

    optimizer = optim.Adam(params, lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    phase = 'train'

    for epoch in range(args.num_epochs):
      running_loss = 0.0
      running_corr_exp1 = 0
      batch_step_size = len(data_loader[phase].dataset) / args.batch_size

      model.train()

      for batch_idx, batch_sample in enumerate(data_loader[phase]):
        image = batch_sample['image'].to(device)
        question = batch_sample['question'].to(device)
        label = batch_sample['answer_label'].to(device)
        multi_choice = batch_sample['answer_multi_choice']  # not tensor, list.

        optimizer.zero_grad()
        with torch.set_grad_enabled(phase == 'train'):
            output = model(image, question)      # [batch_size, ans_vocab_size=1000]
            _, pred_exp1 = torch.max(output, 1)  # [batch_size]
            loss = criterion(output, label)

            loss.backward()
            optimizer.step()
            scheduler.step()

        # Evaluation metric of 'multiple choice'
        # Exp1: our model prediction to '<unk>' IS accepted as the answer.
        # Exp2: our model prediction to '<unk>' is NOT accepted as the answer.
        running_loss += loss.item()
        running_corr_exp1 += torch.stack([(ans == pred_exp1.cpu()) for ans in multi_choice]).any(dim=0).sum()

        # Print the average loss in a mini-batch.
        if batch_idx % 100 == 0:
            print('| {} SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f}'
                  .format(phase.upper(), epoch+1, args.num_epochs, batch_idx, int(batch_step_size), loss.item()))
        # Save the model check points.
        if (batch_idx) == 500:
            torch.save({'epoch': epoch+1, 'state_dict': model.state_dict()},
                        os.path.join(args.model_dir, 'model-epoch-{:02d}.ckpt'.format(epoch+1)))
            break

      if (epoch + 1) % args.save_step == 0:
        torch.save({'epoch': epoch+1, 'state_dict': model.state_dict()},
                        os.path.join(args.model_dir, 'model-epoch-{:02d}.ckpt'.format(epoch+1)))


In [None]:
train(args)

# Evaluate our model with the test data.

In [None]:
import gc; gc.collect()
torch.cuda.empty_cache()

In [None]:
!gdown https://drive.google.com/uc?id=17C8uZTm6WHW0c2q4pi5v-gHzqbegA6LV
!mv model-epoch-25.ckpt models/

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
phase = 'valid'
answer_dict = VocabDict('./datasets/vocab_answers.txt')
question_dict = VocabDict('./datasets/vocab_questions.txt')
tensor2img = transforms.ToPILImage()

def eval(args):

    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    data_loader = get_loader(
        input_dir=args.input_dir,
        input_vqa_train='train.npy',
        input_vqa_valid='valid.npy',
        max_qst_length=args.max_qst_length,
        max_num_ans=args.max_num_ans,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    qst_vocab_size = data_loader['valid'].dataset.qst_vocab.vocab_size
    ans_vocab_size = data_loader['valid'].dataset.ans_vocab.vocab_size
    ans_unk_idx = data_loader['valid'].dataset.ans_vocab.unk2idx

    model = VqaModel(
        embed_size=args.embed_size,
        qst_vocab_size=qst_vocab_size,
        ans_vocab_size=ans_vocab_size,
        word_embed_size=args.word_embed_size,
        num_layers=args.num_layers,
        hidden_size=args.hidden_size)

    # path = './models/model-epoch-01.ckpt'
    path = './models/model-epoch-25.ckpt'
    model.load_state_dict(torch.load(path),strict=False)
    model.to(device)

    model.eval()
    for batch_idx, batch_sample in enumerate(data_loader[phase]):
      image = batch_sample['image'].to(device)
      question = batch_sample['question'].to(device)
      label = batch_sample['answer_label'].to(device)
      multi_choice = batch_sample['answer_multi_choice']  # not tensor, list.

      output = model(image,question)
      _, pred = torch.max(output,0)

      for i in range(args.batch_size):
        import pdb; pdb.set_trace()
        trans_img = image[i].transpose(0,1).transpose(1,2).cpu()
        q = question[i]
        q_ques = []
        plt.imshow(trans_img)
        for q_q in q.tolist():
          q_ques.append(question_dict.idx2word(q_q))
        l = label[i]
        p = pred[i]
        print("The question is : ")
        print(q_ques)
        print("The answer is : " + answer_dict.idx2word(l))
        print("The output of our model is :" + answer_dict.idx2word(p))


In [None]:
eval(args)