<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Vision_Language_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Download Dataset

In [1]:
import gdown
url = 'https://drive.google.com/uc?id=1AOuJXt9yWZfLwPoZsFfWFEgrcERGehOm'
gdown.download(url,'archive.zip',quiet=True)
!unzip -q archive.zip

Install Packages

In [2]:
!pip -q install transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m45.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[?25h

Prepare Dataloader

In [6]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
import pandas as pd
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data_root=None, transform=None, istrain=True):
        if istrain:
            self.data_json= pd.read_json(path_or_buf=os.path.join(data_root,'train.jsonl'), lines=True)
        else:
            self.data_json= pd.read_json(path_or_buf=os.path.join(data_root,'dev.jsonl'), lines=True)

        self.transform = transform
        self.img_root = os.path.join( data_root, 'img')
        self.bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

    def __len__(self):
      return len(self.data_json['id'])

    def __getitem__(self, i):
        img = Image.open(os.path.join( data_root, 'img', str("{:05d}".format(self.data_json['id'][i]))+'.png')).convert('RGB')
        text = self.data_json['text'][i]
        label = self.data_json['label'][i]
        if self.transform:
            img = self.transform(img)
        return img, text, label


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
        )
    ])

data_root = '/content/data'
dataset_train = CustomDataset(data_root=data_root, transform=transform, istrain=True)
dataset_test = CustomDataset(data_root=data_root, transform=transform, istrain=False)
print('Number of training samples:', len(dataset_train), 'Number of test samples:',len(dataset_test))

dataloader_train = DataLoader(dataset_train, batch_size=2, shuffle=False, num_workers=2)
dataloader_test = DataLoader(dataset_test, batch_size=4, shuffle=False, num_workers=2)

Number of training samples: 8500 Number of test samples: 500


Training

In [None]:
import sys
import os
import argparse
import torch
from torch import nn
from torchvision.models import resnet18
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_args():
    parser = argparse.ArgumentParser(description='CIFAR-10H Training')
    parser.add_argument('--lr', default=0.00001, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=20, type=int, help='batch size')
    parser.add_argument('--test_batch_size', default=40, type=int, help='batch size')
    parser.add_argument('--num_epoch', default=2, type=int, help='epoch number')
    parser.add_argument('--num_classes', type=int, default=2, help='number classes')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()

    return args


class visual_feat_extractor(nn.Module):
    def __init__(self, ):
        super(visual_feat_extractor, self).__init__()
        self.model_visual_feat = resnet18(pretrained=True)
        self.model_visual_feat.avgpool = nn.Identity()
        self.model_visual_feat.fc = nn.Identity()
        self.model_visual_feat.to(device)
        self.model_visual_feat.eval()

    def forward(self, img):
        visual_embeds = self.model_visual_feat(img).view(-1, 49, 512)
        return visual_embeds

class VisualBERT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(VisualBERT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config.visual_embedding_dim = 512
        self.visualbert = VisualBertModel(config=self.config)
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]
        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits

def train(model_vqa, text_model, img_model, dataloader_train, criterion, optimizer):
    model_vqa.train()
    for batch_idx, (imgs, texts, targets) in enumerate(dataloader_train):
        imgs, targets = imgs.to(device), targets.to(device)
        inputs = text_model(texts, return_tensors="pt", padding="max_length", max_length=20, truncation=True).to(device)
        with torch.no_grad():
            img_embed = img_model(imgs)

        visual_token_type_ids = torch.ones(img_embed.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(img_embed.shape[:-1], dtype=torch.float).to(device)
        inputs.update({
                "visual_embeds": img_embed,
                "visual_token_type_ids": visual_token_type_ids,
                "visual_attention_mask": visual_attention_mask,
            })
        logits = model_vqa(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

def test(model, text_model, img_model, dataloader_test):
    model.eval()
    correct = 0

    with torch.no_grad():
        for batch_idx, (imgs, texts, targets) in enumerate(dataloader_test):
            imgs, targets = imgs.to(device), targets.to(device)
            inputs = text_model(texts, return_tensors="pt", padding="max_length", max_length=20, truncation=True).to(device)
            with torch.no_grad():
                img_embed = img_model(imgs)

            visual_token_type_ids = torch.ones(img_embed.shape[:-1], dtype=torch.long).to(device)
            visual_attention_mask = torch.ones(img_embed.shape[:-1], dtype=torch.float).to(device)
            inputs.update({
                    "visual_embeds": img_embed,
                    "visual_token_type_ids": visual_token_type_ids,
                    "visual_attention_mask": visual_attention_mask,
                })
            logits = model_vqa(inputs)
            _, predicted = logits.max(1)
            correct += predicted.eq(targets).sum().item()

    return correct / len(dataloader_test.dataset)


args = get_args()
visual_embeds_model = visual_feat_extractor().to(device)
visual_embeds_model.eval()
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

model_vqa = VisualBERT_VQA(num_labels=args.num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_vqa.parameters(), lr=args.lr)
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2)
dataloader_test = DataLoader(dataset_test, batch_size=args.test_batch_size, shuffle=False, num_workers=2)
best_epoch, best_acc = 0.0, 0

for epoch in range(args.num_epoch):
    train(model_vqa, bert_tokenizer, visual_embeds_model, dataloader_train, criterion, optimizer)
    accuracy = test(model_vqa, bert_tokenizer, visual_embeds_model, dataloader_test)
    if accuracy > best_acc:
        best_acc = accuracy
        best_epoch = epoch
        torch.save(model_vqa.state_dict(), 'best_model.pth.tar')

    print('epoch: {}  acc: {:.4f}  best epoch: {}  best acc: {:.4f}'.format(
                epoch, accuracy, best_epoch, best_acc))





In [65]:
epoch, accuracy, best_epoch, best_acc

(0, 0.5, 0, 0.5)

In [6]:
def encoding_saver( json_file, save_name ):
    dataset = []
    for i in range(0,len(json_file['id'])):
        processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
        # text_features = tokenizer(
        #                         json_file['text'][i],
        #                         add_special_tokens = True,
        #                         max_length = 100,
        #                         pad_to_max_length = True ,
        #                         return_attention_mask = True,
        #                         return_tensors ='pt')
        # vison_feautes = vision_encoder(
        #             cv2.imread(json_file['img'][i]),
        #             do_resize = True,
        #             size = 500,
        #             do_normalize = True,
        #             return_tensor = 'pt',
        # )
        label = {'labels' : json_file['label'][i]}
        # final_encoded_dict = { **text_features, **vison_feautes, **label}
        img = Image.open(json_file['img'][i]).convert('RGB')
        text = json_file['text'][i]
        encoding = processor(img, text, return_tensors="pt")
        final_encoded_dict = {**encoding, **label}
        dataset.append(final_encoded_dict)
    dataset = np.array(dataset)
    np.save(save_name, dataset)
    print("Completed ")
    return dataset