## Setting up environment and importing libraries

In [1]:
# install required libraries
!pip install transformers timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.20.1-py3-none-any.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 6.7 MB/s 
[?25hCollecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 47.7 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 36.7 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 8.0 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 58.4 MB/s 
Installing collected packages:

In [2]:
import json
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, classification_report
from transformers import AutoModel, AutoTokenizer, get_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW, lr_scheduler
from tqdm.notebook import tqdm, trange
from time import perf_counter
from PIL import Image

## Data loading and common configs

In [20]:
HOME_FOLDER = '/content/drive/MyDrive/KDD/'
DATA_FOLDER = HOME_FOLDER + 'webvision_data/'
IMAGE_FOLDER = DATA_FOLDER + 'images/'
RESULTS_FOLDER = HOME_FOLDER + 'results/'
os.makedirs(RESULTS_FOLDER, exist_ok=True)

In [21]:
import pandas as pd

df_train = pd.read_csv(DATA_FOLDER + 'train.csv')
df_test = pd.read_csv(DATA_FOLDER + 'test.csv')

# workaround to remove errorneous paths
df_train = df_train[df_train['img_path'] != '#NAME?']
df_test = df_test[df_test['img_path'] != '#NAME?']

In [22]:
label_to_id = {lab:i for i, lab in enumerate(df_train['label'].unique())}
id_to_label = {v:k for k,v in label_to_id.items()}

In [23]:
label_to_id

{'bookshop': 4,
 'breakwater': 2,
 'chiton': 0,
 'coil': 8,
 'confectionery': 7,
 'gar': 5,
 'gasmask': 6,
 'polecat': 1,
 'seashore': 9,
 'streetcar': 3}

In [24]:
num_out_labels = len(label_to_id)

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
# set random seeds for repeatability
seed = 19
torch.manual_seed(seed)
# random.seed(self.seed)
# np.random.seed(self.seed)

<torch._C.Generator at 0x7f625fe2d410>

In [27]:
## training parameters to be used for all models ##
num_train_epochs = 5
batch_size = 16
learning_rate = 2.0e-5
weight_decay = 1.0
warmup_steps = 0
max_seq_length = 200

## BERT

### Dataset

In [11]:
class TextDataset(Dataset):
    def __init__(self, df, label_to_id, train, text_field="text", label_field="label"):
        self.df = df.reset_index()
        self.label_to_id = label_to_id
        self.text_field = text_field
        self.label_field = label_field
                          
    def __getitem__(self, index):
        text = str(self.df.at[index, self.text_field])
        label = self.label_to_id[self.df.at[index, self.label_field]]
        
        return text, label

    def __len__(self):
        return self.df.shape[0]

### Model

In [12]:
class BertModel(nn.Module):

    def __init__(self, num_labels, text_pretrained='bert-base-uncased'):
        super().__init__()

        self.num_labels = num_labels
        self.bert_model = AutoModel.from_pretrained(text_pretrained)
        self.classifier = nn.Linear(
            self.bert_model.config.hidden_size, num_labels)
        
    
    def forward(self, text):
        output = self.bert_model(text.input_ids, attention_mask=text.attention_mask, return_dict=True)
        logits = self.classifier(output.pooler_output)
        return logits

In [None]:
# intialise the model
bert_model = BertModel(num_labels=num_out_labels, text_pretrained='bert-base-uncased')
bert_model.to(device)

In [14]:
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

### Training

In [15]:
## training loop

train_dataset = TextDataset(df=df_train, label_to_id=label_to_id, train=True, text_field='text', label_field='label')
train_sampler = RandomSampler(train_dataset)        
train_dataloader = DataLoader(dataset=train_dataset,
                    batch_size=batch_size, 
                    sampler=train_sampler)


t_total = len(train_dataloader) * num_train_epochs


optimizer = AdamW(bert_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)

tr_loss = 0.0

criterion = nn.CrossEntropyLoss()

bert_model.train()


start = perf_counter()
for epoch_num in trange(num_train_epochs, desc='Epochs'):
    epoch_total_loss = 0

    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc='Batch'):        
        b_text, b_labels = batch            
        b_inputs = bert_tokenizer(
            list(b_text), truncation=True, max_length=max_seq_length,
            return_tensors="pt", padding=True
        )
        
        b_labels = b_labels.to(device)
        b_inputs = b_inputs.to(device)

        bert_model.zero_grad()        
        b_logits = bert_model(text=b_inputs)
        
        loss = criterion(b_logits, b_labels)

        epoch_total_loss += loss.item()

        # Perform a backward pass to calculate the gradients
        loss.backward()


        optimizer.step()
        scheduler.step()
        
    tr_loss += epoch_total_loss
    avg_loss = epoch_total_loss/len(train_dataloader)


    print('epoch =', epoch_num)
    print('    epoch_loss =', epoch_total_loss)
    print('    avg_epoch_loss =', avg_loss)
    print('    learning rate =', optimizer.param_groups[0]["lr"])

