In [1]:
from datasets import load_dataset, load_from_disk, Dataset
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
from torchvision import models, transforms
from transformers import AutoModel, AutoTokenizer,AutoImageProcessor, ResNetModel
from tqdm import tqdm

In [2]:
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
resnet.fc = nn.Identity()
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
minilm = AutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")




In [3]:
img_txt_ds = load_from_disk('image_text_dataset')

In [20]:
img_txt_test = load_from_disk('image_text_test')

In [4]:
def preprocess(example):
    image = example["image"]
    image= image.convert("RGB")
    processed_image = processor(image, return_tensors="pt")["pixel_values"].squeeze(0) 

    text_inputs = tokenizer(
        example["caption"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    
    label = torch.tensor(example["label"]).float()

    return {
        "image": processed_image,
        "input_ids": text_inputs["input_ids"].squeeze(0),
        "attention_mask": text_inputs["attention_mask"].squeeze(0),
        "label": label
    }


In [5]:
class CustomImageTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, processor, max_length=128):
        self.dataset = dataset  
        self.tokenizer = tokenizer  
        self.processor = processor  
        self.max_length = max_length  

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Image preprocessing
        image = example["image"]
        image = image.convert("RGB")
        processed_image = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)

        # Text preprocessing
        text = example["caption"]
        text_inputs = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Label preprocessing
        label = torch.tensor(example["label"]).float()

        return {
            "image": processed_image,
            "input_ids": text_inputs["input_ids"].squeeze(0),
            "attention_mask": text_inputs["attention_mask"].squeeze(0),
            "label": label
        }


In [6]:
train_dataset = CustomImageTextDataset(img_txt_ds, tokenizer, processor)

In [21]:
test_dataset = CustomImageTextDataset(img_txt_test, tokenizer, processor)

In [7]:
train_dataset[0]

