In [None]:
# Code Overview
# Extracts text, vision, and multimodal embeddings from CLIP, CLIP-ITM, and BLIP2 models

In [None]:
"""
Python version: 3.10
Description: Performs Multimodal Authorship Attribution using CLIP training strategy
"""

# %% Importing Libraries
import os
import re
import sys
import argparse
import time
import datetime
import random
from pathlib import Path
from PIL import Image

import pandas as pd

from tqdm import tqdm
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, f1_score, classification_report

import torch
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning.loggers import WandbLogger

import lightning as L
import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.tuner.tuning import Tuner
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from transformers import AutoTokenizer, ViTImageProcessor

# Custom library
sys.path.append('../process/')
from utilities import map_images_with_text_for_clip_model, map_images_with_text

import warnings
warnings.filterwarnings('ignore')

In [None]:
# Suppress TorchDynamo errors and fall back to eager execution
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
import os

class Args:
    def __init__(self):
        self.logged_entry_name = "multimodal-latent-fusion-seed:1111"
        self.data_dir = '/workspace/persistent/HTClipper/data/processed'
        self.image_dir = "/workspace/persistent/HTClipper/data/IMAGES"
        self.save_dir = os.path.join(os.getcwd(), "/workspace/persistent/HTClipper/models/grouped-and-masked/multimodal-baselines/pre-training/")
        self.model_dir_name = None
        self.pairing_mode = "non-associated"
        self.model_type = "BLIP2"  # Can be "CLIP", "BLIP2", or "CLIPITM"
        self.batch_size = 32
        self.nb_epochs = 40
        self.patience = 3
        self.nb_negatives = 1
        self.seed = 1111
        self.warmup_steps = 0
        self.grad_steps = 1
        self.learning_rate = 6e-4
        self.train_data_percentage = 1.0
        self.adam_epsilon = 1e-6
        self.min_delta_change = 0.01
        self.weight_decay = 0.01
        self.augment_data = False
        self.nb_augmented_samples = 1
        self.loss = 'NTXENT'
        self.temp = 0.5

# Instantiate the arguments
args = Args()

In [None]:
class CLIPDataset(Dataset):
    def __init__(self, df, text_tokenizer, image_processor, num_negatives=5, pairing_mode='associated', city=None):
        self.df = df
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor
        self.num_negatives = num_negatives
        self.pairing_mode = pairing_mode
        self.city = city  # City is passed explicitly instead of relying on a 'region' column

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Select a positive text-image pair
        pos_row = self.df.iloc[idx]
        pos_text = pos_row['TEXT']
        pos_image_path = pos_row['IMAGES']
        label = pos_row['VENDOR']  # Assuming 'VENDOR' is the label column

        # Tokenize the text and process the positive image
        pos_text_inputs = self.text_tokenizer(pos_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        pos_image = self._load_image(pos_image_path)

        # Sample negatives (based on pairing_mode)
        if self.pairing_mode == 'associated':
            neg_texts, neg_image_paths = self._get_associated_negative_samples(pos_row['region'])
        elif self.pairing_mode == 'non-associated':
            neg_texts, neg_image_paths = self._get_non_associated_negative_samples(pos_row['region'])
        else:
            raise ValueError("pairing_mode must be either 'associated' or 'non-associated'")

        neg_text_inputs = [self.text_tokenizer(neg_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512) for neg_text in neg_texts]
        neg_images = [self._load_image(neg_image_path) for neg_image_path in neg_image_paths]

        neg_input_ids = torch.stack([neg_text['input_ids'].squeeze(0) for neg_text in neg_text_inputs])
        neg_attention_mask = torch.stack([neg_text['attention_mask'].squeeze(0) for neg_text in neg_text_inputs])
        neg_pixel_values = torch.stack(neg_images)

        return {
            'pos_input_ids': pos_text_inputs['input_ids'].squeeze(0),
            'pos_attention_mask': pos_text_inputs['attention_mask'].squeeze(0),
            'pos_pixel_values': pos_image,
            'neg_input_ids': neg_input_ids,
            'neg_attention_mask': neg_attention_mask,
            'neg_pixel_values': neg_pixel_values,
            'label': torch.tensor(label, dtype=torch.long)  # Add the label here
        }


    def _get_associated_negative_samples(self):
        # Sample negatives from another part of the dataset (can implement any logic for this)
        neg_indices = torch.randint(0, len(self.df), (self.num_negatives,))
        neg_texts = self.df.iloc[neg_indices]['TEXT'].values
        neg_image_paths = self.df.iloc[neg_indices]['IMAGES'].values
        return neg_texts, neg_image_paths

    def _get_non_associated_negative_samples(self):
        # Sample non-associated negatives from the dataset
        neg_indices = torch.randint(0, len(self.df), (self.num_negatives,))
        neg_texts = self.df.iloc[neg_indices]['TEXT'].values
        neg_image_paths = self.df.iloc[neg_indices]['IMAGES'].values
        return neg_texts, neg_image_paths

    def _load_image(self, image_path):
        image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
        image = self.image_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)
        return image


class BLIP2Dataset(Dataset):
    def __init__(self, df, text_tokenizer, t5_tokenizer, image_processor, num_negatives=5, pairing_mode='non-associated'):
        self.df = df
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor
        self.num_negatives = num_negatives
        self.pairing_mode = pairing_mode
        self.t5_tokenizer = t5_tokenizer

        # Assume that 'region' is a column in your DataFrame indicating the region of the text-image pair
        self.region_groups = df.groupby('region')

        # Define a transform to resize the image to 224x224
        # self.image_transform = transforms.Compose([
        #    transforms.Resize((224, 224)),  # Resize image to 224x224
        #    transforms.ToTensor(),  # Convert to tensor
        #])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Select a positive text-image pair
        pos_row = self.df.iloc[idx]
        full_text = pos_row['TEXT']
        pos_image_path = pos_row['IMAGES']
        pos_region = pos_row['region']

        # Split the text on [SEP] for text generation
        if '[SEP]' in full_text:
            conditional_text, target_text = full_text.split('[SEP]', 1)
            conditional_text = conditional_text.strip()
            target_text = target_text.strip()
        else:
            # If [SEP] is not present, handle accordingly
            conditional_text = ''
            target_text = full_text.strip()

        # Tokenize the full text for CLIP and ITM losses
        pos_text_inputs = self.text_tokenizer(
            full_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )

        # Tokenize the conditional text (text prompt) for text generation
        conditional_text_inputs = self.t5_tokenizer(
            conditional_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )

        # Tokenize the target text for text generation
        target_text_inputs = self.t5_tokenizer(
            target_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )

        # Process the positive image
        pos_image = self._load_image(pos_image_path)

        # Sample negatives (based on pairing_mode)
        if self.pairing_mode == 'associated':
            neg_texts, neg_image_paths = self._get_associated_negative_samples(pos_region)
        elif self.pairing_mode == 'non-associated':
            neg_texts, neg_image_paths = self._get_non_associated_negative_samples(pos_region)
        else:
            raise ValueError("pairing_mode must be either 'associated' or 'non-associated'")

        # Tokenize negative texts
        neg_text_inputs = [self.text_tokenizer(
            neg_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512
        ) for neg_text in neg_texts]

        # Load negative images
        neg_images = [self._load_image(neg_image_path) for neg_image_path in neg_image_paths]

        # Prepare negative inputs
        neg_input_ids = torch.stack([neg_text['input_ids'].squeeze(0) for neg_text in neg_text_inputs])
        neg_attention_mask = torch.stack([neg_text['attention_mask'].squeeze(0) for neg_text in neg_text_inputs])
        neg_pixel_values = torch.stack(neg_images)

        return {
            # For CLIP and ITM losses (full text)
            'pos_input_ids': pos_text_inputs['input_ids'].squeeze(0),
            'pos_attention_mask': pos_text_inputs['attention_mask'].squeeze(0),
            # For text generation loss (conditional text and target text)
            'conditional_input_ids': conditional_text_inputs['input_ids'].squeeze(0),
            'conditional_attention_mask': conditional_text_inputs['attention_mask'].squeeze(0),
            'target_input_ids': target_text_inputs['input_ids'].squeeze(0),
            'target_attention_mask': target_text_inputs['attention_mask'].squeeze(0),
            # Positive image
            'pos_pixel_values': pos_image,
            # Negative samples
            'neg_input_ids': neg_input_ids,
            'neg_attention_mask': neg_attention_mask,
            'neg_pixel_values': neg_pixel_values,
        }


    def _get_associated_negative_samples(self, pos_region):
        # Sample negatives from another region (associated text-image pairs)
        other_regions = [region for region in self.region_groups.groups.keys() if region != pos_region]
        neg_region = random.choice(other_regions)
        neg_df = self.region_groups.get_group(neg_region)

        neg_text_indices = torch.randint(0, len(neg_df), (self.num_negatives,))
        neg_texts = neg_df.iloc[neg_text_indices]['TEXT'].values
        neg_image_paths = neg_df.iloc[neg_text_indices]['IMAGES'].values

        return neg_texts, neg_image_paths

    def _get_non_associated_negative_samples(self, pos_region):
        # Sample negatives from another region (non-associated text-image pairs)
        other_regions = [region for region in self.region_groups.groups.keys() if region != pos_region]
        neg_region = random.choice(other_regions)
        neg_df = self.region_groups.get_group(neg_region)

        neg_text_indices = torch.randint(0, len(neg_df), (self.num_negatives,))
        neg_image_indices = torch.randint(0, len(neg_df), (self.num_negatives,))

        neg_texts = neg_df.iloc[neg_text_indices]['TEXT'].values
        neg_image_paths = neg_df.iloc[neg_image_indices]['IMAGES'].values

        return neg_texts, neg_image_paths

    def _load_image(self, image_path):
        image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
        # Resize image to 224x224
        # image = self.image_transform(image)
        image = self.image_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)
        return image

In [None]:
# Setting seed value for reproducibility    
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
# Set TOKENIZERS_PARALLELISM to false to disable parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
seed_everything(args.seed)

# Set matrix multiplication precision
# This setting offers a balance between precision and performance. It’s typically a good starting point for mixed precision training
#  with FP16.
torch.set_float32_matmul_precision("high")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

assert args.loss in ["NTXENT"]
assert args.pairing_mode in ["associated", "non-associated"]
assert args.model_type in ["CLIP", "CLIPITM", "BLIP2"]

num_training_steps = 678
# Setting the warmup steps to 1/10th the size of training data
warmup_steps = 1023

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as pl
from transformers import AutoModel, ViTModel, get_linear_schedule_with_warmup

