In [None]:
# ==========================================
#Preparing image and text embeddings for zero-shot and contrastive learning __V2
# ==========================================
# Author: Morvarid Rahbar
# Student ID: 4033624008
# ==========================================

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, hamming_loss, accuracy_score
from google.colab import drive
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)


‚úÖ Device: cpu


In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
df = pd.read_csv("/content/drive/MyDrive/chexpert_data_v2/valid.csv")

image_embedding_path = "/content/drive/MyDrive/Embedings/Image_embeding/image_embeddings_LRF30%.pt"
text_embedding_path = "/content/drive/MyDrive/Embedings/Text_embeding/disease_text_embeddings.pt"

image_embeddings = torch.load(image_embedding_path)  # dict: key = img_path, value = [197, 64]
text_embeddings = torch.load(text_embedding_path)    # dict: key = class name, value = [768]



In [None]:
disease_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation',
    'Edema', 'Pleural Effusion', 'Pneumonia', 'Pneumothorax'
]

label_dict = {}
prefix = "CheXpert-v1.0-small/"

for i, row in df.iterrows():
    path = row["Path"]
    if path.startswith(prefix):
        path = path[len(prefix):]

    label = []
    for disease in disease_columns:
        val = row[disease]
        label.append(0.0 if pd.isna(val) else float(val))

    label_dict[path] = torch.tensor(label)


In [None]:
# --- ÿ≥ÿßÿÆÿ™ ÿ™ŸÜÿ≥Ÿàÿ± ÿßŸÖÿ®ÿØ€åŸÜ⁄Ø‚ÄåŸáÿß Ÿà ŸÑ€åÿ®ŸÑ‚ÄåŸáÿß ---
img_keys = list(image_embeddings.keys())
img_tensor = torch.stack([image_embeddings[k] for k in img_keys])         # [N, 197, 64]
img_tensor_pooled = img_tensor.max(dim=1).values  # [N, 64]

labels_list = []
valid_indices = []
for idx, k in enumerate(img_keys):
    label = label_dict[k]
    if (label == -1).all():

        continue
    else:
        label = torch.where(label == -1, torch.tensor(1.0), label)
        labels_list.append(label)
        valid_indices.append(idx)

labels = torch.stack(labels_list)  # [M, 7]
img_tensor_pooled = img_tensor_pooled[valid_indices]  # [M, 64]


text_tensor = torch.stack([text_embeddings[d] for d in disease_columns])  # [7, 768]


img_tensor_pooled = img_tensor_pooled.to(device)
labels = labels.to(device)
text_tensor = text_tensor.to(device)

In [None]:
# Project image and text embeddings to shared 256-dim space
image_proj = nn.Linear(64, 256).to(device)
text_proj = nn.Linear(768, 256).to(device)

# Forward projection
img_proj = F.normalize(image_proj(img_tensor_pooled), dim=1)     # [234, 256]
txt_proj = F.normalize(text_proj(text_tensor), dim=1)     # [7, 256]

# Similarity: [234, 7]
similarity = img_proj @ txt_proj.T


In [None]:
# Probabilities
probs = torch.sigmoid(similarity)  # [234, 7]

# Threshold = 0.5
threshold = 0.8
preds = (probs > threshold).float()

# Accuracy (Exact Match)
exact_match = (preds == labels).all(dim=1).float().mean()
print(f" Exact Match Accuracy: {exact_match.item():.4f}")

# Sample-wise Accuracy
sample_accuracy = (preds == labels).float().mean()
print(f" Sample-wise Accuracy: {sample_accuracy.item():.4f}")

# Per-label Accuracy
per_label_acc = (preds == labels).float().mean(dim=0)
for i, disease in enumerate(disease_columns):
    print(f"{disease}: {per_label_acc[i].item():.4f}")

# Macro Accuracy
macro_accuracy = per_label_acc.mean()
print(f" Macro (Mean Per-Label) Accuracy: {macro_accuracy.item():.4f}")


‚úÖ Exact Match Accuracy: 0.3889
‚úÖ Sample-wise Accuracy: 0.8114
Atelectasis: 0.6581
Cardiomegaly: 0.7094
Consolidation: 0.8590
Edema: 0.8077
Pleural Effusion: 0.7137
Pneumonia: 0.9658
Pneumothorax: 0.9658
‚úÖ Macro (Mean Per-Label) Accuracy: 0.8114


In [None]:
y_true = labels.cpu().numpy()
y_prob = probs.detach().cpu().numpy()

