In [11]:
import pandas as pd
import numpy as np
from transformers import AutoModel, CLIPProcessor, AutoConfig, AutoTokenizer
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score
from PIL import Image

from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from lightning.pytorch.loggers import CSVLogger
import os
import torchvision.transforms as T
from torchvision.transforms import Compose

from datasets import Dataset, DatasetDict

SEED=1234542

pl.seed_everything(SEED, workers=True)

df_train=pd.read_csv('../../data/splitted/train.csv')
df_validation=pd.read_csv('../../data/splitted/validation.csv')
df_test=pd.read_csv('../../data/splitted/test.csv')

# Remove nan from caption column
df_train.fillna(value="", inplace=True)
df_validation.fillna(value="", inplace=True)
df_test.fillna(value="", inplace=True)

label_dict={0: 'Movies', 1: 'Sports', 2: 'Music', 3: 'Opinion', 4: 'Media', 5: 'Art & Design', 6: 'Theater', 7: 'Television', 8: 'Technology', 9: 'Economy', 10: 'Books', 11: 'Style', 12: 'Travel', 13: 'Health', 14: 'Real Estate', 15: 'Dance', 16: 'Science', 17: 'Fashion', 18: 'Well', 19: 'Food', 20: 'Your Money', 21: 'Education', 22: 'Automobiles', 23: 'Global Business'}

dataset = DatasetDict()
dataset['train'] = Dataset.from_pandas(df_train)
dataset['validation'] = Dataset.from_pandas(df_validation)
dataset['test'] = Dataset.from_pandas(df_test)

NUM_CLASSES= len(df_train['labels'].unique())

TEXT_CLIP='caption'

TRAIN_IMAGES_PATH= '../../images/train'
VALIDATION_IMAGES_PATH= '../../images/validation'
TEST_IMAGES_PATH= '../../images/test'

TEXT_TRANSF='text_no_cap'
MAX_LENGTH=512

Global seed set to 1234542


In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
CLIP_NAME = 'openai/clip-vit-base-patch32'
TRANSFORMER_NAME= 'microsoft/deberta-base'

clip_model = AutoModel.from_pretrained(CLIP_NAME)
clip_config= AutoConfig.from_pretrained(CLIP_NAME)
clip_processor= CLIPProcessor.from_pretrained(CLIP_NAME)

tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_NAME)
config= AutoConfig.from_pretrained(TRANSFORMER_NAME)
pretrained_model = AutoModel.from_pretrained(TRANSFORMER_NAME)