class CLIPModel(pl.LightningModule):
    def __init__(self, weight_decay, eps, warmup_steps, num_training_steps, text_model_name='johngiorgi/declutr-small', 
                image_model_name='google/vit-base-patch16-224', learning_rate=0.00001, num_negatives=5, temperature=0.5):
        super(CLIPModel, self).__init__()
        self.text_model = AutoModel.from_pretrained(text_model_name)
        self.image_model = ViTModel.from_pretrained(image_model_name)
        self.learning_rate = learning_rate
        self.num_negatives = num_negatives
        self.temperature = temperature
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps

        # Store outputs for validation and testing
        self.validation_outputs = []
        self.test_outputs = []
        
    def forward(self, input_ids, attention_mask, pixel_values):
        # Get text embeddings
        """
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]  # Use CLS token embedding
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)  # Normalize embeddings
        """

        # Get text embeddings
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        eos_mask = input_ids.eq(self.text_model.config.eos_token_id)
        eos_indices = eos_mask.nonzero(as_tuple=False)
        text_embeddings = text_outputs.last_hidden_state[eos_indices[:, 0], eos_indices[:, 1]]
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)  # Normalize embeddings


        # Get image embeddings
        image_outputs = self.image_model(pixel_values=pixel_values)
        image_embeddings = image_outputs.last_hidden_state[:, 0, :]  # Use CLS token embedding
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)  # Normalize embeddings
        
        return text_embeddings, image_embeddings

    def compute_loss(self, pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings):
        # Normalize embeddings
        pos_text_embeddings = F.normalize(pos_text_embeddings, p=2, dim=1)
        pos_image_embeddings = F.normalize(pos_image_embeddings, p=2, dim=1)
        neg_text_embeddings = F.normalize(neg_text_embeddings, p=2, dim=2)  # Normalized over the last dimension
        neg_image_embeddings = F.normalize(neg_image_embeddings, p=2, dim=2)  # Normalized over the last dimension

        # Positive pairs similarity
        pos_sim = torch.exp(torch.sum(pos_text_embeddings * pos_image_embeddings, dim=-1) / self.temperature)
        
        # Negative pairs similarity (text to image)
        neg_sim_text_image = torch.exp(torch.einsum('bij,bj->bi', neg_text_embeddings, pos_image_embeddings) / self.temperature)
        # Negative pairs similarity (image to text)
        neg_sim_image_text = torch.exp(torch.einsum('bi,bkj->bk', pos_text_embeddings, neg_image_embeddings) / self.temperature)

        # Calculate the loss for text-to-image
        denominator_text_image = pos_sim + neg_sim_text_image.sum(dim=1)
        loss_text_image = -torch.log(pos_sim / denominator_text_image)

        # Calculate the loss for image-to-text
        denominator_image_text = pos_sim.unsqueeze(1) + neg_sim_image_text
        loss_image_text = -torch.log(pos_sim.unsqueeze(1) / denominator_image_text).sum(dim=1)

        # Combine both losses
        loss = (loss_text_image + loss_image_text).mean()

        return loss


    def training_step(self, batch, batch_idx):
        # Forward pass
        pos_text_embeddings, pos_image_embeddings = self(batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'])
        neg_text_embeddings, neg_image_embeddings = self(batch['neg_input_ids'].view(-1, batch['neg_input_ids'].shape[-1]), 
                                                         batch['neg_attention_mask'].view(-1, batch['neg_attention_mask'].shape[-1]), 
                                                         batch['neg_pixel_values'].view(-1, batch['neg_pixel_values'].shape[-3], 
                                                                                        batch['neg_pixel_values'].shape[-2], batch['neg_pixel_values'].shape[-1]))
        
        # Reshape negative embeddings
        neg_text_embeddings = neg_text_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        neg_image_embeddings = neg_image_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        
        # Compute loss
        loss = self.compute_loss(pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings)
        
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Forward pass
        pos_text_embeddings, pos_image_embeddings = self(batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'])
        neg_text_embeddings, neg_image_embeddings = self(batch['neg_input_ids'].view(-1, batch['neg_input_ids'].shape[-1]), 
                                                         batch['neg_attention_mask'].view(-1, batch['neg_attention_mask'].shape[-1]), 
                                                         batch['neg_pixel_values'].view(-1, batch['neg_pixel_values'].shape[-3], 
                                                                                        batch['neg_pixel_values'].shape[-2], batch['neg_pixel_values'].shape[-1]))
        
        # Reshape negative embeddings
        neg_text_embeddings = neg_text_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        neg_image_embeddings = neg_image_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        
        # Compute loss
        loss = self.compute_loss(pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings)

        self.log('val_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Forward pass
        pos_text_embeddings, pos_image_embeddings = self(batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'])
        neg_text_embeddings, neg_image_embeddings = self(batch['neg_input_ids'].view(-1, batch['neg_input_ids'].shape[-1]), 
                                                         batch['neg_attention_mask'].view(-1, batch['neg_attention_mask'].shape[-1]), 
                                                         batch['neg_pixel_values'].view(-1, batch['neg_pixel_values'].shape[-3], 
                                                                                        batch['neg_pixel_values'].shape[-2], batch['neg_pixel_values'].shape[-1]))
        
        # Reshape negative embeddings
        neg_text_embeddings = neg_text_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        neg_image_embeddings = neg_image_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        
        # Compute loss
        loss = self.compute_loss(pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings)

        self.test_outputs.append(loss)
        self.log('test_loss', loss)
        return loss

    def on_test_epoch_end(self):
        if self.test_outputs:
            avg_loss = torch.stack(self.test_outputs).mean()
            self.log('avg_test_loss', avg_loss)
        self.test_outputs.clear()

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)

        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]
    
    def get_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None, embedding_type='text'):
        """
    Generate text, image, or multimodal embeddings for inference.

    Args:
        input_ids (torch.Tensor, optional): Tokenized input text IDs of shape [batch_size, seq_len].
        attention_mask (torch.Tensor, optional): Attention mask for the input text of shape [batch_size, seq_len].
        pixel_values (torch.Tensor, optional): Preprocessed image tensor of shape [batch_size, channels, height, width].
        embedding_type (str): Specify 'text', 'image', or 'multimodal' to generate the respective embeddings.

    Returns:
        torch.Tensor: Normalized embeddings.
        """
        if embedding_type == 'text':
            if input_ids is None or attention_mask is None:
                raise ValueError("input_ids and attention_mask are required for text embeddings.")

            # Get text embeddings
            text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
            # Use the [EOS] token embedding or the last token embedding
            if self.text_model.config.eos_token_id is not None:
                eos_mask = input_ids.eq(self.text_model.config.eos_token_id)
                if torch.any(eos_mask):
                    eos_indices = eos_mask.nonzero(as_tuple=False)
                    text_embeddings = text_outputs.last_hidden_state[eos_indices[:, 0], eos_indices[:, 1]]
                else:
                    # If no [EOS] token is found, use the last hidden state
                    text_embeddings = text_outputs.last_hidden_state[:, -1, :]
            else:
                # If eos_token_id is not defined, use the last hidden state
                text_embeddings = text_outputs.last_hidden_state[:, -1, :]
            text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)  # Normalize embeddings
            return text_embeddings

        elif embedding_type == 'image':
            if pixel_values is None:
                raise ValueError("pixel_values are required for image embeddings.")

            # Get image embeddings
            image_outputs = self.image_model(pixel_values=pixel_values)
            image_embeddings = image_outputs.last_hidden_state[:, 0, :]  # Use CLS token embedding
            image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)  # Normalize embeddings
            return image_embeddings

        elif embedding_type == 'multimodal':
            if input_ids is None or attention_mask is None or pixel_values is None:
                raise ValueError("input_ids, attention_mask, and pixel_values are required for multimodal embeddings.")

            # Get text embeddings
            text_embeddings = self.get_embeddings(input_ids=input_ids, attention_mask=attention_mask, embedding_type='text')

            # Get image embeddings
            image_embeddings = self.get_embeddings(pixel_values=pixel_values, embedding_type='image')

            # Compute multimodal embeddings by averaging
            multimodal_embeddings = (text_embeddings + image_embeddings) / 2
            multimodal_embeddings = F.normalize(multimodal_embeddings, p=2, dim=-1)  # Normalize embeddings
            return multimodal_embeddings

        else:
            raise ValueError("Invalid embedding_type. Choose 'text', 'image', or 'multimodal'.")


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as pl
from transformers import AutoModel, ViTModel, get_linear_schedule_with_warmup
from transformers import BertConfig, BertModel

class QFormer(nn.Module):
    def __init__(
        self,
        num_queries=32,
        d_model=768,
        num_attention_heads=12,
        num_hidden_layers=12,  # Set to 6 as per CLIPITMModel initialization
        intermediate_size=3072,
        cross_attention_frequency=1,  # Set to 1 to apply cross-attention in every layer
        dropout=0.1,
    ):
        super(QFormer, self).__init__()

        # Learnable query embeddings (similar to Blip-2's Q-Former)
        self.query_embeddings = nn.Parameter(torch.randn(1, num_queries, d_model))

        # Configuration for a Transformer with cross-attention
        config = BertConfig(
            hidden_size=d_model,
            num_attention_heads=num_attention_heads,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            hidden_act="gelu",
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=dropout,
            is_decoder=True,  # Enable cross-attention by setting is_decoder to True
            add_cross_attention=True,
        )

        # Initialize BertModel with the above configuration
        self.bert = BertModel(config)

        # Cross-attention frequency
        self.cross_attention_frequency = cross_attention_frequency

    def forward(self, image_embeddings):
        """
        image_embeddings: The output from the vision model (e.g., ViT).
        image_embeddings shape: (batch_size, seq_len_image, d_model)
        """
        batch_size, seq_len_image, d_model = image_embeddings.shape

        # Expand query embeddings to match the batch size
        query_embeddings = self.query_embeddings.expand(batch_size, -1, -1)  # [batch_size, num_queries, d_model]
        # print(f"Initial query_embeddings shape: {query_embeddings.shape}")  # Expected: [40, 32, 768]
        # print(f"Image embeddings shape: {image_embeddings.shape}")  # Expected: [40, 197, 768]

        # Create attention masks
        # Self-attention mask for queries (decoder input)
        attention_mask = torch.ones(batch_size, query_embeddings.size(1), device=query_embeddings.device)  # [batch_size, num_queries]
        # print(f"Attention mask shape: {attention_mask.shape}")  # Expected: [40, 32]

        # Encoder attention mask for image embeddings (encoder input)
        encoder_attention_mask = torch.ones(batch_size, seq_len_image, device=image_embeddings.device)  # [batch_size, seq_len_image]
        # print(f"Encoder attention mask shape: {encoder_attention_mask.shape}")  # Expected: [40, 197]

        # Get extended attention masks using BERT's utility function
        extended_attention_mask = self.bert.get_extended_attention_mask(
            attention_mask, attention_mask.shape, image_embeddings.device
        )  # Shape: [batch_size, 1, 1, num_queries]
        # print(f"Extended self-attention mask shape: {extended_attention_mask.shape}")  # Expected: [40, 1, 1, 32]

        # Repeat the self-attention mask for each attention head
        extended_attention_mask = extended_attention_mask.repeat(1, self.bert.config.num_attention_heads, 1, 1)  # [batch_size, num_heads, 1, num_queries]
        # print(f"Extended self-attention mask after repeat: {extended_attention_mask.shape}")  # Expected: [40, 12, 1, 32]

        # Create cross-attention mask: [batch_size, num_queries, seq_len_image]
        cross_attention_mask = torch.ones(batch_size, query_embeddings.size(1), seq_len_image, device=image_embeddings.device)  # [40, 32, 197]
        # print(f"Cross-Attention mask shape: {cross_attention_mask.shape}")  # Expected: [40, 32, 197]

        # Get extended cross-attention mask
        encoder_extended_cross_attention_mask = self.bert.get_extended_attention_mask(
            cross_attention_mask, cross_attention_mask.shape, image_embeddings.device
        )  # Shape: [batch_size, 1, num_queries, seq_len_image]
        # print(f"Encoder extended cross-attention mask shape before repeat: {encoder_extended_cross_attention_mask.shape}")  # Expected: [40, 1, 32, 197]

        # Repeat the cross-attention mask for each attention head
        encoder_extended_cross_attention_mask = encoder_extended_cross_attention_mask.repeat(1, self.bert.config.num_attention_heads, 1, 1)  # [batch_size, num_heads, num_queries, seq_len_image]
        # print(f"Encoder extended cross-attention mask shape after repeat: {encoder_extended_cross_attention_mask.shape}")  # Expected: [40, 12, 32, 197]

        # Initialize hidden states
        hidden_states = query_embeddings  # [batch_size, num_queries, d_model]
        # print(f"Hidden states shape: {hidden_states.shape}")  # Expected: [40, 32, 768]

        # Iterate through each BERT layer
        for i, layer_module in enumerate(self.bert.encoder.layer):
            if i % self.cross_attention_frequency == 0:
                # print(f"Layer {i}: Applying Cross-Attention between query_embeddings and image_embeddings.")
                # Apply cross-attention
                outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,  # Self-attention mask for queries
                    encoder_hidden_states=image_embeddings,
                    encoder_attention_mask=encoder_extended_cross_attention_mask,  # Cross-attention mask for image embeddings
                )
            else:
                # print(f"Layer {i}: Applying Self-Attention only on query_embeddings.")
                # Apply only self-attention by setting encoder_hidden_states=None
                outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,  # Self-attention mask for queries
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                )

            # Update hidden states
            hidden_states = outputs[0]  # [batch_size, num_queries, d_model]
            # print(f"After Layer {i}, query_embeddings shape: {hidden_states.shape}")  # Expected: [40, 32, 768]

        return hidden_states  # Final query embeddings
        

