Initial tests

In [1]:
import pandas as pd
import transformers
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import torch.nn.functional as F

In [2]:
# will check and use CUDA if available, otherwise MPS, otherwise CPU
# mps is super fast mac thing
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using device: MPS (Apple)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using device: CUDA (GPU) - {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using device: CPU")

Using device: CUDA (GPU) - Tesla T4


In [3]:
# Install dependencies as needed:
# pip install kagglehub[pandas-datasets]
# For downloading entire datasets, install the kaggle API client:
!pip install kaggle

import kagglehub
from kagglehub import KaggleDatasetAdapter

# --------- Run once ---------
# 1. Go to your Kaggle profile: https://www.kaggle.com/<your-username>/account
# 2. Under 'API', click 'Create New API Token' to download `kaggle.json`.
# 3. Upload `kaggle.json` to your Colab environment (e.g., File -> Upload to session storage).
# 4. Run the following commands in a separate cell or uncomment and run them here:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
# -----------------------------------------------------

import kaggle

# Download the entire dataset to a local directory and unzip it
dataset_name = "ashery/chexpert"
download_path = "./CheXpert-v1.0-small"
kaggle.api.dataset_download_files(dataset_name, path=download_path, unzip=True)

print(f"Dataset '{dataset_name}' downloaded to: {download_path}")
print("Files in the downloaded dataset:")
!ls -F {download_path}

Dataset URL: https://www.kaggle.com/datasets/ashery/chexpert
Dataset 'ashery/chexpert' downloaded to: ./CheXpert-v1.0-small
Files in the downloaded dataset:
train/	train.csv  valid/  valid.csv


In [4]:
# Load in data and then use only 50% for training and validation
train_full = pd.read_csv("/content/CheXpert-v1.0-small/train.csv")
val_full = pd.read_csv("/content/CheXpert-v1.0-small/valid.csv")
train_subset = train_full.sample(frac=0.5, random_state=42)
val_subset = val_full.sample(frac=0.5, random_state=42)

In [8]:
train_full.columns

Index(['Path', 'Sex', 'Age', 'Frontal/Lateral', 'AP/PA', 'No Finding',
       'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
       'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
       'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
       'Support Devices'],
      dtype='object')

In [5]:
train_full = train_full.fillna(0)
val_full = val_full.fillna(0)
train_subset = train_subset.fillna(0)
val_subset = val_subset.fillna(0)

Encoder Classes

### Text Encoder

In [11]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(
            "bert-base-uncased",
            output_hidden_states=True,
        )
        self.proj = nn.Linear(768, embed_dim)

    def forward(self, token_ids, attention_masks):
        outputs = self.bert(token_ids, attention_mask=attention_masks)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings = self.proj(cls_embeddings)
        # normalizing because we need to compare with image embeddings later
        # for the contrastive similarity
        embeddings = F.normalize(embeddings, p=2, dim=-1, eps=1e-6)
        return embeddings

### Image Encoder

In [6]:
# Use the torchvision's implementation of ResNeXt, but add FC layer to generate 512d embedding.
class VisionEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        resnet = models.resnext50_32x4d(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        in_dim = resnet.fc.in_features
        self.proj = nn.Linear(in_dim, embed_dim)

    def forward(self, x):
        features = self.backbone(x)
        features = features.squeeze(-1).squeeze(-1)
        z = self.proj(features)
        # convert to unit vectors for cosine similarity later
        z = F.normalize(z, p=2, dim=-1, eps=1e-6)
        return z

### Dataset

In [13]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        df = df.reset_index(drop=True) # Reset index to ensure 0-based indexing
        # Text
        self.text_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.reports = df.apply(
            generate_report, axis=1
        )

        # Vision
        self.images = df["Path"]
        self.transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
            ]
        )

    def __len__(self):
        return len(self.reports)  # This could work or we could do another way

    def __getitem__(self, idx):
        # Text part
        report = self.reports[idx]
        encoder = self.text_tokenizer.encode_plus(
            report,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        # Vision part
        img_path = self.images[idx]
        image = Image.open(img_path).convert("RGB")
        img_tensor = self.transform(image)

        return {
            "token_ids": encoder["input_ids"],
            "attention_masks": encoder["attention_mask"],
            "image_tensor": img_tensor,
        }

Samples come in as dataframe and we put in dataset. For each sample, we can then return inputs for text encoder (token ids & attention masks) & for vision encoder (tensor of pixels). We pass relevant inputs to text encoder and vision encoder, get embeddings for that sample. Then blah blah

Training Loop

In [13]:
def generate_report(row):
    labels = row.iloc[4:]
    findings_list = list(labels[labels == 1].index)
    findings = ", ".join(findings_list)
    return f"X-Ray report findings: {findings}"

In [15]:
# Create train & val datasets
train_dataset = CustomDataset(train_subset)
val_dataset = CustomDataset(val_subset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [16]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

### Contrastive Loss

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
K: n is batch size, h,w is height width, c is number of channels (3 for RGB)

# T[n, l] - minibatch of aligned texts
K: n is batch size, l is sequence length (number of tokens)

K: aligned means that they are paired up

# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
this happens in vision_enc
K: get features from resnet

T_f = text_encoder(T) #[n, d_t]
this happens in text_enc
K: get features from bert

# joint multimodal embedding [n, d_e]
K: this is projecting to joint embedding space
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t) #K: scale by temperature


# symmetric loss function
labels = np.arange(n) #K: correct match is on the diagonal
# K: now we compute symmetric loss
loss_i = cross_entropy_loss(logits, labels, axis=0) #K: image to text
loss_t = cross_entropy_loss(logits, labels, axis=1) #K: text to image
# K: now average in both directions
loss = (loss_i + loss_t)/2

In [7]:
# I_f, T_f, I_e, T_e are all done in other places
def contrastive_loss(image_embeddings, text_embeddings, temperature=0.07):
    # we can make the t a fixed parameter but most CLIP reproductions make it fixed
    # so just keeping it as argument for now

    batch_size = image_embeddings.shape[0]  # called n in the pseudo code

    # we already have normalized embeddings, so we can skip to cosine similarity part of pseudo code
    # "scaled pairwise cosine similarities [n, n]"
    # dividing by temperature to scale the logits
    logits = torch.matmul(image_embeddings, text_embeddings.T) / temperature
    # now "symmetric loss function"
    labels = torch.arange(batch_size, device=image_embeddings.device)
    loss_img_to_text = F.cross_entropy(logits, labels)
    loss_text_to_img = F.cross_entropy(logits.T, labels)

    # average loss in both directions
    loss = (loss_img_to_text + loss_text_to_img) / 2

    return loss

In [2]:
from google.colab import drive
import os

drive.mount('/content/drive')

SAVE_PATH = "/content/drive/MyDrive/models/vision_text_new/"

os.makedirs(SAVE_PATH, exist_ok=True)


Mounted at /content/drive


In [8]:
def recall_at_k(image_embeds, text_embeds, k_values=[1, 5]):
    """
    Compute Recall@K for image-to-text and text-to-image retrieval
    """
    # Compute similarity matrix: [N_images, N_texts]
    similarity = torch.matmul(image_embeds, text_embeds.T)

    results = {}

    # Image-to-Text Recall@K
    for k in k_values:
        # For each image, get top-k most similar texts
        top_k_indices = similarity.topk(k, dim=1).indices  # [N, k]

        # Check if correct text (same index) is in top-k
        correct = torch.zeros(len(image_embeds))
        for i in range(len(image_embeds)):
            if i in top_k_indices[i]:
                correct[i] = 1

        recall = correct.mean().item()
        results[f"image_to_text_recall@{k}"] = recall

    # Text-to-Image Recall@K
    for k in k_values:
        # For each text, get top-k most similar images
        top_k_indices = similarity.T.topk(k, dim=1).indices  # [N, k]

        # Check if correct image (same index) is in top-k
        correct = torch.zeros(len(text_embeds))
        for i in range(len(text_embeds)):
            if i in top_k_indices[i]:
                correct[i] = 1

        recall = correct.mean().item()
        results[f"text_to_image_recall@{k}"] = recall

    return results

In [21]:
# =========================
# Setup
# =========================
temperature = 0.07
max_epoch_number = 5
test_freq = 500            # validation frequency
ckpt_freq = 500            # checkpoint frequency (iterations)

# Initialize models
text_enc = TextEncoder(embed_dim=512).to(device)
vision_enc = VisionEncoder(embed_dim=512).to(device)

# Optimizer
# Use differentiated learning rates (THIS IS CRITICAL!)
lr_bert = 2e-5  # Much smaller for pretrained BERT
lr_other = 1e-4  # Normal for linear layers and vision encoder

optimizer = torch.optim.Adam([
    {'params': text_enc.bert.parameters(), 'lr': lr_bert},
    {'params': text_enc.proj.parameters(), 'lr': lr_other},
    {'params': vision_enc.parameters(), 'lr': lr_other}
])

# Freeze BERT + ResNet weights
# for p in text_enc.bert.parameters():
#     p.requires_grad = False

# for p in vision_enc.backbone.parameters():
#     p.requires_grad = False

# # Only projection layers train
# proj_params = list(text_enc.proj.parameters()) + list(vision_enc.proj.parameters())
# optimizer = torch.optim.Adam(proj_params, lr=1e-3)  # you can afford a higher LR here


# =========================
# Mixed Precision Setup
# =========================
scaler = torch.cuda.amp.GradScaler()

iteration = 0
for epoch in range(max_epoch_number):
    text_enc.train()
    vision_enc.train()
    batch_losses = []

    for batch in train_loader:
        images = batch["image_tensor"].to(device)
        token_ids = batch["token_ids"].squeeze(1).to(device)
        attention_masks = batch["attention_masks"].squeeze(1).to(device)

        optimizer.zero_grad()

        # =========================
        # Forward with mixed precision
        # =========================
        with torch.cuda.amp.autocast():
            image_embeddings = vision_enc(images)
            text_embeddings = text_enc(token_ids, attention_masks)

            if torch.isnan(image_embeddings).any():
                print("NaN in image_embeddings"); break
            if torch.isnan(text_embeddings).any():
                print("NaN in text_embeddings"); break

            loss = contrastive_loss(image_embeddings, text_embeddings, temperature)

            if torch.isnan(loss).any():
                print("NaN in loss"); break

        # =========================
        # Backward with gradient scaling
        # =========================
        batch_loss_value = loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_losses.append(batch_loss_value)

        # =========================
        # Validation
        # =========================
        if iteration % test_freq == 0:
            text_enc.eval()
            vision_enc.eval()
            val_losses = []
            all_img_embeds = []
            all_text_embeds = []

            with torch.no_grad():
                for batch_val in val_loader:
                    val_images = batch_val["image_tensor"].to(device)
                    val_token_ids = batch_val["token_ids"].squeeze(1).to(device)
                    val_attention_masks = batch_val["attention_masks"].squeeze(1).to(device)

                    # Mixed precision for validation too
                    with torch.cuda.amp.autocast():
                        val_img_embeds = vision_enc(val_images)
                        val_txt_embeds = text_enc(val_token_ids, val_attention_masks)
                        val_loss = contrastive_loss(val_img_embeds, val_txt_embeds, temperature)

                    val_losses.append(val_loss.item())
                    all_img_embeds.append(val_img_embeds.cpu())
                    all_text_embeds.append(val_txt_embeds.cpu())

            all_img_embeds = torch.cat(all_img_embeds, dim=0)   # shape (N, d)
            all_text_embeds = torch.cat(all_text_embeds, dim=0)

            avg_val_loss = float(np.mean(val_losses))
            recall_results = recall_at_k(all_img_embeds, all_text_embeds)
            rounded_recall_results = {key: round(value, 5) for key, value in recall_results.items()}
            print(f"epoch:{epoch+1:2d} iter:{iteration:4d} val loss:{avg_val_loss:.3f}, recall@k: {rounded_recall_results}")

            text_enc.train()
            vision_enc.train()

        # =========================
        # Periodic checkpoint
        # =========================
        if iteration % ckpt_freq == 0:
            ckpt_path = os.path.join(SAVE_PATH, f"checkpoint_iter_{iteration}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "iteration": iteration,
                    "text_enc": text_enc.state_dict(),
                    "vision_enc": vision_enc.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scaler": scaler.state_dict(),  # Save scaler state
                    "train_loss": float(batch_loss_value),
                },
                ckpt_path,
            )
            print(f"Checkpoint saved to {ckpt_path}")

        iteration += 1

    train_loss = float(np.mean(batch_losses))
    print(f"epoch:{epoch+1:2d} iter:{iteration:4d} train loss:{train_loss:.3f}\n")

