In [None]:
# Code Overview
# Extracts text, vision, and multmodal embeddings from the ViLT and VisualBERT models

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# %% Importing Libraries
import os
import re
import sys
import argparse
import time
import datetime
import random
from pathlib import Path
from PIL import Image

import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, f1_score, classification_report

import torch
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning.loggers import WandbLogger

import lightning as L
import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

from transformers import ViltProcessor, ViltModel
from transformers import VisualBertModel, VisualBertConfig, BertTokenizer, ViTModel, ViTFeatureExtractor

# Custom library
sys.path.append('../process/')
from utilities import map_images_with_text, augment_image_training_data
from loadData import ViLTMultimodalDataset, MultimodalDataset

sys.path.append('../architectures/')
from viltLayer import ViLTClassifier
from visualBERTLayer import VisualBERTClassifier

import warnings
warnings.filterwarnings('ignore')

# Suppress TorchDynamo errors and fall back to eager execution
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [2]:
from tqdm import tqdm

In [3]:
import argparse
import os

class Args:
    """Encapsulates arguments into an object for script and notebook compatibility."""
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

def parse_arguments():
    parser = argparse.ArgumentParser(description="Trains a multimodal classifier (VisualBERT or ViLT) for Multimodal Authorship tasks on Backpage advertisements.")

    # Common arguments
    parser.add_argument('--logged_entry_name', type=str, default="multimodal-latent-fusion-seed:1111", help="Logged entry name visible on weights and biases")
    parser.add_argument('--data_dir', type=str, default='/workspace/persistent/HTClipper/data/processed', help="Data directory")
    parser.add_argument('--city', type=str, default='chicago', help="Demography of data, can be only between chicago, atlanta, houston, dallas, detroit, ny, sf or all")
    parser.add_argument('--batch_size', type=int, default=32, help="Batch Size")
    parser.add_argument('--nb_epochs', type=int, default=40, help="Number of Epochs")
    parser.add_argument('--patience', type=int, default=3, help="Patience for Early Stopping")
    parser.add_argument('--seed', type=int, default=1111, help='Random seed value')
    parser.add_argument('--warmup_steps', type=int, default=0, help="Warmup proportion")
    parser.add_argument('--grad_steps', type=int, default=1, help="Gradient accumulating step")
    parser.add_argument('--learning_rate', type=float, default=6e-4, help="Learning rate")
    parser.add_argument('--train_data_percentage', type=float, default=1.0, help="Percentage of training data to be used")
    parser.add_argument('--adam_epsilon', type=float, default=1e-6, help="Epsilon value for Adam optimizer")
    parser.add_argument('--min_delta_change', type=float, default=0.01, help="Minimum change in delta in validation loss for Early Stopping")
    parser.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay")
    parser.add_argument('--augment_data', type=bool, default=False, help='Enables data augmentation')
    parser.add_argument('--nb_augmented_samples', type=int, default=1, help='Number of augmented samples to be generated')

    # Model-specific arguments
    parser.add_argument('--model_type', type=str, choices=['visualbert', 'vilt'], default='visualbert', help="Choose the model type: visualbert or vilt")
    parser.add_argument('--save_dir', type=str, default=None, help="Directory for models to be saved")
    parser.add_argument('--model_dir_name', type=str, default=None, help="Save the model with the folder name as mentioned.")
    
    args = parser.parse_args()

    # Dynamically set the default save_dir based on model type
    if args.save_dir is None:
        args.save_dir = os.path.join(
            os.getcwd(),
            f"/workspace/persistent/HTClipper/models/grouped-and-masked/multimodal-baselines/classification/{args.model_type}/"
        )

    return args

# Use this in a script
# args = parse_arguments()