class CLIPITMModel(pl.LightningModule):
    def __init__(
        self,
        weight_decay,
        eps,
        warmup_steps,
        num_training_steps,
        text_model_name='johngiorgi/declutr-small',
        image_model_name='google/vit-base-patch16-224',
        learning_rate=0.00001,
        num_negatives=5,
        temperature=0.5,
        num_query_tokens=32,
        qformer_hidden_size=768,
        cross_attention_frequency=1,  # Set to 1 to align with CustomQFormer
        **kwargs
    ):
        super(CLIPITMModel, self).__init__()
        # Text Model
        self.text_model = AutoModel.from_pretrained(text_model_name)
        # Vision Model
        self.image_model = ViTModel.from_pretrained(image_model_name)
        # Custom Q-Former with cross-attention in every layer
        self.qformer = QFormer(
            num_queries=num_query_tokens,
            d_model=qformer_hidden_size,
            num_attention_heads=12,
            num_hidden_layers=12,
            cross_attention_frequency=cross_attention_frequency,
        )

        # Loss components
        self.itm_criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.num_negatives = num_negatives
        self.temperature = temperature
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps

    def forward(self, input_ids, attention_mask, pixel_values, neg_pixel_values=None):
        # Process text embeddings
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]  # [CLS] token
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
        # print(f"Text embeddings shape: {text_embeddings.shape}")

        # Process image embeddings
        image_outputs = self.image_model(pixel_values=pixel_values)
        image_embeddings = image_outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
        # print(f"Image embeddings shape: {image_embeddings.shape}")

        # Pass through the custom Q-Former
        query_embeddings = self.qformer(image_embeddings)
        query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
        # print(f"Query embeddings shape after QFormer: {query_embeddings.shape}")

        # Process negative image embeddings if provided
        neg_image_embeddings = None
        if neg_pixel_values is not None:
            neg_image_embeddings = self._process_negative_images(neg_pixel_values)
            # print(f"Negative image embeddings shape: {neg_image_embeddings.shape}")

        return text_embeddings, query_embeddings, neg_image_embeddings

    def _process_negative_images(self, neg_pixel_values):
        batch_size, num_negatives, _, _, _ = neg_pixel_values.shape
        neg_pixel_values = neg_pixel_values.view(-1, *neg_pixel_values.shape[2:])  # (batch_size * num_negatives, C, H, W)
        neg_image_outputs = self.image_model(pixel_values=neg_pixel_values)
        neg_image_embeddings = neg_image_outputs.last_hidden_state.mean(dim=1)  # (batch_size * num_negatives, d_model)
        neg_image_embeddings = F.normalize(neg_image_embeddings, p=2, dim=-1)
        neg_image_embeddings = neg_image_embeddings.view(batch_size, num_negatives, -1)  # (batch_size, num_negatives, d_model)
        return neg_image_embeddings

    def compute_clip_loss(self, pos_text_embeddings, query_embeddings, neg_image_embeddings):
        # Compute similarities
        # print(f"Positive text embeddings shape: {pos_text_embeddings.shape}")
        # print(f"Query embeddings shape: {query_embeddings.shape}")
        # print(f"Negative image embeddings shape: {neg_image_embeddings.shape}")
        
        # Positive similarities between text and positive image queries
        pos_sim = torch.einsum('bqd,bd->bq', query_embeddings, pos_text_embeddings) / self.temperature  # (batch_size, num_queries)
        pos_sim = pos_sim.max(dim=1, keepdim=True).values  # (batch_size, 1)

        # Negative similarities between text and negative images
        neg_sim = torch.einsum('bd,bnd->bn', pos_text_embeddings, neg_image_embeddings) / self.temperature  # (batch_size, num_negatives)

        # Combine logits
        logits = torch.cat([pos_sim, neg_sim], dim=1)  # (batch_size, 1 + num_negatives)
        # print(f"Logits shape (for CLIP loss): {logits.shape}")

        labels = torch.zeros(logits.size(0), dtype=torch.long).to(logits.device)  # (batch_size,)

        # Compute CLIP loss using cross-entropy
        clip_loss = self.itm_criterion(logits, labels)
        return clip_loss

    def compute_itm_loss(self, pos_text_embeddings, query_embeddings, neg_image_embeddings):
        # Compute positive scores
        pos_sim = torch.einsum('bqd,bd->bq', query_embeddings, pos_text_embeddings) / self.temperature  # (batch_size, num_queries)
        pos_scores = pos_sim.max(dim=1).values  # (batch_size,)

        # Compute negative scores
        neg_sim = torch.einsum('bd,bnd->bn', pos_text_embeddings, neg_image_embeddings) / self.temperature  # (batch_size, num_negatives)
        neg_scores = neg_sim.view(-1)  # (batch_size * num_negatives,)

        # Combine scores and labels
        scores = torch.cat([pos_scores, neg_scores], dim=0)  # (batch_size + batch_size * num_negatives,)
        labels = torch.cat([
            torch.ones(pos_scores.size(0), device=scores.device),
            torch.zeros(neg_scores.size(0), device=scores.device)
        ], dim=0)  # (batch_size + batch_size * num_negatives,)
        
        # print(f"Scores shape: {scores.shape}, Labels shape: {labels.shape}")

        # Compute binary cross-entropy loss with logits
        itm_loss = F.binary_cross_entropy_with_logits(scores, labels)
        return itm_loss

    def training_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings = self(
            batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'], neg_pixel_values=batch.get('neg_pixel_values')
        )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        total_loss = clip_loss + itm_loss
        # print(f"Training Step: CLIP Loss: {clip_loss.item()}, ITM Loss: {itm_loss.item()}, Total Loss: {total_loss.item()}")
        self.log('train_loss', total_loss)
        return total_loss

    def validation_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings = self(
            batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'], neg_pixel_values=batch.get('neg_pixel_values')
        )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        total_loss = clip_loss + itm_loss
        # print(f"Validation Step: CLIP Loss: {clip_loss.item()}, ITM Loss: {itm_loss.item()}, Total Loss: {total_loss.item()}")
        self.log('val_loss', total_loss)
        return total_loss

    def test_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings = self(
            batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'], neg_pixel_values=batch.get('neg_pixel_values')
        )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        total_loss = clip_loss + itm_loss
        # print(f"Validation Step: CLIP Loss: {clip_loss.item()}, ITM Loss: {itm_loss.item()}, Total Loss: {total_loss.item()}")
        self.log('test_loss', total_loss)
        return total_loss

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps
        )
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

    def on_validation_epoch_end(self):
        if 'val_loss' in self.trainer.callback_metrics:
            avg_val_loss = self.trainer.callback_metrics['val_loss'].mean()
            # print(f"[Validation End] Average validation loss: {avg_val_loss}")
            self.log('avg_val_loss', avg_val_loss)

    def on_test_epoch_end(self):
        if 'test_loss' in self.trainer.callback_metrics:
            avg_test_loss = self.trainer.callback_metrics['test_loss'].mean()
            # print(f"[Test End] Average test loss: {avg_test_loss}")
            self.log('avg_test_loss', avg_test_loss)
            
    def get_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None, embedding_type='multimodal'):
        """
        Generate text, image, or multimodal embeddings for inference.

        Args:
            input_ids (torch.Tensor, optional): Tokenized input text IDs.
            attention_mask (torch.Tensor, optional): Attention mask for the input text.
            pixel_values (torch.Tensor, optional): Preprocessed image tensor.
            embedding_type (str): 'text', 'image', or 'multimodal'.

        Returns:
            torch.Tensor: The requested embeddings.
        """
        if embedding_type == 'text':
            if input_ids is None or attention_mask is None:
                raise ValueError("input_ids and attention_mask are required for text embeddings.")
            text_embeddings = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
            return text_embeddings

        elif embedding_type == 'image':
            if pixel_values is None:
                raise ValueError("pixel_values are required for image embeddings.")
            image_embeddings = self.image_model(pixel_values=pixel_values).last_hidden_state
            query_embeddings = self.qformer(image_embeddings).mean(dim=1)
            query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
            return query_embeddings

        elif embedding_type == 'multimodal':
            if input_ids is None or attention_mask is None or pixel_values is None:
                raise ValueError("input_ids, attention_mask, and pixel_values are required for multimodal embeddings.")
            text_embeddings = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            image_embeddings = self.image_model(pixel_values=pixel_values).last_hidden_state
            query_embeddings = self.qformer(image_embeddings).mean(dim=1)
            # Combine embeddings (e.g., concatenate)
            
            multimodal_embeddings = (text_embeddings + query_embeddings)/2
            multimodal_embeddings = F.normalize(multimodal_embeddings, p=2, dim=-1)
            return multimodal_embeddings

        else:
            raise ValueError("Invalid embedding_type. Choose 'text', 'image', or 'multimodal'.")

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning.pytorch as pl

from transformers import AutoModel, ViTModel, get_linear_schedule_with_warmup
from transformers import BertConfig, BertModel
from transformers import T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput


class QFormer(nn.Module):
    def __init__(
        self,
        num_queries=32,
        d_model=768,
        num_attention_heads=12,
        num_hidden_layers=12,  # Set to 6 as per CLIPITMModel initialization
        intermediate_size=3072,
        cross_attention_frequency=1,  # Set to 1 to apply cross-attention in every layer
        dropout=0.1,
    ):
        super(QFormer, self).__init__()

        # Learnable query embeddings (similar to Blip-2's Q-Former)
        self.query_embeddings = nn.Parameter(torch.randn(1, num_queries, d_model))

        # Configuration for a Transformer with cross-attention
        config = BertConfig(
            hidden_size=d_model,
            num_attention_heads=num_attention_heads,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            hidden_act="gelu",
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=dropout,
            is_decoder=True,  # Enable cross-attention by setting is_decoder to True
            add_cross_attention=True,
        )

        # Initialize BertModel with the above configuration
        self.bert = BertModel(config)

        # Cross-attention frequency
        self.cross_attention_frequency = cross_attention_frequency

    def forward(self, image_embeddings):
        """
        image_embeddings: The output from the vision model (e.g., ViT).
        image_embeddings shape: (batch_size, seq_len_image, d_model)
        """
        batch_size, seq_len_image, d_model = image_embeddings.shape

        # Expand query embeddings to match the batch size
        query_embeddings = self.query_embeddings.expand(batch_size, -1, -1)  # [batch_size, num_queries, d_model]
        # print(f"Initial query_embeddings shape: {query_embeddings.shape}")  # Expected: [40, 32, 768]
        # print(f"Image embeddings shape: {image_embeddings.shape}")  # Expected: [40, 197, 768]

        # Create attention masks
        # Self-attention mask for queries (decoder input)
        attention_mask = torch.ones(batch_size, query_embeddings.size(1), device=query_embeddings.device)  # [batch_size, num_queries]
        # print(f"Attention mask shape: {attention_mask.shape}")  # Expected: [40, 32]

        # Encoder attention mask for image embeddings (encoder input)
        encoder_attention_mask = torch.ones(batch_size, seq_len_image, device=image_embeddings.device)  # [batch_size, seq_len_image]
        # print(f"Encoder attention mask shape: {encoder_attention_mask.shape}")  # Expected: [40, 197]

        # Get extended attention masks using BERT's utility function
        extended_attention_mask = self.bert.get_extended_attention_mask(
            attention_mask, attention_mask.shape, image_embeddings.device
        )  # Shape: [batch_size, 1, 1, num_queries]
        # print(f"Extended self-attention mask shape: {extended_attention_mask.shape}")  # Expected: [40, 1, 1, 32]

        # Repeat the self-attention mask for each attention head
        extended_attention_mask = extended_attention_mask.repeat(1, self.bert.config.num_attention_heads, 1, 1)  # [batch_size, num_heads, 1, num_queries]
        # print(f"Extended self-attention mask after repeat: {extended_attention_mask.shape}")  # Expected: [40, 12, 1, 32]

        # Create cross-attention mask: [batch_size, num_queries, seq_len_image]
        cross_attention_mask = torch.ones(batch_size, query_embeddings.size(1), seq_len_image, device=image_embeddings.device)  # [40, 32, 197]
        # print(f"Cross-Attention mask shape: {cross_attention_mask.shape}")  # Expected: [40, 32, 197]

        # Get extended cross-attention mask
        encoder_extended_cross_attention_mask = self.bert.get_extended_attention_mask(
            cross_attention_mask, cross_attention_mask.shape, image_embeddings.device
        )  # Shape: [batch_size, 1, num_queries, seq_len_image]
        # print(f"Encoder extended cross-attention mask shape before repeat: {encoder_extended_cross_attention_mask.shape}")  # Expected: [40, 1, 32, 197]

        # Repeat the cross-attention mask for each attention head
        encoder_extended_cross_attention_mask = encoder_extended_cross_attention_mask.repeat(1, self.bert.config.num_attention_heads, 1, 1)  # [batch_size, num_heads, num_queries, seq_len_image]
        # print(f"Encoder extended cross-attention mask shape after repeat: {encoder_extended_cross_attention_mask.shape}")  # Expected: [40, 12, 32, 197]

        # Initialize hidden states
        hidden_states = query_embeddings  # [batch_size, num_queries, d_model]
        # print(f"Hidden states shape: {hidden_states.shape}")  # Expected: [40, 32, 768]

        # Iterate through each BERT layer
        for i, layer_module in enumerate(self.bert.encoder.layer):
            if i % self.cross_attention_frequency == 0:
                # print(f"Layer {i}: Applying Cross-Attention between query_embeddings and image_embeddings.")
                # Apply cross-attention
                outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,  # Self-attention mask for queries
                    encoder_hidden_states=image_embeddings,
                    encoder_attention_mask=encoder_extended_cross_attention_mask,  # Cross-attention mask for image embeddings
                )
            else:
                # print(f"Layer {i}: Applying Self-Attention only on query_embeddings.")
                # Apply only self-attention by setting encoder_hidden_states=None
                outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,  # Self-attention mask for queries
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                )

            # Update hidden states
            hidden_states = outputs[0]  # [batch_size, num_queries, d_model]
            # print(f"After Layer {i}, query_embeddings shape: {hidden_states.shape}")  # Expected: [40, 32, 768]

        return hidden_states  # Final query embeddings