# =========================
# Final checkpoint
# =========================
final_path = os.path.join("/content/final_checkpoint.pt")
torch.save(
    {
        "epoch": epoch,
        "iteration": iteration,
        "text_enc": text_enc.state_dict(),
        "vision_enc": vision_enc.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),  # Save scaler state
        "train_loss": train_loss,
    },
    final_path,
)
print(f"Final checkpoint saved to {final_path}")

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]



Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth


100%|██████████| 95.8M/95.8M [00:00<00:00, 203MB/s]
  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


epoch: 1 iter:   0 val loss:2.649, recall@k: {'image_to_text_recall@1': 0.00855, 'image_to_text_recall@5': 0.04274, 'text_to_image_recall@1': 0.00855, 'text_to_image_recall@5': 0.05128}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_new/checkpoint_iter_0.pt
epoch: 1 iter: 500 val loss:2.260, recall@k: {'image_to_text_recall@1': 0.04274, 'image_to_text_recall@5': 0.09402, 'text_to_image_recall@1': 0.03419, 'text_to_image_recall@5': 0.12821}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_new/checkpoint_iter_500.pt
epoch: 1 iter:1000 val loss:2.198, recall@k: {'image_to_text_recall@1': 0.04274, 'image_to_text_recall@5': 0.1453, 'text_to_image_recall@1': 0.03419, 'text_to_image_recall@5': 0.21368}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_new/checkpoint_iter_1000.pt
epoch: 1 iter:1500 val loss:2.113, recall@k: {'image_to_text_recall@1': 0.02564, 'image_to_text_recall@5': 0.16239, 'text_to_image_recall@1': 0.02564, 'text_to_image_recall@5': 

In [22]:
final_path = os.path.join(SAVE_PATH, "final_checkpoint.pt")

torch.save(
    {
        "epoch": epoch,
        "iteration": iteration,
        "text_enc": text_enc.state_dict(),
        "vision_enc": vision_enc.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),  # Save scaler state
        "train_loss": train_loss,
    },
    final_path,
)
print(f"Final checkpoint saved to {final_path}")