# Use this in a Jupyter Notebook
args_dict = {
    "logged_entry_name": "multimodal-latent-fusion-seed:1111",
    "data_dir": "/workspace/persistent/HTClipper/data/processed",
    "city": "south",
    "batch_size": 32,
    "nb_epochs": 40,
    "patience": 3,
    "seed": 1111,
    "warmup_steps": 0,
    "grad_steps": 1,
    "learning_rate": 6e-4,
    "train_data_percentage": 1.0,
    "adam_epsilon": 1e-6,
    "min_delta_change": 0.01,
    "weight_decay": 0.01,
    "augment_data": False,
    "nb_augmented_samples": 1,
    "model_type": "visualbert",  # Change to 'vilt' if needed
    "save_dir": "/workspace/persistent/HTClipper/models/grouped-and-masked/multimodal-baselines/classification/visualbert/",
    "model_dir_name": None,
}

args = Args(**args_dict)  # Convert the dictionary to an object

In [4]:
# Set matrix multiplication precision
# This setting offers a balance between precision and performance. It’s typically a good starting point for mixed precision training
#  with FP16.
torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

assert args.city in ["chicago", "atlanta", "dallas", "detroit", "houston", "sf", "ny", "all", "midwest", "northeast", "south", "west"]

# Creating directories
if args.model_dir_name == None:
    directory = os.path.join(args.save_dir, args.city, "seed:" + str(args.seed), "lr-" + str(args.learning_rate))
else:
    directory = os.path.join(args.save_dir, args.city, "seed:" + str(args.seed), "lr-" + str(args.learning_rate), args.model_dir_name)
Path(directory).mkdir(parents=True, exist_ok=True)
Path(args.save_dir).mkdir(parents=True, exist_ok=True)

# %% Load your DataFrame
data_dir = os.path.join(args.data_dir, args.city + ".csv")
args.image_dir = os.path.join("/workspace/persistent/HTClipper/data/IMAGES", args.city, "image", "image")
df = pd.read_csv(data_dir)

# Encode the labels
label_encoder = LabelEncoder()
df['VENDOR'] = label_encoder.fit_transform(df['VENDOR'])

# Identify and keep vendors with at least 2 instances
class_counts = df['VENDOR'].value_counts()
valid_classes = class_counts[class_counts >= 2].index
df_filtered = df[df['VENDOR'].isin(valid_classes)]

# Re-encode labels after filtering
df_filtered['VENDOR'] = label_encoder.fit_transform(df_filtered['VENDOR'])

df_filtered = df_filtered[["TEXT", "IMAGES", "VENDOR"]].drop_duplicates()

# Split the data into train, validation, and test sets without mapping images to text yet
train_df, test_df = train_test_split(
    df_filtered, test_size=0.2, random_state=args.seed, stratify=df_filtered['VENDOR'], shuffle=True)

# Adjust the validation split size based on the number of unique vendors
min_val_size = len(df_filtered['VENDOR'].unique()) / len(train_df)
val_size = max(0.05, min_val_size)  # Choose a larger value if needed, e.g., 0.05 or 5%

train_df, val_df = train_test_split(
    train_df, test_size=val_size, random_state=args.seed, stratify=train_df['VENDOR']
)

# Apply map_images_with_text separately to avoid overlap of text-image pairs across splits
train_df = map_images_with_text(train_df).drop_duplicates()
val_df = map_images_with_text(val_df).drop_duplicates()
test_df = map_images_with_text(test_df).drop_duplicates()

# Replacing all the numbers in the training dataset with the letter "N"
train_df['TEXT'] = train_df['TEXT'].apply(lambda x: re.sub(r'\d', 'N', str(x)))

In [32]:
"""
Python version: 3.10
Description: Contains the architectural implementation of visualBERT based multimodal classifier trained with concatenation based
            fusion techniques.
Reference: https://arxiv.org/pdf/1908.03557 
"""

# %% Importing libraries
from sklearn.metrics import balanced_accuracy_score, f1_score, classification_report

import torch
from torch import nn
import torch.nn.functional as F