end = perf_counter()
bert_training_time = end- start
print('Training completed in ', bert_training_time, 'seconds')

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 0
    epoch_loss = 109.5923444032669
    avg_epoch_loss = 2.191846888065338
    learning rate = 1.6000000000000003e-05


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 1
    epoch_loss = 80.37162786722183
    avg_epoch_loss = 1.6074325573444366
    learning rate = 1.2e-05


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 2
    epoch_loss = 53.82781982421875
    avg_epoch_loss = 1.076556396484375
    learning rate = 8.000000000000001e-06


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 3
    epoch_loss = 39.42004483938217
    avg_epoch_loss = 0.7884008967876435
    learning rate = 4.000000000000001e-06


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 4
    epoch_loss = 33.34722635149956
    avg_epoch_loss = 0.6669445270299912
    learning rate = 0.0
Training completed in  35.55705618999946 seconds


### Testing

In [16]:
# testing loop

bert_prediction_results = []

test_dataset = TextDataset(df=df_test, label_to_id=label_to_id, train=False, text_field='text', label_field='label')
test_sampler = SequentialSampler(test_dataset)        
test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=batch_size, 
                            sampler=test_sampler)


for batch in tqdm(test_dataloader):
  bert_model.eval()

  b_text, b_labels = batch

  b_inputs = bert_tokenizer(list(b_text), truncation=True, max_length=max_seq_length, return_tensors="pt", padding=True)

  b_labels = b_labels.to(device)
  b_inputs = b_inputs.to(device)

  with torch.no_grad():
      b_logits = bert_model(text=b_inputs)
      b_logits = b_logits.detach().cpu()



  bert_prediction_results += torch.argmax(b_logits, dim=-1).tolist()

bert_prediction_labels = [id_to_label[p] for p in bert_prediction_results]

  0%|          | 0/13 [00:00<?, ?it/s]

In [17]:
bert_class_report = classification_report(df_test['label'], bert_prediction_labels, output_dict=True)
bert_class_report['training_time (seconds)'] = bert_training_time

with open(RESULTS_FOLDER + 'bert_class_report.json', 'w') as f:
  json.dump(bert_class_report, f)

print(bert_class_report['accuracy'])

0.8232323232323232


## BERT + ResNet-50

### Dataset

In [18]:
class ResNetDataset(Dataset):
    def __init__(self, df, label_to_id, train=False, text_field="text", label_field="label", image_path_field="img_path"):
        self.df = df.reset_index()
        self.label_to_id = label_to_id
        self.train = train
        self.text_field = text_field
        self.label_field = label_field
        self.image_path_field = image_path_field

        # ResNet-50 settings
        self.img_size = 224
        self.mean, self.std = (
            0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)


        self.train_transform_func = transforms.Compose(
                [transforms.RandomResizedCrop(self.img_size, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std)
                    ])

        self.eval_transform_func = transforms.Compose(
                [transforms.Resize(256),
                    transforms.CenterCrop(self.img_size),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std)
                    ])

                    
    def __getitem__(self, index):
        text = str(self.df.at[index, self.text_field])
        label = self.label_to_id[self.df.at[index, self.label_field]]
        img_path = IMAGE_FOLDER + self.df.at[index, self.image_path_field]

        
        image = Image.open(img_path)
        if self.train:
          img = self.train_transform_func(image)
        else:
          img = self.eval_transform_func(image)

        return text, label, img

    def __len__(self):
        return self.df.shape[0]

