In [1]:
# !unzip /content/easyvqa.zip

In [2]:
# !pip install -qqq easy-vqa
# !pip install -qqq sentence_transformers transformers timm

In [3]:
from easy_vqa import get_train_questions, get_test_questions

train_questions, train_answers, train_image_ids = get_train_questions()
test_questions, test_answers, test_image_ids = get_test_questions()

In [4]:
import pandas as pd
pd.set_option("max_colwidth", None)

def gen_dataframes(questions, answers, image_ids, mode="train"):
    records = []
    for question, answer, image_id in zip(questions, answers, image_ids):
        image_path = f"/content/data/{mode}/images/{image_id}.png"
        records.append({"question" : question, "answer": answer, "image_path": image_path})
    return pd.DataFrame(records)

df =  gen_dataframes(train_questions, train_answers, train_image_ids)
from sklearn.model_selection import train_test_split
df = df.sample(frac=1)
train_df, eval_df = train_test_split(df)
test_df =  gen_dataframes(test_questions, test_answers, test_image_ids, mode="test")


In [5]:
from easy_vqa import get_answers
answers = get_answers()
print("Total labels", len(answers))
label2idx = {answer:i for i, answer in enumerate(answers)}

Total labels 13


In [6]:
train_df["label"] = train_df["answer"].apply(lambda x: label2idx.get(x))
eval_df["label"] = eval_df["answer"].apply(lambda x: label2idx.get(x))
test_df["label"] = test_df["answer"].apply(lambda x: label2idx.get(x))

In [7]:
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModel
import torchvision.transforms as T
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text_encoder = AutoModel.from_pretrained("bert-base-uncased")
for p in text_encoder.parameters():
    p.requires_grad = False


image_processor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
image_encoder = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k")

for p in image_encoder.parameters():
    p.requires_grad = False

image_encoder.to(device)
text_encoder.to(device)




BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [8]:
from PIL import Image
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

class EasyQADataset(Dataset):
    def __init__(self,df,image_encoder,text_encoder,image_processor,tokenizer,):
        self.df = df
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.image_processor = image_processor
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_file = self.df["image_path"][idx]
        question = self.df['question'][idx]
        image = Image.open(image_file).convert("RGB")
        label = self.df['label'][idx]
        image_inputs = self.image_processor(image, return_tensors="pt")
        image_inputs = {k:v.to(device) for k,v in image_inputs.items()}
        image_outputs = self.image_encoder(**image_inputs)
        image_embedding = image_outputs.pooler_output
        image_embedding = image_embedding.view(-1)
        image_embedding = image_embedding.detach()
        text_inputs = self.tokenizer(question, return_tensors="pt")
        text_inputs = {k:v.to(device) for k,v in text_inputs.items()}
        text_outputs = self.text_encoder(**text_inputs)
        text_embedding = text_outputs.pooler_output
        text_embedding = text_embedding.view(-1)
        text_embedding = text_embedding.detach()
        encoding={}
        encoding["image_emb"] = image_embedding
        encoding["text_emb"] = text_embedding
        encoding["label"] = torch.tensor(label)
        return encoding

In [9]:
train_df.reset_index(drop=True, inplace=True)
eval_df.reset_index(drop=True, inplace=True)
train_dataset = EasyQADataset(df=train_df,image_encoder = image_encoder,text_encoder = text_encoder,tokenizer = tokenizer,image_processor = image_processor,)
eval_dataset = EasyQADataset(df=eval_df,image_encoder = image_encoder,text_encoder = text_encoder,tokenizer = tokenizer,image_processor = image_processor,)

In [10]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = 32
eval_batch_size = 32
dataloader_train = DataLoader(train_dataset,sampler=RandomSampler(train_dataset),batch_size=batch_size)
dataloader_validation = DataLoader(eval_dataset,sampler=SequentialSampler(eval_dataset),batch_size=eval_batch_size)

In [11]:
from sklearn.metrics import accuracy_score
def accuracy_score_func(preds, labels):
    return accuracy_score(labels, preds)

In [13]:

import random
from torch import nn
from tqdm.notebook import tqdm
import numpy as np

criterion = nn.CrossEntropyLoss()

def VQA_Test(dataloader_val):
    model.eval()
    total_Computedlossvalue= 0
    predictions, true_values, confidence = [], [], []
    for Vqa_input_data_row in dataloader_val:
        Vqa_input_data_row = tuple(b.to(device) for b in Vqa_input_data_row.values())
        inputs_img_text = {'image_emb':  Vqa_input_data_row[0],'text_emb': Vqa_input_data_row[1]}
        with torch.no_grad():
            outputs = model(**inputs_img_text)
        labeled_Categories =  Vqa_input_data_row[2]
        computed_Lossvalue = criterion(outputs.view(-1, 13), labeled_Categories.view(-1))
        total_Computedlossvalue+= computed_Lossvalue.item()
        max_Pdctscore   = torch.max(outputs.softmax(dim=1), dim=-1)[0].detach().cpu().numpy()
        outputs = outputs.argmax(-1)
        pdtd_valuelog = outputs.detach().cpu().numpy()
        labelcategory_key = labeled_Categories.cpu().numpy()
        predictions.append(pdtd_valuelog)
        true_values.append(labelcategory_key)
        confidence.append(max_Pdctscore)
    avervage_Loss_val = total_Computedlossvalue/len(dataloader_val)
    predictions = np.concatenate(predictions, axis=0)
    true_values = np.concatenate(true_values, axis=0)
    confidence = np.concatenate(confidence, axis=0)
    return avervage_Loss_val, predictions, true_values, confidence