{'image': tensor([[[ 1.4269,  1.3755,  1.3413,  ...,  1.4612,  1.4783,  1.4954],
          [ 1.3927,  1.3413,  1.3070,  ...,  1.5125,  1.5125,  1.4783],
          [ 1.3755,  1.3242,  1.3070,  ...,  1.5125,  1.5297,  1.5297],
          ...,
          [-0.7479, -0.6623, -0.5253,  ..., -0.6452, -0.5767, -0.6281],
          [-0.6281, -0.5424, -0.5424,  ..., -0.8164, -0.7137, -0.6794],
          [-0.7308, -0.6452, -0.5767,  ..., -1.0562, -1.0390, -0.9534]],
 
         [[ 1.5707,  1.5182,  1.4832,  ...,  1.6232,  1.6408,  1.6057],
          [ 1.5357,  1.4832,  1.4482,  ...,  1.6758,  1.6933,  1.6408],
          [ 1.5182,  1.4657,  1.4482,  ...,  1.6933,  1.7108,  1.6758],
          ...,
          [ 1.0455,  1.0805,  1.1681,  ...,  0.8880,  0.9230,  1.0105],
          [ 1.1506,  1.1681,  1.1331,  ...,  0.7479,  0.8004,  0.9580],
          [ 1.0455,  1.0980,  1.1155,  ...,  0.8704,  0.8529,  0.9230]],
 
         [[ 1.7163,  1.6640,  1.6291,  ...,  1.8557,  1.8557,  1.8034],
          [ 1.6814,

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=32,shuffle = True)

In [23]:
test_dataloader =  DataLoader(test_dataset, batch_size=32,shuffle = True)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 32
epochs = 10
learning_rate = 1e-4


In [9]:
minilm.to(device)
resnet.to(device)

ResNetModel(
  (embedder): ResNetEmbeddings(
    (embedder): ResNetConvLayer(
      (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (encoder): ResNetEncoder(
    (stages): ModuleList(
      (0): ResNetStage(
        (layers): Sequential(
          (0): ResNetBottleNeckLayer(
            (shortcut): ResNetShortCut(
              (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (layer): Sequential(
              (0): ResNetConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalizatio

In [10]:
class ImageTextClassifier(nn.Module):
    def __init__(self, image_dim, text_dim, hidden_dim=256):
        super(ImageTextClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(image_dim + text_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid() 
        )

    def forward(self, image_features, text_features):
        combined_features = torch.cat([image_features, text_features], dim=1)
        return self.fc(combined_features)

In [11]:
classifier = ImageTextClassifier(image_dim=2048, text_dim=384)  # ResNet output: 2048, MiniLM output: 384
classifier.to(device)

ImageTextClassifier(
  (fc): Sequential(
    (0): Linear(in_features=2432, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
    (4): Sigmoid()
  )
)

In [12]:
criterion = nn.BCELoss()  
optimizer = torch.optim.Adam(
    list(resnet.parameters()) + list(minilm.parameters()) + list(classifier.parameters()), 
    lr=learning_rate
)

In [13]:
resnet.train()
minilm.train()
classifier.train()

ImageTextClassifier(
  (fc): Sequential(
    (0): Linear(in_features=2432, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
    (4): Sigmoid()
  )
)

In [15]:
for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm(train_dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        text_inputs = batch["input_ids"].to(device)  # This is the tokenized text
        attention_masks = batch["attention_mask"].to(device)

        # Extract features from ResNet (image features)
        image_features = resnet(images)
        image_features = image_features.pooler_output.flatten(start_dim=1)
        
        # Extract features from MiniLM (text features)
        text_features = minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output


        # Forward pass
        outputs = classifier(image_features, text_features)
        loss = criterion(outputs.squeeze(), labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_dataloader):.4f}")


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [15:43<00:00,  2.74it/s]


Epoch 1/10, Loss: 0.6937


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [16:47<00:00,  2.57it/s]


Epoch 2/10, Loss: 0.6925


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [17:38<00:00,  2.45it/s]


Epoch 3/10, Loss: 0.6765


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [16:19<00:00,  2.64it/s]


Epoch 4/10, Loss: 0.5937


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [14:17<00:00,  3.02it/s]


Epoch 5/10, Loss: 0.4475


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [14:40<00:00,  2.94it/s]


Epoch 6/10, Loss: 0.3145


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [14:28<00:00,  2.98it/s]


Epoch 7/10, Loss: 0.2246


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [14:27<00:00,  2.98it/s]


Epoch 8/10, Loss: 0.1704


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [14:28<00:00,  2.98it/s]


Epoch 9/10, Loss: 0.1397


100%|█████████████████████████████████████████████████████████████████████| 2587/2587 [14:28<00:00,  2.98it/s]

Epoch 10/10, Loss: 0.1162





In [17]:
class ImageTextModel(nn.Module):
    def __init__(self, resnet, minilm, classifier):
        super(ImageTextModel, self).__init__()
        self.resnet = resnet
        self.minilm = minilm
        self.classifier = classifier

    def forward(self, images, text_inputs, attention_masks):
        # Extract features from ResNet (image features)
        image_features = self.resnet(images).pooler_output.flatten(start_dim=1)
        
        # Extract features from MiniLM (text features)
        text_features = self.minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output
        
        # Combine features and classify
        return self.classifier(image_features, text_features)


In [18]:
combined_model = ImageTextModel(resnet, minilm, classifier)

In [19]:
torch.save(combined_model.state_dict(), 'baseline_model.pth')


In [24]:
combined_model.eval()

ImageTextModel(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBottleNeckLayer(
              (shortcut): ResNetShortCut(
                (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64, 64, kernel_s

In [25]:
with torch.no_grad():
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch in tqdm(test_dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        text_inputs = batch["input_ids"].to(device)
        attention_masks = batch["attention_mask"].to(device)

        # Extract features from ResNet (image features)
        image_features = resnet(images)
        image_features = image_features.pooler_output.flatten(start_dim=1)

        # Extract features from MiniLM (text features)
        text_features = minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output

        # Forward pass
        outputs = classifier(image_features, text_features)

        # Compute the loss
        loss = criterion(outputs.squeeze(), labels)
        total_loss += loss.item()

        # Compute accuracy (if classification task)
        predictions = (outputs.squeeze() > 0.5).float()  # assuming binary classification (0 or 1)
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)

    # Print evaluation results
    avg_loss = total_loss / len(test_dataloader)
    accuracy = correct_predictions / total_samples * 100
    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}%")

100%|███████████████████████████████████████████████████████████████████████| 157/157 [02:52<00:00,  1.10s/it]

Test Loss: 1.8247
Test Accuracy: 49.70%