### Model

In [19]:
# extract layers of resnet-50 to build a new model

import torch.nn as nn
from torchvision.models.resnet import resnet50

class ResNetFeatureModel(nn.Module):
    def __init__(self, output_layer):
        super().__init__()
        self.output_layer = output_layer
        self.pretrained = resnet50(pretrained=True)
        self.children_list = []
        for n,c in self.pretrained.named_children():
            self.children_list.append(c)
            if n == self.output_layer:
                break

        self.net = nn.Sequential(*self.children_list)
        self.pretrained = None
        
    def forward(self,x):
        x = self.net(x)
        x = torch.flatten(x, 1)
        return x

In [20]:
# uncomment to see all the layers of resnet50
# resnet50(pretrained=True)

In [21]:
# last output layer name for resnet is named 'layer4', dim 2048*7*7
# last layer name before fc is named 'avgpool', dim 2048*1*1 -> needs to be flattened
# reference: https://medium.com/the-owl/extracting-features-from-an-intermediate-layer-of-a-pretrained-model-in-pytorch-c00589bda32b

class BertResNetModel(nn.Module):
    def __init__(self, num_labels, text_pretrained='bert-base-uncased'):
        super().__init__()
        self.bert_model = AutoModel.from_pretrained(text_pretrained)
        self.image_model = ResNetFeatureModel(output_layer='avgpool')
        self.image_hidden_size = 2048
        
        self.classifier = nn.Linear(self.bert_model.config.hidden_size + self.image_hidden_size, num_labels)

    def forward(self, text, image):
        text_output = self.bert_model(**text)
        # text_feature = text_output[0][:, 0, :]
        text_feature = text_output.pooler_output
        img_feature = self.image_model(image)
        features = torch.cat((text_feature, img_feature), 1)

        logits = self.classifier(features)

        return logits

In [None]:
resnet_model = BertResNetModel(num_labels=num_out_labels, text_pretrained='bert-base-uncased')
resnet_model.to(device)

In [23]:
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

### Training

In [24]:
## training loop


train_dataset = ResNetDataset(df=df_train, label_to_id=label_to_id, train=True, text_field='text', label_field='label', image_path_field='img_path')
train_sampler = RandomSampler(train_dataset)        
train_dataloader = DataLoader(dataset=train_dataset,
                    batch_size=batch_size, 
                    sampler=train_sampler)


t_total = len(train_dataloader) * num_train_epochs


optimizer = AdamW(resnet_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)

tr_loss = 0.0

criterion = nn.CrossEntropyLoss()

resnet_model.train()

start = perf_counter()
for epoch_num in trange(num_train_epochs, desc='Epochs'):
    epoch_total_loss = 0

    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc='Batch'):        
        b_text, b_labels, b_imgs = batch            
        b_inputs = bert_tokenizer(
            list(b_text), truncation=True, max_length=max_seq_length,
            return_tensors="pt", padding=True
        )
        
        b_labels = b_labels.to(device)
        b_imgs = b_imgs.to(device)
        b_inputs = b_inputs.to(device)

        resnet_model.zero_grad()        
        b_logits = resnet_model(text=b_inputs, image=b_imgs)
        
        loss = criterion(b_logits, b_labels)

        epoch_total_loss += loss.item()

        # Perform a backward pass to calculate the gradients
        loss.backward()


        optimizer.step()
        scheduler.step()
        
    tr_loss += epoch_total_loss
    avg_loss = epoch_total_loss/len(train_dataloader)


    print('epoch =', epoch_num)
    print('    epoch_loss =', epoch_total_loss)
    print('    avg_epoch_loss =', avg_loss)
    print('    learning rate =', optimizer.param_groups[0]["lr"])