Final checkpoint saved to /content/drive/MyDrive/models/vision_text_new/final_checkpoint.pt


Continue Training on Full Dataset, With Adjusted LRs

In [23]:
train_full = train_full.drop("AP/PA", axis=1)
val_full = val_full.drop("AP/PA", axis=1)
train_full = train_full.fillna(0)
val_full = val_full.fillna(0)

train_dataset_full = CustomDataset(train_full)
val_dataset_full = CustomDataset(val_full)

train_loader_full = DataLoader(train_dataset_full, batch_size=16, shuffle=False)
val_loader_full = DataLoader(val_dataset_full, batch_size=16, shuffle=False)

In [28]:
# Initialize models with best weights from initial training
text_enc_full = TextEncoder(embed_dim=512).to(device)
vision_enc_full = VisionEncoder(embed_dim=512).to(device)

# Load the best weights
best_path = "/content/drive/MyDrive/models/vision_text_new/checkpoint_iter_25000.pt"
state_dict_text = torch.load(best_path)["text_enc"]
missing, unexpected = text_enc_full.load_state_dict(state_dict_text, strict=True)

state_dict_vision = torch.load(best_path)["vision_enc"]
vision_enc_full.load_state_dict(state_dict_vision, strict=True)
# Also option to do final weights from last checkpoint (end of 5 epochs)