import lightning.pytorch as pl

from transformers import ViTModel, get_linear_schedule_with_warmup

class VisualBERTClassifier(pl.LightningModule):
    def __init__(self, visualbert_model, vit_model, learning_rate, num_classes, weight_decay, eps, warmup_steps, num_training_steps, max_seq_length=512, 
                max_visual_tokens=197):
        super(VisualBERTClassifier, self).__init__()
        self.visualbert_model = visualbert_model
        self.vit_model = vit_model
        self.classifier = nn.Linear(self.visualbert_model.config.hidden_size, num_classes)
        self.criterion = nn.CrossEntropyLoss()
        # Max sequence length for transformers model
        self.max_seq_length = max_seq_length
        # Since we using the ViT-patch 16, max_visual_tokens = (Image Height/Patch Size) x (Image Height/Patch Size)
        # = (224/16) x (224/16) = 196
        self.max_visual_tokens = max_visual_tokens  # This is standard for ViT with 224x224 images
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps
        self.test_outputs = []

    def forward(self, input_ids, attention_mask, pixel_values):
        # Extract visual embeddings using ViT
        vit_outputs = self.vit_model(pixel_values)
        visual_embeds = vit_outputs.last_hidden_state

        # Ensure the visual_embeds shape matches expected shape by VisualBERT
        batch_size, num_visual_tokens, hidden_dim = visual_embeds.shape

        # Trim or pad visual embeddings
        if num_visual_tokens > self.max_visual_tokens:  # Trim if the size exceeds expected
            visual_embeds = visual_embeds[:, :self.max_visual_tokens, :]
        elif num_visual_tokens < self.max_visual_tokens:  # Pad if the size is less than expected
            padding = torch.zeros((batch_size, self.max_visual_tokens - num_visual_tokens, hidden_dim), device=visual_embeds.device)
            visual_embeds = torch.cat((visual_embeds, padding), dim=1)

        visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long).to(input_ids.device)
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(input_ids.device)

        # Adjust input_ids and attention_mask to ensure the total length is within the limit
        total_length = input_ids.size(1) + self.max_visual_tokens
        if total_length > self.max_seq_length:
            excess_length = total_length - self.max_seq_length
            input_ids = input_ids[:, :-excess_length]
            attention_mask = attention_mask[:, :-excess_length]

        # Concatenate text and visual embeddings
        text_embeds = self.visualbert_model.embeddings.word_embeddings(input_ids)
        token_type_embeddings = self.visualbert_model.embeddings.token_type_embeddings(
            torch.cat((torch.zeros_like(input_ids), visual_token_type_ids), dim=1))
        position_ids = torch.arange(text_embeds.size(1) + visual_embeds.size(1), dtype=torch.long, device=input_ids.device)
        position_embeddings = self.visualbert_model.embeddings.position_embeddings(position_ids)

        embeddings = torch.cat((text_embeds, visual_embeds), dim=1)
        embeddings += token_type_embeddings + position_embeddings
        embeddings = self.visualbert_model.embeddings.LayerNorm(embeddings)
        embeddings = self.visualbert_model.embeddings.dropout(embeddings)

        # Concatenate attention masks
        combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=1)
        combined_attention_mask = combined_attention_mask.unsqueeze(1).unsqueeze(2)

        encoder_outputs = self.visualbert_model.encoder(
            embeddings,
            attention_mask=combined_attention_mask,
            head_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=True,
        )

        pooled_output = encoder_outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled_output)
        return logits

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['label']
        logits = self(input_ids, attention_mask, pixel_values)
        loss = self.criterion(logits, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['label']
        logits = self(input_ids, attention_mask, pixel_values)
        loss = self.criterion(logits, labels)
        preds = torch.argmax(logits, dim=1)

        acc = balanced_accuracy_score(labels.cpu(), preds.cpu())
        f1_weighted = f1_score(labels.cpu(), preds.cpu(), average='weighted')
        f1_micro = f1_score(labels.cpu(), preds.cpu(), average='micro')
        f1_macro = f1_score(labels.cpu(), preds.cpu(), average='macro')

        self.log('val_loss', loss)
        self.log('val_acc', acc)
        self.log('val_f1_weighted', f1_weighted)
        self.log('val_f1_micro', f1_micro)
        self.log('val_f1_macro', f1_macro)

        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['label']
        logits = self(input_ids, attention_mask, pixel_values)
        preds = torch.argmax(logits, dim=1)

        acc = balanced_accuracy_score(labels.cpu(), preds.cpu())
        f1_weighted = f1_score(labels.cpu(), preds.cpu(), average='weighted')
        f1_micro = f1_score(labels.cpu(), preds.cpu(), average='micro')
        f1_macro = f1_score(labels.cpu(), preds.cpu(), average='macro')

        self.test_outputs.append({"acc": acc, "f1_weighted": f1_weighted, "f1_micro": f1_micro, "f1_macro": f1_macro, "labels": labels.cpu(), "preds": preds.cpu()})

        return {"acc": acc, "f1_weighted": f1_weighted, "f1_micro": f1_micro, "f1_macro": f1_macro}

    def on_test_epoch_end(self):
        avg_acc = torch.tensor([x['acc'] for x in self.test_outputs]).mean()
        avg_f1_weighted = torch.tensor([x['f1_weighted'] for x in self.test_outputs]).mean()
        avg_f1_micro = torch.tensor([x['f1_micro'] for x in self.test_outputs]).mean()
        avg_f1_macro = torch.tensor([x['f1_macro'] for x in self.test_outputs]).mean()

        labels = torch.cat([x['labels'] for x in self.test_outputs])
        preds = torch.cat([x['preds'] for x in self.test_outputs])

        self.log('test_acc', avg_acc)
        self.log('test_f1_weighted', avg_f1_weighted)
        self.log('test_f1_micro', avg_f1_micro)
        self.log('test_f1_macro', avg_f1_macro)

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay':self.weight_decay}, 
                                        {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay}]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)

        # We also use a scheduler that is supplied by transformers.
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]
    
    def extract_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None):
        self.eval()
        with torch.no_grad():
            # Re-implement the forward method up to the point before classification to get the pooled output
            # Extract visual embeddings using ViT
            vit_outputs = self.vit_model(pixel_values)
            visual_embeds = vit_outputs.last_hidden_state

            # Ensure the visual_embeds shape matches expected shape by VisualBERT
            batch_size, num_visual_tokens, hidden_dim = visual_embeds.shape

            # Trim or pad visual embeddings
            if num_visual_tokens > self.max_visual_tokens:
                visual_embeds = visual_embeds[:, :self.max_visual_tokens, :]
            elif num_visual_tokens < self.max_visual_tokens:
                padding = torch.zeros((batch_size, self.max_visual_tokens - num_visual_tokens, hidden_dim), device=visual_embeds.device)
                visual_embeds = torch.cat((visual_embeds, padding), dim=1)

            visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long).to(input_ids.device)
            visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(input_ids.device)

            # Adjust input_ids and attention_mask
            total_length = input_ids.size(1) + self.max_visual_tokens
            if total_length > self.max_seq_length:
                excess_length = total_length - self.max_seq_length
                input_ids = input_ids[:, :-excess_length]
                attention_mask = attention_mask[:, :-excess_length]

            # Concatenate text and visual embeddings
            text_embeds = self.visualbert_model.embeddings.word_embeddings(input_ids)
            token_type_embeddings = self.visualbert_model.embeddings.token_type_embeddings(
                torch.cat((torch.zeros_like(input_ids), visual_token_type_ids), dim=1))
            position_ids = torch.arange(text_embeds.size(1) + visual_embeds.size(1), dtype=torch.long, device=input_ids.device)
            position_embeddings = self.visualbert_model.embeddings.position_embeddings(position_ids)

            embeddings = torch.cat((text_embeds, visual_embeds), dim=1)
            embeddings += token_type_embeddings + position_embeddings
            embeddings = self.visualbert_model.embeddings.LayerNorm(embeddings)
            embeddings = self.visualbert_model.embeddings.dropout(embeddings)

            # Concatenate attention masks
            combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=1)
            extended_attention_mask = combined_attention_mask.unsqueeze(1).unsqueeze(2)

            encoder_outputs = self.visualbert_model.encoder(
                embeddings,
                attention_mask=extended_attention_mask,
                head_mask=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=True,
            )

            # Extract the pooled output (e.g., [CLS] token)
            pooled_output = encoder_outputs.last_hidden_state[:, 0]  # Shape: [batch_size, hidden_size]

            return pooled_output