Some weights of the model checkpoint at microsoft/deberta-base were not used when initializing DebertaModel: ['lm_predictions.lm_head.bias', 'lm_predictions.lm_head.LayerNorm.bias', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.dense.weight']
- This IS expected if you are initializing DebertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
def tokenize(batch):
    tokens = tokenizer(batch[TEXT_TRANSF], truncation=True, max_length=MAX_LENGTH)
    batch['input_ids'], batch['attention_mask'] = tokens['input_ids'], tokens['attention_mask']
    return batch

dataset = dataset.map(tokenize)

dataset['train'] = dataset['train'].remove_columns(['headline', 'abstract', 'caption', 'image_url', 'article_url', 'image_id', 'body', 'full_text', 'text_no_cap', 'labels_text'])
dataset['validation'] = dataset['validation'].remove_columns(['headline', 'abstract', 'caption', 'image_url', 'article_url', 'image_id', 'body', 'full_text', 'text_no_cap', 'labels_text'])
dataset['test'] = dataset['test'].remove_columns(['headline', 'abstract', 'caption', 'image_url', 'article_url', 'image_id', 'body', 'full_text', 'text_no_cap', 'labels_text'])

100%|##########| 48180/48180 [02:47<00:00, 287.13ex/s]
100%|##########| 6022/6022 [00:20<00:00, 291.76ex/s]
100%|##########| 6023/6023 [00:20<00:00, 291.52ex/s]


In [14]:
class CustomMultimodalDataset(torch.utils.data.Dataset):
    def __init__(self, df, img_dir, ds):
        self.df= df
        self.img_dir = img_dir
        self.ds=ds
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        label_text= self.df['labels_text'].iloc[idx]
        img_path = os.path.join(self.img_dir, label_text, self.df['image_id'].iloc[idx])
        img_path=img_path + '.jpg'
        image = Image.open(img_path)
        if(image.mode != 'RGB'):
            image=image.convert('RGB')
        caption = self.df[TEXT_CLIP].iloc[idx]
        text_input_ids= self.ds[idx]['input_ids']
        text_attention_mask= self.ds[idx]['attention_mask']
        label = self.df['labels'].iloc[idx]
        return image, caption, text_input_ids, text_attention_mask, label
    
train_dataset= CustomMultimodalDataset(df_train, TRAIN_IMAGES_PATH, dataset['train'])
validation_dataset= CustomMultimodalDataset(df_validation, VALIDATION_IMAGES_PATH, dataset['validation'])
test_dataset= CustomMultimodalDataset(df_test, TEST_IMAGES_PATH, dataset['test'])

In [15]:
class MultimodalCollator:
    HARD_IMG_AUGMENTER = T.RandAugment(num_ops=6, magnitude=9)
    SOFT_IMG_AUGMENTER = Compose([T.RandomPerspective(.1, p=.5),
                                  T.RandomHorizontalFlip(p=.5),
                                ])
    
    def __init__(self, processor=clip_processor, augment_mode='hard', split='train', max_length=77):
        # 40 max length for vilt // 77 max length for clip
        self.processor = processor
        self.split = split
        self.max_length = max_length
        self.augment_mode = augment_mode

    def __call__(self, batch):
        images, captions, text_input_ids, text_attention_masks, labels = list(zip(*batch))
        if self.split=='train' and self.augment_mode == 'hard':
            images = [self.HARD_IMG_AUGMENTER(img) for img in images]
        elif self.split=='train' and self.augment_mode == 'soft':
            images = [self.SOFT_IMG_AUGMENTER(img) for img in images]
        
        # Pad text_input_ids and text_attention_masks
        max_length = max(len(ids) for ids in text_input_ids)
        padded_text_input_ids = [ids + [1] * (max_length - len(ids)) for ids in text_input_ids]
        padded_text_attention_masks = [masks + [0] * (max_length - len(masks)) for masks in text_attention_masks]

        encoding = self.processor(images=images, 
                                  text=list(captions), 
                                  padding=True,
                                  max_length=self.max_length,
                                  truncation=True,
                                  return_tensors='pt')
        encoding['text_input_ids'] = torch.tensor(padded_text_input_ids)
        encoding['text_attention_masks'] = torch.tensor(padded_text_attention_masks)
        encoding['labels']=torch.tensor(labels)
        return encoding

In [16]:
BATCH_SIZE=8

collator_train=MultimodalCollator(split='train')
collator_val=MultimodalCollator(split='val')
collator_test=MultimodalCollator(split='test')
train_loader = DataLoader(train_dataset, collate_fn=collator_train, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validation_dataset, collate_fn=collator_val, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, collate_fn=collator_test, batch_size=BATCH_SIZE)

In [18]:
class MultimodalClassifier(pl.LightningModule):
    def __init__(self, clip_model=clip_model, text_transformer= pretrained_model,  lr_transformer=2e-5, lr_heads=2e-3):
        super(MultimodalClassifier, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.lr_transformer=lr_transformer
        self.lr_heads=lr_heads
        # En el train hacemos media de medias
        self.train_loss=[]
        self.train_accs=[]
        self.train_f1s=[]
        
        
        # Aqui computamos las métricas con todo para mayor precision   
        self.val_loss=[]             
        self.all_val_y_true=[]
        self.all_val_y_pred=[]
        
        self.text_transformer = text_transformer
        self.clip_model = clip_model
        # Freeze CLIP model parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False
        
        # CLIP output
        self.clip_fc1 = nn.Linear(1024, 512)
        self.clip_activation1 = nn.GELU()
        self.output_clip = nn.Linear(256, NUM_CLASSES)
        
        # Transformer output
        self.transformer_fc1 = nn.Linear(768, 512)
        self.transformer_activation1 = nn.GELU()
        self.output_transformer = nn.Linear(512, NUM_CLASSES)
        
        # Se puede probar a concatenar el output de CLIP (512) con una proyección del output del transformer 
        # (768 proyectarlo a 512)
        
        self.fusion_fc1 = nn.Linear(1024, 512)
        self.fusion_activation1 = nn.GELU()
        self.fusion_output = nn.Linear(512, NUM_CLASSES)
        
    def compute_outputs(self, input_ids, attention_mask, pixel_values, text_input_ids, text_attention_masks):
        # Get CLIP embedding
        out_text=self.clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
        out_image=self.clip_model.get_image_features(pixel_values=pixel_values)
        
        combined_embed= torch.cat((out_text,out_image), dim=-1) # Concat
        
        clip_embed = self.clip_activation1(self.clip_fc1(combined_embed))
        
        output_3= self.output_clip(clip_embed)
        
        # Transformer embedding
        
        outputs = self.text_transformer(input_ids=text_input_ids, attention_mask=text_attention_masks)
        logits = outputs['last_hidden_state'][:, 0]  #Get the CLS tokens (deberta)
        # logits = outputs.pooler_output
        transformer_embed = self.transformer_activation1(self.transformer_fc1(logits))
        
        output_1= self.output_clip(transformer_embed)
        
        
        # Combine Transformer and CLIP and get output
        
        fusion_embed= torch.cat((transformer_embed,clip_embed), dim=-1) # Concat
        
        x = self.fusion_activation1(self.fusion_fc1(fusion_embed))
        
        return output_1, self.fusion_output(x), output_3
    
    def forward(self, batch):
        pixel_values = batch['pixel_values']
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        
        text_input_ids= batch['text_input_ids']
        text_attention_masks=batch['text_attention_masks']
        output_1, output_2, output_3 = self.compute_outputs(input_ids, attention_mask, pixel_values, text_input_ids, text_attention_masks)
        return output_1, output_2, output_3
    
    def training_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        
        text_input_ids= batch['text_input_ids']
        text_attention_masks=batch['text_attention_masks']
        
        labels = batch['labels']
        #Compute the output logits
        output_1, output_2, output_3 = self.compute_outputs(input_ids, attention_mask, pixel_values, text_input_ids, text_attention_masks)
        #Compute metrics
        loss_1=self.criterion(output_1,labels)
        loss_2=self.criterion(output_2,labels)
        loss_3=self.criterion(output_3,labels)
        total_loss= loss_1 + loss_2 + loss_3
        
        
        preds_2 = torch.argmax(output_2, dim=-1)
        
        # Acc and loss only with output 2 (fusion)
        acc=accuracy_score(y_true=labels.tolist(), y_pred=preds_2.tolist())
        f1=f1_score(y_true=labels.tolist(), y_pred=preds_2.tolist(), average='macro')
        self.train_loss.append(total_loss)
        self.train_accs.append(acc)
        self.train_f1s.append(f1)
        
        return total_loss
    
    def on_train_epoch_end(self):
        # outs is a list of whatever you returned in `validation_step`
        mean_loss = sum(self.train_loss)/len(self.train_loss)
        mean_acc=sum(self.train_accs)/len(self.train_accs)
        mean_f1=sum(self.train_f1s)/len(self.train_f1s)
        
        self.log("train_loss", mean_loss)
        self.log("train_acc", mean_acc)
        self.log("train_f1", mean_f1)
        
        self.train_loss=[]
        self.train_accs=[]
        self.train_f1s=[]
    
    
    def validation_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        
        text_input_ids= batch['text_input_ids']
        text_attention_masks=batch['text_attention_masks']
        labels = batch['labels']
        #Compute the output logits
        output_1, output_2, output_3 = self.compute_outputs(input_ids, attention_mask, pixel_values, text_input_ids, text_attention_masks)
        #Compute metrics
        loss_1=self.criterion(output_1,labels)
        loss_2=self.criterion(output_2,labels)
        loss_3=self.criterion(output_3,labels)
        total_loss= loss_1 + loss_2 + loss_3
        
        
        preds_2 = torch.argmax(output_2, dim=-1)
        
        # Acc and loss only with output 2 (fusion)
        acc=accuracy_score(y_true=labels.tolist(), y_pred=preds_2.tolist())
        f1=f1_score(y_true=labels.tolist(), y_pred=preds_2.tolist(), average='macro')
        self.train_loss.append(total_loss)
        self.train_accs.append(acc)
        self.train_f1s.append(f1)
        
        return total_loss
    
    def on_validation_epoch_end(self):
        # outs is a list of whatever you returned in `validation_step`
        mean_loss = sum(self.val_loss)/len(self.val_loss)
        
        acc= accuracy_score(y_true=self.all_val_y_true, y_pred=self.all_val_y_pred)
        f1= f1_score(y_true=self.all_val_y_true, y_pred=self.all_val_y_pred, average='macro')
        
        self.log("val_loss", mean_loss)
        self.log("val_acc", acc)
        self.log("val_f1", f1)
        
        self.val_loss=[]
        self.all_val_y_true=[]
        self.all_val_y_pred=[]
    
    def configure_optimizers(self):
        optimizer = optim.AdamW([
            {'params': self.text_transformer.parameters(), 'lr': self.lr_transformer},
            {'params': self.clip_fc1.parameters()},
            {'params': self.transformer_fc1.parameters()},
            {'params': self.fusion_fc1.parameters()},
            {'params': self.fusion_output.parameters()},
        ],lr=self.lr_heads, amsgrad=True, weight_decay=0.01)
        
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.1, patience=5)
        return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val_loss",
                    },
                }

