<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/Vision_Language_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Download Dataset

In [1]:
import gdown
url = 'https://drive.google.com/uc?id=1AOuJXt9yWZfLwPoZsFfWFEgrcERGehOm'
gdown.download(url,'archive.zip',quiet=True)
!unzip -q archive.zip

Install Packages

In [2]:
!pip -q install transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m66.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
[?25h

Prepare Dataloader<br>
    Label0 Label1<br>
Test: 250 251<br>
Train: 5450 3051

In [3]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
import pandas as pd
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data_root=None, transform=None, istrain=True):
        if istrain:
            self.data_json= pd.read_json(path_or_buf=os.path.join(data_root,'train.jsonl'), lines=True)
        else:
            self.data_json= pd.read_json(path_or_buf=os.path.join(data_root,'dev.jsonl'), lines=True)

        self.transform = transform
        self.img_root = os.path.join( data_root, 'img')

    def __len__(self):
      return len(self.data_json['id'])

    def __getitem__(self, i):
        img = Image.open(os.path.join( data_root, 'img', str("{:05d}".format(self.data_json['id'][i]))+'.png')).convert('RGB')
        text = self.data_json['text'][i]
        label = self.data_json['label'][i]
        if self.transform:
            img = self.transform(img)
        return img, text, label


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
        )
    ])

data_root = '/content/data'
dataset_train = CustomDataset(data_root=data_root, transform=transform, istrain=True)
dataset_test = CustomDataset(data_root=data_root, transform=transform, istrain=False)
print('Number of training samples:', len(dataset_train), 'Number of test samples:',len(dataset_test))

dataloader_train = DataLoader(dataset_train, batch_size=2, shuffle=True, num_workers=2)
dataloader_test = DataLoader(dataset_test, batch_size=4, shuffle=False, num_workers=2)
print('Number of training loader:', len(dataloader_train), 'Number of test loader:',len(dataloader_test))

Number of training samples: 8500 Number of test samples: 500
Number of training loader: 4250 Number of test loader: 125


Training V1: (ResNet18 and without VQA pretrained weights)

In [None]:
import sys
import os
import argparse
import torch
from torch import nn
from torchvision.models import resnet18
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--lr', default=0.00001, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=20, type=int, help='batch size')
    parser.add_argument('--test_batch_size', default=40, type=int, help='batch size')
    parser.add_argument('--num_epoch', default=30, type=int, help='epoch number')
    parser.add_argument('--num_classes', type=int, default=2, help='number classes')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()

    return args


class visual_feat_extractor(nn.Module):
    def __init__(self, ):
        super(visual_feat_extractor, self).__init__()
        self.model_visual_feat = resnet18(pretrained=True)#b 1000
        self.model_visual_feat.avgpool = nn.Identity()
        self.model_visual_feat.fc = nn.Identity()
        self.model_visual_feat.to(device)
        self.model_visual_feat.eval()

    def forward(self, img):
        visual_embeds = self.model_visual_feat(img).view(-1, 49, 512) #b 49 512
        return visual_embeds

class VisualBERT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(VisualBERT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.config.visual_embedding_dim = 512
        self.visualbert = VisualBertModel(config=self.config)
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]
        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits

def train_epoch(model_vqa, text_model, img_model, dataloader_train, criterion, optimizer):
    model_vqa.train()
    loss_all = 0
    for batch_idx, (imgs, texts, targets) in enumerate(dataloader_train):
        imgs, targets = imgs.to(device), targets.to(device)
        inputs = text_model(texts, return_tensors="pt", padding="max_length", max_length=20, truncation=True).to(device)#b 20
        with torch.no_grad():
            img_embed = img_model(imgs)#b 512

        visual_token_type_ids = torch.ones(img_embed.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(img_embed.shape[:-1], dtype=torch.float).to(device)
        inputs.update({
                "visual_embeds": img_embed,
                "visual_token_type_ids": visual_token_type_ids,
                "visual_attention_mask": visual_attention_mask,
            })

        logits = model_vqa(inputs)
        loss = criterion(logits, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_all += loss.item()

    return loss_all/len(dataloader_train.dataset)

def test_epoch(model, text_model, img_model, dataloader_test, criterion):
    model.eval()
    test_sample_size = len(dataloader_test.dataset)
    correct = 0
    loss_all = 0

    with torch.no_grad():
        for batch_idx, (imgs, texts, targets) in enumerate(dataloader_test):
            imgs, targets = imgs.to(device), targets.to(device)
            inputs = text_model(texts, return_tensors="pt", padding="max_length", max_length=20, truncation=True).to(device)
            with torch.no_grad():
                img_embed = img_model(imgs)

            visual_token_type_ids = torch.ones(img_embed.shape[:-1], dtype=torch.long).to(device)
            visual_attention_mask = torch.ones(img_embed.shape[:-1], dtype=torch.float).to(device)
            inputs.update({
                    "visual_embeds": img_embed,
                    "visual_token_type_ids": visual_token_type_ids,
                    "visual_attention_mask": visual_attention_mask,
                })
            logits = model(inputs)# b C logits/Prob. 40 2 1 img [0.1 0.9]
            loss = criterion(logits, targets)
            loss_all += loss.item()
            _, predicted = logits.max(1)
            correct += predicted.eq(targets).sum().item()

        print('Correct:', correct, 'Total Sample:', test_sample_size)

    return correct / test_sample_size, loss_all / test_sample_size

def main():
    args = get_args()
    visual_embeds_model = visual_feat_extractor().to(device)
    visual_embeds_model.eval()
    bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

    model_vqa = VisualBERT_VQA(num_labels=args.num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model_vqa.parameters(), lr=args.lr)

    data_root = '/content/data'
    dataset_train = CustomDataset(data_root=data_root, transform=transform, istrain=True)
    dataset_test = CustomDataset(data_root=data_root, transform=transform, istrain=False)
    print('Number of training samples:', len(dataset_train), 'Number of test samples:',len(dataset_test))
    dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2)
    dataloader_test = DataLoader(dataset_test, batch_size=args.test_batch_size, shuffle=False, num_workers=2)
    best_epoch, best_acc = 0.0, 0

    for epoch in range(args.num_epoch):
        train_loss = train_epoch(model_vqa, bert_tokenizer, visual_embeds_model, dataloader_train, criterion, optimizer)
        accuracy, test_loss = test_epoch(model_vqa, bert_tokenizer, visual_embeds_model, dataloader_test, criterion)
        if accuracy > best_acc:
            best_acc = accuracy
            best_epoch = epoch
            torch.save(model_vqa.state_dict(), 'best_model.pth.tar')

        print('epoch: {}/{}  current:[train loss: {:.4f} test loss:{:.4f} test acc: {:.4f}]  best epoch: {}  best test acc: {:.4f}'.format(
                    epoch, args.num_epoch, train_loss, test_loss, accuracy, best_epoch, best_acc))

main()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 267MB/s]


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/631 [00:00<?, ?B/s]