In [33]:
"""
Python version: 3.10
Description: Contains the architectural implementation of ViLT based multimodal classifier trained with concatenation based
            fusion techniques.
Reference: https://arxiv.org/abs/2102.03334
"""

# %% Importing libraries
from sklearn.metrics import balanced_accuracy_score, f1_score, classification_report

import torch
from torch import nn
import torch.nn.functional as F

import lightning.pytorch as pl

from transformers import ViTModel, get_linear_schedule_with_warmup

# %% Model Definition for ViLT
class ViLTClassifier(pl.LightningModule):
    def __init__(self, vilt_model, learning_rate, num_classes, weight_decay, eps, warmup_steps, num_training_steps):
        super(ViLTClassifier, self).__init__()
        self.vilt_model = vilt_model
        self.classifier = nn.Linear(self.vilt_model.config.hidden_size, num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps
        self.test_outputs = [] 

    def forward(self, input_ids, attention_mask, pixel_values):
        outputs = self.vilt_model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
        pooled_output = outputs.pooler_output  # ViLT provides a pooled output directly
        logits = self.classifier(pooled_output)
        return logits

    def training_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask'], batch['pixel_values'])
        loss = self.criterion(logits, batch['label'])
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['label']
        logits = self(input_ids, attention_mask, pixel_values)
        loss = self.criterion(logits, labels)
        preds = torch.argmax(logits, dim=1)

        acc = balanced_accuracy_score(labels.cpu(), preds.cpu())
        f1_weighted = f1_score(labels.cpu(), preds.cpu(), average='weighted')
        f1_micro = f1_score(labels.cpu(), preds.cpu(), average='micro')
        f1_macro = f1_score(labels.cpu(), preds.cpu(), average='macro')

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_f1_weighted', f1_weighted, prog_bar=True)
        self.log('val_f1_micro', f1_micro, prog_bar=True)
        self.log('val_f1_macro', f1_macro, prog_bar=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        pixel_values = batch['pixel_values']
        labels = batch['label']
        logits = self(input_ids, attention_mask, pixel_values)
        preds = torch.argmax(logits, dim=1)

        acc = balanced_accuracy_score(labels.cpu(), preds.cpu())
        f1_weighted = f1_score(labels.cpu(), preds.cpu(), average='weighted')
        f1_micro = f1_score(labels.cpu(), preds.cpu(), average='micro')
        f1_macro = f1_score(labels.cpu(), preds.cpu(), average='macro')

        self.test_outputs.append({"acc": acc, "f1_weighted": f1_weighted, "f1_micro": f1_micro, "f1_macro": f1_macro, "labels": labels.cpu(), "preds": preds.cpu()})

        return {"acc": acc, "f1_weighted": f1_weighted, "f1_micro": f1_micro, "f1_macro": f1_macro}

    def on_test_epoch_end(self):
        avg_acc = torch.tensor([x['acc'] for x in self.test_outputs]).mean()
        avg_f1_weighted = torch.tensor([x['f1_weighted'] for x in self.test_outputs]).mean()
        avg_f1_micro = torch.tensor([x['f1_micro'] for x in self.test_outputs]).mean()
        avg_f1_macro = torch.tensor([x['f1_macro'] for x in self.test_outputs]).mean()

        labels = torch.cat([x['labels'] for x in self.test_outputs])
        preds = torch.cat([x['preds'] for x in self.test_outputs])

        self.log('test_acc', avg_acc)
        self.log('test_f1_weighted', avg_f1_weighted)
        self.log('test_f1_micro', avg_f1_micro)
        self.log('test_f1_macro', avg_f1_macro)

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay':self.weight_decay}, 
                                        {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay}]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)

        # We also use a scheduler that is supplied by transformers.
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]
    
    def extract_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None):
        self.eval()
        with torch.no_grad():
            # Use the standard ViLT forward method with both text and image inputs
            outputs = self.vilt_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values,
            )

            # Extract the pooled output, which is the multimodal embedding before classification
            pooled_output = outputs.pooler_output  # Shape: [batch_size, hidden_size]

            return pooled_output

