In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize
from torchvision.io import read_image, ImageReadMode
import torch.nn.functional as F
import json
import pandas as pd
import random
from pathlib import Path
import cv2
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoImageProcessor, AutoModelForMaskedLM
from tqdm.auto import tqdm
from sklearn.metrics import f1_score, classification_report
from PIL import Image
import os
import gc
import time
import math

2024-06-06 15:07:00.148042: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-06 15:07:00.148152: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-06 15:07:00.442184: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7e7c4c30f3a0>

In [3]:
# Config File

class CFG:
    save_models = True
    init_weights = True
    tokeniser_model = 'csebuetnlp/banglabert'
    text_model = 'csebuetnlp/banglabert'
    image_model = 'google/vit-base-patch16-224-in21k'
    text_model_config = None
    image_model_config = None
    images_base_path = Path('/kaggle/input/flickr8k/Images')
    images_base_path_test = Path('/kaggle/input/flickr8k/Images')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    debug = False
    print_freq = 300
    apex = True # for faster training
    epochs = 15
    learning_rate = 2e-5  # for adam optimizer
    eps = 1e-6
    betas = (0.9, 0.999)  # for adam optimizer
    batch_size = 64
    max_len = 512
    weight_decay = 0.01  # for adam optimizer regulaization parameter
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    train = True
    num_class = 2  # Number of class in your dataset
    
    frozen_lm = False # True or False 
    fusion_mode = "co_attention" # "merged_attention", "co_attention"
    no_fusion_encoder = 2
    num_heads = 4
    
    mlp_hidden_size = 256
    mlp_hidden_layers = 0
    mlp_dropout = 0.1
    mlp_grad_clip = 1.0
    mlp_init_range = 0.2
    mlp_attn_dim = 256
    
    pretraining_obj = "itm+mlm+ucl"
    #"itm", "mlm", "mcl", "ucl", "itm+mlm", "mcl+ucl", "itm+mlm+mcl", "itm+mlm+mcl+ucl"

In [4]:
CFG.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
CFG.text_model_config = AutoConfig.from_pretrained(CFG.text_model) if not 'M-CLIP' in CFG.text_model else None
CFG.image_model_config = AutoConfig.from_pretrained(CFG.image_model) if not 'M-CLIP' in CFG.text_model else None

config.json:   0%|          | 0.00/586 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [6]:
CFG.tokenizer = AutoTokenizer.from_pretrained(CFG.tokeniser_model, padding=True, truncation=True)
CFG.processor = AutoImageProcessor.from_pretrained(CFG.image_model)
CFG.text_model = AutoModelForMaskedLM.from_pretrained(CFG.text_model).to(CFG.device)
CFG.image_model = AutoModel.from_pretrained(CFG.image_model).to(CFG.device)