class BLIP2Model(pl.LightningModule):
    def __init__(
        self,
        weight_decay,
        eps,
        warmup_steps,
        num_training_steps,
        text_model_name='johngiorgi/declutr-small',
        image_model_name='google/vit-base-patch16-224',
        t5_model_name='google/flan-t5-small',
        learning_rate=0.00001,
        num_negatives=5,
        temperature=0.5,
        num_query_tokens=32,
        qformer_hidden_size=768,
        cross_attention_frequency=1,  # Set to 1 to align with CustomQFormer
        **kwargs
    ):
        super(BLIP2Model, self).__init__()
        # Text Model
        self.text_model = AutoModel.from_pretrained(text_model_name)
        # Vision Model
        self.image_model = ViTModel.from_pretrained(image_model_name)
        # Custom Q-Former with cross-attention in every layer
        self.qformer = QFormer(
            num_queries=num_query_tokens,
            d_model=qformer_hidden_size,
            num_attention_heads=12,
            num_hidden_layers=12,
            cross_attention_frequency=cross_attention_frequency,
        )

        # Loss components
        self.itm_criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.num_negatives = num_negatives
        self.temperature = temperature
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps

        # Initialize T5 model
        self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)
        # Projects from 768 (Q-Former hidden size) to 512 (T5 hidden size).
        self.query_proj_t5 = nn.Linear(qformer_hidden_size, self.t5_model.config.d_model)


    def forward(self, input_ids, attention_mask, pixel_values, pos_input_ids_t5, pos_attention_mask_t5, neg_pixel_values=None):
        # Process text embeddings
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]  # [CLS] token
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
        # print(f"Text embeddings shape: {text_embeddings.shape}")

        # Process image embeddings
        image_outputs = self.image_model(pixel_values=pixel_values)
        image_embeddings = image_outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
        # print(f"Image embeddings shape: {image_embeddings.shape}")

        # Pass through the custom Q-Former
        query_embeddings = self.qformer(image_embeddings)
        query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
        # print(f"Query embeddings shape after QFormer: {query_embeddings.shape}")

        # Process negative image embeddings if provided
        neg_image_embeddings = None
        if neg_pixel_values is not None:
            neg_image_embeddings = self._process_negative_images(neg_pixel_values)
            # print(f"Negative image embeddings shape: {neg_image_embeddings.shape}")

        # Use query_embeddings as encoder outputs for T5
        # T5 expects encoder outputs in the shape (batch_size, seq_len, hidden_size)
        # Our query_embeddings are already in this shape (batch_size, num_queries, hidden_size)

        # Project unnormalized query_embeddings to match T5 hidden size
        projected_query_embeddings = self.query_proj_t5(query_embeddings)
        # Create T5 encoder outputs
        encoder_outputs = BaseModelOutput(last_hidden_state=projected_query_embeddings)
        # Create encoder attention mask
        encoder_attention_mask = torch.ones(
            projected_query_embeddings.size()[:-1],
            dtype=torch.long,
            device=projected_query_embeddings.device
            )
        # Pass through T5 model
        t5_outputs = self.t5_model(
            attention_mask=encoder_attention_mask,
            encoder_outputs=encoder_outputs,
            labels=pos_input_ids_t5,
            return_dict=True,
        )
        text_generation_loss = t5_outputs.loss

        return text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss

    def _process_negative_images(self, neg_pixel_values):
        batch_size, num_negatives, _, _, _ = neg_pixel_values.shape
        neg_pixel_values = neg_pixel_values.view(-1, *neg_pixel_values.shape[2:])  # (batch_size * num_negatives, C, H, W)
        neg_image_outputs = self.image_model(pixel_values=neg_pixel_values)
        neg_image_embeddings = neg_image_outputs.last_hidden_state.mean(dim=1)  # (batch_size * num_negatives, d_model)
        neg_image_embeddings = F.normalize(neg_image_embeddings, p=2, dim=-1)
        neg_image_embeddings = neg_image_embeddings.view(batch_size, num_negatives, -1)  # (batch_size, num_negatives, d_model)
        return neg_image_embeddings

    def compute_clip_loss(self, pos_text_embeddings, query_embeddings, neg_image_embeddings):
        # Compute similarities
        # print(f"Positive text embeddings shape: {pos_text_embeddings.shape}")
        # print(f"Query embeddings shape: {query_embeddings.shape}")
        # print(f"Negative image embeddings shape: {neg_image_embeddings.shape}")

        # Positive similarities between text and positive image queries
        pos_sim = torch.einsum('bqd,bd->bq', query_embeddings, pos_text_embeddings) / self.temperature  # (batch_size, num_queries)
        pos_sim = pos_sim.max(dim=1, keepdim=True).values  # (batch_size, 1)

        # Negative similarities between text and negative images
        neg_sim = torch.einsum('bd,bnd->bn', pos_text_embeddings, neg_image_embeddings) / self.temperature  # (batch_size, num_negatives)

        # Combine logits
        logits = torch.cat([pos_sim, neg_sim], dim=1)  # (batch_size, 1 + num_negatives)
        # print(f"Logits shape (for CLIP loss): {logits.shape}")

        labels = torch.zeros(logits.size(0), dtype=torch.long).to(logits.device)  # (batch_size,)

        # Compute CLIP loss using cross-entropy
        clip_loss = self.itm_criterion(logits, labels)
        return clip_loss

    def compute_itm_loss(self, pos_text_embeddings, query_embeddings, neg_image_embeddings):
        # Compute positive scores
        pos_sim = torch.einsum('bqd,bd->bq', query_embeddings, pos_text_embeddings) / self.temperature  # (batch_size, num_queries)
        pos_scores = pos_sim.max(dim=1).values  # (batch_size,)

        # Compute negative scores
        neg_sim = torch.einsum('bd,bnd->bn', pos_text_embeddings, neg_image_embeddings) / self.temperature  # (batch_size, num_negatives)
        neg_scores = neg_sim.view(-1)  # (batch_size * num_negatives,)

        # Combine scores and labels
        scores = torch.cat([pos_scores, neg_scores], dim=0)  # (batch_size + batch_size * num_negatives,)
        labels = torch.cat([
            torch.ones(pos_scores.size(0), device=scores.device),
            torch.zeros(neg_scores.size(0), device=scores.device)
        ], dim=0)  # (batch_size + batch_size * num_negatives,)

        # print(f"Scores shape: {scores.shape}, Labels shape: {labels.shape}")

        # Compute binary cross-entropy loss with logits
        itm_loss = F.binary_cross_entropy_with_logits(scores, labels)
        return itm_loss

    def training_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss = self(
            batch['pos_input_ids'],
            batch['pos_attention_mask'],
            batch['pos_pixel_values'],
            batch['pos_input_ids_t5'],
            batch['pos_attention_mask_t5'],
            neg_pixel_values=batch.get('neg_pixel_values')
            )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        total_loss = clip_loss + itm_loss + text_generation_loss
        # print(f"Training Step: CLIP Loss: {clip_loss.item()}, ITM Loss: {itm_loss.item()}, Total Loss: {total_loss.item()}")
        self.log('clip_loss', clip_loss)
        self.log('itm_loss', itm_loss)
        self.log('text_generation_loss', text_generation_loss)
        self.log('train_loss', total_loss)
        return total_loss

    def validation_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss = self(
            batch['pos_input_ids'],
            batch['pos_attention_mask'],
            batch['pos_pixel_values'],
            batch['pos_input_ids_t5'],
            batch['pos_attention_mask_t5'],
            neg_pixel_values=batch.get('neg_pixel_values')
            )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        total_loss = clip_loss + itm_loss + text_generation_loss
        # print(f"Training Step: CLIP Loss: {clip_loss.item()}, ITM Loss: {itm_loss.item()}, Total Loss: {total_loss.item()}")
        self.log('clip_loss', clip_loss)
        self.log('itm_loss', itm_loss)
        self.log('text_generation_loss', text_generation_loss)
        self.log('val_loss', total_loss)
        return total_loss

    def test_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss = self(
            batch['pos_input_ids'],
            batch['pos_attention_mask'],
            batch['pos_pixel_values'],
            batch['pos_input_ids_t5'],
            batch['pos_attention_mask_t5'],
            neg_pixel_values=batch.get('neg_pixel_values')
            )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        total_loss = clip_loss + itm_loss + text_generation_loss
        # print(f"Training Step: CLIP Loss: {clip_loss.item()}, ITM Loss: {itm_loss.item()}, Total Loss: {total_loss.item()}")
        self.log('clip_loss', clip_loss)
        self.log('itm_loss', itm_loss)
        self.log('text_generation_loss', text_generation_loss)
        self.log('test_loss', total_loss)
        return total_loss

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps
        )
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

    def on_validation_epoch_end(self):
        if 'val_loss' in self.trainer.callback_metrics:
            avg_val_loss = self.trainer.callback_metrics['val_loss'].mean()
            # print(f"[Validation End] Average validation loss: {avg_val_loss}")
            self.log('avg_val_loss', avg_val_loss)

    def on_test_epoch_end(self):
        if 'test_loss' in self.trainer.callback_metrics:
            avg_test_loss = self.trainer.callback_metrics['test_loss'].mean()
            # print(f"[Test End] Average test loss: {avg_test_loss}")
            self.log('avg_test_loss', avg_test_loss)


    def get_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None, embedding_type='multimodal'):
        """
        Generate text, image, or multimodal embeddings for inference.

        Args:
            input_ids (torch.Tensor, optional): Tokenized input text IDs.
            attention_mask (torch.Tensor, optional): Attention mask for the input text.
            pixel_values (torch.Tensor, optional): Preprocessed image tensor.
            embedding_type (str): 'text', 'image', or 'multimodal'.

        Returns:
            torch.Tensor: The requested embeddings.
        """
        if embedding_type == 'text':
            if input_ids is None or attention_mask is None:
                raise ValueError("input_ids and attention_mask are required for text embeddings.")
            text_embeddings = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
            return text_embeddings

        elif embedding_type == 'image':
            if pixel_values is None:
                raise ValueError("pixel_values are required for image embeddings.")
            image_embeddings = self.image_model(pixel_values=pixel_values).last_hidden_state
            query_embeddings = self.qformer(image_embeddings).mean(dim=1)
            query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
            return query_embeddings

        elif embedding_type == 'multimodal':
            if input_ids is None or attention_mask is None or pixel_values is None:
                raise ValueError("input_ids, attention_mask, and pixel_values are required for multimodal embeddings.")
            text_embeddings = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            image_embeddings = self.image_model(pixel_values=pixel_values).last_hidden_state
            query_embeddings = self.qformer(image_embeddings).mean(dim=1)
            # Combine embeddings (e.g., concatenate)
            
            multimodal_embeddings = (text_embeddings + query_embeddings)/2
            multimodal_embeddings = F.normalize(multimodal_embeddings, p=2, dim=-1)
            return multimodal_embeddings

        else:
            raise ValueError("Invalid embedding_type. Choose 'text', 'image', or 'multimodal'.")

In [4]:

from transformers import AutoModel, ViTModel, get_linear_schedule_with_warmup
from transformers import BertConfig, BertModel
from transformers import T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput

# import bitsandbytes as bnb
# from deepspeed.ops.adam import DeepSpeedCPUAdam

class QFormer(nn.Module):
    def __init__(
        self,
        num_queries=32,
        d_model=768,
        num_attention_heads=12,
        num_hidden_layers=12,  # Set to 6 as per CLIPITMModel initialization
        intermediate_size=3072,
        cross_attention_frequency=1,  # Set to 1 to apply cross-attention in every layer
        dropout=0.1,
    ):
        super(QFormer, self).__init__()

        # Learnable query embeddings (similar to Blip-2's Q-Former)
        self.query_embeddings = nn.Parameter(torch.randn(1, num_queries, d_model))

        # Configuration for a Transformer with cross-attention
        config = BertConfig(
            hidden_size=d_model,
            num_attention_heads=num_attention_heads,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            hidden_act="gelu",
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=dropout,
            is_decoder=True,  # Enable cross-attention by setting is_decoder to True
            add_cross_attention=True,
        )

        # Initialize BertModel with the above configuration
        self.bert = BertModel(config)

        # Cross-attention frequency
        self.cross_attention_frequency = cross_attention_frequency

    def forward(self, image_embeddings):
        """
        image_embeddings: The output from the vision model (e.g., ViT).
        image_embeddings shape: (batch_size, seq_len_image, d_model)
        """
        batch_size, seq_len_image, d_model = image_embeddings.shape

        # Expand query embeddings to match the batch size
        query_embeddings = self.query_embeddings.expand(batch_size, -1, -1)  # [batch_size, num_queries, d_model]
        # print(f"Initial query_embeddings shape: {query_embeddings.shape}")  # Expected: [40, 32, 768]
        # print(f"Image embeddings shape: {image_embeddings.shape}")  # Expected: [40, 197, 768]

        # Create attention masks
        # Self-attention mask for queries (decoder input)
        attention_mask = torch.ones(batch_size, query_embeddings.size(1), device=query_embeddings.device)  # [batch_size, num_queries]
        # print(f"Attention mask shape: {attention_mask.shape}")  # Expected: [40, 32]

        # Encoder attention mask for image embeddings (encoder input)
        encoder_attention_mask = torch.ones(batch_size, seq_len_image, device=image_embeddings.device)  # [batch_size, seq_len_image]
        # print(f"Encoder attention mask shape: {encoder_attention_mask.shape}")  # Expected: [40, 197]

        # Get extended attention masks using BERT's utility function
        extended_attention_mask = self.bert.get_extended_attention_mask(
            attention_mask, attention_mask.shape, image_embeddings.device
        )  # Shape: [batch_size, 1, 1, num_queries]
        # print(f"Extended self-attention mask shape: {extended_attention_mask.shape}")  # Expected: [40, 1, 1, 32]

        # Repeat the self-attention mask for each attention head
        extended_attention_mask = extended_attention_mask.repeat(1, self.bert.config.num_attention_heads, 1, 1)  # [batch_size, num_heads, 1, num_queries]
        # print(f"Extended self-attention mask after repeat: {extended_attention_mask.shape}")  # Expected: [40, 12, 1, 32]

        # Create cross-attention mask: [batch_size, num_queries, seq_len_image]
        cross_attention_mask = torch.ones(batch_size, query_embeddings.size(1), seq_len_image, device=image_embeddings.device)  # [40, 32, 197]
        # print(f"Cross-Attention mask shape: {cross_attention_mask.shape}")  # Expected: [40, 32, 197]

        # Get extended cross-attention mask
        encoder_extended_cross_attention_mask = self.bert.get_extended_attention_mask(
            cross_attention_mask, cross_attention_mask.shape, image_embeddings.device
        )  # Shape: [batch_size, 1, num_queries, seq_len_image]
        # print(f"Encoder extended cross-attention mask shape before repeat: {encoder_extended_cross_attention_mask.shape}")  # Expected: [40, 1, 32, 197]

        # Repeat the cross-attention mask for each attention head
        encoder_extended_cross_attention_mask = encoder_extended_cross_attention_mask.repeat(1, self.bert.config.num_attention_heads, 1, 1)  # [batch_size, num_heads, num_queries, seq_len_image]
        # print(f"Encoder extended cross-attention mask shape after repeat: {encoder_extended_cross_attention_mask.shape}")  # Expected: [40, 12, 32, 197]

        # Initialize hidden states
        hidden_states = query_embeddings  # [batch_size, num_queries, d_model]
        # print(f"Hidden states shape: {hidden_states.shape}")  # Expected: [40, 32, 768]

        # Iterate through each BERT layer
        for i, layer_module in enumerate(self.bert.encoder.layer):
            if i % self.cross_attention_frequency == 0:
                # print(f"Layer {i}: Applying Cross-Attention between query_embeddings and image_embeddings.")
                # Apply cross-attention
                outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,  # Self-attention mask for queries
                    encoder_hidden_states=image_embeddings,
                    encoder_attention_mask=encoder_extended_cross_attention_mask,  # Cross-attention mask for image embeddings
                )
            else:
                # print(f"Layer {i}: Applying Self-Attention only on query_embeddings.")
                # Apply only self-attention by setting encoder_hidden_states=None
                outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,  # Self-attention mask for queries
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                )

            # Update hidden states
            hidden_states = outputs[0]  # [batch_size, num_queries, d_model]
            # print(f"After Layer {i}, query_embeddings shape: {hidden_states.shape}")  # Expected: [40, 32, 768]

        return hidden_states  # Final query embeddings

