In [7]:
# ==========================================
#Preparing image and text embeddings for zero-shot and contrastive learning __V2
# ==========================================
# Author: Morvarid Rahbar
# Student ID: 4033624008
# ==========================================

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, hamming_loss, accuracy_score
from google.colab import drive
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)


 Device: cpu


In [9]:
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
df = pd.read_csv("/content/drive/MyDrive/chexpert_data_v2/valid.csv")

image_embedding_path = "/content/drive/MyDrive/Embedings/Image_embeding/image_embeddings_LRF30%.pt"
text_embedding_path = "/content/drive/MyDrive/Embedings/Text_embeding/disease_text_embeddings.pt"

image_embeddings = torch.load(image_embedding_path)  # dict: key = img_path, value = [197, 64]
text_embeddings = torch.load(text_embedding_path)    # dict: key = class name, value = [768]



In [11]:
disease_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation',
    'Edema', 'Pleural Effusion', 'Pneumonia', 'Pneumothorax'
]

label_dict = {}
prefix = "CheXpert-v1.0-small/"

for i, row in df.iterrows():
    path = row["Path"]
    if path.startswith(prefix):
        path = path[len(prefix):]

    label = []
    for disease in disease_columns:
        val = row[disease]
        label.append(0.0 if pd.isna(val) else float(val))

    label_dict[path] = torch.tensor(label)


In [12]:
# --- ÿ≥ÿßÿÆÿ™ ÿ™ŸÜÿ≥Ÿàÿ± ÿßŸÖÿ®ÿØ€åŸÜ⁄Ø‚ÄåŸáÿß Ÿà ŸÑ€åÿ®ŸÑ‚ÄåŸáÿß ---
img_keys = list(image_embeddings.keys())
img_tensor = torch.stack([image_embeddings[k] for k in img_keys])         # [N, 197, 64]
img_tensor_pooled = img_tensor.max(dim=1).values  # [N, 64]

labels_list = []
valid_indices = []
for idx, k in enumerate(img_keys):
    label = label_dict[k]
    if (label == -1).all():

        continue
    else:
        label = torch.where(label == -1, torch.tensor(1.0), label)
        labels_list.append(label)
        valid_indices.append(idx)

labels = torch.stack(labels_list)  # [M, 7]
img_tensor_pooled = img_tensor_pooled[valid_indices]  # [M, 64]


text_tensor = torch.stack([text_embeddings[d] for d in disease_columns])  # [7, 768]


img_tensor_pooled = img_tensor_pooled.to(device)
labels = labels.to(device)
text_tensor = text_tensor.to(device)

In [13]:
# Project image and text embeddings to shared 256-dim space
image_proj = nn.Linear(64, 256).to(device)
text_proj = nn.Linear(768, 256).to(device)

# Forward projection
img_proj = F.normalize(image_proj(img_tensor_pooled), dim=1)     # [234, 256]
txt_proj = F.normalize(text_proj(text_tensor), dim=1)     # [7, 256]

# Similarity: [234, 7]
similarity = img_proj @ txt_proj.T


In [14]:
# Probabilities
probs = torch.sigmoid(similarity)  # [234, 7]

# Threshold = 0.5
threshold = 0.8
preds = (probs > threshold).float()

# Accuracy (Exact Match)
exact_match = (preds == labels).all(dim=1).float().mean()
print(f" Exact Match Accuracy: {exact_match.item():.4f}")

# Sample-wise Accuracy
sample_accuracy = (preds == labels).float().mean()
print(f" Sample-wise Accuracy: {sample_accuracy.item():.4f}")

# Per-label Accuracy
per_label_acc = (preds == labels).float().mean(dim=0)
for i, disease in enumerate(disease_columns):
    print(f"{disease}: {per_label_acc[i].item():.4f}")

# Macro Accuracy
macro_accuracy = per_label_acc.mean()
print(f" Macro (Mean Per-Label) Accuracy: {macro_accuracy.item():.4f}")


 Exact Match Accuracy: 0.3889
 Sample-wise Accuracy: 0.8114