SAVE_PATH_FULL = "/content/drive/MyDrive/models/vision_text_full_data/"

os.makedirs(SAVE_PATH_FULL, exist_ok=True)



In [29]:
# =========================
# Setup
# =========================
temperature = 0.07
max_epoch_number = 3
test_freq = 1000            # validation frequency
ckpt_freq = 1000            # checkpoint frequency (iterations)


# Optimizer
# Use even further differentiated learning rates this time
proj_lr = 1e-3      # new layers, small + need to move fast
vision_lr = 1e-4    # 1e-4 if you’re okay being a bit bolder
bert_lr = 2e-5      # already in the usual BERT fine-tuning range

proj_parameters = list(text_enc_full.proj.parameters()) + list(vision_enc_full.proj.parameters())

optimizer = torch.optim.Adam([
    {'params': text_enc_full.bert.parameters(), 'lr': bert_lr},
    {'params': proj_parameters, 'lr': proj_lr},
    {'params': vision_enc_full.backbone.parameters(), 'lr': vision_lr}
])

# Freeze BERT + ResNet weights
# for p in text_enc.bert.parameters():
#     p.requires_grad = False

# for p in vision_enc.backbone.parameters():
#     p.requires_grad = False

# # Only projection layers train
# proj_params = list(text_enc.proj.parameters()) + list(vision_enc.proj.parameters())
# optimizer = torch.optim.Adam(proj_params, lr=1e-3)  # you can afford a higher LR here


