In [196]:
# pip install scikit-learn

In [197]:
# ==========================================
#Preparing image and text embeddings for zero-shot and contrastive learning
# ==========================================
# Author: Morvarid Rahbar
# Student ID: 4033624008
# ==========================================

In [198]:
#  Imports
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, hamming_loss, f1_score
import torch.nn as nn

In [199]:
#  Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [188]:
#  Google Drive Mount
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [189]:
# Cell 2: Load Embeddings and Labels
import pandas as pd

# Paths to embeddings and csv
image_embedding_path = "/content/drive/MyDrive/Embedings/Image_embeding/image_embeddings_LRF30%.pt"
text_embedding_path = "/content/drive/MyDrive/Embedings/Text_embeding/disease_text_embeddings.pt"
csv_path = "/content/drive/MyDrive/chexpert_data_v2/valid.csv"

# Load embeddings (dictionaries)
image_embeddings = torch.load(image_embedding_path)  # dict: {image_path: tensor}
text_embeddings = torch.load(text_embedding_path)    # dict: {disease: tensor}

print(f"Loaded {len(image_embeddings)} image embeddings")
print(f"Loaded {len(text_embeddings)} text embeddings")

# Load CSV with labels
df = pd.read_csv(csv_path)

# List of diseases (consistent order)
disease_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
    'Pleural Effusion', 'Pneumonia', 'Pneumothorax'
]

# Create label dict with paths as keys and label tensors as values
label_dict = {}
prefix = "CheXpert-v1.0-small/"
for _, row in df.iterrows():
    path = row["Path"]
    if path.startswith(prefix):
        path = path[len(prefix):]
    label = []
    for disease in disease_columns:
        val = row[disease]
        if pd.isna(val):
            label.append(0.0)
        else:
            label.append(float(val))
    label_dict[path] = torch.tensor(label)


Loaded 234 image embeddings
Loaded 7 text embeddings


In [190]:
# Cell 3: Prepare Image and Label Tensors
img_keys = list(image_embeddings.keys())
img_tensor = torch.stack([image_embeddings[k] for k in img_keys])  # shape: [N, seq_len, img_emb_dim]
labels = torch.stack([label_dict[k] for k in img_keys])           # shape: [N, num_classes]

print(f"Image tensor shape: {img_tensor.shape}")
print(f"Labels shape: {labels.shape}")


Image tensor shape: torch.Size([234, 197, 64])
Labels shape: torch.Size([234, 7])


In [191]:
# Cell 4: Convert text embeddings dict to tensor
text_tensor = torch.stack([text_embeddings[k] for k in disease_columns]).to(device)  # [num_classes, txt_emb_dim]
print(f"Text tensor shape: {text_tensor.shape}")


Text tensor shape: torch.Size([7, 768])


In [192]:
# Cell 5: Define projection dimension and projection layers
dim = 512  # projection dimension

image_proj_layer = torch.nn.Linear(img_tensor.size(-1), dim).to(device)  # img_emb_dim -> dim
text_proj_layer = torch.nn.Linear(text_tensor.size(-1), dim).to(device)  # txt_emb_dim -> dim


In [193]:
# Cell 6: Project and normalize embeddings
img_tensor = img_tensor.to(device)

# Average image embeddings along seq_len dimension (e.g. 197)
image_repr = img_tensor.mean(dim=1)  # [N, img_emb_dim]

# Project image and text embeddings
image_proj = image_proj_layer(image_repr)  # [N, dim]
text_proj = text_proj_layer(text_tensor)   # [num_classes, dim]

# Normalize to unit vectors
image_proj = F.normalize(image_proj, dim=1)  # [N, dim]
text_proj = F.normalize(text_proj, dim=1)    # [num_classes, dim]

print(f"Projected image embeddings shape: {image_proj.shape}")
print(f"Projected text embeddings shape: {text_proj.shape}")


Projected image embeddings shape: torch.Size([234, 512])
Projected text embeddings shape: torch.Size([7, 512])


In [194]:
# Cell 7: Compute similarity, probabilities, and predictions
similarity = torch.matmul(image_proj, text_proj.T)  # [N, num_classes]
probs = torch.sigmoid(similarity)  # [N, num_classes]

# Thresholding predictions
threshold = 0.5
preds = (probs > threshold).float()

# Move labels and preds to CPU for sklearn
y_true = labels.cpu().numpy()
y_pred = preds.cpu().numpy()


In [195]:
# Cell 8: Evaluation Metrics and Reports


# Convert tensors to numpy arrays (detach from computation graph first)
y_true = labels.cpu().numpy()
y_pred_probs = probs.detach().cpu().numpy()

# Evaluate metrics at fixed threshold 0.5
threshold = 0.5
y_pred = (y_pred_probs > threshold).astype(int)

# Exact match accuracy (all labels correct per sample)
exact_match = np.all(y_true == y_pred, axis=1).mean()
print(f"Exact Match Accuracy: {exact_match:.4f}")

# Sample-wise accuracy (mean accuracy across all samples and labels)
sample_accuracy = (y_true == y_pred).mean()
print(f"Sample-wise Accuracy: {sample_accuracy:.4f}")

# Per-label accuracy
per_label_acc = (y_true == y_pred).mean(axis=0)
for i, disease in enumerate(disease_columns):
    print(f"{disease}: {per_label_acc[i]:.4f}")

# Macro (mean per-label) accuracy
macro_accuracy = per_label_acc.mean()
print(f"Macro (Mean Per-Label) Accuracy: {macro_accuracy:.4f}")

# Classification report
print("\nüîç Classification Report (Macro & Per-label):")
print(classification_report(y_true, y_pred, target_names=disease_columns, zero_division=0))

# Hamming Loss
hl = hamming_loss(y_true, y_pred)
print(f"\nHamming Loss: {hl:.4f}")

# Find best threshold based on F1 score
best_threshold = 0.0
best_f1 = 0.0
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (y_pred_probs > t).astype(int)
    f1 = f1_score(y_true, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = t

print(f"\nBest threshold: {best_threshold:.2f} with F1-score: {best_f1:.4f}")

Exact Match Accuracy: 0.0000
Sample-wise Accuracy: 0.1825
Atelectasis: 0.3419
Cardiomegaly: 0.2778
Consolidation: 0.1410
Edema: 0.1709
Pleural Effusion: 0.2650
Pneumonia: 0.0427
Pneumothorax: 0.0385
Macro (Mean Per-Label) Accuracy: 0.1825

üîç Classification Report (Macro & Per-label):
                  precision    recall  f1-score   support

     Atelectasis       0.34      1.00      0.51        80
    Cardiomegaly       0.28      0.94      0.43        68
   Consolidation       0.12      0.85      0.22        33
           Edema       0.16      0.78      0.27        45
Pleural Effusion       0.27      0.93      0.42        67
       Pneumonia       0.03      0.88      0.06         8
    Pneumothorax       0.03      1.00      0.07         8

       micro avg       0.18      0.92      0.30       309
       macro avg       0.18      0.91      0.28       309
    weighted avg       0.25      0.92      0.38       309
     samples avg       0.19      0.58      0.26       309


Hamming Loss

### Contrastive Learning