Atelectasis: 0.6581
Cardiomegaly: 0.7094
Consolidation: 0.8590
Edema: 0.8077
Pleural Effusion: 0.7137
Pneumonia: 0.9658
Pneumothorax: 0.9658
 Macro (Mean Per-Label) Accuracy: 0.8114


In [15]:
y_true = labels.cpu().numpy()
y_prob = probs.detach().cpu().numpy()

# Threshold Tuning
best_f1 = 0.0
best_thresh = 0.5
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (y_prob > t).astype(int)
    f1 = f1_score(y_true, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f"üîß Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

# Final Predictions with best threshold
y_pred = (y_prob > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(y_true, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(y_true, y_pred)
print(f" Hamming Loss: {hl:.4f}")


üîß Best Threshold: 0.10 ‚Üí F1: 0.3010

üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.34      1.00      0.51        80
    Cardiomegaly       0.29      1.00      0.45        68
   Consolidation       0.14      1.00      0.25        33
           Edema       0.19      1.00      0.32        45
Pleural Effusion       0.29      1.00      0.45        67
       Pneumonia       0.03      1.00      0.07         8
    Pneumothorax       0.03      1.00      0.07         8

       micro avg       0.19      1.00      0.32       309
       macro avg       0.19      1.00      0.30       309
    weighted avg       0.26      1.00      0.40       309
     samples avg       0.19      0.61      0.27       309

 Hamming Loss: 0.8114


### BCW learning

In [16]:
optimizer = optim.Adam(list(image_proj.parameters()) + list(text_proj.parameters()), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

num_epochs = 100
batch_size = 64

dataset_size = img_tensor_pooled.size(0)

for epoch in range(num_epochs):
    perm = torch.randperm(dataset_size)
    epoch_loss = 0
    for i in range(0, dataset_size, batch_size):
        optimizer.zero_grad()
        indices = perm[i:i+batch_size]

        img_batch = img_tensor_pooled[indices]   # [B, 64]
        label_batch = labels[indices]             # [B, 7]

        img_proj_batch = image_proj(img_batch)   # [B, 256]
        txt_proj = text_proj(text_tensor)        # [7, 256]

        img_proj_norm = F.normalize(img_proj_batch, dim=1)
        txt_proj_norm = F.normalize(txt_proj, dim=1)

        similarity = torch.matmul(img_proj_norm, txt_proj_norm.T)  # [B, 7]

        loss = criterion(similarity, label_batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * img_batch.size(0)

    avg_loss = epoch_loss / dataset_size
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

Epoch 1/100 - Loss: 0.5909
Epoch 2/100 - Loss: 0.5023
Epoch 3/100 - Loss: 0.4983
Epoch 4/100 - Loss: 0.4963
Epoch 5/100 - Loss: 0.4946
Epoch 6/100 - Loss: 0.4929
Epoch 7/100 - Loss: 0.4913
Epoch 8/100 - Loss: 0.4898
Epoch 9/100 - Loss: 0.4884
Epoch 10/100 - Loss: 0.4870
Epoch 11/100 - Loss: 0.4858
Epoch 12/100 - Loss: 0.4852
Epoch 13/100 - Loss: 0.4841
Epoch 14/100 - Loss: 0.4834
Epoch 15/100 - Loss: 0.4829
Epoch 16/100 - Loss: 0.4823
Epoch 17/100 - Loss: 0.4817
Epoch 18/100 - Loss: 0.4811
Epoch 19/100 - Loss: 0.4807
Epoch 20/100 - Loss: 0.4802
Epoch 21/100 - Loss: 0.4797
Epoch 22/100 - Loss: 0.4792
Epoch 23/100 - Loss: 0.4787
Epoch 24/100 - Loss: 0.4783
Epoch 25/100 - Loss: 0.4778
Epoch 26/100 - Loss: 0.4774
Epoch 27/100 - Loss: 0.4769
Epoch 28/100 - Loss: 0.4766
Epoch 29/100 - Loss: 0.4760
Epoch 30/100 - Loss: 0.4756
Epoch 31/100 - Loss: 0.4752
Epoch 32/100 - Loss: 0.4747
Epoch 33/100 - Loss: 0.4743
Epoch 34/100 - Loss: 0.4737
Epoch 35/100 - Loss: 0.4734
Epoch 36/100 - Loss: 0.4729
E

In [17]:

image_proj.eval()
text_proj.eval()

with torch.no_grad():
    img_proj = image_proj(img_tensor_pooled)
    txt_proj = text_proj(text_tensor)

    img_proj_norm = F.normalize(img_proj, dim=1)
    txt_proj_norm = F.normalize(txt_proj, dim=1)

    similarity = torch.matmul(img_proj_norm, txt_proj_norm.T)
    probs = torch.sigmoid(similarity).cpu().numpy()  # [M, 7]

labels_np = labels.cpu().numpy()


best_f1 = 0.0
best_thresh = 0.5
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (probs > t).astype(int)
    f1 = f1_score(labels_np, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f" Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

 Best Threshold: 0.35 ‚Üí F1: 0.4928


In [18]:

y_pred = (probs > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(labels_np, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(labels_np, y_pred)
print(f" Hamming Loss: {hl:.4f}")



üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.69      0.85      0.76        80
    Cardiomegaly       0.59      0.53      0.56        68
   Consolidation       0.51      0.85      0.64        33
           Edema       0.50      0.56      0.53        45
Pleural Effusion       0.67      0.87      0.75        67
       Pneumonia       0.10      0.50      0.17         8
    Pneumothorax       0.03      0.12      0.05         8

       micro avg       0.52      0.71      0.60       309
       macro avg       0.44      0.61      0.49       309
    weighted avg       0.58      0.71      0.63       309
     samples avg       0.23      0.35      0.26       309

 Hamming Loss: 0.1807


### Contrastive Learning

In [19]:
def contrastive_loss(image_features, text_features, temperature=0.2):
    """
    image_features: [B, D]
    text_features: [B, D]
    Returns:
        scalar contrastive loss between image and its positive label embedding
    """
    image_features = F.normalize(image_features, dim=1)
    text_features = F.normalize(text_features, dim=1)

    logits = torch.matmul(image_features, text_features.T) / temperature
    targets = torch.arange(image_features.size(0)).to(image_features.device)

    loss_i = F.cross_entropy(logits, targets)
    loss_t = F.cross_entropy(logits.T, targets)

    return (loss_i + loss_t) / 2


In [20]:
def get_positive_text_embeddings(labels_batch, text_tensor):
    """
    For each image in batch, select one of its positive labels randomly.
    Returns: [B, 768] ‚Üí input to text_proj
    """
    B = labels_batch.size(0)
    selected_texts = []

    for i in range(B):
        positive_indices = torch.nonzero(labels_batch[i]).squeeze(1)

        if len(positive_indices) == 0:
            idx = torch.tensor(0).to(labels_batch.device)
        else:
            idx = positive_indices[torch.randint(len(positive_indices), (1,)).item()]

        selected_texts.append(text_tensor[idx])

    return torch.stack(selected_texts)  # [B, 768]


In [21]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(img_tensor_pooled, labels)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


In [22]:
image_proj = nn.Sequential(
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.LayerNorm(128),
    nn.Dropout(0.3),
    nn.Linear(128, 256),
    nn.LayerNorm(256)
).to(device)

text_proj = nn.Sequential(
    nn.Linear(768, 128),
    nn.ReLU(),
    nn.LayerNorm(128),
    nn.Dropout(0.3),
    nn.Linear(128, 256),
    nn.LayerNorm(256)
).to(device)


In [23]:
optimizer = torch.optim.Adam(
    list(image_proj.parameters()) + list(text_proj.parameters()), lr=5e-4
)

num_epochs = 100

for epoch in range(num_epochs):
    image_proj.train()
    text_proj.train()
    total_loss = 0.0

    for img_batch, label_batch in train_loader:
        img_batch = img_batch.to(device)
        label_batch = label_batch.to(device)

        optimizer.zero_grad()

        # Forward: image ‚Üí projected
        img_proj = image_proj(img_batch)  # [B, 256]

        # Text embeddings for positive labels
        selected_txt_embed = get_positive_text_embeddings(label_batch, text_tensor)  # [B, 768]
        txt_proj = text_proj(selected_txt_embed)  # [B, 256]

        # Contrastive loss
        loss = contrastive_loss(img_proj, txt_proj)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * img_batch.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")


Epoch 1/100 - Loss: 4.0931
Epoch 2/100 - Loss: 4.0540
Epoch 3/100 - Loss: 4.0370
Epoch 4/100 - Loss: 3.9910
Epoch 5/100 - Loss: 3.9271
Epoch 6/100 - Loss: 3.9795
Epoch 7/100 - Loss: 3.9844
Epoch 8/100 - Loss: 3.9322
Epoch 9/100 - Loss: 3.9365
Epoch 10/100 - Loss: 3.9613
Epoch 11/100 - Loss: 3.9459
Epoch 12/100 - Loss: 3.9199
Epoch 13/100 - Loss: 3.9081
Epoch 14/100 - Loss: 3.9123
Epoch 15/100 - Loss: 3.9083
Epoch 16/100 - Loss: 3.9154
Epoch 17/100 - Loss: 3.9120
Epoch 18/100 - Loss: 3.8964
Epoch 19/100 - Loss: 3.8614
Epoch 20/100 - Loss: 3.8923
Epoch 21/100 - Loss: 3.8558
Epoch 22/100 - Loss: 3.8542
Epoch 23/100 - Loss: 3.8537
Epoch 24/100 - Loss: 3.8579
Epoch 25/100 - Loss: 3.8146
Epoch 26/100 - Loss: 3.8953
Epoch 27/100 - Loss: 3.8358
Epoch 28/100 - Loss: 3.7985
Epoch 29/100 - Loss: 3.8132
Epoch 30/100 - Loss: 3.8329
Epoch 31/100 - Loss: 3.8346
Epoch 32/100 - Loss: 3.8066
Epoch 33/100 - Loss: 3.8177
Epoch 34/100 - Loss: 3.7838
Epoch 35/100 - Loss: 3.8370
Epoch 36/100 - Loss: 3.8993
E

In [24]:
image_proj.eval()
text_proj.eval()

with torch.no_grad():
    img_proj_all = F.normalize(image_proj(img_tensor_pooled), dim=1)  # [M, 256]
    txt_proj_all = F.normalize(text_proj(text_tensor), dim=1)         # [7, 256]

    similarity = torch.matmul(img_proj_all, txt_proj_all.T)           # [M, 7]
    probs = torch.sigmoid(similarity).cpu().numpy()
    labels_np = labels.cpu().numpy()


In [25]:
from sklearn.metrics import f1_score, classification_report, hamming_loss

best_f1 = 0.0
best_thresh = 0.5

for t in np.arange(0.1, 0.9, 0.01):
    preds_t = (probs > t).astype(int)
    f1 = f1_score(labels_np, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f"‚úÖ Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

y_pred = (probs > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(labels_np, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(labels_np, y_pred)
print(f"üîª Hamming Loss: {hl:.4f}")


‚úÖ Best Threshold: 0.57 ‚Üí F1: 0.4680

üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.14      0.16      0.15        80
    Cardiomegaly       0.69      0.56      0.62        68
   Consolidation       0.47      0.94      0.63        33
           Edema       0.55      0.84      0.67        45
Pleural Effusion       0.79      0.84      0.81        67
       Pneumonia       0.12      0.88      0.21         8
    Pneumothorax       0.11      0.75      0.19         8

       micro avg       0.41      0.61      0.49       309
       macro avg       0.41      0.71      0.47       309
    weighted avg       0.50      0.61      0.53       309
     samples avg       0.32      0.39      0.33       309

üîª Hamming Loss: 0.2424