tokenizer_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/528k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of ElectraForMaskedLM were not initialized from the model checkpoint at csebuetnlp/banglabert and are newly initialized: ['generator_lm_head.bias', 'generator_predictions.LayerNorm.bias', 'generator_predictions.LayerNorm.weight', 'generator_predictions.dense.bias', 'generator_predictions.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [7]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(CFG.seed)

In [8]:
df = pd.read_csv('/kaggle/input/bancap/BAN-Cap_captiondata.csv')
df.head()

Unnamed: 0,caption_id,english_caption,bengali_caption
0,1000268201_693b08cb0e.jpg#0,A child in a pink dress is climbing up a set o...,একটি গোলাপী জামা পরা বাচ্চা মেয়ে একটি বাড়ির প্...
1,1000268201_693b08cb0e.jpg#1,A girl going into a wooden building .,একটি মেয়ে শিশু একটি কাঠের বাড়িতে ঢুকছে
2,1000268201_693b08cb0e.jpg#2,A little girl climbing into a wooden playhouse .,একটি বাচ্চা তার কাঠের খেলাঘরে উঠছে ।
3,1000268201_693b08cb0e.jpg#3,A little girl climbing the stairs to her playh...,ছোট মেয়েটি তার খেলার ঘরের সিড়ি বেয়ে উঠছে
4,1000268201_693b08cb0e.jpg#4,A little girl in a pink dress going into a woo...,গোলাপি জামা পড়া ছোট একটি মেয়ে একটি কাঠের তৈরি...


In [9]:
df[['image', 'caption_number']] = df['caption_id'].str.extract(r'(.+)#(\d+)$')
df.head()

Unnamed: 0,caption_id,english_caption,bengali_caption,image,caption_number
0,1000268201_693b08cb0e.jpg#0,A child in a pink dress is climbing up a set o...,একটি গোলাপী জামা পরা বাচ্চা মেয়ে একটি বাড়ির প্...,1000268201_693b08cb0e.jpg,0
1,1000268201_693b08cb0e.jpg#1,A girl going into a wooden building .,একটি মেয়ে শিশু একটি কাঠের বাড়িতে ঢুকছে,1000268201_693b08cb0e.jpg,1
2,1000268201_693b08cb0e.jpg#2,A little girl climbing into a wooden playhouse .,একটি বাচ্চা তার কাঠের খেলাঘরে উঠছে ।,1000268201_693b08cb0e.jpg,2
3,1000268201_693b08cb0e.jpg#3,A little girl climbing the stairs to her playh...,ছোট মেয়েটি তার খেলার ঘরের সিড়ি বেয়ে উঠছে,1000268201_693b08cb0e.jpg,3
4,1000268201_693b08cb0e.jpg#4,A little girl in a pink dress going into a woo...,গোলাপি জামা পড়া ছোট একটি মেয়ে একটি কাঠের তৈরি...,1000268201_693b08cb0e.jpg,4


In [10]:
if CFG.debug:
    print(df.shape)
    df = df.sample(frac=0.05)
    print(df.shape)
    CFG.print_freq = 50
    CFG.epochs = 2
    

In [11]:
class BanCapDataset(Dataset):
    def __init__(self, features, img_dir, processor, tokenizer, img_transform=None, caption_transform=None, target_transform=None):
        self.features = features
        self.img_dir = img_dir
        self.processor = processor
        self.tokenizer = tokenizer
        self.img_transform = img_transform
        self.caption_transform = caption_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        img_path = str(self.img_dir.joinpath(self.features['image'].iloc[idx]))
        image = read_image(img_path, mode=ImageReadMode.RGB).to(device=CFG.device)
        caption = self.features['bengali_caption'].iloc[idx]
        identity = self.features['image'].iloc[idx]
        
        if self.img_transform:
            image = self.img_transform(image)
        if self.caption_transform:
            caption = self.caption_transform(caption)
        if self.target_transform:
            label = self.target_transform(label)
            
        processed_img = self.processor(images=image, return_tensors="pt")
        image = processed_img['pixel_values']
        
        processed_txt = self.tokenizer.encode_plus(
            caption,
            padding='longest',
            truncation=True,
            return_tensors='pt'
        )
        seq = processed_txt['input_ids']
        mask = processed_txt['attention_mask']
        
        return identity, image, seq, mask

In [12]:
class Collator(object):
    def __init__(self, test=False):
        self.test = test
    def __call__(self, batch):
        ids, images, seqs, masks = zip(*batch)

        seqs = [seq.squeeze(dim=0) for seq in seqs]
        masks = [mask.squeeze(dim=0) for mask in masks]
        images = [image.squeeze(dim=0) for image in images]

        seqs = nn.utils.rnn.pad_sequence(seqs, batch_first=True)
        masks = nn.utils.rnn.pad_sequence(masks, batch_first=True)

        images = torch.stack(images)
        
        return ids, images, seqs, masks

In [13]:
resizer = Resize((224, 224), antialias=True)

def resize_images(img_tensor):
    return resizer(img_tensor)

In [14]:
def dataframe_train_test_split(df, target_label=None, seed=CFG.seed, test_size=0.2, split_labels=True):
    train = df.sample(frac=(1.0 - test_size), random_state=seed)
    test = df.drop(train.index).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    
    train.reset_index(drop=True, inplace=True)

    if split_labels:
        return train.drop(columns=target_label), test.drop(columns=target_label), train[target_label], test[target_label]
    else:
        return train, test

# if CFG.debug:
#     train_df = df.head(350)
#     val_df = df.tail(55)
# else:
train_df, val_df = dataframe_train_test_split(df, test_size=0.2, seed=CFG.seed, split_labels=False)
print(train_df.shape)
print(val_df.shape)

(32364, 5)
(8091, 5)


In [15]:
train_dataset = BanCapDataset(train_df, CFG.images_base_path, 
                              CFG.processor, CFG.tokenizer, 
                              img_transform=resize_images)

val_dataset = BanCapDataset(val_df, CFG.images_base_path, 
                            CFG.processor, CFG.tokenizer, 
                            img_transform=resize_images)

print(len(train_dataset))
print(len(val_dataset))

32364
8091


In [16]:
class ITMHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.head = nn.Linear(hidden_size, 2, dtype=torch.float16)
        
    def forward(self, embedding):
        x = self.head(embedding)
        
        return x

In [17]:

import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, input_dim):
        super(CrossAttention, self).__init__()
        self.query_linear = nn.Linear(input_dim, input_dim)
        self.key_linear = nn.Linear(input_dim, input_dim)
        self.value_linear = nn.Linear(input_dim, input_dim)

    def forward(self, query, key, value):
        query_proj = self.query_linear(query)
        key_proj = self.key_linear(key)
        value_proj = self.value_linear(value)

        # Calculate attention scores
        scores = torch.matmul(query_proj, key_proj.transpose(-2, -1))
        attention_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(attention_weights, value_proj)
        return output
    

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadCrossAttention, self).__init__()
        assert input_dim % num_heads == 0, "input_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads
        
        # Linear transformations for queries, keys, and values for each head
        self.query_linear = nn.Linear(input_dim, input_dim)
        self.key_linear = nn.Linear(input_dim, input_dim)
        self.value_linear = nn.Linear(input_dim, input_dim)
        
        # Final linear transformation after concatenating heads
        self.out_linear = nn.Linear(input_dim, input_dim)

    def forward(self, query, key, value):
        batch_size = query.size(0)
        
        # Apply linear transformations for queries, keys, and values for each head
        query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
        
        # Calculate attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        output = torch.matmul(attention_weights, value)  # (batch_size, num_heads, seq_len, head_dim)
        
        # Concatenate heads and perform final linear transformation
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)  # (batch_size, seq_len, input_dim)
        output = self.out_linear(output)
        return output
    
    
class CrossAttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, dropout):
        super(CrossAttentionLayer, self).__init__()
        self.self_attention = nn.MultiheadAttention(input_dim, num_heads)
        self.cross_attention = MultiHeadCrossAttention(input_dim, num_heads)
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory):
        # Self-attention
        x, _ = self.self_attention(x, x, x)

        # Cross-attention
        x = self.cross_attention(x, memory, memory)

        x = self.linear2(F.relu(self.linear1(x)))
        x = self.dropout(x)
        return x

class CrossAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout = 0.05):
        super(CrossAttentionEncoder, self).__init__()
        self.layers = nn.ModuleList([CrossAttentionLayer(input_dim, hidden_dim, num_heads, dropout) 
                                     for _ in range(num_layers)])

    def forward(self, x, memory):
        for layer in self.layers:
            x = layer(x, memory)
        return x
    
    
    
hudai = CrossAttentionEncoder(768, 768, 2, 2)
x = torch.rand(32, 10, 768)
memory = torch.rand(32, 20, 768)

output = hudai(x, memory)
output.shape

# multihead_attn = nn.MultiheadAttention(768, 2)
# cross = MultiHeadCrossAttention(768, 2)
# attn_output, attn_output_weights = multihead_attn(x, x, x)
# attn_output = cross(x, memory, memory)
# attn_output.shape

torch.Size([32, 10, 768])

In [18]:

class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings

In [19]:

class BanCAP_Pretraining(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.itm = ITMHead(self.cfg.text_model_config.hidden_size)
        self.text_model = self.cfg.text_model
        self.image_model = self.cfg.image_model
        
        self.hidden_size = 768
        self.v_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.l_proj = torch.nn.Linear(self.hidden_size, self.hidden_size)
        
        if self.cfg.fusion_mode == "merged_attention":
            self.fusion_encoder = nn.TransformerEncoderLayer(d_model=self.hidden_size, 
                                                        nhead=self.cfg.num_heads)
            self.fusion_layers = nn.TransformerEncoder(self.fusion_encoder, 
                                                  num_layers=self.cfg.no_fusion_encoder)
        
        elif self.cfg.fusion_mode == "co_attention":
            self.pool = MeanPooling()
            self.visual_cross_encoder = CrossAttentionEncoder(self.hidden_size, self.hidden_size,
                                                        self.cfg.no_fusion_encoder,
                                                        self.cfg.num_heads)
            
            self.language_cross_encoder = CrossAttentionEncoder(self.hidden_size, self.hidden_size,
                                                        self.cfg.no_fusion_encoder,
                                                        self.cfg.num_heads)
        
        
        self.text_mlp = nn.Linear(self.hidden_size, self.hidden_size)
        self.vision_mlp = nn.Linear(self.hidden_size, self.hidden_size)
        
        if self.cfg.frozen_lm:
            self._frozen_lm()
            
    def _frozen_lm(self):
        for param in self.text_model.parameters():
            param.requires_grad = False
            
        for param in self.image_model.parameters():
            param.requires_grad = False
            

    def forward(self, input_ids, attention_mask, image):
        
        text_output = self.text_model(input_ids=input_ids, 
                                      attention_mask=attention_mask,
                                      output_hidden_states=True)
        image_output = self.image_model(image)

        visual_features = image_output.last_hidden_state
        text_features = text_output.hidden_states[-1]
        
        visual_embeds = visual_features[:, 0, :]
        text_embeds = text_features[:, 0, :]
        
        visual_poj = self.v_proj(visual_features)
        language_proj = self.l_proj(text_features)
  
        if self.cfg.fusion_mode == "co_attention":
            vision_final = self.visual_cross_encoder(visual_poj, language_proj)
            language_final = self.language_cross_encoder(language_proj, visual_poj)
            
            vision_final = torch.mean(vision_final, dim = 1)
            language_final = self.pool(language_final, attention_mask)
            
            vision_final = self.vision_mlp(vision_final)
            language_final = self.text_mlp(language_final)
            
        elif self.cfg.fusion_mode == "merged_attention":
            merged_embed = torch.cat((language_proj, visual_poj), dim = 1)
            merged_attention_features = self.fusion_layers(merged_embed)
            
            text_embed_len = attention_mask.size(1)
            language_final = merged_attention_features[:, :text_embed_len, : ]
            vision_final = merged_attention_features[:, text_embed_len:, : ]
            
            vision_final = self.vision_mlp(vision_final)
            language_final = self.text_mlp(language_final)
        
        return language_final, vision_final


In [20]:
# For now, only consider similarity between embeds and not features
def itm_loss_func(model, text_embeds, image_embeds):
    batch_size = text_embeds.shape[0]
    
    # Dims are (bs, hidden_size)
    # Do we need to call no_grad here?
    with torch.no_grad():
        sim_i2t = image_embeds @ text_embeds.t()  # No temp yet
        sim_t2i = text_embeds @ image_embeds.t()
        
        weights_i2t = F.softmax(sim_i2t[:, :batch_size], dim=1) + 1e-4
        weights_t2i = F.softmax(sim_t2i[:, :batch_size], dim=1) + 1e-4
    
        weights_i2t.fill_diagonal_(0)
        weights_t2i.fill_diagonal_(0)
        
    # Fake images for valid text
    image_embeds_neg = []
    for b in range(batch_size):
        neg_idx = torch.multinomial(weights_t2i[b], 1).item()
        image_embeds_neg.append(image_embeds[neg_idx])
    image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
    
    # Fake text for valid images
    text_embeds_neg = []
    for b in range(batch_size):
        neg_idx = torch.multinomial(weights_i2t[b], 1).item()
        text_embeds_neg.append(text_embeds[neg_idx])
    text_embeds_neg = torch.stack(text_embeds_neg, dim=0)
    
    # The layout will be valid text-image pairs, valid text but fake image pairs and valid image but fake text pairs
    image_embeds_all = torch.cat([image_embeds, image_embeds_neg, image_embeds], dim=0)
    text_embeds_all = torch.cat([text_embeds, text_embeds, text_embeds_neg], dim=0)
    
    # 1 for every valid pair from the original batch
    # First zeroed batch_size entries for t2i and second zeroed batch_size entries for i2t
    itm_labels = torch.cat([torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], dim=0).to(CFG.device)
    itm_features = torch.stack([image_embeds_all, text_embeds_all], dim=1)
    itm_features = torch.mean(itm_features, dim=1)
 
    itm_preds = model.itm(itm_features)
    
    itm_loss = F.cross_entropy(itm_preds, itm_labels)
    
    return itm_loss

In [21]:
def mlm_loss_func(text_model, old_seq, mask):
    seq = old_seq.clone() 
    mlm_labels = seq.clone()
    
    mask_idx = torch.where(torch.rand(mlm_labels.shape) < .15)
    seq[mask_idx] = CFG.tokenizer.mask_token_id
    seq[mask == CFG.tokenizer.pad_token_id] = CFG.tokenizer.pad_token_id
    
    output = text_model(input_ids=seq, attention_mask=mask, labels=mlm_labels)
    mlm_loss = output.loss
    
    return mlm_loss

In [22]:
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

def clip_loss_func(model, text_embeds, image_embeds):        
    batch_size = text_embeds.shape[0]
    
    logits = text_embeds @ image_embeds.T
    sim_i2i = image_embeds @ image_embeds.T
    sim_t2t = text_embeds @ text_embeds.T
    
    targets = F.softmax((sim_i2i + sim_t2t) / 2, dim=-1)
    text_loss = cross_entropy(logits, targets)
    image_loss = cross_entropy(logits.T, targets.T)
    
    both_loss = (text_loss + image_loss) / 2.0
    clip_loss = both_loss.mean()
    
    return clip_loss

In [23]:
def ucl_loss_func(model, text_embeds, image_embeds):
    with torch.no_grad():
        sim_i2t = image_embeds @ text_embeds.t() # only used for finding the size
        sim_targets = torch.zeros(sim_i2t.size())
        sim_targets.fill_diagonal_(1)

    sim_i2i = image_embeds @ image_embeds.T
    sim_t2t = text_embeds @ text_embeds.T
    sim_targets = sim_targets.to(CFG.device)

    loss_i2i = -torch.sum(F.log_softmax(sim_i2i, dim=1) * sim_targets, dim=1).mean()
    loss_t2t = -torch.sum(F.log_softmax(sim_t2t, dim=1) * sim_targets, dim=1).mean()
    
    ucl_loss = (loss_i2i + loss_t2t) / 2
    
    return ucl_loss

In [24]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


In [25]:
def train_loop(model, optimizer, train_dataloader, epoch):
    all_image_embeds = []
    all_text_embeds = []
    model.train()
    train_losses = AverageMeter()
    start = end = time.time()
    
    for step, (_, image, seq, mask) in enumerate(tqdm(train_dataloader)):
        
        train_image = image.to(CFG.device)
        train_seq = seq.to(CFG.device)
        train_mask = mask.to(CFG.device)
        
        batch_size = train_image.shape[0]

        with torch.cuda.amp.autocast(enabled=CFG.apex):
            text_embeds, image_embeds = model(train_seq, train_mask, train_image)
        
        losses = []
        if 'itm' in CFG.pretraining_obj:
            itm = itm_loss_func(model.module, text_embeds, image_embeds)
            losses.append(itm)
        if 'mlm' in CFG.pretraining_obj:
            mlm = mlm_loss_func(model.module.text_model, train_seq, train_mask)  # Access submodule when using DataParallel
            losses.append(mlm)
        if 'mcl' in CFG.pretraining_obj:
            mcl = clip_loss_func(model.module, text_embeds, image_embeds)
            losses.append(mcl)
        if 'ucl' in CFG.pretraining_obj:
            ucl = ucl_loss_func(model.module, text_embeds, image_embeds)
            losses.append(ucl)

        if not losses:
            raise ValueError("No loss functions specified or invalid CFG.pretraining_obj")

        loss = sum(losses)
        
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), CFG.mlp_grad_clip)
        optimizer.step()
        
        train_losses.update(loss.item(), batch_size)
            
        end = time.time()

        if step % CFG.print_freq == 0 or step == (len(train_dataloader) - 1):
            print(f'Epoch: [{epoch + 1}][{step}/{len(train_dataloader)}] '
                  f'Elapsed {timeSince(start, float(step + 1) / len(train_dataloader)):s} '
                  f'Loss: {train_losses.val:.4f} ({train_losses.avg:.4f}) ')
        
        if step % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()

    return train_losses.avg, all_text_embeds, all_image_embeds