# =========================
# Mixed Precision Setup
# =========================
scaler = torch.cuda.amp.GradScaler()

iteration = 0
for epoch in range(max_epoch_number):
    text_enc_full.train()
    vision_enc_full.train()
    batch_losses = []

    for batch in train_loader_full:
        images = batch["image_tensor"].to(device)
        token_ids = batch["token_ids"].squeeze(1).to(device)
        attention_masks = batch["attention_masks"].squeeze(1).to(device)

        optimizer.zero_grad()

        # =========================
        # Forward with mixed precision
        # =========================
        with torch.cuda.amp.autocast():
            image_embeddings = vision_enc_full(images)
            text_embeddings = text_enc_full(token_ids, attention_masks)

            if torch.isnan(image_embeddings).any():
                print("NaN in image_embeddings"); break
            if torch.isnan(text_embeddings).any():
                print("NaN in text_embeddings"); break

            loss = contrastive_loss(image_embeddings, text_embeddings, temperature)

            if torch.isnan(loss).any():
                print("NaN in loss"); break

        # =========================
        # Backward with gradient scaling
        # =========================
        batch_loss_value = loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_losses.append(batch_loss_value)

        # =========================
        # Validation
        # =========================
        if iteration % test_freq == 0:
            text_enc_full.eval()
            vision_enc_full.eval()
            val_losses = []
            all_img_embeds = []
            all_text_embeds = []

            with torch.no_grad():
                for batch_val in val_loader_full:
                    val_images = batch_val["image_tensor"].to(device)
                    val_token_ids = batch_val["token_ids"].squeeze(1).to(device)
                    val_attention_masks = batch_val["attention_masks"].squeeze(1).to(device)

                    # Mixed precision for validation too
                    with torch.cuda.amp.autocast():
                        val_img_embeds = vision_enc_full(val_images)
                        val_txt_embeds = text_enc_full(val_token_ids, val_attention_masks)
                        val_loss = contrastive_loss(val_img_embeds, val_txt_embeds, temperature)

                    val_losses.append(val_loss.item())
                    all_img_embeds.append(val_img_embeds.cpu())
                    all_text_embeds.append(val_txt_embeds.cpu())

            all_img_embeds = torch.cat(all_img_embeds, dim=0)   # shape (N, d)
            all_text_embeds = torch.cat(all_text_embeds, dim=0)

            avg_val_loss = float(np.mean(val_losses))
            recall_results = recall_at_k(all_img_embeds, all_text_embeds)
            rounded_recall_results = {key: round(value, 5) for key, value in recall_results.items()}
            print(f"epoch:{epoch+1:2d} iter:{iteration:4d} val loss:{avg_val_loss:.3f}, recall@k: {rounded_recall_results}")

            text_enc_full.train()
            vision_enc_full.train()

        # =========================
        # Periodic checkpoint
        # =========================
        if iteration % ckpt_freq == 0:
            ckpt_path = os.path.join(SAVE_PATH_FULL, f"checkpoint_iter_{iteration}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "iteration": iteration,
                    "text_enc": text_enc_full.state_dict(),
                    "vision_enc": vision_enc_full.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scaler": scaler.state_dict(),  # Save scaler state
                    "train_loss": float(batch_loss_value),
                },
                ckpt_path,
            )
            print(f"Checkpoint saved to {ckpt_path}")

        iteration += 1

    train_loss = float(np.mean(batch_losses))
    print(f"epoch:{epoch+1:2d} iter:{iteration:4d} train loss:{train_loss:.3f}\n")