# Threshold Tuning
best_f1 = 0.0
best_thresh = 0.5
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (y_prob > t).astype(int)
    f1 = f1_score(y_true, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f"üîß Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

# Final Predictions with best threshold
y_pred = (y_prob > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(y_true, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(y_true, y_pred)
print(f" Hamming Loss: {hl:.4f}")


üîß Best Threshold: 0.10 ‚Üí F1: 0.3010

üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.34      1.00      0.51        80
    Cardiomegaly       0.29      1.00      0.45        68
   Consolidation       0.14      1.00      0.25        33
           Edema       0.19      1.00      0.32        45
Pleural Effusion       0.29      1.00      0.45        67
       Pneumonia       0.03      1.00      0.07         8
    Pneumothorax       0.03      1.00      0.07         8

       micro avg       0.19      1.00      0.32       309
       macro avg       0.19      1.00      0.30       309
    weighted avg       0.26      1.00      0.40       309
     samples avg       0.19      0.61      0.27       309

‚ùå Hamming Loss: 0.8114


### Contrastive learning

In [None]:
optimizer = optim.Adam(list(image_proj.parameters()) + list(text_proj.parameters()), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

num_epochs = 100
batch_size = 64

dataset_size = img_tensor_pooled.size(0)

for epoch in range(num_epochs):
    perm = torch.randperm(dataset_size)
    epoch_loss = 0
    for i in range(0, dataset_size, batch_size):
        optimizer.zero_grad()
        indices = perm[i:i+batch_size]

        img_batch = img_tensor_pooled[indices]   # [B, 64]
        label_batch = labels[indices]             # [B, 7]

        img_proj_batch = image_proj(img_batch)   # [B, 256]
        txt_proj = text_proj(text_tensor)        # [7, 256]

        img_proj_norm = F.normalize(img_proj_batch, dim=1)
        txt_proj_norm = F.normalize(txt_proj, dim=1)

        similarity = torch.matmul(img_proj_norm, txt_proj_norm.T)  # [B, 7]

        loss = criterion(similarity, label_batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * img_batch.size(0)

    avg_loss = epoch_loss / dataset_size
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

Epoch 1/100 - Loss: 0.4722
Epoch 2/100 - Loss: 0.4732
Epoch 3/100 - Loss: 0.4705
Epoch 4/100 - Loss: 0.4698
Epoch 5/100 - Loss: 0.4699
Epoch 6/100 - Loss: 0.4683
Epoch 7/100 - Loss: 0.4685
Epoch 8/100 - Loss: 0.4679
Epoch 9/100 - Loss: 0.4680
Epoch 10/100 - Loss: 0.4680
Epoch 11/100 - Loss: 0.4678
Epoch 12/100 - Loss: 0.4676
Epoch 13/100 - Loss: 0.4678
Epoch 14/100 - Loss: 0.4675
Epoch 15/100 - Loss: 0.4683
Epoch 16/100 - Loss: 0.4687
Epoch 17/100 - Loss: 0.4684
Epoch 18/100 - Loss: 0.4678
Epoch 19/100 - Loss: 0.4678
Epoch 20/100 - Loss: 0.4677
Epoch 21/100 - Loss: 0.4672
Epoch 22/100 - Loss: 0.4681
Epoch 23/100 - Loss: 0.4682
Epoch 24/100 - Loss: 0.4672
Epoch 25/100 - Loss: 0.4678
Epoch 26/100 - Loss: 0.4673
Epoch 27/100 - Loss: 0.4672
Epoch 28/100 - Loss: 0.4671
Epoch 29/100 - Loss: 0.4671
Epoch 30/100 - Loss: 0.4669
Epoch 31/100 - Loss: 0.4669
Epoch 32/100 - Loss: 0.4669
Epoch 33/100 - Loss: 0.4668
Epoch 34/100 - Loss: 0.4670
Epoch 35/100 - Loss: 0.4667
Epoch 36/100 - Loss: 0.4669
E

In [None]:

image_proj.eval()
text_proj.eval()

with torch.no_grad():
    img_proj = image_proj(img_tensor_pooled)
    txt_proj = text_proj(text_tensor)

    img_proj_norm = F.normalize(img_proj, dim=1)
    txt_proj_norm = F.normalize(txt_proj, dim=1)

    similarity = torch.matmul(img_proj_norm, txt_proj_norm.T)
    probs = torch.sigmoid(similarity).cpu().numpy()  # [M, 7]

labels_np = labels.cpu().numpy()


best_f1 = 0.0
best_thresh = 0.5
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (probs > t).astype(int)
    f1 = f1_score(labels_np, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f" Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

üîß Best Threshold: 0.35 ‚Üí F1: 0.5064


In [None]:

y_pred = (probs > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(labels_np, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(labels_np, y_pred)
print(f" Hamming Loss: {hl:.4f}")



üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.71      0.84      0.77        80
    Cardiomegaly       0.60      0.59      0.59        68
   Consolidation       0.51      0.88      0.64        33
           Edema       0.48      0.53      0.51        45
Pleural Effusion       0.70      0.81      0.75        67
       Pneumonia       0.15      0.62      0.24         8
    Pneumothorax       0.03      0.12      0.05         8

       micro avg       0.53      0.71      0.61       309
       macro avg       0.45      0.63      0.51       309
    weighted avg       0.59      0.71      0.64       309
     samples avg       0.23      0.34      0.26       309

‚ùå Hamming Loss: 0.1722