In [34]:
class MultimodalDataset(Dataset):
    def __init__(self, dataframe, text_tokenizer, image_processor, label_encoder, image_dir, augment=False, image_size=(224, 224)):
        self.dataframe = dataframe
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor
        self.label_encoder = label_encoder
        self.augment = augment
        self.augmentation_pipelines = get_augmentation_pipeline() if augment else None
        self.image_size = image_size
        self.image_dir = image_dir
        
        # Remove rows with missing image files
        self.dataframe = self.dataframe[self.dataframe['IMAGES'].apply(lambda x: os.path.exists(os.path.join(self.image_dir, x)))]
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        text = row['TEXT']
        image_path = row['IMAGES']
        label = row['VENDOR']
        
        text_inputs = self.text_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        
        try:
            image = Image.open(os.path.join(self.image_dir, image_path)).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a default image or handle the error as needed
            image = Image.new('RGB', self.image_size, (255, 255, 255))

        # Resize the image to a consistent size
        image = image.resize(self.image_size)
        
        image_array = np.array(image)

        # Ensure the image array has the correct dimensions (H, W, C)
        if image_array.shape[-1] != 3:
            image_array = np.stack((image_array,) * 3, axis=-1)
        
        if self.augment and 'AUGMENT' in row and row['AUGMENT'] >= 0:
            augment_idx = row['AUGMENT']
            augmented = self.augmentation_pipelines[augment_idx](image=image_array)
            image_array = augmented['image']
        
        # Ensure image dimensions are (C, H, W)
        if image_array.shape[-1] == 3:
            image_array = np.transpose(image_array, (2, 0, 1))
        
        image_tensor = torch.tensor(image_array, dtype=torch.float)
        image_tensor = self.image_processor(images=image_tensor, return_tensors="pt")['pixel_values'].squeeze(0)

        
        input_ids = text_inputs['input_ids'].squeeze(0)
        attention_mask = text_inputs['attention_mask'].squeeze(0)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'pixel_values': image_tensor,
            'label': torch.tensor(label, dtype=torch.long)
        }

