In [2]:
from datasets import load_dataset, load_from_disk
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
from torchvision import models, transforms
from transformers import AutoModel, AutoTokenizer,AutoImageProcessor, ResNetModel
from tqdm import tqdm
import torch.nn.functional as F

In [22]:
batch_size = 32
epochs = 5

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



In [5]:
img_txt_ds = load_from_disk('image_text_dataset')

In [6]:
processor_resnet = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
resnet.fc = nn.Identity()
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
minilm = AutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
teacher_model = CLIPModel.from_pretrained("fine_tuned_clip").to(device)
clip_processor = CLIPProcessor.from_pretrained("fine_tuned_clip_processor")


In [6]:
len(img_txt_ds)

82783

In [7]:
img_txt_test = load_from_disk('image_text_test')

In [8]:
from torch.utils.data import Dataset,DataLoader

In [9]:
class CustomImageTextDataset(Dataset):
    def __init__(self, dataset, clip_processor, tokenizer_miniLM, resnet_processor, max_length=50):
        """
        Args:
            dataset (list of dict): The dataset, should contain "image", "caption", and "label".
            clip_processor (Processor): The CLIP processor for image-text preprocessing.
            tokenizer_miniLM (Tokenizer): The tokenizer for MiniLM text inputs.
            other_model_processor (Processor): The processor for the other model's image inputs (e.g., ResNet).
            max_length (int): The maximum length for text truncation.
        """
        self.dataset = dataset
        self.clip_processor = clip_processor  
        self.tokenizer_miniLM = tokenizer_miniLM  
        self.resnet_processor = resnet_processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        clip_processed = self.clip_processor(
            text=example["caption"], 
            images=example["image"], 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt",
            max_length = self.max_length
        )

        miniLM_text_inputs = self.tokenizer_miniLM(
            example["caption"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        label = torch.tensor(example["label"]).float()

        image = example["image"]
        image = image.convert("RGB")
        resnet_image = self.resnet_processor(image, return_tensors="pt")["pixel_values"].squeeze(0)


        return {
            "clip_image": clip_processed["pixel_values"].squeeze(0),
            "clip_input_ids": clip_processed["input_ids"].squeeze(0),  
            "clip_attention_mask": clip_processed["attention_mask"].squeeze(0),  
            "miniLM_input_ids": miniLM_text_inputs["input_ids"].squeeze(0),
            "miniLM_attention_mask": miniLM_text_inputs["attention_mask"].squeeze(0),  
            "resnet_image": resnet_image,  
            "label": label
        }


In [10]:
train_dataset = CustomImageTextDataset(img_txt_ds, clip_processor,tokenizer,processor_resnet)

In [11]:
test_dataset = CustomImageTextDataset(img_txt_test, clip_processor,tokenizer,processor_resnet)

In [12]:
train_dataloader = DataLoader(train_dataset, batch_size=32,shuffle = True) 

In [13]:
from torch.utils.data import Subset
N = 100
small_dataset = Subset(train_dataset, range(320))
small_dataloader = DataLoader(small_dataset, batch_size=batch_size, shuffle=True)


In [13]:
test_dataloader =  DataLoader(test_dataset, batch_size=32,shuffle = True)

In [14]:
class ImageTextClassifier(nn.Module):
    def __init__(self, image_dim, text_dim, hidden_dim=256):
        super(ImageTextClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(image_dim + text_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            #nn.Sigmoid() 
        )

    def forward(self, image_features, text_features):
        combined_features = torch.cat([image_features, text_features], dim=1)
        return self.fc(combined_features)

In [15]:
minilm.to(device)
resnet.to(device)

ResNetModel(
  (embedder): ResNetEmbeddings(
    (embedder): ResNetConvLayer(
      (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
    )
    (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (encoder): ResNetEncoder(
    (stages): ModuleList(
      (0): ResNetStage(
        (layers): Sequential(
          (0): ResNetBottleNeckLayer(
            (shortcut): ResNetShortCut(
              (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (layer): Sequential(
              (0): ResNetConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalizatio

In [16]:
classifier = ImageTextClassifier(image_dim=2048, text_dim=384)  # ResNet output: 2048, MiniLM output: 384
classifier.to(device)

ImageTextClassifier(
  (fc): Sequential(
    (0): Linear(in_features=2432, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [17]:
class ImageTextModel(nn.Module):
    def __init__(self, resnet, minilm, classifier):
        super(ImageTextModel, self).__init__()
        self.resnet = resnet
        self.minilm = minilm
        self.classifier = classifier

    def forward(self, images, text_inputs, attention_masks):
        #ResNet image features
        image_features = self.resnet(images).pooler_output.flatten(start_dim=1)
        
        #MiniLM text features
        text_features = self.minilm(input_ids=text_inputs, attention_mask=attention_masks).pooler_output
        
        # Combine features and classify
        return self.classifier(image_features, text_features)


In [18]:
student_model = ImageTextModel(resnet, minilm, classifier)

In [19]:
student_model.to(device)

ImageTextModel(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBottleNeckLayer(
              (shortcut): ResNetShortCut(
                (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64, 64, kernel_s

In [20]:
teacher_model.eval()

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [21]:
for batch in train_dataloader:
    print(batch.keys())
    break

dict_keys(['clip_image', 'clip_input_ids', 'clip_attention_mask', 'miniLM_input_ids', 'miniLM_attention_mask', 'resnet_image', 'label'])


In [23]:
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-5)
ce_loss = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
    epoch_loss = 0
    total_classification_loss = 0
    total_kl_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        inputs = {
            "input_ids": batch["clip_input_ids"].to(device),
            "attention_mask": batch["clip_attention_mask"].to(device),
            "pixel_values": batch["clip_image"].to(device),
        }
        res_images = batch["resnet_image"].to(device)
        stud_inputs = batch["miniLM_input_ids"].to(device)  
        stud_attention_masks = batch["miniLM_attention_mask"].to(device)

        labels = batch["label"].to(device).float()

        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits_per_image
            teacher_outputs = torch.diag(teacher_logits)
            #print(teacher_outputs)

        
        image_features = student_model.resnet(res_images)
        image_features = image_features.pooler_output.flatten(start_dim=1)
        text_features = student_model.minilm(input_ids=stud_inputs, attention_mask=stud_attention_masks).pooler_output
        student_outputs = student_model.classifier(image_features, text_features)

        #logit distillation loss
        temperature = 3.0
        teacher_probs = torch.sigmoid(teacher_outputs / temperature)
        student_probs = torch.sigmoid(student_outputs.squeeze() / temperature)

        outputs= torch.sigmoid(student_outputs.squeeze())


        predictions = (outputs > 0.5).float()  
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)
        
        logit_distillation_loss = F.binary_cross_entropy(student_probs, teacher_probs) * (temperature ** 2)


        # classification loss
        classification_loss = ce_loss(student_outputs.squeeze(), labels)

        alpha = 0.5 
        total_loss = alpha * logit_distillation_loss + (1 - alpha) * classification_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        accuracy = correct_predictions / total_samples * 100
        epoch_loss += total_loss.item()
        total_classification_loss += classification_loss.item()
        total_kl_loss += logit_distillation_loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Epoch Loss: {epoch_loss / len(train_dataloader):.4f},Classification Loss: {total_classification_loss:.4f}, KD Loss: {total_kl_loss:.4f},Accuracy: {accuracy:.2f}%")


Epoch 1/5: 100%|████████████████████████████████████████████████████████████████████| 2587/2587 [20:07<00:00,  2.14it/s]


Epoch 1/5, Epoch Loss: 3.2787,Classification Loss: 1606.9914, KD Loss: 15357.0982,Accuracy: 64.08%


Epoch 2/5: 100%|████████████████████████████████████████████████████████████████████| 2587/2587 [20:20<00:00,  2.12it/s]


Epoch 2/5, Epoch Loss: 3.0278,Classification Loss: 1286.9043, KD Loss: 14379.0292,Accuracy: 76.98%


Epoch 3/5: 100%|████████████████████████████████████████████████████████████████████| 2587/2587 [20:20<00:00,  2.12it/s]


Epoch 3/5, Epoch Loss: 2.9218,Classification Loss: 1142.4317, KD Loss: 13975.0391,Accuracy: 80.80%


Epoch 4/5: 100%|████████████████████████████████████████████████████████████████████| 2587/2587 [20:20<00:00,  2.12it/s]


Epoch 4/5, Epoch Loss: 2.8444,Classification Loss: 1048.3767, KD Loss: 13668.5894,Accuracy: 82.96%


Epoch 5/5: 100%|████████████████████████████████████████████████████████████████████| 2587/2587 [20:20<00:00,  2.12it/s]

Epoch 5/5, Epoch Loss: 2.7478,Classification Loss: 939.8831, KD Loss: 13277.2020,Accuracy: 85.56%





In [23]:
student_model.load_state_dict(torch.load("distillation_model.pth",weights_only= True))


<All keys matched successfully>

In [25]:
torch.save(student_model.state_dict(), 'distillation_model.pth')

In [24]:
student_model.eval()
with torch.no_grad():
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    all_labels = []
    all_predictions = []

    for batch in tqdm(test_dataloader):
        res_images = batch["resnet_image"].to(device)
        stud_inputs = batch["miniLM_input_ids"].to(device)  
        stud_attention_masks = batch["miniLM_attention_mask"].to(device)
        labels = batch["label"].to(device).float()
        
        image_features = student_model.resnet(res_images)
        image_features = image_features.pooler_output.flatten(start_dim=1)

        text_features = student_model.minilm(input_ids=stud_inputs, attention_mask=stud_attention_masks).pooler_output

        outputs = student_model.classifier(image_features, text_features)
        outputs= torch.sigmoid(outputs.squeeze())

        predictions = (outputs > 0.5).float()  
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)
        
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predictions.cpu().numpy())
        
    accuracy = correct_predictions / total_samples * 100
    precision = precision_score(all_labels, all_predictions)
    recall = recall_score(all_labels, all_predictions)
    f1 = f1_score(all_labels, all_predictions)

    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Precision: {precision *100:.2f}%")
    print(f"Recall: {recall *100:.2f}%")
    print(f"F1 Score: {f1 * 100:.2f}%")

100%|█████████████████████████████████████████████████████████████████████████████████| 157/157 [00:53<00:00,  2.92it/s]

Accuracy: 85.56%
Precision: 80.26%
Recall: 94.32%
F1 Score: 86.72%