class BLIP2ConditionalModel(pl.LightningModule):
    def __init__(
        self,
        weight_decay,
        eps,
        warmup_steps,
        num_training_steps,
        text_model_name='johngiorgi/declutr-small',
        image_model_name='google/vit-base-patch16-224',
        t5_model_name='google/flan-t5-small',
        learning_rate=0.00001,
        num_negatives=5,
        temperature=0.5,
        num_query_tokens=32,
        qformer_hidden_size=768,
        cross_attention_frequency=1,  # Set to 1 to align with CustomQFormer
        **kwargs
    ):
        super(BLIP2ConditionalModel, self).__init__()
        # Text Model
        self.text_model = AutoModel.from_pretrained(text_model_name)
        # Vision Model
        self.image_model = ViTModel.from_pretrained(image_model_name)
        # Custom Q-Former with cross-attention in every layer
        self.qformer = QFormer(
            num_queries=num_query_tokens,
            d_model=qformer_hidden_size,
            num_attention_heads=12,
            num_hidden_layers=12,
            cross_attention_frequency=cross_attention_frequency,
        )

        # Loss components
        self.itm_criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        self.num_negatives = num_negatives
        self.temperature = temperature
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps

        # Initialize T5 model
        self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)

        # Ensure that pad_token_id is set
        if self.t5_model.config.pad_token_id is None:
            self.t5_model.config.pad_token_id = self.t5_model.config.eos_token_id

        if self.t5_model.config.decoder_start_token_id is None:
            self.t5_model.config.decoder_start_token_id = self.t5_model.config.pad_token_id


        # Projects from 768 (Q-Former hidden size) to 512 (T5 hidden size).
        self.query_proj_t5 = nn.Linear(qformer_hidden_size, self.t5_model.config.d_model)


    def forward(self, input_ids, attention_mask, pixel_values, conditional_input_ids, conditional_attention_mask, target_input_ids,
                target_attention_mask, neg_pixel_values=None):
        # Process text embeddings for CLIP and ITM losses (full text)
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]  # [CLS] token
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        # Process image embeddings
        image_outputs = self.image_model(pixel_values=pixel_values)
        image_embeddings = image_outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)

        # Pass through the custom Q-Former
        query_embeddings = self.qformer(image_embeddings)
        query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)

        # Process negative image embeddings if provided
        neg_image_embeddings = None
        if neg_pixel_values is not None:
            neg_image_embeddings = self._process_negative_images(neg_pixel_values)

        # Text Generation Loss
        # Encode conditional text using T5 encoder
        conditional_text_outputs = self.t5_model.encoder(
            input_ids=conditional_input_ids,
            attention_mask=conditional_attention_mask,
            return_dict=True,
        )
        conditional_text_embeddings = conditional_text_outputs.last_hidden_state  # (batch_size, seq_len_conditional, hidden_size)

        # Project query embeddings to match T5 hidden size
        projected_query_embeddings = self.query_proj_t5(query_embeddings)

        # Combine conditional text embeddings and projected query embeddings
        combined_encoder_embeddings = torch.cat([conditional_text_embeddings, projected_query_embeddings], dim=1)  # Concatenate along sequence length

        # Create encoder attention mask
        projected_query_attention_mask = torch.ones(
            projected_query_embeddings.size()[:-1],
            dtype=torch.long,
            device=self.device
        )
        combined_attention_mask = torch.cat([conditional_attention_mask, projected_query_attention_mask], dim=1)

        # Prepare labels for T5 (target text)
        labels = target_input_ids.clone()
        labels[labels == self.t5_model.config.pad_token_id] = -100  # Mask padding tokens in loss computation

        # Pass through T5 model
        t5_outputs = self.t5_model(
            encoder_outputs=BaseModelOutput(last_hidden_state=combined_encoder_embeddings),
            attention_mask=combined_attention_mask,
            labels=labels,
            return_dict=True,
        )
        text_generation_loss = t5_outputs.loss

        return text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss


    def _process_negative_images(self, neg_pixel_values):
        batch_size, num_negatives, _, _, _ = neg_pixel_values.shape
        neg_pixel_values = neg_pixel_values.view(-1, *neg_pixel_values.shape[2:])  # (batch_size * num_negatives, C, H, W)
        neg_image_outputs = self.image_model(pixel_values=neg_pixel_values)
        neg_image_embeddings = neg_image_outputs.last_hidden_state.mean(dim=1)  # (batch_size * num_negatives, d_model)
        neg_image_embeddings = F.normalize(neg_image_embeddings, p=2, dim=-1)
        neg_image_embeddings = neg_image_embeddings.view(batch_size, num_negatives, -1)  # (batch_size, num_negatives, d_model)
        return neg_image_embeddings

    def compute_clip_loss(self, pos_text_embeddings, query_embeddings, neg_image_embeddings):
        # Compute similarities
        # print(f"Positive text embeddings shape: {pos_text_embeddings.shape}")
        # print(f"Query embeddings shape: {query_embeddings.shape}")
        # print(f"Negative image embeddings shape: {neg_image_embeddings.shape}")
        
        # Positive similarities between text and positive image queries
        pos_sim = torch.einsum('bqd,bd->bq', query_embeddings, pos_text_embeddings) / self.temperature  # (batch_size, num_queries)
        pos_sim = pos_sim.max(dim=1, keepdim=True).values  # (batch_size, 1)

        # Negative similarities between text and negative images
        neg_sim = torch.einsum('bd,bnd->bn', pos_text_embeddings, neg_image_embeddings) / self.temperature  # (batch_size, num_negatives)

        # Combine logits
        logits = torch.cat([pos_sim, neg_sim], dim=1)  # (batch_size, 1 + num_negatives)
        # print(f"Logits shape (for CLIP loss): {logits.shape}")

        labels = torch.zeros(logits.size(0), dtype=torch.long).to(logits.device)  # (batch_size,)

        # Compute CLIP loss using cross-entropy
        clip_loss = self.itm_criterion(logits, labels)
        return clip_loss

    def compute_itm_loss(self, pos_text_embeddings, query_embeddings, neg_image_embeddings):
        # Compute positive scores
        pos_sim = torch.einsum('bqd,bd->bq', query_embeddings, pos_text_embeddings) / self.temperature  # (batch_size, num_queries)
        pos_scores = pos_sim.max(dim=1).values  # (batch_size,)

        # Compute negative scores
        neg_sim = torch.einsum('bd,bnd->bn', pos_text_embeddings, neg_image_embeddings) / self.temperature  # (batch_size, num_negatives)
        neg_scores = neg_sim.view(-1)  # (batch_size * num_negatives,)

        # Combine scores and labels
        scores = torch.cat([pos_scores, neg_scores], dim=0)  # (batch_size + batch_size * num_negatives,)
        labels = torch.cat([
            torch.ones(pos_scores.size(0), device=scores.device),
            torch.zeros(neg_scores.size(0), device=scores.device)
        ], dim=0)  # (batch_size + batch_size * num_negatives,)
        
        # print(f"Scores shape: {scores.shape}, Labels shape: {labels.shape}")

        # Compute binary cross-entropy loss with logits
        itm_loss = F.binary_cross_entropy_with_logits(scores, labels)
        return itm_loss


    def training_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss = self(
            input_ids=batch['pos_input_ids'],
            attention_mask=batch['pos_attention_mask'],
            pixel_values=batch['pos_pixel_values'],
            conditional_input_ids=batch['conditional_input_ids'],
            conditional_attention_mask=batch['conditional_attention_mask'],
            target_input_ids=batch['target_input_ids'],
            target_attention_mask=batch['target_attention_mask'],
            neg_pixel_values=batch.get('neg_pixel_values')
        )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        train_loss = clip_loss + itm_loss + text_generation_loss
        self.log('train_clip_loss', clip_loss)
        self.log('train_itm_loss', itm_loss)
        self.log('train_text_generation_loss', text_generation_loss)
        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss = self(
            input_ids=batch['pos_input_ids'],
            attention_mask=batch['pos_attention_mask'],
            pixel_values=batch['pos_pixel_values'],
            conditional_input_ids=batch['conditional_input_ids'],
            conditional_attention_mask=batch['conditional_attention_mask'],
            target_input_ids=batch['target_input_ids'],
            target_attention_mask=batch['target_attention_mask'],
            neg_pixel_values=batch.get('neg_pixel_values')
        )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        val_loss = clip_loss + itm_loss + text_generation_loss
        self.log('val_clip_loss', clip_loss)
        self.log('val_itm_loss', itm_loss)
        self.log('val_text_generation_loss', text_generation_loss)
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        pos_text_embeddings, query_embeddings, neg_image_embeddings, text_generation_loss = self(
            input_ids=batch['pos_input_ids'],
            attention_mask=batch['pos_attention_mask'],
            pixel_values=batch['pos_pixel_values'],
            conditional_input_ids=batch['conditional_input_ids'],
            conditional_attention_mask=batch['conditional_attention_mask'],
            target_input_ids=batch['target_input_ids'],
            target_attention_mask=batch['target_attention_mask'],
            neg_pixel_values=batch.get('neg_pixel_values')
        )

        clip_loss = self.compute_clip_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)
        itm_loss = self.compute_itm_loss(pos_text_embeddings, query_embeddings, neg_image_embeddings)

        test_loss = clip_loss + itm_loss + text_generation_loss
        self.log('test_clip_loss', clip_loss)
        self.log('test_itm_loss', itm_loss)
        self.log('test_text_generation_loss', text_generation_loss)
        self.log('test_loss', test_loss)
        return test_loss


    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps
        )
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]

    def on_validation_epoch_end(self):
        if 'val_loss' in self.trainer.callback_metrics:
            avg_val_loss = self.trainer.callback_metrics['val_loss'].mean()
            # print(f"[Validation End] Average validation loss: {avg_val_loss}")
            self.log('avg_val_loss', avg_val_loss)

    def on_test_epoch_end(self):
        if 'test_loss' in self.trainer.callback_metrics:
            avg_test_loss = self.trainer.callback_metrics['test_loss'].mean()
            # print(f"[Test End] Average test loss: {avg_test_loss}")
            self.log('avg_test_loss', avg_test_loss)
            
    # Add the generate_embeddings function to generate text and image embeddings at inference
    def get_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None, embedding_type='multimodal'):
        """
        Generate text, image, or multimodal embeddings for inference.

        Args:
            input_ids (torch.Tensor, optional): Tokenized input text IDs.
            attention_mask (torch.Tensor, optional): Attention mask for the input text.
            pixel_values (torch.Tensor, optional): Preprocessed image tensor.
            embedding_type (str): 'text', 'image', or 'multimodal'.

        Returns:
            torch.Tensor: The requested embeddings.
        """
        if embedding_type == 'text':
            if input_ids is None or attention_mask is None:
                raise ValueError("input_ids and attention_mask are required for text embeddings.")
            text_embeddings = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
            return text_embeddings

        elif embedding_type == 'image':
            if pixel_values is None:
                raise ValueError("pixel_values are required for image embeddings.")
            image_embeddings = self.image_model(pixel_values=pixel_values).last_hidden_state
            query_embeddings = self.qformer(image_embeddings).mean(dim=1)
            query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
            return query_embeddings

        elif embedding_type == 'multimodal':
            if input_ids is None or attention_mask is None or pixel_values is None:
                raise ValueError("input_ids, attention_mask, and pixel_values are required for multimodal embeddings.")
            text_embeddings = self.text_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            image_embeddings = self.image_model(pixel_values=pixel_values).last_hidden_state
            query_embeddings = self.qformer(image_embeddings).mean(dim=1)
            # Combine embeddings (e.g., concatenate)
            
            multimodal_embeddings = (text_embeddings + query_embeddings)/2
            multimodal_embeddings = F.normalize(multimodal_embeddings, p=2, dim=-1)
            return multimodal_embeddings

        else:
            raise ValueError("Invalid embedding_type. Choose 'text', 'image', or 'multimodal'.")

    # Add the generate_text function to generate text at inference
    def generate_text(self, conditional_input_ids, conditional_attention_mask, pixel_values=None, max_length=50):
        """
        Generate text based on input text and/or image.

        Args:
            conditional_input_ids (torch.Tensor): Tokenized conditional input text IDs of shape [batch_size, seq_len].
            conditional_attention_mask (torch.Tensor): Attention mask for the conditional input text of shape [batch_size, seq_len].
            pixel_values (torch.Tensor, optional): Image tensor of shape [batch_size, channels, height, width].
            max_length (int, optional): Maximum length of generated text. Default is 50.

        Returns:
            list: A list of generated text strings.
        """
        # Encode conditional text using T5 encoder
        conditional_text_outputs = self.t5_model.encoder(
            input_ids=conditional_input_ids,
            attention_mask=conditional_attention_mask,
            return_dict=True,
        )
        conditional_text_embeddings = conditional_text_outputs.last_hidden_state  # (batch_size, seq_len_conditional, hidden_size)

        # If pixel values are provided, encode them with the Q-Former
        if pixel_values is not None:
            image_outputs = self.image_model(pixel_values=pixel_values)
            image_embeddings = image_outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
            query_embeddings = self.qformer(image_embeddings)
            projected_query_embeddings = self.query_proj_t5(query_embeddings)
            combined_encoder_embeddings = torch.cat([conditional_text_embeddings, projected_query_embeddings], dim=1)

            # Create encoder attention mask
            projected_query_attention_mask = torch.ones(
                projected_query_embeddings.size()[:-1],
                dtype=torch.long,
                device=self.device
            )
            combined_attention_mask = torch.cat([conditional_attention_mask, projected_query_attention_mask], dim=1)
        else:
            combined_encoder_embeddings = conditional_text_embeddings
            combined_attention_mask = conditional_attention_mask

        # Generate text using T5 decoder
        generated_ids = self.t5_model.generate(
            inputs_embeds=combined_encoder_embeddings,
            attention_mask=combined_attention_mask,
            max_length=max_length,
        )
        generated_text = [self.t5_model.config.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]

        return generated_text

