In [None]:
!wget https://huggingface.co/spaces/CVPR/VizWiz-CLIP-VQA/raw/main/data/annotations/class_mapping.csv

In [None]:
!pip install -q transformers

In [1]:
import json
import os
import csv
import numpy as np
import random
import re
from typing import Optional
from tqdm import tqdm
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
from PIL import Image
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torch import optim, nn
from transformers import ViltProcessor, ViltForQuestionAnswering

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", device)

Using cpu


In [3]:
train_json_dir = "/kaggle/input/vizwiz-2023-edition/Annotations/train.json"
val_json_dir = "/kaggle/input/vizwiz-2023-edition/Annotations/val.json"
classmapping_dir = "class_mapping.csv"

batch_size = 32
num_epoch = 25
learning_rate = 5e-4
es_patience = 8
lr_patience = 3
model_save_path = "ckpt_vilt.pth"

In [4]:
with open(classmapping_dir, "r") as f:
    next(f)  # Skip the header
    reader = csv.reader(f, skipinitialspace=True)
    class_mapping = dict(reader)
    label2id = {k: int(v) for k, v in class_mapping.items()}
    id2label = {v: k for k, v in label2id.items()}

In [5]:
def get_score(count: int) -> float:
    return min(1.0, count / 3)

def add_label_score(annotations):
    for annotation in tqdm(annotations):
        answers_dict = annotation["answers"]
        answer_count = {}
        for answers in answers_dict:
            answer = answers["answer"]
            answer_count[answer] = answer_count.get(answer, 0) + 1

        labels = []
        scores = []
        for answer_word in answer_count:
            if answer_word in list(label2id.keys()):
                labels.append(label2id[answer_word])
                scores.append(get_score(answer_count[answer_word]))
        annotation["labels"] = labels
        annotation["scores"] = scores

In [None]:
def ShearX(img, v):  # [-0.3, 0.3]
    assert -0.3 <= v <= 0.3
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v):  # [-0.3, 0.3]
    assert -0.3 <= v <= 0.3
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert -0.45 <= v <= 0.45
    if random.random() > 0.5:
        v = -v
    v = v * img.size[0]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateXabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert 0 <= v
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert -0.45 <= v <= 0.45
    if random.random() > 0.5:
        v = -v
    v = v * img.size[1]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def TranslateYabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    assert 0 <= v
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def Rotate(img, v):  # [-30, 30]
    assert -30 <= v <= 30
    if random.random() > 0.5:
        v = -v
    return img.rotate(v)


def AutoContrast(img, _):
    return PIL.ImageOps.autocontrast(img)


def Equalize(img, _):
    return PIL.ImageOps.equalize(img)


def Flip(img, _):  # not from the paper
    return PIL.ImageOps.mirror(img)


def Solarize(img, v):  # [0, 256]
    assert 0 <= v <= 256
    return PIL.ImageOps.solarize(img, v)


def SolarizeAdd(img, addition=0, threshold=128):
    img_np = np.array(img).astype(np.int32)
    img_np = img_np + addition
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def Posterize(img, v):  # [4, 8]
    v = int(v)
    v = max(1, v)
    return PIL.ImageOps.posterize(img, v)


def Contrast(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def Color(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Color(img).enhance(v)


def Brightness(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Sharpness(img, v):  # [0.1,1.9]
    assert 0.1 <= v <= 1.9
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def SamplePairing(imgs):  # [0, 0.4]
    def f(img1, v):
        i = np.random.choice(len(imgs))
        img2 = PIL.Image.fromarray(imgs[i])
        return PIL.Image.blend(img1, img2, v)

    return f


def Identity(img, v):
    return img


def augment_list():  # 16 oeprations and their ranges
    # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
    # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
    l = [
        (AutoContrast, 0, 1),
        (Equalize, 0, 1),
        (Rotate, 0, 30),
        (Posterize, 0, 4),
        (Color, 0.1, 1.9),
        (Contrast, 0.1, 1.9),
        (Brightness, 0.1, 1.9),
        (Sharpness, 0.1, 1.9),
        (ShearX, 0.0, 0.3),
        (ShearY, 0.0, 0.3),
        (TranslateXabs, 0.0, 100),
        (TranslateYabs, 0.0, 100),
    ]

    return l


class RandAugment:
    def __init__(self, n, m):
        self.n = n
        self.m = m  # [0, 30]
        self.augment_list = augment_list()

    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)
        for op, minval, maxval in ops:
            val = (float(self.m) / 30) * float(maxval - minval) + minval
            img = op(img, val)

        return img

In [6]:
class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, annotations, subset, processor):
        self.annotations = annotations
        self.processor = processor
        self.subset = subset

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # get image + text
        annotation = self.annotations[idx]
        image = Image.open(os.path.join("/kaggle/input/vizwiz-2023-edition", self.subset, self.subset, annotation["image"]))
        augmenter = RandAugment(n=2, m=9)
        image = augmenter(image)
        text = annotation['question']

        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        # add labels
        labels = annotation["labels"]
        scores = annotation["scores"]
        # # based on: https://github.com/dandelin/ViLT/blob/762fd3975c180db6fc88f577cf39549983fa373a/vilt/modules/objectives.py#L301
        targets = torch.zeros(len(label2id))
        for label, score in zip(labels, scores):
              targets[label] = score
        encoding["labels"] = targets

        return encoding

In [7]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

In [8]:
with open(train_json_dir, "r") as f:
    train_data = json.load(f)
    add_label_score(train_data)
train_dataset = VQADataset(annotations=train_data, subset="train", processor=processor)

with open(val_json_dir, "r") as f:
    val_data = json.load(f)
    add_label_score(val_data)
val_dataset = VQADataset(annotations=val_data, subset="val", processor=processor)

100%|██████████| 20523/20523 [00:14<00:00, 1370.07it/s]
100%|██████████| 4319/4319 [00:03<00:00, 1369.47it/s]


In [17]:
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-mlm",
                                                 id2label=id2label,
                                                 label2id=label2id)