end = perf_counter()
resnet_training_time = end- start
print('Training completed in ', resnet_training_time, 'seconds')

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 0
    epoch_loss = 102.82620513439178
    avg_epoch_loss = 2.0565241026878356
    learning rate = 1.6000000000000003e-05


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 1
    epoch_loss = 72.51278483867645
    avg_epoch_loss = 1.450255696773529
    learning rate = 1.2e-05


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 2
    epoch_loss = 51.52484309673309
    avg_epoch_loss = 1.030496861934662
    learning rate = 8.000000000000001e-06


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 3
    epoch_loss = 38.98902493715286
    avg_epoch_loss = 0.7797804987430572
    learning rate = 4.000000000000001e-06


Batch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch = 4
    epoch_loss = 35.056549072265625
    avg_epoch_loss = 0.7011309814453125
    learning rate = 0.0
Training completed in  374.2873686040002 seconds


### Testing

In [25]:
# testing loop

resnet_prediction_results = []

test_dataset = ResNetDataset(df=df_test, label_to_id=label_to_id, train=False, text_field='text', label_field='label', image_path_field='img_path')
test_sampler = SequentialSampler(test_dataset)        
test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=batch_size, 
                            sampler=test_sampler)


for batch in tqdm(test_dataloader):
  resnet_model.eval()

  b_text, b_labels, b_imgs = batch

  b_inputs = bert_tokenizer(list(b_text), truncation=True, max_length=max_seq_length, return_tensors="pt", padding=True)

  b_labels = b_labels.to(device)
  b_imgs = b_imgs.to(device)
  b_inputs = b_inputs.to(device)

  with torch.no_grad():
      b_logits = resnet_model(text=b_inputs, image=b_imgs)
      b_logits = b_logits.detach().cpu()

  resnet_prediction_results += torch.argmax(b_logits, dim=-1).tolist()

resnet_prediction_labels = [id_to_label[p] for p in resnet_prediction_results]

  0%|          | 0/13 [00:00<?, ?it/s]

In [26]:
resnet_class_report = classification_report(df_test['label'], resnet_prediction_labels, output_dict=True)
resnet_class_report['training_time (seconds)'] = resnet_training_time

with open(RESULTS_FOLDER + 'resnet_class_report.json', 'w') as f:
  json.dump(resnet_class_report, f)

print(resnet_class_report['accuracy'])

0.8080808080808081


## ALBEF

### ALBEF-specific setup

In [28]:
ALBEF_FOLDER = HOME_FOLDER + 'ALBEF/'
os.makedirs(ALBEF_FOLDER, exist_ok=True)

In [29]:
# download pre-trained ALBEF model and required ALBEF files from ALBEF's official repo (only need to do this once to save it in your gdrive)
!wget https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF_4M.pth -O $ALBEF_FOLDER/ALBEF.pth
!wget https://raw.githubusercontent.com/salesforce/ALBEF/main/models/vit.py -O $ALBEF_FOLDER/vit.py
!wget https://raw.githubusercontent.com/salesforce/ALBEF/main/models/tokenization_bert.py -O $ALBEF_FOLDER/tokenization_bert.py
!wget https://raw.githubusercontent.com/salesforce/ALBEF/main/models/xbert.py -O $ALBEF_FOLDER/xbert.py
!wget https://raw.githubusercontent.com/salesforce/ALBEF/main/configs/config_bert.json -O $ALBEF_FOLDER/config_bert.json

--2022-06-22 04:07:22--  https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF_4M.pth
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.2.128, 74.125.137.128, 2607:f8b0:4023:c03::80, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.2.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3493862061 (3.3G) [application/octet-stream]
Saving to: ‘/content/drive/MyDrive/KDD/ALBEF//ALBEF.pth’


2022-06-22 04:08:44 (40.9 MB/s) - ‘/content/drive/MyDrive/KDD/ALBEF//ALBEF.pth’ saved [3493862061/3493862061]

--2022-06-22 04:08:44--  https://raw.githubusercontent.com/salesforce/ALBEF/main/models/vit.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8558 (8.4K) [text/plain]
Saving to: ‘/conte