In [26]:
def validation_loop(model, valid_dataloader, epoch):
    model.eval()
    validation_losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    
    for step, (_, image, seq, mask) in enumerate(tqdm(valid_dataloader)):
        
        image = image.to(device=CFG.device)
        seq = seq.to(device=CFG.device)
        mask = mask.to(device=CFG.device)
        
        batch_size = image.shape[0]

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=CFG.apex):
                text_embeds, image_embeds = model(seq, mask, image)
            
        losses = []
        if 'itm' in CFG.pretraining_obj:
            itm = itm_loss_func(model.module, text_embeds, image_embeds)
            losses.append(itm)
        if 'mlm' in CFG.pretraining_obj:
            mlm = mlm_loss_func(model.module.text_model, seq, mask)  # Access submodule when using DataParallel
            losses.append(mlm)
        if 'mcl' in CFG.pretraining_obj:
            mcl = clip_loss_func(model.module, text_embeds, image_embeds)
            losses.append(mcl)
        if 'ucl' in CFG.pretraining_obj:
            ucl = ucl_loss_func(model.module, text_embeds, image_embeds)
            losses.append(ucl)

        if not losses:
            raise ValueError("No loss functions specified or invalid CFG.pretraining_obj")

        loss = sum(losses)
        
        validation_losses.update(loss.item(), batch_size)
            
        end = time.time()

        if step % CFG.print_freq == 0 or step == (len(valid_dataloader) - 1):
            print(f'Epoch: [{epoch + 1}][{step}/{len(valid_dataloader)}] '
                  f'Elapsed {timeSince(start, float(step + 1) / len(valid_dataloader)):s} '
                  f'Loss: {validation_losses.val:.4f} ({validation_losses.avg:.4f})')
        
        if step % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()
        
        
    return validation_losses.avg