Number of training samples: 8500 Number of test samples: 500
Correct: 259 Total Sample: 500
epoch: 0/30  current:[train loss: 0.0333 test loss:0.0182 test acc: 0.5180]  best epoch: 0  best test acc: 0.5180
Correct: 250 Total Sample: 500
epoch: 1/30  current:[train loss: 0.0325 test loss:0.0230 test acc: 0.5000]  best epoch: 0  best test acc: 0.5180
Correct: 266 Total Sample: 500
epoch: 2/30  current:[train loss: 0.0319 test loss:0.0183 test acc: 0.5320]  best epoch: 2  best test acc: 0.5320
Correct: 253 Total Sample: 500
epoch: 3/30  current:[train loss: 0.0312 test loss:0.0199 test acc: 0.5060]  best epoch: 2  best test acc: 0.5320
Correct: 252 Total Sample: 500
epoch: 4/30  current:[train loss: 0.0303 test loss:0.0212 test acc: 0.5040]  best epoch: 2  best test acc: 0.5320
Correct: 252 Total Sample: 500
epoch: 5/30  current:[train loss: 0.0290 test loss:0.0237 test acc: 0.5040]  best epoch: 2  best test acc: 0.5320
Correct: 254 Total Sample: 500
epoch: 6/30  current:[train loss: 0.02

Training V2: (ResNet101 and with VQA pretrained weights)

In [None]:
import sys
import os
import argparse
import torch
from torch import nn
from torchvision.models import resnet18, resnet101
from transformers import VisualBertModel, VisualBertConfig, BertTokenizerFast

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_args():
    parser = argparse.ArgumentParser(description='CIFAR-10H Training')
    parser.add_argument('--lr', default=0.00001, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=20, type=int, help='batch size')
    parser.add_argument('--test_batch_size', default=40, type=int, help='batch size')
    parser.add_argument('--num_epoch', default=30, type=int, help='epoch number')
    parser.add_argument('--num_classes', type=int, default=2, help='number classes')

    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()

    return args


class visual_feat_extractor(nn.Module):
    def __init__(self, ):
        super(visual_feat_extractor, self).__init__()
        self.model_visual_feat = resnet101(pretrained=True)#b 1000
        self.model_visual_feat.avgpool = nn.Identity()
        self.model_visual_feat.fc = nn.Identity()
        self.model_visual_feat.to(device)
        self.model_visual_feat.eval()

    def forward(self, img):
        visual_embeds = self.model_visual_feat(img).view(-1, 49, 2048) #b 49 2048
        return visual_embeds

class VisualBERT_VQA(nn.Module):
    def __init__(self, num_labels=2):
        super(VisualBERT_VQA, self).__init__()
        self.config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        # self.config.visual_embedding_dim = 512
        self.visualbert = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        self.cls = nn.Linear(768, num_labels)

    def forward(self, inputs):
        last_hidden_state = self.visualbert(**inputs).last_hidden_state #[1, 56, 768]

        # Get the index of the last text token
        index_to_gather = inputs['attention_mask'].sum(1) - 2  # as in original
        index_to_gather = (
            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, last_hidden_state.size(-1))
        ) # [b c hw]=[1, 1, 768]
        pooled_output = torch.gather(last_hidden_state, 1, index_to_gather) # [1, 1, 768]
        logits = self.cls(pooled_output).squeeze(1)
        return logits