In [5]:
class CLIPCheckpointModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-4):
        super().__init__()
        self.model = model  # This should be a pre-trained CLIPModel
        self.learning_rate = learning_rate
        self.clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
        self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, images, texts):
        # Process images
        image_inputs = self.clip_processor(images=images, return_tensors='pt').to(self.device_type)
        image_embeddings = self.model.get_image_features(**image_inputs)

        # Process texts
        text_embeddings = self.compute_text_embeddings(texts)

        # Normalize embeddings
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        return image_embeddings, text_embeddings

    def compute_text_embeddings(self, texts):
        text_embeddings = []
        for text in texts:
            # Tokenize the text without truncation
            tokens = self.clip_processor.tokenizer(
                text,
                return_tensors='pt',
                truncation=False,
                add_special_tokens=False
            )['input_ids'].squeeze(0)

            # Implement sliding window
            window_size = 77
            stride = 50
            num_tokens = tokens.size(0)
            window_embeddings = []

            for i in range(0, num_tokens, stride):
                window_tokens = tokens[i:i + window_size]
                if window_tokens.size(0) == 0:
                    break

                if window_tokens.size(0) < window_size:
                    padding = torch.full(
                        (window_size - window_tokens.size(0),),
                        self.clip_processor.tokenizer.pad_token_id
                    )
                    window_tokens = torch.cat([window_tokens, padding])

                attention_mask = (window_tokens != self.clip_processor.tokenizer.pad_token_id).long()

                window_tokens = window_tokens.unsqueeze(0).to(self.device)
                attention_mask = attention_mask.unsqueeze(0).to(self.device)

                # Encode window tokens
                with torch.no_grad():
                    outputs = self.model.text_model(
                        input_ids=window_tokens,
                        attention_mask=attention_mask
                    )
                embedding = outputs.last_hidden_state[:, 0, :]
                window_embeddings.append(embedding)

            # Aggregate embeddings
            if window_embeddings:
                window_embeddings = torch.cat(window_embeddings, dim=0)
                aggregated_embedding = window_embeddings.mean(dim=0)
            else:
                aggregated_embedding = torch.zeros(self.model.config.hidden_size).to(self.device)

            text_embeddings.append(aggregated_embedding)

        text_embeddings = torch.stack(text_embeddings)
        return text_embeddings

    def compute_loss(self, image_embeddings, text_embeddings):
        """
        Compute contrastive loss for the image and text embeddings.
        """
        # Compute similarity scores
        logits_per_image = image_embeddings @ text_embeddings.t() * self.model.logit_scale.exp()
        logits_per_text = logits_per_image.t()

        # Labels
        batch_size = image_embeddings.size(0)
        labels = torch.arange(batch_size).to(image_embeddings.device)

        # Compute cross-entropy loss
        loss_i = F.cross_entropy(logits_per_image, labels)
        loss_t = F.cross_entropy(logits_per_text, labels)
        loss = (loss_i + loss_t) / 2

        return loss

    def training_step(self, batch, batch_idx):
        images = batch['images']
        texts = batch['texts']
        batch_size = len(images)  # Calculate batch size
        image_embeddings, text_embeddings = self(images, texts)
        loss = self.compute_loss(image_embeddings, text_embeddings)
        self.log('train_loss', loss, batch_size=batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        images = batch['images']
        texts = batch['texts']
        batch_size = len(images)  # Calculate batch size
        image_embeddings, text_embeddings = self(images, texts)
        loss = self.compute_loss(image_embeddings, text_embeddings)
        self.log('val_loss', loss, batch_size=batch_size)
        return loss

    def test_step(self, batch, batch_idx):
        images = batch['images']
        texts = batch['texts']
        batch_size = len(images)  # Calculate batch size
        image_embeddings, text_embeddings = self(images, texts)
        loss = self.compute_loss(image_embeddings, text_embeddings)
        self.log('test_loss', loss, batch_size=batch_size)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

    def generate_embeddings(self, images=None, texts=None):
        """
        Generate embeddings for the given images or texts.

        Parameters:
        - images: List of images to generate embeddings for.
        - texts: List of texts to generate embeddings for.

        Returns:
        - Dictionary with 'image_embeddings' and 'text_embeddings' as keys, depending on the inputs.
        """
        results = {}

        # Check if images are provided
        if images is not None:
            self.model.eval()  # Set model to evaluation mode
            with torch.no_grad():
                image_inputs = self.clip_processor(images=images, return_tensors='pt').to(self.device_type)
                image_embeddings = self.model.get_image_features(**image_inputs)
                image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
                results['image_embeddings'] = image_embeddings

        # Check if texts are provided
        if texts is not None:
            self.model.eval()  # Set model to evaluation mode
            with torch.no_grad():
                text_embeddings = self.compute_text_embeddings(texts)
                text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)
                results['text_embeddings'] = text_embeddings

        return results

In [6]:
class CLIPCLSModel(pl.LightningModule):
    def __init__(self, weight_decay, eps, warmup_steps, num_training_steps, text_model_name='johngiorgi/declutr-small', 
                image_model_name='google/vit-base-patch16-224', learning_rate=0.00001, num_negatives=5, temperature=0.5):
        super(CLIPCLSModel, self).__init__()
        self.text_model = AutoModel.from_pretrained(text_model_name)
        self.image_model = ViTModel.from_pretrained(image_model_name)
        self.learning_rate = learning_rate
        self.num_negatives = num_negatives
        self.temperature = temperature
        self.weight_decay = weight_decay
        self.eps = eps
        self.warmup_steps = warmup_steps
        self.num_training_steps = num_training_steps

        # Store outputs for validation and testing
        self.validation_outputs = []
        self.test_outputs = []
        
    def forward(self, input_ids, attention_mask, pixel_values):
        # Get text embeddings
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = text_outputs.last_hidden_state[:, 0, :]  # Use CLS token embedding
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)  # Normalize embeddings

        """
        # Get text embeddings
        text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        eos_mask = input_ids.eq(self.text_model.config.eos_token_id)
        eos_indices = eos_mask.nonzero(as_tuple=False)
        text_embeddings = text_outputs.last_hidden_state[eos_indices[:, 0], eos_indices[:, 1]]
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)  # Normalize embeddings
        """

        # Get image embeddings
        image_outputs = self.image_model(pixel_values=pixel_values)
        image_embeddings = image_outputs.last_hidden_state[:, 0, :]  # Use CLS token embedding
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)  # Normalize embeddings
        
        return text_embeddings, image_embeddings

    def compute_loss(self, pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings):
        # Normalize embeddings
        pos_text_embeddings = F.normalize(pos_text_embeddings, p=2, dim=1)
        pos_image_embeddings = F.normalize(pos_image_embeddings, p=2, dim=1)
        neg_text_embeddings = F.normalize(neg_text_embeddings, p=2, dim=2)  # Normalized over the last dimension
        neg_image_embeddings = F.normalize(neg_image_embeddings, p=2, dim=2)  # Normalized over the last dimension

        # Positive pairs similarity
        pos_sim = torch.exp(torch.sum(pos_text_embeddings * pos_image_embeddings, dim=-1) / self.temperature)
        
        # Negative pairs similarity (text to image)
        neg_sim_text_image = torch.exp(torch.einsum('bij,bj->bi', neg_text_embeddings, pos_image_embeddings) / self.temperature)
        # Negative pairs similarity (image to text)
        neg_sim_image_text = torch.exp(torch.einsum('bi,bkj->bk', pos_text_embeddings, neg_image_embeddings) / self.temperature)

        # Calculate the loss for text-to-image
        denominator_text_image = pos_sim + neg_sim_text_image.sum(dim=1)
        loss_text_image = -torch.log(pos_sim / denominator_text_image)

        # Calculate the loss for image-to-text
        denominator_image_text = pos_sim.unsqueeze(1) + neg_sim_image_text
        loss_image_text = -torch.log(pos_sim.unsqueeze(1) / denominator_image_text).sum(dim=1)

        # Combine both losses
        loss = (loss_text_image + loss_image_text).mean()

        return loss


    def training_step(self, batch, batch_idx):
        # Forward pass
        pos_text_embeddings, pos_image_embeddings = self(batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'])
        neg_text_embeddings, neg_image_embeddings = self(batch['neg_input_ids'].view(-1, batch['neg_input_ids'].shape[-1]), 
                                                         batch['neg_attention_mask'].view(-1, batch['neg_attention_mask'].shape[-1]), 
                                                         batch['neg_pixel_values'].view(-1, batch['neg_pixel_values'].shape[-3], 
                                                                                        batch['neg_pixel_values'].shape[-2], batch['neg_pixel_values'].shape[-1]))
        
        # Reshape negative embeddings
        neg_text_embeddings = neg_text_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        neg_image_embeddings = neg_image_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        
        # Compute loss
        loss = self.compute_loss(pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings)
        
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Forward pass
        pos_text_embeddings, pos_image_embeddings = self(batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'])
        neg_text_embeddings, neg_image_embeddings = self(batch['neg_input_ids'].view(-1, batch['neg_input_ids'].shape[-1]), 
                                                         batch['neg_attention_mask'].view(-1, batch['neg_attention_mask'].shape[-1]), 
                                                         batch['neg_pixel_values'].view(-1, batch['neg_pixel_values'].shape[-3], 
                                                                                        batch['neg_pixel_values'].shape[-2], batch['neg_pixel_values'].shape[-1]))
        
        # Reshape negative embeddings
        neg_text_embeddings = neg_text_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        neg_image_embeddings = neg_image_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        
        # Compute loss
        loss = self.compute_loss(pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings)

        self.log('val_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Forward pass
        pos_text_embeddings, pos_image_embeddings = self(batch['pos_input_ids'], batch['pos_attention_mask'], batch['pos_pixel_values'])
        neg_text_embeddings, neg_image_embeddings = self(batch['neg_input_ids'].view(-1, batch['neg_input_ids'].shape[-1]), 
                                                         batch['neg_attention_mask'].view(-1, batch['neg_attention_mask'].shape[-1]), 
                                                         batch['neg_pixel_values'].view(-1, batch['neg_pixel_values'].shape[-3], 
                                                                                        batch['neg_pixel_values'].shape[-2], batch['neg_pixel_values'].shape[-1]))
        
        # Reshape negative embeddings
        neg_text_embeddings = neg_text_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        neg_image_embeddings = neg_image_embeddings.view(batch['neg_input_ids'].shape[0], self.num_negatives, -1)
        
        # Compute loss
        loss = self.compute_loss(pos_text_embeddings, pos_image_embeddings, neg_text_embeddings, neg_image_embeddings)

        self.test_outputs.append(loss)
        self.log('test_loss', loss)
        return loss

    def on_test_epoch_end(self):
        if self.test_outputs:
            avg_loss = torch.stack(self.test_outputs).mean()
            self.log('avg_test_loss', avg_loss)
        self.test_outputs.clear()

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.eps)

        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.num_training_steps)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]
    
    def get_embeddings(self, input_ids=None, attention_mask=None, pixel_values=None, embedding_type='text'):
        """
        Generate text or image embeddings for inference.

        Args:
        input_ids: Tokenized input text (for text embeddings).
        attention_mask: Attention mask for input text (for text embeddings).
        pixel_values: Preprocessed image (for image embeddings).
        embedding_type: Specify 'text' or 'image' to generate the respective embeddings.

        Returns:
        Normalized embeddings.
        """
        if embedding_type == 'text':
            if input_ids is None or attention_mask is None:
                raise ValueError("input_ids and attention_mask are required for text embeddings.")
            
            # Get text embeddings
            text_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
            eos_mask = input_ids.eq(self.text_model.config.eos_token_id)
            eos_indices = eos_mask.nonzero(as_tuple=False)
            text_embeddings = text_outputs.last_hidden_state[eos_indices[:, 0], eos_indices[:, 1]]
            text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)  # Normalize embeddings
            return text_embeddings
        
        elif embedding_type == 'image':
            if pixel_values is None:
                raise ValueError("pixel_values are required for image embeddings.")
            
            # Get image embeddings
            image_outputs = self.image_model(pixel_values=pixel_values)
            image_embeddings = image_outputs.last_hidden_state[:, 0, :]  # Use CLS token embedding
            image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)  # Normalize embeddings
            return image_embeddings
        
        else:
            raise ValueError("Invalid embedding_type. Choose either 'text' or 'image'.")

In [97]:
args.model_type = "BLIP2"

In [98]:
# Initialize the model
if args.model_type == "CLIP":

    sys.path.append('../architectures/')
    # from CLIPLayer import CLIPModel
    
    model = CLIPModel(
        text_model_name='johngiorgi/declutr-small', 
        image_model_name='google/vit-base-patch16-224', 
        learning_rate=args.learning_rate, 
        num_negatives=args.nb_negatives,
        weight_decay=args.weight_decay,
        eps=args.adam_epsilon,
        warmup_steps=warmup_steps,
        num_training_steps=1000,
        temperature=args.temp
    )
    