# =========================
# Final checkpoint
# =========================
final_path = "/content/final_checkpoint_full_data.pt"
torch.save(
    {
        "epoch": epoch,
        "iteration": iteration,
        "text_enc": text_enc_full.state_dict(),
        "vision_enc": vision_enc_full.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),  # Save scaler state
        "train_loss": train_loss,
    },
    final_path,
)
print(f"Final checkpoint saved to {final_path}")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


epoch: 1 iter:   0 val loss:2.141, recall@k: {'image_to_text_recall@1': 0.04274, 'image_to_text_recall@5': 0.11538, 'text_to_image_recall@1': 0.05128, 'text_to_image_recall@5': 0.16239}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_full_data/checkpoint_iter_0.pt
epoch: 1 iter:1000 val loss:2.028, recall@k: {'image_to_text_recall@1': 0.05556, 'image_to_text_recall@5': 0.17094, 'text_to_image_recall@1': 0.04274, 'text_to_image_recall@5': 0.18803}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_full_data/checkpoint_iter_1000.pt
epoch: 1 iter:2000 val loss:1.890, recall@k: {'image_to_text_recall@1': 0.03846, 'image_to_text_recall@5': 0.20085, 'text_to_image_recall@1': 0.05128, 'text_to_image_recall@5': 0.21368}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_full_data/checkpoint_iter_2000.pt
epoch: 1 iter:3000 val loss:1.961, recall@k: {'image_to_text_recall@1': 0.05128, 'image_to_text_recall@5': 0.1453, 'text_to_image_recall@1': 0.05556, 'text_t

In [30]:
final_path_f = os.path.join(SAVE_PATH_FULL, "final_checkpoint.pt")

torch.save(
    {
        "epoch": epoch,
        "iteration": iteration,
        "text_enc": text_enc_full.state_dict(),
        "vision_enc": vision_enc_full.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),  # Save scaler state
        "train_loss": train_loss,
    },
    final_path_f,
)
print(f"Final checkpoint saved to {final_path_f}")

Final checkpoint saved to /content/drive/MyDrive/models/vision_text_full_data/final_checkpoint.pt


Try with Clinical Bert as Text Encoder Backbone

In [9]:
from google.colab import drive
import os

drive.mount('/content/drive')

SAVE_PATH = "/content/drive/MyDrive/models/vision_text_clinical_full/"

os.makedirs(SAVE_PATH, exist_ok=True)

Mounted at /content/drive


In [10]:
def generate_report_updated(row):
    labels = row.iloc[5:]  # Skip Path, Sex, Age, Frontal/Lateral, AP/PA

    # Separate findings by certainty
    positive_findings = list(labels[labels == 1.0].index)
    uncertain_findings = list(labels[labels == -1.0].index)

    # Build the report
    report_parts = []

    # Add patient demographics for context
    age = int(row['Age']) if pd.notna(row['Age']) else None
    sex = row['Sex'].lower() if pd.notna(row['Sex']) else None
    view = row['Frontal/Lateral'].lower() if pd.notna(row['Frontal/Lateral']) else None

    # Start with view type
    if view:
        report_parts.append(f"{view.capitalize()} chest radiograph")
    else:
        report_parts.append("Chest radiograph")

    # Add demographics
    demo = []
    if age:
        demo.append(f"{age}-year-old")
    if sex:
        demo.append(sex)
    if demo:
        report_parts.append(f"of {' '.join(demo)} patient")

    # Add findings
    if len(positive_findings) == 0 and len(uncertain_findings) == 0:
        report_parts.append("demonstrates no acute cardiopulmonary abnormality")
    else:
        findings_text = []

        # Definite findings
        if positive_findings:
            findings_clean = [f.lower().replace('_', ' ') for f in positive_findings]
            findings_text.append("shows " + ", ".join(findings_clean))

        # Uncertain findings (optional - you might want to treat these differently)
        if uncertain_findings:
            uncertain_clean = [f.lower().replace('_', ' ') for f in uncertain_findings]
            findings_text.append("possible " + ", ".join(uncertain_clean))

        report_parts.append(". ".join(findings_text))

    return " ".join(report_parts) + "."