In [35]:
def load_trained_model(model_name):

    if model_name == "visualBERT":
        text_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        image_processor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224', size=224, do_resize=True)
        vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224')

        visualbert_config = VisualBertConfig.from_pretrained('uclanlp/visualbert-vcr')
        visualbert_model = VisualBertModel.from_pretrained('uclanlp/visualbert-vcr', config=visualbert_config)

        # Create the datasets and dataloaders
        train_dataset = MultimodalDataset(train_df, text_tokenizer, image_processor, label_encoder, image_dir=args.image_dir, augment=args.augment_data)
        val_dataset = MultimodalDataset(val_df, text_tokenizer, image_processor, label_encoder, image_dir=args.image_dir, augment=False)
        test_dataset = MultimodalDataset(test_df, text_tokenizer, image_processor, label_encoder, image_dir=args.image_dir, augment=False)

        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
        val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
        test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

        num_training_steps = args.nb_epochs * len(train_dataloader)
        # Setting the warmup steps to 1/10th the size of training data
        warmup_steps = int(0.1 * num_training_steps)

        # %% Loading the model
        model = VisualBERTClassifier(visualbert_model=visualbert_model, vit_model=vit_model, learning_rate=args.learning_rate, 
                                    num_classes=len(label_encoder.classes_), weight_decay=args.weight_decay, eps=args.adam_epsilon, 
                                    warmup_steps=warmup_steps, num_training_steps=num_training_steps)

        # Load the checkpoint
        checkpoint = torch.load("/workspace/persistent/HTClipper/models/grouped-and-masked/multimodal-baselines/classification/visualBERT/south/seed:1111/lr-0.0001/final_model.ckpt")

        # Load the state dictionary into the model
        model.load_state_dict(checkpoint['state_dict'])

    else:

        # %% Load the processor and model for ViLT
        vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        vilt_model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")

        # Define train, validation, and test datasets and loaders
        train_dataset = ViLTMultimodalDataset(train_df, vilt_processor, label_encoder, image_dir=args.image_dir, augment=args.augment_data)
        val_dataset = ViLTMultimodalDataset(val_df, vilt_processor, label_encoder, image_dir=args.image_dir, augment=False)
        test_dataset = ViLTMultimodalDataset(test_df, vilt_processor, label_encoder, image_dir=args.image_dir, augment=False)

        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
        val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
        test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

        num_training_steps = args.nb_epochs * len(train_dataloader)
        # Setting the warmup steps to 1/10th the size of training data
        warmup_steps = int(0.1 * num_training_steps)

        model = ViLTClassifier(vilt_model=vilt_model, learning_rate=args.learning_rate, num_classes=len(label_encoder.classes_), weight_decay=args.weight_decay,
                            eps=args.adam_epsilon, warmup_steps=warmup_steps, num_training_steps=num_training_steps)

        # Load the checkpoint
        checkpoint = torch.load("/workspace/persistent/HTClipper/models/grouped-and-masked/multimodal-baselines/classification/vilt/south/seed:1111/lr-0.0001/final_model.ckpt")

        # Load the state dictionary into the model
        model.load_state_dict(checkpoint['state_dict'])
        
    return model