# Initialize the model
if args.model_type == "CLIPCLS":

    sys.path.append('../architectures/')
    # from CLIPLayer import CLIPModel
    
    model = CLIPCLSModel(
        text_model_name='johngiorgi/declutr-small', 
        image_model_name='google/vit-base-patch16-224', 
        learning_rate=args.learning_rate, 
        num_negatives=args.nb_negatives,
        weight_decay=args.weight_decay,
        eps=args.adam_epsilon,
        warmup_steps=warmup_steps,
        num_training_steps=1000,
        temperature=args.temp
    )
    
if args.model_type == "BigCLIP":
        model = CLIPModel(
        text_model_name='johngiorgi/declutr-base', 
        image_model_name='openai/clip-vit-base-patch32', 
        learning_rate=args.learning_rate, 
        num_negatives=args.nb_negatives,
        weight_decay=args.weight_decay,
        eps=args.adam_epsilon,
        warmup_steps=warmup_steps,
        num_training_steps=1000,
        temperature=args.temp
    )
        
if args.model_type == "BigCLIPITM":
        model = CLIPITMModel(
        text_model_name='johngiorgi/declutr-base', 
        image_model_name='openai/clip-vit-base-patch32', 
        learning_rate=args.learning_rate, 
        num_negatives=args.nb_negatives,
        weight_decay=args.weight_decay,
        eps=args.adam_epsilon,
        warmup_steps=warmup_steps,
        num_training_steps=1000,
        temperature=args.temp
    )
        
elif args.model_type == "CLIPCheckpoint":
        from transformers import CLIPProcessor, CLIPModel
        clip_model_name = 'openai/clip-vit-base-patch32'
        clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
        model = CLIPModel.from_pretrained(clip_model_name)
        
        model = CLIPCheckpointModel(model)

elif args.model_type == "CLIPITM":

    sys.path.append('../architectures/')
    # from CLIPITMLayer import CLIPITMModel

    model = CLIPITMModel(
    text_model_name='johngiorgi/declutr-small', 
    image_model_name='google/vit-base-patch16-224', 
    learning_rate=args.learning_rate, 
    num_negatives=args.nb_negatives,
    weight_decay=args.weight_decay,
    eps=args.adam_epsilon,
    warmup_steps=warmup_steps,
    num_training_steps=1000,
    temperature=args.temp
    )

elif args.model_type == "BLIP2":
    sys.path.append('../architectures/')
    # from BLIP2Layer import BLIP2Model

    model = BLIP2Model(
    text_model_name='johngiorgi/declutr-small', 
    image_model_name='google/vit-base-patch16-224', 
    t5_model_name='google/flan-t5-small',
    learning_rate=args.learning_rate, 
    num_negatives=args.nb_negatives,
    weight_decay=args.weight_decay,
    eps=args.adam_epsilon,
    warmup_steps=warmup_steps,
    num_training_steps=1000,
    temperature=args.temp,
    num_query_tokens=32,
    qformer_hidden_size=768,
    cross_attention_frequency=1
    )
    
elif args.model_type == "ConditionalBLIP":
    sys.path.append('../architectures/')
    # from BLIP2Layer import BLIP2Model

    model = BLIP2ConditionalModel(
    text_model_name='johngiorgi/declutr-small', 
    image_model_name='google/vit-base-patch16-224', 
    t5_model_name='google/flan-t5-small',
    learning_rate=args.learning_rate, 
    num_negatives=args.nb_negatives,
    weight_decay=args.weight_decay,
    eps=args.adam_epsilon,
    warmup_steps=warmup_steps,
    num_training_steps=1000,
    temperature=args.temp,
    num_query_tokens=32,
    qformer_hidden_size=768,
    cross_attention_frequency=1
    )

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [99]:
# Load the checkpoint
checkpoint = torch.load("/workspace/persistent/HTClipper/models/grouped-and-masked/multimodal-baselines/pre-training/BLIP2/non-associated/seed:1111/lr-0.0001/NTXENT/0.1/negatives-5/final_model.ckpt")

# Load the state dictionary into the model
model.load_state_dict(checkpoint['state_dict'])

# Set the model to evaluation mode
model.eval()

# Move the model to the desired device
model = model.to(device)

# Helper functions

In [7]:
from tqdm import tqdm

In [8]:
# Function to map images with text for CLIP model
def map_images_with_text_for_clip_model(df, img_dir, filter_by="vendor"):
    # Initialize a list to store the new rows
    new_rows = []

    # Iterate over each row in the dataframe
    for _, row in df.iterrows():
        text = row['TEXT']
        all_images = str(row['IMAGES']).split('|')
        if filter_by == "vendor":
            vendor = row['VENDOR']
        elif filter_by == "id":
            vendor = row['ID']
        region = row['region']
        
        # Create a new entry for each image
        for image in all_images:
            full_image_path = os.path.join(img_dir, region, "image", "image", image)
            
            # Only add the row if the image exists at the specified path
            if os.path.exists(full_image_path):
                new_rows.append({
                    'TEXT': text,
                    'IMAGES': full_image_path,  # Store the full image path
                    'VENDOR': vendor,
                    'region' : region
                })

    # Create a new dataframe from the list of new rows
    return pd.DataFrame(new_rows)

In [9]:
# Function to map images with text for BLIP2 model
def map_images_with_text_for_blip2_model(df, img_dir, filter_by="vendor"):
    # Similar to your function, adjust if needed
    new_rows = []

    for _, row in df.iterrows():
        text = row['TEXT']
        all_images = str(row['IMAGES']).split('|')
        if filter_by == "vendor":
            vendor = row['VENDOR']
        elif filter_by == "id":
            vendor = row['ID']
        region = row['region']

        for image in all_images:
            full_image_path = os.path.join(img_dir, region, "image", "image", image)

            if os.path.exists(full_image_path):
                new_rows.append({
                    'TEXT': text,
                    'IMAGES': full_image_path,
                    'VENDOR': vendor,
                    'region' : region
                })

    return pd.DataFrame(new_rows)

In [10]:
# Inference Dataset
class InferenceDataset(Dataset):
    def __init__(self, df, text_tokenizer, image_processor):
        self.df = df.reset_index(drop=True)
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row['TEXT'])
        image_path = row['IMAGES']
        vendor = row['VENDOR']

        # Tokenize the text
        text_inputs = self.text_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        input_ids = text_inputs['input_ids'].squeeze(0)
        attention_mask = text_inputs['attention_mask'].squeeze(0)

        # Load and process the image
        image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
        image = self.image_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'pixel_values': image,
            'label': vendor
        }

In [11]:
# Custom Inference Dataset for BLIP2
class BLIP2InferenceDataset(Dataset):
    def __init__(self, df, text_tokenizer, image_processor):
        self.df = df.reset_index(drop=True)
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = str(row['TEXT'])
        image_path = row['IMAGES']
        vendor = row['VENDOR']

        # Tokenize the text
        text_inputs = self.text_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        input_ids = text_inputs['input_ids'].squeeze(0)
        attention_mask = text_inputs['attention_mask'].squeeze(0)

        # Load and process the image
        image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
        image = self.image_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'pixel_values': image,
            'label': vendor
        }

In [105]:
class BLIP2ConditionalInferenceDataset(Dataset):
    def __init__(self, df, text_tokenizer, t5_tokenizer, image_processor):
        self.df = df.reset_index(drop=True)
        self.text_tokenizer = text_tokenizer
        self.t5_tokenizer = t5_tokenizer
        self.image_processor = image_processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        full_text = str(row['TEXT'])
        image_path = row['IMAGES']
        vendor = row['VENDOR']

        # Split the text on [SEP] for text generation
        if '[SEP]' in full_text:
            conditional_text, target_text = full_text.split('[SEP]', 1)
            conditional_text = conditional_text.strip()
            target_text = target_text.strip()
        else:
            # If [SEP] is not present, handle accordingly
            conditional_text = ''
            target_text = full_text.strip()

        # Tokenize the full text for CLIP and ITM losses
        text_inputs = self.text_tokenizer(
            full_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )
        input_ids = text_inputs['input_ids'].squeeze(0)
        attention_mask = text_inputs['attention_mask'].squeeze(0)

        # Tokenize the conditional text (text prompt) for text generation
        conditional_text_inputs = self.t5_tokenizer(
            conditional_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )
        conditional_input_ids = conditional_text_inputs['input_ids'].squeeze(0)
        conditional_attention_mask = conditional_text_inputs['attention_mask'].squeeze(0)

        # Tokenize the target text for text generation
        target_text_inputs = self.t5_tokenizer(
            target_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        )
        target_input_ids = target_text_inputs['input_ids'].squeeze(0)
        target_attention_mask = target_text_inputs['attention_mask'].squeeze(0)

        # Load and process the image
        image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
        image = self.image_processor(images=image, return_tensors="pt")['pixel_values'].squeeze(0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'pixel_values': image,
            'conditional_input_ids': conditional_input_ids,
            'conditional_attention_mask': conditional_attention_mask,
            'target_input_ids': target_input_ids,
            'target_attention_mask': target_attention_mask,
            'label': vendor
        }

In [106]:
# only run use them in our further experiments.
text_tokenizer = AutoTokenizer.from_pretrained('johngiorgi/declutr-small')
image_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

# text_tokenizer = AutoTokenizer.from_pretrained('johngiorgi/declutr-base')
# image_processor = ViTImageProcessor.from_pretrained('openai/clip-vit-base-patch32')

# Extract emnbeddings from unique text and image ads

In [211]:
from tqdm import tqdm
from PIL import Image

from tqdm import tqdm
from PIL import Image

def process_dataset_for_CLIPModel(region_name, data_dir, image_dir, model, text_tokenizer, image_processor, filter_by="vendor", batch_size=32):
    assert filter_by in ["vendor", "ids"]
    # Load the dataset
    df = pd.read_csv(os.path.join(data_dir, f"{region_name}.csv"))
    df['region'] = region_name
    df = map_images_with_text_for_clip_model(df, img_dir=image_dir, filter_by=filter_by).drop_duplicates()

    df_filtered = df

    # Get unique text embeddings
    unique_texts = df_filtered['TEXT'].unique()
    text_embeddings = {}
    text_labels = []

    # Extract text embeddings with tqdm progress bar
    for text in tqdm(unique_texts, desc="Extracting Text Embeddings"):
        inputs = text_tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
        text_embed = model.get_embeddings(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], embedding_type='text')
        text_embeddings[text] = text_embed.detach().cpu().numpy()
        text_labels.append(df_filtered[df_filtered['TEXT'] == text]['VENDOR'].values[0])  # Get vendor for the text

    # Get unique images and their embeddings
    unique_images = df_filtered['IMAGES'].unique()
    image_embeddings = {}
    image_labels = []
    seen_embeddings = set()  # To track unique embeddings

    # Extract image embeddings with tqdm progress bar
    for image_path in tqdm(unique_images, desc="Extracting Image Embeddings"):
        # Load the image
        image = Image.open(image_path).convert("RGB")  # Convert to RGB format
        image_tensor = image_processor(images=image, return_tensors="pt")['pixel_values'].to(device)  # Preprocess the image
        image_embed = model.get_embeddings(pixel_values=image_tensor, embedding_type='image')

        # Convert the embedding to a tuple to make it hashable for the set
        embedding_tuple = tuple(image_embed.detach().cpu().numpy().flatten())

        if embedding_tuple not in seen_embeddings:
            seen_embeddings.add(embedding_tuple)  # Track the unique embedding
            image_embeddings[image_path] = image_embed.detach().cpu().numpy()  # Store the unique embedding
            image_labels.append(df_filtered[df_filtered['IMAGES'] == image_path]['VENDOR'].values[0])  # Get vendor for the image

    # Train-test split
    train_text_embeddings, test_text_embeddings, train_text_labels, test_text_labels = train_test_split(
        list(text_embeddings.values()), text_labels, test_size=0.2, random_state=1111
    )
    train_image_embeddings, test_image_embeddings, train_image_labels, test_image_labels = train_test_split(
        list(image_embeddings.values()), image_labels, test_size=0.2, random_state=1111
    )

    output_dir = os.path.join("/workspace/persistent/HTClipper/models/pickled/embeddings/grouped-and-masked/trained_declutr_vit/", "CLIP")
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Ensure embeddings are in the desired shape (batch_size, embedding_dim)
    train_text_embeddings = np.array(train_text_embeddings).squeeze()  # Shape: (train_size, 768)
    test_text_embeddings = np.array(test_text_embeddings).squeeze()    # Shape: (test_size, 768)
    train_image_embeddings = np.array(train_image_embeddings).squeeze()  # Shape: (train_size, 768)
    test_image_embeddings = np.array(test_image_embeddings).squeeze()    # Shape: (test_size, 768)

    if filter_by == "vendor":
        np.save(os.path.join(output_dir, f'train_text_embeddings_{region_name}_vendors.npy'), np.array(train_text_embeddings))
        np.save(os.path.join(output_dir, f'train_image_embeddings_{region_name}_vendors.npy'), np.array(train_image_embeddings))
        np.save(os.path.join(output_dir, f'train_text_labels_{region_name}_vendors.npy'), np.array(train_text_labels))
        np.save(os.path.join(output_dir, f'train_image_labels_{region_name}_vendors.npy'), np.array(train_image_labels))

        np.save(os.path.join(output_dir, f'test_text_embeddings_{region_name}_vendors.npy'), np.array(test_text_embeddings))
        np.save(os.path.join(output_dir, f'test_image_embeddings_{region_name}_vendors.npy'), np.array(test_image_embeddings))
        np.save(os.path.join(output_dir, f'test_text_labels_{region_name}_vendors.npy'), np.array(test_text_labels))
        np.save(os.path.join(output_dir, f'test_image_labels_{region_name}_vendors.npy'), np.array(test_image_labels))
        
    else:
        np.save(os.path.join(output_dir, f'train_text_embeddings_{region_name}_ids.npy'), np.array(train_text_embeddings))
        np.save(os.path.join(output_dir, f'train_image_embeddings_{region_name}_ids.npy'), np.array(train_image_embeddings))
        np.save(os.path.join(output_dir, f'train_text_labels_{region_name}_ids.npy'), np.array(train_text_labels))
        np.save(os.path.join(output_dir, f'train_image_labels_{region_name}_ids.npy'), np.array(train_image_labels))

        np.save(os.path.join(output_dir, f'test_text_embeddings_{region_name}_ids.npy'), np.array(test_text_embeddings))
        np.save(os.path.join(output_dir, f'test_image_embeddings_{region_name}_ids.npy'), np.array(test_image_embeddings))
        np.save(os.path.join(output_dir, f'test_text_labels_{region_name}_ids.npy'), np.array(test_text_labels))
        np.save(os.path.join(output_dir, f'test_image_labels_{region_name}_ids.npy'), np.array(test_image_labels))

    print(f"Processed region: {region_name}")
    print(f"Number of training samples: {len(train_text_labels)}")
    print(f"Number of testing samples: {len(test_text_labels)}\n")
    
    return train_text_embeddings, train_image_embeddings, train_text_labels, train_image_labels, test_text_embeddings, test_image_embeddings, test_text_labels, test_image_labels

