In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, AdamW
import numpy as np

# Dataset Class for Custom Image-Caption Pairs
class CustomImageTextDataset(Dataset):
    def __init__(self, caption_file, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.data = []

        # Read image-caption pairs from file
        with open(caption_file, 'r') as file:
            for line in file:
                parts = line.strip().split('\t')
                if len(parts) == 2:
                    image_name, caption = parts
                    image_path = os.path.join(image_dir, image_name + ".jpg")
                    if os.path.exists(image_path):
                        self.data.append((image_path, caption))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, caption = self.data[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, caption

# Collate Function to Combine Batches
def collate_fn(batch):
    images, captions = zip(*batch)
    return list(images), list(captions)

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create Datasets and DataLoaders
train_dataset = CustomImageTextDataset("data/train/radiology/captions.txt", "data/train/radiology/images", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)

val_dataset = CustomImageTextDataset("data/validation/radiology/captions.txt", "data/validation/radiology/images", transform=transform)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2, collate_fn=collate_fn)

# Load CLIP Model and Processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

# Set model to training mode
clip_model.train()

# Define the optimizer and loss function
optimizer = AdamW(clip_model.parameters(), lr=5e-6)
criterion = nn.CrossEntropyLoss()

# Fine-tuning Loop
num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch [{epoch+1}/{num_epochs}]")

    # Training Phase
    clip_model.train()
    total_loss = 0
    for images, captions in tqdm(train_loader, desc="Training"):
        # Process images and text
        inputs = clip_processor(text=captions, images=images, return_tensors="pt", padding=True, truncation=True).to(device)

        # Get logits
        outputs = clip_model(**inputs)
        logits_per_image = outputs.logits_per_image  # Image-to-text similarity
        logits_per_text = outputs.logits_per_text    # Text-to-image similarity

        # Construct ground truth labels
        ground_truth = torch.arange(len(images)).long().to(device)  # [0, 1, 2, ..., batch_size-1]

        # Compute loss (Symmetric cross-entropy between image-to-text and text-to-image)
        loss_i2t = criterion(logits_per_image, ground_truth)
        loss_t2i = criterion(logits_per_text, ground_truth)
        loss = (loss_i2t + loss_t2i) / 2
        total_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Training Loss: {total_loss / len(train_loader):.4f}")

    # Validation Phase
    clip_model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, captions in tqdm(val_loader, desc="Validation"):
            inputs = clip_processor(text=captions, images=images, return_tensors="pt", padding=True, truncation=True).to(device)
            outputs = clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text

            ground_truth = torch.arange(len(images)).long().to(device)
            loss_i2t = criterion(logits_per_image, ground_truth)
            loss_t2i = criterion(logits_per_text, ground_truth)
            loss = (loss_i2t + loss_t2i) / 2
            total_val_loss += loss.item()

    print(f"Validation Loss: {total_val_loss / len(val_loader):.4f}")

# Save the fine-tuned CLIP model
clip_model.save_pretrained("clip_finetuned")
clip_processor.save_pretrained("clip_finetuned")
print("Fine-tuned CLIP model saved.")



Epoch [1/5]


Training:   0%|          | 0/16354 [00:00<?, ?it/s]It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