In [36]:
def process_dataset_for_ClassifierModel(region_name, data_dir, model, model_name, filter_by="vendor", batch_size=32):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    assert filter_by in ["vendor", "id"]
    
    image_dir = os.path.join("/workspace/persistent/HTClipper/data/IMAGES", region_name, "image", "image")

    # Load the dataset
    df = pd.read_csv(os.path.join(data_dir, f"{region_name}.csv"))
    df['region'] = region_name

    # Encode the labels
    label_encoder = LabelEncoder()
    df['VENDOR'] = label_encoder.fit_transform(df['VENDOR'])

    # Identify and keep vendors with at least 2 instances
    class_counts = df['VENDOR'].value_counts()
    valid_classes = class_counts[class_counts >= 2].index
    df_filtered = df[df['VENDOR'].isin(valid_classes)]

    # Re-encode labels after filtering
    df_filtered['VENDOR'] = label_encoder.fit_transform(df_filtered['VENDOR'])

    df_filtered = df_filtered[["TEXT", "IMAGES", "VENDOR", "region"]].drop_duplicates()

    # Dynamically adjust test_size based on the number of classes
    min_test_size = len(df_filtered['VENDOR'].unique()) / len(df_filtered)
    test_size = max(0.2, min_test_size)  # Ensure the test size is at least 20% or large enough to include all classes

    train_df, test_df = train_test_split(
        df_filtered, test_size=test_size, random_state=1111, stratify=df_filtered['VENDOR'], shuffle=True
    )

    # Apply map_images_with_text_fn to avoid overlap of text-image pairs across splits
    train_df = map_images_with_text(train_df).drop_duplicates()
    test_df = map_images_with_text(test_df).drop_duplicates()
    
    if model_name == "visualBERT":
        text_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        image_processor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224', size=224, do_resize=True)
        vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224')

        # Create the datasets and dataloaders
        train_dataset = MultimodalDataset(train_df, text_tokenizer, image_processor, label_encoder, image_dir=image_dir, augment=False)
        test_dataset = MultimodalDataset(test_df, text_tokenizer, image_processor, label_encoder, image_dir=image_dir, augment=False)

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    else:

        # Load the processor and model for ViLT
        vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        vilt_model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")

        # Define train and test datasets and loaders
        train_dataset = ViLTMultimodalDataset(train_df, vilt_processor, label_encoder, image_dir=image_dir, augment=False)
        test_dataset = ViLTMultimodalDataset(test_df, vilt_processor, label_encoder, image_dir=image_dir, augment=False)

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Helper function to fetch embeddings
    def get_embeddings(dataloader):
        multimodal_embeddings = []
        labels = []

        model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            for batch in tqdm(dataloader, desc='Fetching Embeddings', leave=False):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                pixel_values = batch['pixel_values'].to(device)
                batch_labels = batch['label']

                # Extract multimodal embeddings
                embeddings = model.extract_embeddings(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=pixel_values
                )
                multimodal_embeddings.append(embeddings.cpu().numpy())

                # Append labels
                labels.extend(batch_labels.cpu().numpy())

        # Concatenate embeddings across all batches
        multimodal_embeddings = np.concatenate(multimodal_embeddings)
        labels = np.array(labels)
        return multimodal_embeddings, labels

    # Get embeddings for train and test sets
    train_multimodal_embeddings, train_labels = get_embeddings(train_dataloader)
    test_multimodal_embeddings, test_labels = get_embeddings(test_dataloader)

    # Create output directory
    output_dir = os.path.join(
        "/workspace/persistent/HTClipper/models/pickled/embeddings/grouped-and-masked/multimodal_baselines/E2E", model_name)
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Save embeddings and labels
    np.save(os.path.join(output_dir, f'train_multimodal_embeddings_{region_name}_{filter_by}.npy'), train_multimodal_embeddings)
    np.save(os.path.join(output_dir, f'train_labels_{region_name}_{filter_by}.npy'), train_labels)

    np.save(os.path.join(output_dir, f'test_multimodal_embeddings_{region_name}_{filter_by}.npy'), test_multimodal_embeddings)
    np.save(os.path.join(output_dir, f'test_labels_{region_name}_{filter_by}.npy'), test_labels)

In [37]:
for model_name in ["visualBERT", "ViLT"]:
    
    model = None
    print(f"model_name:{model_name}")
    model = load_trained_model(model_name)
    model.eval()

    for region in ["south", "midwest", "west", "northeast"]:
        print(f"-----------------------------------------------{region}-------------------------------------------------------")
        process_dataset_for_ClassifierModel(region_name=region, data_dir=args.data_dir, model=model.to(device), model_name=model_name, filter_by="vendor", batch_size=32)

model_name:visualBERT


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


-----------------------------------------------south-------------------------------------------------------


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                        

-----------------------------------------------midwest-------------------------------------------------------


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                      

-----------------------------------------------west-------------------------------------------------------


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                      

-----------------------------------------------northeast-------------------------------------------------------


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                      

model_name:ViLT
-----------------------------------------------south-------------------------------------------------------


                                                                        

-----------------------------------------------midwest-------------------------------------------------------


                                                                      

-----------------------------------------------west-------------------------------------------------------


                                                                      

-----------------------------------------------northeast-------------------------------------------------------


                                                                      