In [30]:
# replace all occurrences of tokenizer_class with processor_class in xbert.py to make it compatible with newer transformers version
# if you don't do this step, you will need to install transformers==4.8.1 as specified by the requirements in the ALBEF repo

!sed -i 's/tokenizer_class/processor_class/g' $ALBEF_FOLDER/xbert.py

In [31]:
# add path to downloaded ALBEF files
import sys
sys.path.append(ALBEF_FOLDER)

#import libraries required for ALBEF
from vit import VisionTransformer
from xbert import BertConfig as AlbefBertConfig, BertModel as AlbefBertModel
from functools import partial

### Dataset

In [32]:
class AlbefDataset(Dataset):
    def __init__(self, df, label_to_id, train=False, text_field="text", label_field="label", image_path_field="img_path"):
        self.df = df.reset_index()
        self.label_to_id = label_to_id
        self.train = train
        self.text_field = text_field
        self.label_field = label_field
        self.image_path_field = image_path_field

        # ALBEF settings
        self.img_size = 256
        self.mean, self.std = (
            0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)


        self.train_transform_func = transforms.Compose(
                [transforms.RandomResizedCrop(self.img_size, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std)
                    ])

        self.eval_transform_func = transforms.Compose(
                [transforms.Resize(256),
                    transforms.CenterCrop(self.img_size),
                    transforms.ToTensor(),
                    transforms.Normalize(self.mean, self.std)
                    ])

                    
    def __getitem__(self, index):
        text = str(self.df.at[index, self.text_field])
        label = self.label_to_id[self.df.at[index, self.label_field]]
        img_path = IMAGE_FOLDER + self.df.at[index, self.image_path_field]

        
        image = Image.open(img_path)
        if self.train:
          img = self.train_transform_func(image)
        else:
          img = self.eval_transform_func(image)

        return text, label, img

    def __len__(self):
        return self.df.shape[0]

### Model

In [33]:
class AlbefModel(nn.Module):

    def __init__(self, bert_config, num_labels, text_pretrained='bert-base-uncased'):
        super().__init__()

        self.num_labels = num_labels
        self.bert_model = AlbefBertModel.from_pretrained(
            text_pretrained, config=bert_config, add_pooling_layer=False)

        self.image_model = VisionTransformer(
            img_size=256, patch_size=16, embed_dim=768, depth=12, num_heads=12,
            mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))

        self.classifier = nn.Linear(
            self.bert_model.config.hidden_size, num_labels)
        
    
    def forward(self, text, image):
        image_embeds = self.image_model(image)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
        output = self.bert_model(text.input_ids, attention_mask=text.attention_mask,
                                   encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True
                                   )
        logits = self.classifier(output.pooler_output)
        return logits

In [34]:
# intialise ALBEF model

albef_bert_config_fp = os.path.join(ALBEF_FOLDER, 'config_bert.json')
albef_bert_config = AlbefBertConfig.from_json_file(albef_bert_config_fp)

albef_model_fp = os.path.join(ALBEF_FOLDER, 'ALBEF.pth')
albef_model = AlbefModel(bert_config=albef_bert_config, num_labels=num_out_labels)

albef_checkpoint = torch.load(albef_model_fp, map_location=device)
albef_state_dict = albef_checkpoint['model']

for key in list(albef_state_dict.keys()):
    if 'bert' in key:
        encoder_key = key.replace('bert.', '')
        albef_state_dict[encoder_key] = albef_state_dict[key]
        del albef_state_dict[key]

msg = albef_model.load_state_dict(albef_state_dict, strict=False)
print("ALBEF checkpoint loaded from ", albef_model_fp)
print(msg)