model.to(device)

Some weights of ViltForQuestionAnswering were not initialized from the model checkpoint at dandelin/vilt-b32-finetuned-vqa and are newly initialized because the shapes did not match:
- classifier.3.weight: found shape torch.Size([3129, 1536]) in the checkpoint and torch.Size([5726, 1536]) in the model instantiated
- classifier.3.bias: found shape torch.Size([3129]) in the checkpoint and torch.Size([5726]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViltForQuestionAnswering(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=76

In [18]:
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    token_type_ids = [item['token_type_ids'] for item in batch]
    labels = [item['labels'] for item in batch]

    # create padded pixel values and corresponding pixel mask
    encoding = processor.image_processor.pad(pixel_values, return_tensors="pt")

    # create new batch
    batch = {}
    batch['input_ids'] = torch.stack(input_ids)
    batch['attention_mask'] = torch.stack(attention_mask)
    batch['token_type_ids'] = torch.stack(token_type_ids)
    batch['pixel_values'] = encoding['pixel_values']
    batch['pixel_mask'] = encoding['pixel_mask']
    batch['labels'] = torch.stack(labels)

    return batch


train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, 
                              shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=1, 
                            shuffle=False, num_workers=2, pin_memory=True)

In [19]:
# Early Stopping
# Stop training if validation loss does not improve
class EarlyStopping:

    def __init__(self, patience, model_save_path, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.model_save_path = model_save_path
        self.counter = 0
        self.min_validation_loss = np.inf
        self.best_epoch = 0
        self.early_stop = False


    def __call__(self, epoch, model, validation_loss):
        delta_loss = self.min_validation_loss - validation_loss
        # Check if val loss is smaller than min loss
        if delta_loss > self.min_delta:
            self.min_validation_loss = validation_loss
            self.counter = 0
            # Save best model
            self.best_epoch = epoch
            torch.save(model.state_dict(), self.model_save_path)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print(f"Early Stopping.")
                print(f"Save best model at epoch {self.best_epoch}")
                self.early_stop = True

In [20]:
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Reduce learning rate when validation loss stops improving
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=lr_patience, verbose=True)

In [21]:
def train_model(data_loader, model, optimizer, device):
    num_batches = len(data_loader)
    # Define loss
    total_loss = 0
    
    model.train()
    for batch in data_loader:
        # get the inputs;
        inputs = {k:v.to(device) for k,v in batch.items()}
        # forward pass
        outputs = model(**inputs)
        loss = outputs.loss
        # backward and optimize
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    # Loss over batches
    train_loss = total_loss / num_batches

    return train_loss


def val_model(data_loader, model, device):
    num_batches = len(data_loader)
    # Define loss and accuracy
    total_loss = 0
    total_acc = 0

    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            inputs = {k:v.to(device) for k,v in batch.items()}
            outputs = model(**inputs)

            # Calculate loss
            loss = outputs.loss
            total_loss += loss.item()

            # Get top predict for each question
            preds = outputs.logits.argmax(-1)
            # Get ground truth answers for each questiojn
            scores, labels = batch["labels"].to(device).topk(10, -1)
            # Calculate accuracy
            total_acc += scores[preds==labels].sum() / len(preds)
            
    # Accuracy over batches
    val_acc = total_acc / num_batches
    # Loss over batches
    val_loss = total_loss / num_batches

    return val_acc, val_loss

In [None]:
# Log metrics for plotting
all_losses = []

# Initialize Early Stopping object
early_stopper = EarlyStopping(patience=es_patience, model_save_path=model_save_path)
for epoch in range(num_epoch):
    print(f"Epoch [{epoch}/{num_epoch-1}]")
    train_loss = train_model(train_dataloader, model, optimizer, device)
    val_acc, val_loss = val_model(val_dataloader, model, device)
    
    all_losses.append([train_loss, val_loss])

    # Display
    print(f"Train loss: {train_loss:.5f} - Val loss: {val_loss:.5f}")
    print(f"Val accuracy: {val_acc:.5f}\n")
    
    # EarlyStopping
    early_stopper(epoch, model, val_loss)
    if early_stopper.early_stop:
        break
    # Adjust learning rate
    lr_scheduler.step(val_loss)