def train():
  inputed_Dataset = open("/content/models/train_history.csv", "w")
  titles  = "Epoch, train_loss, train_acc, val_loss, val_acc"
  inputed_Dataset.write(titles  + "\n")
  items_Accuracy_vals = []
  items_train_Acc = []
  max_Accuracy = 0
  train_losses = []
  pdct_Loss_list = []
  min_Loss = -1
  max_Epoches = 3
  epoch_incr = 0
  Iter_cdtn = False
  for current_epoch in tqdm(range(1, 5)):
      model.train()
      calc_Total_loss = 0
      predictions_of_data, train_true_vals_log = [], []
      progress_Bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(current_epoch), leave=False, disable=False)
      for Vqa_input_data_row in progress_Bar:
          model.zero_grad()
          Vqa_input_data_row = tuple(b.to(device) for b in Vqa_input_data_row.values())
          inputs_img_text = {'image_emb':  Vqa_input_data_row[0],'text_emb': Vqa_input_data_row[1]}
          labeled_Categories =  Vqa_input_data_row[2]
          outputs = model(**inputs_img_text)
          calc_item_loss = criterion(outputs.view(-1, 13), labeled_Categories.view(-1))
          calc_Total_loss += calc_item_loss.item()
          calc_item_loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          pdtd_valuelog = outputs.argmax(-1)
          pdtd_valuelog = pdtd_valuelog.detach().cpu().numpy()
          labelcategory_key = labeled_Categories.cpu().numpy()
          predictions_of_data.append(pdtd_valuelog)
          train_true_vals_log.append(labelcategory_key)
          optimizer.step()
          scheduler.step()
          progress_Bar.set_postfix({'Training_loss': '{:.3f}'.format(calc_item_loss.item()/len(Vqa_input_data_row))})
      predictions_of_data = np.concatenate(predictions_of_data, axis=0)
      train_true_vals_log = np.concatenate(train_true_vals_log, axis=0)

      tqdm.write(f'\nEpoch {current_epoch}')
      Average_loss = calc_Total_loss/len(dataloader_train)
      tqdm.write(f'Training Avg loss: {Average_loss}')
      train_acc = accuracy_score_func(predictions_of_data, train_true_vals_log)
      tqdm.write(f'Train Accuracy: {train_acc}')

      val_loss, predictions, true_vals,_ = VQA_Test(dataloader_validation)
      Accuracy_vals = accuracy_score_func(predictions, true_vals)
      tqdm.write(f'Validation loss: {val_loss}')
      tqdm.write(f'Value Accuracy: {Accuracy_vals}')

      if Accuracy_vals >= max_Accuracy:
          tqdm.write('\nModel')
          torch.save(model.state_dict(), f'./easyvqa_finetuned_epoch_{current_epoch}.model')
          max_Accuracy = Accuracy_vals

      train_losses.append(Average_loss)
      pdct_Loss_list.append(val_loss)
      items_train_Acc.append(train_acc)
      items_Accuracy_vals.append(Accuracy_vals)
      log_str  = "{}, {}, {}, {}, {}".format(current_epoch, Average_loss, train_acc, val_loss, Accuracy_vals)
      inputed_Dataset.write(log_str + "\n")

      if min_Loss < 0:
          min_Loss = val_loss
      else:
        if val_loss < min_Loss:
            min_Loss = val_loss
        else:
            epoch_incr += 1
            if epoch_incr >= max_Epoches:
                Iter_cdtn = True
                break
            else:
                continue

  if Iter_cdtn:
    print("Stopped at epoch -", current_epoch )
    print("Use the checkpoint at epoch - ", current_epoch - max_Epoches)

  inputed_Dataset.close()
  return train_losses, pdct_Loss_list

In [14]:
import math
class EasyQAMidFusionNetwork(nn.Module):
    def __init__(self, hyperparms=None):
        super(EasyQAMidFusionNetwork, self).__init__()
        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(768, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.classifier = nn.Linear(256, 13)
        W = torch.Tensor(768, 768)
        self.W = nn.Parameter(W)
        self.relu_f = nn.ReLU()
        nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))

    def forward(self, image_emb, text_emb):
        x1 = image_emb
        Xv = torch.nn.functional.normalize(x1, p=2, dim=1)
        x2 = text_emb
        Xt = torch.nn.functional.normalize(x2, p=2, dim=1)
        Xvt = Xv * Xt
        Xvt = self.relu_f(torch.mm(Xvt, self.W.t()))
        Xvt = self.fc1(Xvt)
        Xvt = self.bn1(Xvt)
        Xvt = self.dropout(Xvt)
        Xvt = self.classifier(Xvt)
        return Xvt

In [15]:
torch.cuda.empty_cache()
model = EasyQAMidFusionNetwork()
model.to(device)

EasyQAMidFusionNetwork(
  (dropout): Dropout(p=0.3, inplace=False)
  (fc1): Linear(in_features=768, out_features=256, bias=True)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (classifier): Linear(in_features=256, out_features=13, bias=True)
  (relu_f): ReLU()
)

In [16]:
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
optimizer = AdamW(model.parameters(),lr=5e-5,weight_decay = 1e-5,eps=1e-8 )
epochs = 5
train_steps=20000
warm_steps = train_steps * 0.1
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=warm_steps,num_training_steps=train_steps)

In [None]:
from matplotlib import pyplot as plt
# try:
!rm -rf /content/models
!mkdir /content/models
train_losses, val_losses =  train()
torch.cuda.empty_cache()
plt.plot(train_losses)
plt.plot(val_losses)
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/905 [00:00<?, ?it/s]


Epoch 1
Training Avg loss: 2.421151947448267
Train Accuracy: 0.2352148214717777
Validation loss: 2.0607976214775188
Value Accuracy: 0.41777270841974284

Model


Epoch 2:   0%|          | 0/905 [00:00<?, ?it/s]