def train_epoch(model_vqa, text_model, img_model, dataloader_train, criterion, optimizer):
    model_vqa.train()
    loss_all = 0
    for batch_idx, (imgs, texts, targets) in enumerate(dataloader_train):
        imgs, targets = imgs.to(device), targets.to(device)
        inputs = text_model(texts, return_tensors="pt", padding="max_length", max_length=20, truncation=True).to(device)#b 20
        with torch.no_grad():
            img_embed = img_model(imgs)#b 512

        visual_token_type_ids = torch.ones(img_embed.shape[:-1], dtype=torch.long).to(device)
        visual_attention_mask = torch.ones(img_embed.shape[:-1], dtype=torch.float).to(device)
        inputs.update({
                "visual_embeds": img_embed,
                "visual_token_type_ids": visual_token_type_ids,
                "visual_attention_mask": visual_attention_mask,
            })
        optimizer.zero_grad()
        logits = model_vqa(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        loss_all += loss.item()

    return loss_all/len(dataloader_train.dataset)

def test_epoch(model, text_model, img_model, dataloader_test, criterion):
    model.eval()
    test_sample_size = len(dataloader_test.dataset)
    correct = 0
    loss_all = 0

    with torch.no_grad():
        for batch_idx, (imgs, texts, targets) in enumerate(dataloader_test):
            imgs, targets = imgs.to(device), targets.to(device)
            inputs = text_model(texts, return_tensors="pt", padding="max_length", max_length=20, truncation=True).to(device)
            with torch.no_grad():
                img_embed = img_model(imgs)

            visual_token_type_ids = torch.ones(img_embed.shape[:-1], dtype=torch.long).to(device)
            visual_attention_mask = torch.ones(img_embed.shape[:-1], dtype=torch.float).to(device)
            inputs.update({
                    "visual_embeds": img_embed,
                    "visual_token_type_ids": visual_token_type_ids,
                    "visual_attention_mask": visual_attention_mask,
                })
            logits = model(inputs)# b C logits/Prob. 40 2 1 img [0.1 0.9]
            loss = criterion(logits, targets)
            loss_all += loss.item()
            _, predicted = logits.max(1)
            #targets 40
            correct += predicted.eq(targets).sum().item()

        print('Correct:', correct, 'Total Sample:', test_sample_size)

    return correct / test_sample_size, loss_all / test_sample_size

def main():
    args = get_args()
    visual_embeds_model = visual_feat_extractor().to(device)
    visual_embeds_model.eval()
    bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

    model_vqa = VisualBERT_VQA(num_labels=args.num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model_vqa.parameters(), lr=args.lr)

    data_root = '/content/data'
    dataset_train = CustomDataset(data_root=data_root, transform=transform, istrain=True)
    dataset_test = CustomDataset(data_root=data_root, transform=transform, istrain=False)
    print('Number of training samples:', len(dataset_train), 'Number of test samples:',len(dataset_test))
    dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2)
    dataloader_test = DataLoader(dataset_test, batch_size=args.test_batch_size, shuffle=False, num_workers=2)
    best_epoch, best_acc = 0.0, 0

    for epoch in range(args.num_epoch):
        train_loss = train_epoch(model_vqa, bert_tokenizer, visual_embeds_model, dataloader_train, criterion, optimizer)
        accuracy, test_loss = test_epoch(model_vqa, bert_tokenizer, visual_embeds_model, dataloader_test, criterion)
        if accuracy > best_acc:
            best_acc = accuracy
            best_epoch = epoch
            torch.save(model_vqa.state_dict(), 'best_model.pth.tar')

        print('epoch: {}/{}  current:[train loss: {:.4f} test loss:{:.4f} acc: {:.4f}]  best epoch: {}  best acc: {:.4f}'.format(
                    epoch, args.num_epoch-1, train_loss, test_loss, accuracy, best_epoch, best_acc))

main()



Number of training samples: 8500 Number of test samples: 500
Correct: 274 Total Sample: 500
epoch: 0/29  current:[train loss: 0.0309 test loss:0.0192 acc: 0.5480]  best epoch: 0  best acc: 0.5480
Correct: 263 Total Sample: 500
epoch: 1/29  current:[train loss: 0.0270 test loss:0.0229 acc: 0.5260]  best epoch: 0  best acc: 0.5480
Correct: 266 Total Sample: 500
epoch: 2/29  current:[train loss: 0.0238 test loss:0.0237 acc: 0.5320]  best epoch: 0  best acc: 0.5480
Correct: 262 Total Sample: 500
epoch: 3/29  current:[train loss: 0.0210 test loss:0.0240 acc: 0.5240]  best epoch: 0  best acc: 0.5480
Correct: 266 Total Sample: 500
epoch: 4/29  current:[train loss: 0.0183 test loss:0.0249 acc: 0.5320]  best epoch: 0  best acc: 0.5480
Correct: 268 Total Sample: 500
epoch: 5/29  current:[train loss: 0.0161 test loss:0.0278 acc: 0.5360]  best epoch: 0  best acc: 0.5480
Correct: 268 Total Sample: 500
epoch: 6/29  current:[train loss: 0.0149 test loss:0.0308 acc: 0.5360]  best epoch: 0  best acc: 0