In [27]:
# ====================================================
# Logger File
# ====================================================

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report

OUTPUT_DIR = "./"
def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

CFG.model = CFG.fusion_mode + "__" + CFG.pretraining_obj
print(CFG.model)

co_attention__itm+mlm+ucl


In [28]:
collate = Collator()
train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, collate_fn=collate)
valid_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, collate_fn=collate)

model = BanCAP_Pretraining(CFG).to(CFG.device)
model = nn.DataParallel(model)

optim = torch.optim.AdamW(model.parameters(), lr=CFG.learning_rate, eps=CFG.eps, betas=CFG.betas)

best_loss = float('inf')

for epoch in range(CFG.epochs):

    start_time = time.time()

    # train function 
    avg_train_loss, all_text_embeds, all_image_embeds = train_loop(model, optim, train_dataloader, epoch)

    # val function 
    avg_val_loss = validation_loop(model, valid_dataloader, epoch)

    elapsed = time.time() - start_time

    LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_train_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
        torch.save({'model': model.state_dict()},
                    OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_loss{best_loss:.4f}_best.pth")

    torch.cuda.empty_cache()
    gc.collect()

  0%|          | 0/506 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Epoch: [1][0/506] Elapsed 0m 9s (remain 81m 1s) Loss: 17.1748 (17.1748) 
Epoch: [1][300/506] Elapsed 18m 59s (remain 12m 55s) Loss: 3.1646 (6.6401) 
Epoch: [1][505/506] Elapsed 31m 10s (remain 0m 0s) Loss: 1.9983 (4.9665) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [1][0/127] Elapsed 0m 1s (remain 2m 39s) Loss: 2.0135 (2.0135)


Epoch 1 - avg_train_loss: 4.9665  avg_val_loss: 2.0339  time: 2023s
Epoch 1 - Save Best Loss: 2.0339 Model


Epoch: [1][126/127] Elapsed 2m 32s (remain 0m 0s) Loss: 1.4494 (2.0339)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [2][0/506] Elapsed 0m 3s (remain 31m 38s) Loss: 2.3163 (2.3163) 
Epoch: [2][300/506] Elapsed 17m 53s (remain 12m 11s) Loss: 1.7162 (2.0087) 
Epoch: [2][505/506] Elapsed 30m 0s (remain 0m 0s) Loss: 1.6871 (1.9392) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [2][0/127] Elapsed 0m 1s (remain 2m 42s) Loss: 1.6754 (1.6754)


Epoch 2 - avg_train_loss: 1.9392  avg_val_loss: 1.6221  time: 1951s
Epoch 2 - Save Best Loss: 1.6221 Model


Epoch: [2][126/127] Elapsed 2m 30s (remain 0m 0s) Loss: 1.2174 (1.6221)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [3][0/506] Elapsed 0m 3s (remain 31m 36s) Loss: 1.8622 (1.8622) 
Epoch: [3][300/506] Elapsed 17m 44s (remain 12m 5s) Loss: 1.7048 (1.7585) 
Epoch: [3][505/506] Elapsed 29m 52s (remain 0m 0s) Loss: 1.6071 (1.7281) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [3][0/127] Elapsed 0m 1s (remain 2m 40s) Loss: 1.6422 (1.6422)


Epoch 3 - avg_train_loss: 1.7281  avg_val_loss: 1.5854  time: 1944s
Epoch 3 - Save Best Loss: 1.5854 Model


Epoch: [3][126/127] Elapsed 2m 31s (remain 0m 0s) Loss: 1.3255 (1.5854)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [4][0/506] Elapsed 0m 3s (remain 31m 58s) Loss: 1.5769 (1.5769) 
Epoch: [4][300/506] Elapsed 17m 47s (remain 12m 6s) Loss: 1.5602 (1.6402) 
Epoch: [4][505/506] Elapsed 29m 55s (remain 0m 0s) Loss: 1.6347 (1.6194) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [4][0/127] Elapsed 0m 1s (remain 2m 43s) Loss: 1.4390 (1.4390)


Epoch 4 - avg_train_loss: 1.6194  avg_val_loss: 1.4700  time: 1946s
Epoch 4 - Save Best Loss: 1.4700 Model


Epoch: [4][126/127] Elapsed 2m 30s (remain 0m 0s) Loss: 1.1141 (1.4700)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [5][0/506] Elapsed 0m 3s (remain 31m 29s) Loss: 1.4921 (1.4921) 
Epoch: [5][300/506] Elapsed 17m 45s (remain 12m 5s) Loss: 1.6436 (1.5826) 
Epoch: [5][505/506] Elapsed 29m 46s (remain 0m 0s) Loss: 1.3248 (1.5699) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [5][0/127] Elapsed 0m 1s (remain 2m 36s) Loss: 1.9571 (1.9571)


Epoch 5 - avg_train_loss: 1.5699  avg_val_loss: 1.8042  time: 1934s


Epoch: [5][126/127] Elapsed 2m 27s (remain 0m 0s) Loss: 1.6368 (1.8042)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [6][0/506] Elapsed 0m 3s (remain 31m 17s) Loss: 1.5778 (1.5778) 
Epoch: [6][300/506] Elapsed 17m 42s (remain 12m 3s) Loss: 1.5032 (1.5352) 
Epoch: [6][505/506] Elapsed 29m 51s (remain 0m 0s) Loss: 1.3835 (1.5292) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [6][0/127] Elapsed 0m 1s (remain 2m 43s) Loss: 1.5163 (1.5163)


Epoch 6 - avg_train_loss: 1.5292  avg_val_loss: 1.6177  time: 1942s


Epoch: [6][126/127] Elapsed 2m 30s (remain 0m 0s) Loss: 1.4961 (1.6177)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [7][0/506] Elapsed 0m 3s (remain 29m 45s) Loss: 1.6579 (1.6579) 
Epoch: [7][300/506] Elapsed 17m 49s (remain 12m 8s) Loss: 1.6192 (1.4726) 
Epoch: [7][505/506] Elapsed 29m 59s (remain 0m 0s) Loss: 1.1277 (1.4384) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [7][0/127] Elapsed 0m 1s (remain 2m 45s) Loss: 1.2563 (1.2563)


Epoch 7 - avg_train_loss: 1.4384  avg_val_loss: 1.1862  time: 1951s
Epoch 7 - Save Best Loss: 1.1862 Model


Epoch: [7][126/127] Elapsed 2m 30s (remain 0m 0s) Loss: 1.0160 (1.1862)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [8][0/506] Elapsed 0m 3s (remain 32m 2s) Loss: 1.1850 (1.1850) 
Epoch: [8][300/506] Elapsed 17m 49s (remain 12m 8s) Loss: 1.1237 (1.1548) 
Epoch: [8][505/506] Elapsed 29m 57s (remain 0m 0s) Loss: 1.1414 (1.1270) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [8][0/127] Elapsed 0m 1s (remain 2m 42s) Loss: 1.0535 (1.0535)


Epoch 8 - avg_train_loss: 1.1270  avg_val_loss: 1.0739  time: 1947s
Epoch 8 - Save Best Loss: 1.0739 Model


Epoch: [8][126/127] Elapsed 2m 29s (remain 0m 0s) Loss: 0.9623 (1.0739)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [9][0/506] Elapsed 0m 3s (remain 31m 40s) Loss: 1.1908 (1.1908) 
Epoch: [9][300/506] Elapsed 17m 38s (remain 12m 0s) Loss: 1.1409 (1.0077) 
Epoch: [9][505/506] Elapsed 29m 39s (remain 0m 0s) Loss: 1.1243 (1.0066) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [9][0/127] Elapsed 0m 1s (remain 2m 43s) Loss: 1.2262 (1.2262)


Epoch 9 - avg_train_loss: 1.0066  avg_val_loss: 1.2582  time: 1929s


Epoch: [9][126/127] Elapsed 2m 28s (remain 0m 0s) Loss: 0.9986 (1.2582)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [10][0/506] Elapsed 0m 3s (remain 31m 46s) Loss: 0.9113 (0.9113) 
Epoch: [10][300/506] Elapsed 17m 45s (remain 12m 5s) Loss: 1.0265 (1.0396) 
Epoch: [10][505/506] Elapsed 29m 54s (remain 0m 0s) Loss: 0.9181 (1.0211) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [10][0/127] Elapsed 0m 1s (remain 2m 41s) Loss: 1.2200 (1.2200)


Epoch 10 - avg_train_loss: 1.0211  avg_val_loss: 1.2022  time: 1945s


Epoch: [10][126/127] Elapsed 2m 30s (remain 0m 0s) Loss: 0.9137 (1.2022)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [11][0/506] Elapsed 0m 3s (remain 29m 18s) Loss: 1.1308 (1.1308) 
Epoch: [11][300/506] Elapsed 17m 50s (remain 12m 9s) Loss: 1.0409 (1.0060) 
Epoch: [11][505/506] Elapsed 29m 59s (remain 0m 0s) Loss: 1.0472 (1.0040) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [11][0/127] Elapsed 0m 1s (remain 2m 53s) Loss: 1.3597 (1.3597)


Epoch 11 - avg_train_loss: 1.0040  avg_val_loss: 1.3803  time: 1952s


Epoch: [11][126/127] Elapsed 2m 32s (remain 0m 0s) Loss: 1.2262 (1.3803)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [12][0/506] Elapsed 0m 3s (remain 31m 52s) Loss: 1.0832 (1.0832) 
Epoch: [12][300/506] Elapsed 17m 44s (remain 12m 5s) Loss: 0.9267 (1.0055) 
Epoch: [12][505/506] Elapsed 29m 49s (remain 0m 0s) Loss: 1.0075 (1.0015) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [12][0/127] Elapsed 0m 1s (remain 2m 39s) Loss: 1.3489 (1.3489)


Epoch 12 - avg_train_loss: 1.0015  avg_val_loss: 1.3508  time: 1938s


Epoch: [12][126/127] Elapsed 2m 28s (remain 0m 0s) Loss: 1.1009 (1.3508)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [13][0/506] Elapsed 0m 3s (remain 31m 7s) Loss: 0.8523 (0.8523) 
Epoch: [13][300/506] Elapsed 17m 48s (remain 12m 7s) Loss: 1.0837 (0.9893) 
Epoch: [13][505/506] Elapsed 30m 2s (remain 0m 0s) Loss: 1.0681 (0.9860) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [13][0/127] Elapsed 0m 1s (remain 2m 40s) Loss: 1.4194 (1.4194)


Epoch 13 - avg_train_loss: 0.9860  avg_val_loss: 1.3645  time: 1953s


Epoch: [13][126/127] Elapsed 2m 30s (remain 0m 0s) Loss: 1.0723 (1.3645)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [14][0/506] Elapsed 0m 3s (remain 30m 11s) Loss: 0.8953 (0.8953) 
Epoch: [14][300/506] Elapsed 17m 55s (remain 12m 12s) Loss: 0.9351 (0.9759) 
Epoch: [14][505/506] Elapsed 30m 9s (remain 0m 0s) Loss: 1.0296 (0.9729) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [14][0/127] Elapsed 0m 1s (remain 2m 43s) Loss: 1.5826 (1.5826)


Epoch 14 - avg_train_loss: 0.9729  avg_val_loss: 1.5620  time: 1962s


Epoch: [14][126/127] Elapsed 2m 32s (remain 0m 0s) Loss: 1.1483 (1.5620)


  0%|          | 0/506 [00:00<?, ?it/s]

Epoch: [15][0/506] Elapsed 0m 3s (remain 31m 16s) Loss: 1.0913 (1.0913) 
Epoch: [15][300/506] Elapsed 17m 52s (remain 12m 10s) Loss: 0.9520 (0.9529) 
Epoch: [15][505/506] Elapsed 29m 59s (remain 0m 0s) Loss: 0.9694 (0.9536) 


  0%|          | 0/127 [00:00<?, ?it/s]

Epoch: [15][0/127] Elapsed 0m 1s (remain 2m 41s) Loss: 1.4852 (1.4852)


Epoch 15 - avg_train_loss: 0.9536  avg_val_loss: 1.5407  time: 1949s


Epoch: [15][126/127] Elapsed 2m 29s (remain 0m 0s) Loss: 1.2731 (1.5407)