In [19]:
experiment_name=f'MultiOutput'
# Define the callbacks
checkpoint_callback = ModelCheckpoint(
     dirpath='../../model_ckpts/Multimodal/CLIP+Transformer',
     filename=experiment_name,
     monitor='val_f1', mode='max')
lr_monitor = LearningRateMonitor(logging_interval='epoch')
early_stopping = EarlyStopping('val_f1', patience=15,mode='max')

# instantiate the logger object
logger = CSVLogger(save_dir="../../logs/Multimodal/CLIP+Transformer", name=experiment_name)
 

my_model=MultimodalClassifier()
trainer=pl.Trainer(accelerator="gpu", devices=[0], deterministic=True, max_epochs=60, logger=logger, precision='16-mixed', accumulate_grad_batches=2,
                   callbacks=[lr_monitor, early_stopping, checkpoint_callback])
trainer.fit(model=my_model,train_dataloaders=train_loader, val_dataloaders=validation_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                    | Type             | Params
-------------------------------------------------------------
0 | criterion               | CrossEntropyLoss | 0     
1 | text_transformer        | DebertaModel     | 138 M 
2 | clip_model              | CLIPModel        | 151 M 
3 | clip_fc1                | Linear           | 524 K 
4 | clip_activation1        | GELU             | 

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   7%|6         | 421/6023 [08:18<1:50:34,  1.18s/it, v_num=0]