In [2]:
from datasets import load_dataset, load_from_disk, Dataset
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
from torchvision import models, transforms
from transformers import AutoModel, AutoTokenizer,AutoImageProcessor, ResNetModel
from tqdm import tqdm

In [3]:
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
resnet.fc = nn.Identity()
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
minilm = AutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")


In [4]:
img_txt_ds = load_from_disk('image_text_dataset')

In [5]:
img_txt_test = load_from_disk('image_text_test')

In [6]:
def preprocess(example):
    '''
    Preprocesses a single example from the dataset for use in a multimodal machine learning model.
    Args:
        example (HuggingFace Dataset): A single data example from the dataset. 
                        It must include the following keys:
                        - "image": A PIL Image object.
                        - "caption": A string containing the caption text.
                        - "label": An integer representing the label (e.g. 0 or 1).
    Returns:
        dict: A dictionary containing preprocessed inputs for the model:
              - "image": A PyTorch tensor of processed pixel values.
              - "input_ids": A PyTorch tensor of tokenized input IDs for the caption.
              - "attention_mask": A PyTorch tensor representing the attention mask for the caption.
              - "label": A PyTorch tensor containing the label as a float.
    '''
    image = example["image"]
    image= image.convert("RGB")
    processed_image = processor(image, return_tensors="pt")["pixel_values"].squeeze(0) 

    text_inputs = tokenizer(
        example["caption"],
        padding="max_length",
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    
    label = torch.tensor(example["label"]).float()

    return {
        "image": processed_image,
        "input_ids": text_inputs["input_ids"].squeeze(0),
        "attention_mask": text_inputs["attention_mask"].squeeze(0),
        "label": label
    }


In [7]:
class CustomImageTextDataset(Dataset):
    def __init__(self, dataset, tokenizer, processor, max_length=128):
        self.dataset = dataset  
        self.tokenizer = tokenizer  
        self.processor = processor  
        self.max_length = max_length  

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Image preprocessing
        image = example["image"]
        image = image.convert("RGB")
        processed_image = self.processor(image, return_tensors="pt")["pixel_values"].squeeze(0)

        # Text preprocessing
        text = example["caption"]
        text_inputs = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Label preprocessing
        label = torch.tensor(example["label"]).float()

        return {
            "image": processed_image,
            "input_ids": text_inputs["input_ids"].squeeze(0),
            "attention_mask": text_inputs["attention_mask"].squeeze(0),
            "label": label
        }


In [8]:
train_dataset = CustomImageTextDataset(img_txt_ds, tokenizer, processor)

In [9]:
test_dataset = CustomImageTextDataset(img_txt_test, tokenizer, processor)

In [9]:
train_dataset[0]

{'image': tensor([[[-1.7240, -2.0152, -2.0152,  ..., -1.4843, -1.1589, -1.5014],
          [-1.9124, -1.9638, -1.8782,  ..., -0.7479, -0.6109, -0.2513],
          [-1.8953, -2.0323, -1.9295,  ...,  0.0741,  0.4337,  0.2282],
          ...,
          [-1.7754, -1.7412, -1.6898,  ..., -1.0390, -1.6384, -1.2103],
          [-1.5185, -1.5870, -1.6898,  ..., -1.3302, -1.5528, -1.7069],
          [-1.2959, -1.6042, -1.4843,  ..., -1.2274, -1.1247, -1.2959]],
 
         [[-1.5980, -1.9657, -1.9657,  ..., -1.3354, -0.9503, -1.3004],
          [-1.8256, -1.9132, -1.8606,  ..., -0.5301, -0.3725, -0.0049],
          [-1.7906, -1.9657, -1.8782,  ...,  0.3627,  0.7129,  0.4853],
          ...,
          [-1.7556, -1.8431, -1.8081,  ..., -0.9678, -1.5630, -1.1429],
          [-1.5980, -1.6681, -1.7731,  ..., -1.3004, -1.4930, -1.6856],
          [-1.4755, -1.6681, -1.5455,  ..., -1.2479, -1.1429, -1.3179]],
 
         [[-1.5256, -1.6999, -1.6476,  ..., -0.8633, -0.4450, -0.7761],
          [-1.6302,

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=32,shuffle = True)

In [11]:
test_dataloader =  DataLoader(test_dataset, batch_size=32,shuffle = True)

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 32
epochs = 10
learning_rate = 1e-4


In [13]:
minilm.to(device)
resnet.to(device)

ResNetModel(
  (embedder): ResNetEmbeddings(
    (embedder): ResNetConvLayer(
      (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (encoder): ResNetEncoder(
    (stages): ModuleList(
      (0): ResNetStage(
        (layers): Sequential(
          (0): ResNetBottleNeckLayer(
            (shortcut): ResNetShortCut(
              (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (layer): Sequential(
              (0): ResNetConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalizatio

In [14]:
class ImageTextClassifier(nn.Module):
    def __init__(self, image_dim, text_dim, hidden_dim=256):
        super(ImageTextClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(image_dim + text_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            #nn.Sigmoid() 
        )

    def forward(self, image_features, text_features):
        combined_features = torch.cat([image_features, text_features], dim=1)
        return self.fc(combined_features)

In [15]:
classifier = ImageTextClassifier(image_dim=2048, text_dim=384)  # ResNet output: 2048, MiniLM output: 384
classifier.to(device)

ImageTextClassifier(
  (fc): Sequential(
    (0): Linear(in_features=2432, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [16]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(
    list(resnet.parameters()) + list(minilm.parameters()) + list(classifier.parameters()), 
    lr=learning_rate
)

In [44]:
resnet.train()
minilm.train()
classifier.train()

ImageTextClassifier(
  (fc): Sequential(
    (0): Linear(in_features=2432, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [47]:
for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm(train_dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        text_inputs = batch["input_ids"].to(device)  # This is the tokenized text
        attention_masks = batch["attention_mask"].to(device)

        # Extract features from ResNet (image features)
        image_features = resnet(images)
        image_features = image_features.pooler_output.flatten(start_dim=1)
        
        # Extract features from MiniLM (text features)
        text_features = minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output


        # Forward pass
        outputs = classifier(image_features, text_features)
        loss = criterion(outputs.squeeze(1), labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_dataloader):.4f}")


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [15:00<00:00,  2.87it/s]


Epoch 1/10, Loss: 0.6937


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [16:29<00:00,  2.61it/s]


Epoch 2/10, Loss: 0.6925


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [16:35<00:00,  2.60it/s]


Epoch 3/10, Loss: 0.6797


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [16:36<00:00,  2.60it/s]


Epoch 4/10, Loss: 0.6079


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [16:35<00:00,  2.60it/s]


Epoch 5/10, Loss: 0.4707


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [16:36<00:00,  2.60it/s]


Epoch 6/10, Loss: 0.3340


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [16:36<00:00,  2.60it/s]


Epoch 7/10, Loss: 0.2425


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [15:08<00:00,  2.85it/s]


Epoch 8/10, Loss: 0.1848


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [14:00<00:00,  3.08it/s]


Epoch 9/10, Loss: 0.1460


100%|███████████████████████████████████████████████████████████████████████████████| 2587/2587 [14:02<00:00,  3.07it/s]

Epoch 10/10, Loss: 0.1228





In [78]:
class ImageTextModel(nn.Module):
    def __init__(self, resnet, minilm, classifier):
        super(ImageTextModel, self).__init__()
        self.resnet = resnet
        self.minilm = minilm
        self.classifier = classifier

    def forward(self, images, text_inputs, attention_masks):
        # Extract features from ResNet (image features)
        image_features = self.resnet(images).pooler_output.flatten(start_dim=1)
        
        # Extract features from MiniLM (text features)
        text_features = self.minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output
        
        # Combine features and classify
        return self.classifier(image_features, text_features)


In [79]:
combined_model = ImageTextModel(resnet, minilm, classifier)

In [57]:
torch.save(combined_model.state_dict(), 'baseline_model.pth')


In [80]:
combined_model.load_state_dict(torch.load("baseline_model.pth",weights_only= True))


<All keys matched successfully>

In [81]:
combined_model.eval()

ImageTextModel(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBottleNeckLayer(
              (shortcut): ResNetShortCut(
                (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64, 64, kernel_s

In [85]:
with torch.no_grad():
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    all_labels = []
    all_predictions = []

    for batch in tqdm(test_dataloader):
        images = batch["image"].to(device)
        labels = batch["label"].to(device)
        text_inputs = batch["input_ids"].to(device)
        attention_masks = batch["attention_mask"].to(device)

        # Extract features from ResNet (image features)
        image_features = resnet(images)
        image_features = image_features.pooler_output.flatten(start_dim=1)

        # Extract features from MiniLM (text features)
        text_features = minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output

        # Forward pass
        outputs = classifier(image_features, text_features)
        outputs= torch.sigmoid(outputs).squeeze()


        # Compute the loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Compute predictions
        predictions = (outputs > 0.5).float()  # assuming binary classification (0 or 1)
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)

        # Store the true labels and predictions for later metric calculation
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predictions.cpu().numpy())

    # Compute the metrics
    accuracy = correct_predictions / total_samples * 100
    precision = precision_score(all_labels, all_predictions)
    recall = recall_score(all_labels, all_predictions)
    f1 = f1_score(all_labels, all_predictions)

    # Print evaluation results
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Precision: {precision *100:.2f}%")
    print(f"Recall: {recall *100:.2f}%")
    print(f"F1 Score: {f1 * 100:.2f}%")

100%|█████████████████████████████████████████████████████████████████████████████████| 157/157 [00:37<00:00,  4.21it/s]

Accuracy: 49.22%
Precision: 49.27%
Recall: 52.88%
F1 Score: 51.01%