def process_dataset_for_CLIPITMModel(region_name, data_dir, image_dir, model, text_tokenizer, image_processor, filter_by="vendor", batch_size=32):
    # Load the dataset
    df = pd.read_csv(os.path.join(data_dir, f"{region_name}.csv"))
    df['region'] = region_name
    df = map_images_with_text_for_clip_model(df, img_dir=image_dir, filter_by=filter_by).drop_duplicates()

    df_filtered = df

    # Get unique text embeddings
    unique_texts = df_filtered['TEXT'].unique()
    text_embeddings = {}
    text_labels = []

    # Extract text embeddings with tqdm progress bar
    for text in tqdm(unique_texts, desc="Extracting Text Embeddings"):
        inputs = text_tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
        text_embed = model.get_embeddings(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], embedding_type='text')
        text_embeddings[text] = text_embed.detach().cpu().numpy()
        text_labels.append(df_filtered[df_filtered['TEXT'] == text]['VENDOR'].values[0])  # Get vendor for the text

    # Get unique images and their embeddings
    unique_images = df_filtered['IMAGES'].unique()
    image_embeddings = {}
    image_labels = []
    seen_embeddings = set()  # To track unique embeddings

    # Extract image embeddings with tqdm progress bar
    for image_path in tqdm(unique_images, desc="Extracting Image Embeddings"):
        # Load the image
        image = Image.open(image_path).convert("RGB")  # Convert to RGB format
        image_tensor = image_processor(images=image, return_tensors="pt")['pixel_values'].to(device)  # Preprocess the image
        image_embed = model.get_embeddings(pixel_values=image_tensor, embedding_type='image')

        # Convert the embedding to a tuple to make it hashable for the set
        embedding_tuple = tuple(image_embed.detach().cpu().numpy().flatten())

        if embedding_tuple not in seen_embeddings:
            seen_embeddings.add(embedding_tuple)  # Track the unique embedding
            image_embeddings[image_path] = image_embed.detach().cpu().numpy()  # Store the unique embedding
            image_labels.append(df_filtered[df_filtered['IMAGES'] == image_path]['VENDOR'].values[0])  # Get vendor for the image

    # Train-test split
    train_text_embeddings, test_text_embeddings, train_text_labels, test_text_labels = train_test_split(
        list(text_embeddings.values()), text_labels, test_size=0.2, random_state=1111
    )
    train_image_embeddings, test_image_embeddings, train_image_labels, test_image_labels = train_test_split(
        list(image_embeddings.values()), image_labels, test_size=0.2, random_state=1111
    )

    output_dir = os.path.join("/workspace/persistent/HTClipper/models/pickled/embeddings/grouped-and-masked/trained_declutr_vit/", "CLIPITM")
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Ensure embeddings are in the desired shape (batch_size, embedding_dim)
    train_text_embeddings = np.array(train_text_embeddings).squeeze()  # Shape: (train_size, 768)
    test_text_embeddings = np.array(test_text_embeddings).squeeze()    # Shape: (test_size, 768)
    train_image_embeddings = np.array(train_image_embeddings).squeeze()  # Shape: (train_size, 768)
    test_image_embeddings = np.array(test_image_embeddings).squeeze()    # Shape: (test_size, 768)

    if filter_by == "vendor":
        np.save(os.path.join(output_dir, f'train_text_embeddings_{region_name}_vendors.npy'), train_text_embeddings)
        np.save(os.path.join(output_dir, f'train_image_embeddings_{region_name}_vendors.npy'), train_image_embeddings)
        np.save(os.path.join(output_dir, f'train_text_labels_{region_name}_vendors.npy'), np.array(train_text_labels))
        np.save(os.path.join(output_dir, f'train_image_labels_{region_name}_vendors.npy'), np.array(train_image_labels))

        np.save(os.path.join(output_dir, f'test_text_embeddings_{region_name}_vendors.npy'), test_text_embeddings)
        np.save(os.path.join(output_dir, f'test_image_embeddings_{region_name}_vendors.npy'), test_image_embeddings)
        np.save(os.path.join(output_dir, f'test_text_labels_{region_name}_vendors.npy'), np.array(test_text_labels))
        np.save(os.path.join(output_dir, f'test_image_labels_{region_name}_vendors.npy'), np.array(test_image_labels))

    else:
        np.save(os.path.join(output_dir, f'train_text_embeddings_{region_name}_ids.npy'), train_text_embeddings)
        np.save(os.path.join(output_dir, f'train_image_embeddings_{region_name}_ids.npy'), train_image_embeddings)
        np.save(os.path.join(output_dir, f'train_text_labels_{region_name}_ids.npy'), np.array(train_text_labels))
        np.save(os.path.join(output_dir, f'train_image_labels_{region_name}_ids.npy'), np.array(train_image_labels))

        np.save(os.path.join(output_dir, f'test_text_embeddings_{region_name}_ids.npy'), test_text_embeddings)
        np.save(os.path.join(output_dir, f'test_image_embeddings_{region_name}_ids.npy'), test_image_embeddings)
        np.save(os.path.join(output_dir, f'test_text_labels_{region_name}_ids.npy'), np.array(test_text_labels))
        np.save(os.path.join(output_dir, f'test_image_labels_{region_name}_ids.npy'), np.array(test_image_labels))

    print(f"Processed region: {region_name}")
    print(f"Number of training samples: {len(train_text_labels)}")
    print(f"Number of testing samples: {len(test_text_labels)}\n")
    
    return train_text_embeddings, train_image_embeddings, train_text_labels, train_image_labels, test_text_embeddings, test_image_embeddings, test_text_labels, test_image_labels

def process_dataset_for_BLIP2Model(region_name, data_dir, image_dir, model, text_tokenizer, image_processor, filter_by="vendor", batch_size=32):
    # Load the dataset
    df = pd.read_csv(os.path.join(data_dir, f"{region_name}.csv"))
    df['region'] = region_name
    df = map_images_with_text_for_blip2_model(df, img_dir=image_dir, filter_by=filter_by).drop_duplicates()

    df_filtered = df

    # Get unique text embeddings
    unique_texts = df_filtered['TEXT'].unique()
    text_embeddings = {}
    text_labels = []

    # Extract text embeddings with tqdm progress bar
    for text in tqdm(unique_texts, desc="Extracting Text Embeddings"):
        inputs = text_tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
        text_embed = model.get_embeddings(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], embedding_type='text')
        text_embeddings[text] = text_embed.detach().cpu().numpy()
        text_labels.append(df_filtered[df_filtered['TEXT'] == text]['VENDOR'].values[0])  # Get vendor for the text

    # Get unique images and their embeddings
    unique_images = df_filtered['IMAGES'].unique()
    image_embeddings = {}
    image_labels = []
    seen_embeddings = set()  # To track unique embeddings

    # Extract image embeddings with tqdm progress bar
    for image_path in tqdm(unique_images, desc="Extracting Image Embeddings"):
        # Load the image
        image = Image.open(image_path).convert("RGB")  # Convert to RGB format
        image_tensor = image_processor(images=image, return_tensors="pt")['pixel_values'].to(device)  # Preprocess the image
        image_embed = model.get_embeddings(pixel_values=image_tensor, embedding_type='image')

        # Convert the embedding to a tuple to make it hashable for the set
        embedding_tuple = tuple(image_embed.detach().cpu().numpy().flatten())

        if embedding_tuple not in seen_embeddings:
            seen_embeddings.add(embedding_tuple)  # Track the unique embedding
            image_embeddings[image_path] = image_embed.detach().cpu().numpy()  # Store the unique embedding
            image_labels.append(df_filtered[df_filtered['IMAGES'] == image_path]['VENDOR'].values[0])  # Get vendor for the image

    # Train-test split
    train_text_embeddings, test_text_embeddings, train_text_labels, test_text_labels = train_test_split(
        list(text_embeddings.values()), text_labels, test_size=0.2, random_state=1111
    )
    train_image_embeddings, test_image_embeddings, train_image_labels, test_image_labels = train_test_split(
        list(image_embeddings.values()), image_labels, test_size=0.2, random_state=1111
    )

    output_dir = os.path.join("/workspace/persistent/HTClipper/models/pickled/embeddings/grouped-and-masked/trained_declutr_vit/", "BLIP2")
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Ensure embeddings are in the desired shape (batch_size, embedding_dim)
    train_text_embeddings = np.array(train_text_embeddings).squeeze()  # Shape: (train_size, 768)
    test_text_embeddings = np.array(test_text_embeddings).squeeze()    # Shape: (test_size, 768)
    train_image_embeddings = np.array(train_image_embeddings).squeeze()  # Shape: (train_size, 768)
    test_image_embeddings = np.array(test_image_embeddings).squeeze()    # Shape: (test_size, 768)

    if filter_by == "vendor":
        np.save(os.path.join(output_dir, f'train_text_embeddings_{region_name}_vendors.npy'), train_text_embeddings)
        np.save(os.path.join(output_dir, f'train_image_embeddings_{region_name}_vendors.npy'), train_image_embeddings)
        np.save(os.path.join(output_dir, f'train_text_labels_{region_name}_vendors.npy'), np.array(train_text_labels))
        np.save(os.path.join(output_dir, f'train_image_labels_{region_name}_vendors.npy'), np.array(train_image_labels))

        np.save(os.path.join(output_dir, f'test_text_embeddings_{region_name}_vendors.npy'), test_text_embeddings)
        np.save(os.path.join(output_dir, f'test_image_embeddings_{region_name}_vendors.npy'), test_image_embeddings)
        np.save(os.path.join(output_dir, f'test_text_labels_{region_name}_vendors.npy'), np.array(test_text_labels))
        np.save(os.path.join(output_dir, f'test_image_labels_{region_name}_vendors.npy'), np.array(test_image_labels))

    else:
        np.save(os.path.join(output_dir, f'train_text_embeddings_{region_name}_ids.npy'), train_text_embeddings)
        np.save(os.path.join(output_dir, f'train_image_embeddings_{region_name}_ids.npy'), train_image_embeddings)
        np.save(os.path.join(output_dir, f'train_text_labels_{region_name}_ids.npy'), np.array(train_text_labels))
        np.save(os.path.join(output_dir, f'train_image_labels_{region_name}_ids.npy'), np.array(train_image_labels))

        np.save(os.path.join(output_dir, f'test_text_embeddings_{region_name}_ids.npy'), test_text_embeddings)
        np.save(os.path.join(output_dir, f'test_image_embeddings_{region_name}_ids.npy'), test_image_embeddings)
        np.save(os.path.join(output_dir, f'test_text_labels_{region_name}_ids.npy'), np.array(test_text_labels))
        np.save(os.path.join(output_dir, f'test_image_labels_{region_name}_ids.npy'), np.array(test_image_labels))

    print(f"Processed region: {region_name}")
    print(f"Number of training samples: {len(train_text_labels)}")
    print(f"Number of testing samples: {len(test_text_labels)}\n")
    
    return train_text_embeddings, train_image_embeddings, train_text_labels, train_image_labels, test_text_embeddings, test_image_embeddings, test_text_labels, test_image_labels

In [None]:
# List of regions to process
regions = ['south', 'midwest', 'west', 'northeast']

# , 'northeast'
# Process each regiond
for region in regions:
    print("-"*50 + region + "-"*50)
    _, _, _, _, _, _, _, _ = process_dataset_for_BLIP2Model(
        region_name=region,
        data_dir=args.data_dir,
        image_dir=args.image_dir,
        model=model,
        text_tokenizer=text_tokenizer,
        image_processor=image_processor,
        filter_by = "vendor",
        batch_size=32
    )

--------------------------------------------------south--------------------------------------------------


Extracting Text Embeddings: 100%|██████████| 13677/13677 [02:05<00:00, 109.31it/s]
Extracting Image Embeddings:  16%|█▌        | 10529/65544 [09:36<52:20, 17.52it/s]  

In [None]:
# List of regions to process
regions = ['south', 'midwest', 'west', 'northeast']

# , 'northeast'
# Process each region
for region in regions:
    print("-"*50 + region + "-"*50)
    _, _, _, _, _, _, _, _ = process_dataset_for_BLIP2Model(
        region_name=region,
        data_dir=args.data_dir,
        image_dir=args.image_dir,
        model=model,
        text_tokenizer=text_tokenizer,
        image_processor=image_processor,
        filter_by = "ids",
        batch_size=32
    )