In [11]:
from transformers import AutoTokenizer, AutoModel

MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"

In [12]:
class ClinicalTextEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super(ClinicalTextEncoder, self).__init__()
        self.bert = AutoModel.from_pretrained(
            MODEL_NAME
        )
        self.proj = nn.Linear(768, embed_dim)

    def forward(self, token_ids, attention_masks):
        outputs = self.bert(token_ids, attention_mask=attention_masks)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        embeddings = self.proj(cls_embeddings)
        # normalizing because we need to compare with image embeddings later
        # for the contrastive similarity
        embeddings = F.normalize(embeddings, p=2, dim=-1, eps=1e-6)
        return embeddings

In [13]:

class ClinicalCustomDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        df = df.reset_index(drop=True) # Reset index to ensure 0-based indexing
        # Text
        self.text_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        self.reports = df.apply(
            generate_report_updated, axis=1
        )  # Maybe move generate report to inside the dataset later

        # Vision
        self.images = df["Path"]
        self.transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        )

    def __len__(self):
        return len(self.reports)  # This could work or we could do another way

    def __getitem__(self, idx):
        # Text part
        report = self.reports[idx]
        encoder = self.text_tokenizer.encode_plus(
            report,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        # Vision part
        img_path = self.images[idx]
        image = Image.open(img_path).convert("RGB")
        img_tensor = self.transform(image)

        return {
            "token_ids": encoder["input_ids"],
            "attention_masks": encoder["attention_mask"],
            "image_tensor": img_tensor,
        }  # Is vision part correct?

In [28]:
train_dataset_clinical = ClinicalCustomDataset(train_subset)
val_dataset_clinical = ClinicalCustomDataset(val_full) # Use full validation data since it's not that big anyway

train_loader_clinical = DataLoader(train_dataset_clinical, batch_size=32, shuffle=True)
val_loader_clinical = DataLoader(val_dataset_clinical, batch_size=32, shuffle=False)

In [29]:
SAVE_PATH_CLINICAL = "/content/drive/MyDrive/models/vision_text_clinical/"

os.makedirs(SAVE_PATH_CLINICAL, exist_ok=True)

In [30]:
# =========================
# Setup
# =========================
temperature = 0.2
max_epoch_number = 3
test_freq = 1000            # validation frequency
ckpt_freq = 1000            # checkpoint frequency (iterations)

# Initialize models
text_enc = ClinicalTextEncoder(embed_dim=512).to(device)
vision_enc = VisionEncoder(embed_dim=512).to(device)

# Optimizer
# Use even further differentiated learning rates this time
proj_lr = 1e-3      # new layers, small + need to move fast
vision_lr = 5e-5    # 1e-4 if you’re okay being a bit bolder
bert_lr = 3e-5      # already in the usual BERT fine-tuning range

proj_parameters = list(text_enc.proj.parameters()) + list(vision_enc.proj.parameters())

optimizer = torch.optim.Adam([
    {'params': text_enc.bert.parameters(), 'lr': bert_lr},
    {'params': proj_parameters, 'lr': proj_lr},
    {'params': vision_enc.backbone.parameters(), 'lr': vision_lr}
])

# =========================
# Mixed Precision Setup
# =========================
scaler = torch.cuda.amp.GradScaler()

iteration = 0
for epoch in range(max_epoch_number):
    text_enc.train()
    vision_enc.train()
    batch_losses = []

    for batch in train_loader_clinical:
        images = batch["image_tensor"].to(device)
        token_ids = batch["token_ids"].squeeze(1).to(device)
        attention_masks = batch["attention_masks"].squeeze(1).to(device)

        optimizer.zero_grad()

        # =========================
        # Forward with mixed precision
        # =========================
        with torch.cuda.amp.autocast():
            image_embeddings = vision_enc(images)
            text_embeddings = text_enc(token_ids, attention_masks)

            if torch.isnan(image_embeddings).any():
                print("NaN in image_embeddings"); break
            if torch.isnan(text_embeddings).any():
                print("NaN in text_embeddings"); break

            loss = contrastive_loss(image_embeddings, text_embeddings, temperature)

            if torch.isnan(loss).any():
                print("NaN in loss"); break

        # =========================
        # Backward with gradient scaling
        # =========================
        batch_loss_value = loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_losses.append(batch_loss_value)

        # =========================
        # Validation
        # =========================
        if iteration % test_freq == 0:
            text_enc.eval()
            vision_enc.eval()
            val_losses = []
            all_img_embeds = []
            all_text_embeds = []

            with torch.no_grad():
                for batch_val in val_loader_clinical:
                    val_images = batch_val["image_tensor"].to(device)
                    val_token_ids = batch_val["token_ids"].squeeze(1).to(device)
                    val_attention_masks = batch_val["attention_masks"].squeeze(1).to(device)

                    # Mixed precision for validation too
                    with torch.cuda.amp.autocast():
                        val_img_embeds = vision_enc(val_images)
                        val_txt_embeds = text_enc(val_token_ids, val_attention_masks)
                        val_loss = contrastive_loss(val_img_embeds, val_txt_embeds, temperature)

                    val_losses.append(val_loss.item())
                    all_img_embeds.append(val_img_embeds.cpu())
                    all_text_embeds.append(val_txt_embeds.cpu())

            all_img_embeds = torch.cat(all_img_embeds, dim=0)   # shape (N, d)
            all_text_embeds = torch.cat(all_text_embeds, dim=0)

            avg_val_loss = float(np.mean(val_losses))
            recall_results = recall_at_k(all_img_embeds, all_text_embeds)
            rounded_recall_results = {key: round(value, 5) for key, value in recall_results.items()}
            print(f"epoch:{epoch+1:2d} iter:{iteration:4d} val loss:{avg_val_loss:.3f}, recall@k: {rounded_recall_results}")

            text_enc.train()
            vision_enc.train()

        # =========================
        # Periodic checkpoint
        # =========================
        if iteration % ckpt_freq == 0:
            ckpt_path = os.path.join(SAVE_PATH_CLINICAL, f"checkpoint_iter_{iteration}.pt")
            torch.save(
                {
                    "epoch": epoch,
                    "iteration": iteration,
                    "text_enc": text_enc.state_dict(),
                    "vision_enc": vision_enc.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scaler": scaler.state_dict(),  # Save scaler state
                    "train_loss": float(batch_loss_value),
                },
                ckpt_path,
            )
            print(f"Checkpoint saved to {ckpt_path}")

        iteration += 1

    train_loss = float(np.mean(batch_losses))
    print(f"epoch:{epoch+1:2d} iter:{iteration:4d} train loss:{train_loss:.3f}\n")

# =========================
# Final checkpoint
# =========================
final_path = os.path.join(SAVE_PATH_CLINICAL, "final_checkpoint_clinical.pt")
torch.save(
    {
        "epoch": epoch,
        "iteration": iteration,
        "text_enc": text_enc.state_dict(),
        "vision_enc": vision_enc.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scaler": scaler.state_dict(),  # Save scaler state
        "train_loss": train_loss,
    },
    final_path,
)
print(f"Final checkpoint saved to {final_path}")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


epoch: 1 iter:   0 val loss:3.313, recall@k: {'image_to_text_recall@1': 0.00427, 'image_to_text_recall@5': 0.02991, 'text_to_image_recall@1': 0.00427, 'text_to_image_recall@5': 0.03419}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_clinical/checkpoint_iter_0.pt
epoch: 1 iter:1000 val loss:1.707, recall@k: {'image_to_text_recall@1': 0.12821, 'image_to_text_recall@5': 0.4188, 'text_to_image_recall@1': 0.08547, 'text_to_image_recall@5': 0.39744}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_clinical/checkpoint_iter_1000.pt
epoch: 1 iter:2000 val loss:1.528, recall@k: {'image_to_text_recall@1': 0.15385, 'image_to_text_recall@5': 0.50427, 'text_to_image_recall@1': 0.15385, 'text_to_image_recall@5': 0.48718}
Checkpoint saved to /content/drive/MyDrive/models/vision_text_clinical/checkpoint_iter_2000.pt
epoch: 1 iter:3000 val loss:1.490, recall@k: {'image_to_text_recall@1': 0.17521, 'image_to_text_recall@5': 0.49573, 'text_to_image_recall@1': 0.1453, 'text_to_im