ALBEF checkpoint loaded from  /content/drive/MyDrive/KDD/ALBEF/ALBEF.pth
_IncompatibleKeys(missing_keys=['bert_model.embeddings.position_ids', 'bert_model.embeddings.word_embeddings.weight', 'bert_model.embeddings.position_embeddings.weight', 'bert_model.embeddings.token_type_embeddings.weight', 'bert_model.embeddings.LayerNorm.weight', 'bert_model.embeddings.LayerNorm.bias', 'bert_model.encoder.layer.0.attention.self.query.weight', 'bert_model.encoder.layer.0.attention.self.query.bias', 'bert_model.encoder.layer.0.attention.self.key.weight', 'bert_model.encoder.layer.0.attention.self.key.bias', 'bert_model.encoder.layer.0.attention.self.value.weight', 'bert_model.encoder.layer.0.attention.self.value.bias', 'bert_model.encoder.layer.0.attention.output.dense.weight', 'bert_model.encoder.layer.0.attention.output.dense.bias', 'bert_model.encoder.layer.0.attention.output.LayerNorm.weight', 'bert_model.encoder.layer.0.attention.output.LayerNorm.bias', 'bert_model.encoder.layer.0.intermediat

In [35]:
albef_model.to(device)

AlbefModel(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [36]:
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

### Training

In [None]:
## training loop


train_dataset = AlbefDataset(df=df_train, label_to_id=label_to_id, train=True, text_field='text', label_field='label', image_path_field='img_path')
train_sampler = RandomSampler(train_dataset)        
train_dataloader = DataLoader(dataset=train_dataset,
                    batch_size=batch_size, 
                    sampler=train_sampler)


t_total = len(train_dataloader) * num_train_epochs


optimizer = AdamW(albef_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)

tr_loss = 0.0

criterion = nn.CrossEntropyLoss()

albef_model.train()

start = perf_counter()
for epoch_num in trange(num_train_epochs, desc='Epochs'):
    epoch_total_loss = 0

    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc='Batch'):        
        b_text, b_labels, b_imgs = batch            
        b_inputs = bert_tokenizer(
            list(b_text), truncation=True, max_length=max_seq_length,
            return_tensors="pt", padding=True
        )
        
        b_labels = b_labels.to(device)
        b_imgs = b_imgs.to(device)
        b_inputs = b_inputs.to(device)

        albef_model.zero_grad()        
        b_logits = albef_model(text=b_inputs, image=b_imgs)
        
        loss = criterion(b_logits, b_labels)

        epoch_total_loss += loss.item()

        # Perform a backward pass to calculate the gradients
        loss.backward()


        optimizer.step()
        scheduler.step()
        
    tr_loss += epoch_total_loss
    avg_loss = epoch_total_loss/len(train_dataloader)


    print('epoch =', epoch_num)
    print('    epoch_loss =', epoch_total_loss)
    print('    avg_epoch_loss =', avg_loss)
    print('    learning rate =', optimizer.param_groups[0]["lr"])
end = perf_counter()
albef_training_time = end- start
print('Training completed in ', albef_training_time, 'seconds')

Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Batch:   0%|          | 0/50 [00:00<?, ?it/s]

### Testing

In [None]:
# testing loop

albef_prediction_results = []

test_dataset = AlbefDataset(df=df_test, label_to_id=label_to_id, train=False, text_field='text', label_field='label', image_path_field='img_path')
test_sampler = SequentialSampler(test_dataset)        
test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=batch_size, 
                            sampler=test_sampler)


for batch in tqdm(test_dataloader):
  albef_model.eval()

  b_text, b_labels, b_imgs = batch

  b_inputs = bert_tokenizer(list(b_text), truncation=True, max_length=max_seq_length, return_tensors="pt", padding=True)

  b_labels = b_labels.to(device)
  b_imgs = b_imgs.to(device)
  b_inputs = b_inputs.to(device)

  with torch.no_grad():
      b_logits = albef_model(text=b_inputs, image=b_imgs)
      b_logits = b_logits.detach().cpu()



  albef_prediction_results += torch.argmax(b_logits, dim=-1).tolist()

albef_prediction_labels = [id_to_label[p] for p in albef_prediction_results]


In [None]:
albef_class_report = classification_report(df_test['label'], albef_prediction_labels, output_dict=True)
albef_class_report['training_time (seconds)'] = albef_training_time

with open(RESULTS_FOLDER + 'albef_class_report.json', 'w') as f:
  json.dump(albef_class_report, f)

print(albef_class_report['accuracy'])
