In [1]:
import argparse
import os
import requests
import zipfile
import tarfile
import time
import json
import re
from typing import Optional
from os import listdir
from os.path import isfile, join
from tqdm.auto import tqdm
from PIL import Image
from transformers import ViltConfig
from tqdm.notebook import tqdm
import torch
from transformers import ViltProcessor
import numpy as np
from transformers import ViltForQuestionAnswering
from torch.utils.data import DataLoader

# TRAINING

In [2]:
# root = '/Users/samch/OneDrive/Desktop/FALL 2023/AC215/AC215-BiteSize/notebooks/vaq_files/file_names/val2014'
# path_questions = '/Users/samch/OneDrive/Desktop/FALL 2023/AC215/AC215-BiteSize/notebooks/vaq_files/datasets/v2_OpenEnded_mscoco_val2014_questions.json'
# path_annotations = '/Users/samch/OneDrive/Desktop/FALL 2023/AC215/AC215-BiteSize/notebooks/vaq_files/data_vqa/v2_mscoco_val2014_annotations.json'

In [3]:
root = '/Users/kimberlyllajarunaperalta/Documents/MLOps/vaq_files/data_vqa/val2014'
path_questions = '/Users/kimberlyllajarunaperalta/Documents/MLOps/vaq_files/data_vqa/v2_OpenEnded_mscoco_val2014_questions.json'
path_annotations = '/Users/kimberlyllajarunaperalta/Documents/MLOps/vaq_files/data_vqa/v2_mscoco_val2014_annotations.json'

In [4]:
def id_from_filename(filename: str) -> Optional[int]:
    match = filename_re.fullmatch(filename)
    if match is None:
        return None
    return int(match.group(1))

In [5]:
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  encoding = processor.image_processor.pad(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  batch['token_type_ids'] = torch.stack(token_type_ids)
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = torch.stack(labels)

  return batch

In [6]:
def get_score(count: int) -> float:
    return min(1.0, count / 3)

In [7]:
class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, questions, annotations, processor):
        self.questions = questions
        self.annotations = annotations
        self.processor = processor

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # get image + text
        annotation = self.annotations[idx]
        questions = self.questions[idx]
        image = Image.open(id_to_filename[annotation['image_id']])
        text = questions['question']

        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        # remove batch dimension
        for k,v in encoding.items():
          encoding[k] = v.squeeze()
        # add labels
        labels = annotation['labels']
        scores = annotation['scores']
        # based on: https://github.com/dandelin/ViLT/blob/762fd3975c180db6fc88f577cf39549983fa373a/vilt/modules/objectives.py#L301
        targets = torch.zeros(len(config.id2label))
        for label, score in zip(labels, scores):
              targets[label] = score
        encoding["labels"] = targets

        return encoding

In [8]:
file_names = [f for f in tqdm(listdir(root)) if isfile(join(root, f))]
filename_re = re.compile(r".*(\d{12})\.((jpg)|(png))")

filename_to_id = {root + "/" + file: id_from_filename(file) for file in file_names}
id_to_filename = {v:k for k,v in filename_to_id.items()}

  0%|          | 0/40504 [00:00<?, ?it/s]

In [9]:
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm",
                                                 id2label=config.id2label,
                                                 label2id=config.label2id)
model.to(device)

Some weights of ViltForQuestionAnswering were not initialized from the model checkpoint at dandelin/vilt-b32-mlm and are newly initialized: ['classifier.3.bias', 'classifier.0.weight', 'classifier.1.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.1.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViltForQuestionAnswering(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=76

In [10]:
#Questions
data_questions = json.load(open(path_questions))
questions = data_questions['questions']

#Annotations
data_annotations = json.load(open(path_annotations))
annotations = data_annotations['annotations']

#Processor
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

#Dataset
dataset = VQADataset(questions=questions[:100],
                     annotations=annotations[:100],
                     processor=processor)

#Train dataloader
train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)

In [11]:
for annotation in tqdm(annotations):
    answers = annotation['answers']
    answer_count = {}
    for answer in answers:
        answer_ = answer["answer"]
        answer_count[answer_] = answer_count.get(answer_, 0) + 1
    labels = []
    scores = []
    for answer in answer_count:
        if answer not in list(config.label2id.keys()):
            continue
        labels.append(config.label2id[answer])
        score = get_score(answer_count[answer])
        scores.append(score)
    annotation['labels'] = labels
    annotation['scores'] = scores

  0%|          | 0/214354 [00:00<?, ?it/s]

In [13]:
start_time = time.time()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(2):  # loop over the dataset multiple times, trainign only 10 epochs
   print(f"Epoch: {epoch}")
   for batch in tqdm(train_dataloader):
        # get the inputs;
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()
        
execution_time = (time.time() - start_time) / 60.0
print("Training execution time (mins)", execution_time)
print("Training Job Complete")

Epoch: 0


  0%|          | 0/25 [00:00<?, ?it/s]

Loss: 956.8610229492188
Loss: 921.6065063476562
Loss: 888.2288208007812
Loss: 855.6438598632812
Loss: 822.6932373046875
Loss: 793.51806640625
Loss: 766.9517822265625
Loss: 739.0214233398438
Loss: 712.250244140625
Loss: 689.1754150390625
Loss: 663.0001220703125
Loss: 639.6845092773438
Loss: 617.6478881835938
Loss: 595.3203735351562
Loss: 574.157958984375
Loss: 554.617431640625
Loss: 535.9575805664062
Loss: 517.53857421875
Loss: 500.2691345214844
Loss: 482.7394714355469
Loss: 466.7906799316406
Loss: 450.4107971191406
Loss: 435.88555908203125
Loss: 421.92926025390625
Loss: 408.8621826171875
Epoch: 1


  0%|          | 0/25 [00:00<?, ?it/s]

Loss: 396.33319091796875
Loss: 381.93218994140625
Loss: 370.30377197265625
Loss: 360.6387634277344
Loss: 348.5657043457031
Loss: 336.47064208984375
Loss: 328.6730651855469
Loss: 316.94097900390625
Loss: 307.144775390625
Loss: 298.4525451660156
Loss: 289.9792785644531
Loss: 283.0616455078125
Loss: 274.0792236328125
Loss: 265.34515380859375
Loss: 259.7699890136719
Loss: 253.74591064453125
Loss: 245.09278869628906
Loss: 238.89111328125
Loss: 234.4183349609375
Loss: 228.06741333007812
Loss: 221.37376403808594
Loss: 217.68740844726562
Loss: 210.84178161621094
Loss: 205.9726104736328
Loss: 201.2008819580078
Training execution time (mins) 0.855648132165273
Training Job Complete


# Inference

In [14]:
example = dataset[0]
processor.decode(example['input_ids'])

'[CLS] where is he looking? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [15]:
# add batch dimension + move to GPU
example = {k: v.unsqueeze(0).to(device) for k,v in example.items()}
# forward pass
outputs = model(**example)

logits = outputs.logits
predicted_classes = torch.sigmoid(logits)

probs, classes = torch.topk(predicted_classes, 5)

for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
  print(prob, model.config.id2label[class_idx])

0.2725496292114258 no
0.23888686299324036 yes
0.23368898034095764 tie
0.20375598967075348 scarf
0.19615337252616882 kites
