In [1]:
# ==========================================
#Preparing image and text embeddings for zero-shot and contrastive learning __V2
# ==========================================
# Author: Morvarid Rahbar
# Student ID: 4033624008
# ==========================================

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, classification_report, hamming_loss, accuracy_score
from google.colab import drive
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)


 Device: cpu


In [3]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
df = pd.read_csv("/content/drive/MyDrive/chexpert_data_v2/train.csv")

image_embedding_path = "/content/drive/MyDrive/image_embeddings_LRF30_fulltrainversion%.pt"
text_embedding_path = "/content/drive/MyDrive/Embedings/Text_embeding/disease_text_embeddings.pt"

image_embeddings = torch.load(image_embedding_path)  # dict: key = img_path, value = [197, 64]
text_embeddings = torch.load(text_embedding_path)    # dict: key = class name, value = [768]



In [5]:
disease_columns = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation',
    'Edema', 'Pleural Effusion', 'Pneumonia', 'Pneumothorax'
]

label_dict = {}
prefix = "CheXpert-v1.0-small/"

for i, row in df.iterrows():
    path = row["Path"]
    if path.startswith(prefix):
        path = path[len(prefix):]

    label = []
    for disease in disease_columns:
        val = row[disease]
        label.append(0.0 if pd.isna(val) else float(val))

    label_dict[path] = torch.tensor(label)


In [6]:
# --- ÿ≥ÿßÿÆÿ™ ÿ™ŸÜÿ≥Ÿàÿ± ÿßŸÖÿ®ÿØ€åŸÜ⁄Ø‚ÄåŸáÿß Ÿà ŸÑ€åÿ®ŸÑ‚ÄåŸáÿß ---
img_keys = list(image_embeddings.keys())
img_tensor = torch.stack([image_embeddings[k] for k in img_keys])         # [N, 197, 64]
img_tensor_pooled = img_tensor.max(dim=1).values  # [N, 64]

labels_list = []
valid_indices = []
for idx, k in enumerate(img_keys):
    label = label_dict[k]
    if (label == -1).all():

        continue
    else:
        label = torch.where(label == -1, torch.tensor(0.0), label)
        labels_list.append(label)
        valid_indices.append(idx)

labels = torch.stack(labels_list)  # [M, 7]
img_tensor_pooled = img_tensor_pooled[valid_indices]  # [M, 64]


text_tensor = torch.stack([text_embeddings[d] for d in disease_columns])  # [7, 768]


img_tensor_pooled = img_tensor_pooled.to(device)
labels = labels.to(device)
text_tensor = text_tensor.to(device)

### Similarity Based Classification

In [7]:
# image_tensor shape: [B, 64]
image_proj = nn.Sequential(
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.LayerNorm(128),
    nn.Dropout(0.3),
    nn.Linear(128, 256),
    nn.LayerNorm(256)
)

text_proj = nn.Sequential(
    nn.Linear(768, 128),
    nn.ReLU(),
    nn.LayerNorm(128),
    nn.Dropout(0.3),
    nn.Linear(128, 256),
    nn.LayerNorm(256)
)


# Apply projection and normalize
img_proj = F.normalize(image_proj(img_tensor_pooled), dim=1)  # [B, 256]
txt_proj = F.normalize(text_proj(text_tensor), dim=1)         # [7, 256]

# Similarity calculation: [B, 7]
temperature = nn.Parameter(torch.tensor(0.07))
similarity = (img_proj @ txt_proj.T) / temperature



In [8]:
# Probabilities
probs = torch.sigmoid(similarity)  # [234, 7]

# Threshold = 0.5
threshold = 0.5
preds = (probs > threshold).float()

# Accuracy (Exact Match)
exact_match = (preds == labels).all(dim=1).float().mean()
print(f" Exact Match Accuracy: {exact_match.item():.4f}")

# Sample-wise Accuracy
sample_accuracy = (preds == labels).float().mean()
print(f" Sample-wise Accuracy: {sample_accuracy.item():.4f}")

# Per-label Accuracy
per_label_acc = (preds == labels).float().mean(dim=0)
for i, disease in enumerate(disease_columns):
    print(f"{disease}: {per_label_acc[i].item():.4f}")

# Macro Accuracy
macro_accuracy = per_label_acc.mean()
print(f" Macro (Mean Per-Label) Accuracy: {macro_accuracy.item():.4f}")


 Exact Match Accuracy: 0.0023
 Sample-wise Accuracy: 0.2624
Atelectasis: 0.3923
Cardiomegaly: 0.2177
Consolidation: 0.1383
Edema: 0.3060
Pleural Effusion: 0.5149
Pneumonia: 0.1520
Pneumothorax: 0.1156
 Macro (Mean Per-Label) Accuracy: 0.2624


In [9]:
y_true = labels.cpu().numpy()
y_prob = probs.detach().cpu().numpy()

# Threshold Tuning
best_f1 = 0.0
best_thresh = 0.5
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (y_prob > t).astype(int)
    f1 = f1_score(y_true, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f"üîß Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

# Final Predictions with best threshold
y_pred = (y_prob > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(y_true, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(y_true, y_pred)
print(f" Hamming Loss: {hl:.4f}")


üîß Best Threshold: 0.15 ‚Üí F1: 0.2488

üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.15      1.00      0.26      9934
    Cardiomegaly       0.12      1.00      0.21      8062
   Consolidation       0.07      1.00      0.12      4432
           Edema       0.23      1.00      0.38     15602
Pleural Effusion       0.38      1.00      0.56     25776
       Pneumonia       0.03      1.00      0.05      1792
    Pneumothorax       0.09      1.00      0.16      5797

       micro avg       0.15      1.00      0.26     71395
       macro avg       0.15      1.00      0.25     71395
    weighted avg       0.24      1.00      0.37     71395
     samples avg       0.15      0.67      0.24     71395

 Hamming Loss: 0.8476


### BCEW learning

In [10]:
optimizer = optim.Adam(list(image_proj.parameters()) + list(text_proj.parameters()), lr=5e-4)
criterion = nn.BCEWithLogitsLoss()

num_epochs = 100
batch_size = 64

dataset_size = img_tensor_pooled.size(0)

for epoch in range(num_epochs):
    perm = torch.randperm(dataset_size)
    epoch_loss = 0
    for i in range(0, dataset_size, batch_size):
        optimizer.zero_grad()
        indices = perm[i:i+batch_size]

        img_batch = img_tensor_pooled[indices]   # [B, 64]
        label_batch = labels[indices]             # [B, 7]

        img_proj_batch = image_proj(img_batch)   # [B, 256]
        txt_proj = text_proj(text_tensor)        # [7, 256]

        img_proj_norm = F.normalize(img_proj_batch, dim=1)
        txt_proj_norm = F.normalize(txt_proj, dim=1)

        similarity = torch.matmul(img_proj_norm, txt_proj_norm.T)  # [B, 7]

        loss = criterion(similarity, label_batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * img_batch.size(0)

    avg_loss = epoch_loss / dataset_size
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

Epoch 1/130 - Loss: 1.0843
Epoch 2/130 - Loss: 1.0677
Epoch 3/130 - Loss: 1.0649
Epoch 4/130 - Loss: 1.0657
Epoch 5/130 - Loss: 1.1025
Epoch 6/130 - Loss: 1.0901
Epoch 7/130 - Loss: 1.0803
Epoch 8/130 - Loss: 1.0608
Epoch 9/130 - Loss: 1.0568
Epoch 10/130 - Loss: 1.0594
Epoch 11/130 - Loss: 1.0547
Epoch 12/130 - Loss: 1.0558
Epoch 13/130 - Loss: 1.0542
Epoch 14/130 - Loss: 1.0544
Epoch 15/130 - Loss: 1.0532
Epoch 16/130 - Loss: 1.0540
Epoch 17/130 - Loss: 1.0522
Epoch 18/130 - Loss: 1.0521
Epoch 19/130 - Loss: 1.0519
Epoch 20/130 - Loss: 1.0514
Epoch 21/130 - Loss: 1.0508
Epoch 22/130 - Loss: 1.0505
Epoch 23/130 - Loss: 1.0503
Epoch 24/130 - Loss: 1.0502
Epoch 25/130 - Loss: 1.0500
Epoch 26/130 - Loss: 1.0498
Epoch 27/130 - Loss: 1.0493
Epoch 28/130 - Loss: 1.0489
Epoch 29/130 - Loss: 1.0497
Epoch 30/130 - Loss: 1.0490
Epoch 31/130 - Loss: 1.0497
Epoch 32/130 - Loss: 1.0490
Epoch 33/130 - Loss: 1.0488
Epoch 34/130 - Loss: 1.0490
Epoch 35/130 - Loss: 1.0489
Epoch 36/130 - Loss: 1.0485
E

In [11]:

image_proj.eval()
text_proj.eval()

with torch.no_grad():
    img_proj = image_proj(img_tensor_pooled)
    txt_proj = text_proj(text_tensor)

    img_proj_norm = F.normalize(img_proj, dim=1)
    txt_proj_norm = F.normalize(txt_proj, dim=1)

    similarity = torch.matmul(img_proj_norm, txt_proj_norm.T)
    probs = torch.sigmoid(similarity).cpu().numpy()  # [M, 7]

labels_np = labels.cpu().numpy()


best_f1 = 0.0
best_thresh = 0.5
for t in np.arange(0.1, 0.9, 0.05):
    preds_t = (probs > t).astype(int)
    f1 = f1_score(labels_np, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f" Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

 Best Threshold: 0.55 ‚Üí F1: 0.3751


In [12]:

y_pred = (probs > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(labels_np, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(labels_np, y_pred)
print(f" Hamming Loss: {hl:.4f}")



üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.24      0.53      0.33      9934
    Cardiomegaly       0.36      0.75      0.48      8062
   Consolidation       0.12      0.67      0.20      4432
           Edema       0.48      0.67      0.56     15602
Pleural Effusion       0.69      0.72      0.70     25776
       Pneumonia       0.05      0.52      0.09      1792
    Pneumothorax       0.17      0.51      0.25      5797

       micro avg       0.32      0.66      0.43     71395
       macro avg       0.30      0.62      0.38     71395
    weighted avg       0.45      0.66      0.51     71395
     samples avg       0.24      0.42      0.29     71395

 Hamming Loss: 0.2660


In [13]:
y_pred

array([[0, 1, 1, ..., 1, 0, 0],
       [1, 0, 1, ..., 1, 0, 0],
       [0, 0, 1, ..., 1, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [1, 0, 1, ..., 0, 1, 1],
       [0, 0, 0, ..., 0, 0, 0]])

### Contrastive learning


In [14]:
def contrastive_loss(image_features, text_features, temperature=0.2):
    """
    image_features: [B, D]
    text_features: [B, D]
    Returns:
        scalar contrastive loss between image and its positive label embedding
    """
    image_features = F.normalize(image_features, dim=1)
    text_features = F.normalize(text_features, dim=1)

    logits = torch.matmul(image_features, text_features.T) / temperature
    targets = torch.arange(image_features.size(0)).to(image_features.device)

    loss_i = F.cross_entropy(logits, targets)
    loss_t = F.cross_entropy(logits.T, targets)

    return (loss_i + loss_t) / 2


In [15]:
def get_positive_text_embeddings(labels_batch, text_tensor):
    """
    For each image in batch, select one of its positive labels randomly.
    Returns: [B, 768] ‚Üí input to text_proj
    """
    B = labels_batch.size(0)
    selected_texts = []

    for i in range(B):
        positive_indices = torch.nonzero(labels_batch[i]).squeeze(1)

        if len(positive_indices) == 0:
            idx = torch.tensor(0).to(labels_batch.device)
        else:
            idx = positive_indices[torch.randint(len(positive_indices), (1,)).item()]

        selected_texts.append(text_tensor[idx])

    return torch.stack(selected_texts)  # [B, 768]


In [16]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(img_tensor_pooled, labels)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


In [17]:
image_proj = nn.Sequential(
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.LayerNorm(128),
    nn.Dropout(0.3),
    nn.Linear(128, 256),
    nn.LayerNorm(256)
).to(device)

text_proj = nn.Sequential(
    nn.Linear(768, 128),
    nn.ReLU(),
    nn.LayerNorm(128),
    nn.Dropout(0.3),
    nn.Linear(128, 256),
    nn.LayerNorm(256)
).to(device)


In [21]:
optimizer = torch.optim.Adam(
    list(image_proj.parameters()) + list(text_proj.parameters()), lr=5e-4
)

num_epochs = 100

for epoch in range(num_epochs):
    image_proj.train()
    text_proj.train()
    total_loss = 0.0

    for img_batch, label_batch in train_loader:
        img_batch = img_batch.to(device)
        label_batch = label_batch.to(device)

        optimizer.zero_grad()

        # Forward: image ‚Üí projected
        img_proj = image_proj(img_batch)  # [B, 256]

        # Text embeddings for positive labels
        selected_txt_embed = get_positive_text_embeddings(label_batch, text_tensor)  # [B, 768]
        txt_proj = text_proj(selected_txt_embed)  # [B, 256]

        # Contrastive loss
        loss = contrastive_loss(img_proj, txt_proj)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * img_batch.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")


Epoch 1/100 - Loss: 3.8717
Epoch 2/100 - Loss: 3.8716
Epoch 3/100 - Loss: 3.8710
Epoch 4/100 - Loss: 3.8716
Epoch 5/100 - Loss: 3.8721
Epoch 6/100 - Loss: 3.8692
Epoch 7/100 - Loss: 3.8688
Epoch 8/100 - Loss: 3.8709
Epoch 9/100 - Loss: 3.8702
Epoch 10/100 - Loss: 3.8712
Epoch 11/100 - Loss: 3.8686
Epoch 12/100 - Loss: 3.8683
Epoch 13/100 - Loss: 3.8702
Epoch 14/100 - Loss: 3.8700
Epoch 15/100 - Loss: 3.8696
Epoch 16/100 - Loss: 3.8696
Epoch 17/100 - Loss: 3.8691
Epoch 18/100 - Loss: 3.8699
Epoch 19/100 - Loss: 3.8697
Epoch 20/100 - Loss: 3.8678
Epoch 21/100 - Loss: 3.8674
Epoch 22/100 - Loss: 3.8683
Epoch 23/100 - Loss: 3.8666
Epoch 24/100 - Loss: 3.8655
Epoch 25/100 - Loss: 3.8676
Epoch 26/100 - Loss: 3.8665
Epoch 27/100 - Loss: 3.8662
Epoch 28/100 - Loss: 3.8653
Epoch 29/100 - Loss: 3.8671
Epoch 30/100 - Loss: 3.8676
Epoch 31/100 - Loss: 3.8682
Epoch 32/100 - Loss: 3.8671
Epoch 33/100 - Loss: 3.8652
Epoch 34/100 - Loss: 3.8653
Epoch 35/100 - Loss: 3.8644
Epoch 36/100 - Loss: 3.8633
E

In [19]:
image_proj.eval()
text_proj.eval()

with torch.no_grad():
    img_proj_all = F.normalize(image_proj(img_tensor_pooled), dim=1)  # [M, 256]
    txt_proj_all = F.normalize(text_proj(text_tensor), dim=1)         # [7, 256]

    similarity = torch.matmul(img_proj_all, txt_proj_all.T)           # [M, 7]
    probs = torch.sigmoid(similarity).cpu().numpy()
    labels_np = labels.cpu().numpy()


In [20]:
from sklearn.metrics import f1_score, classification_report, hamming_loss

best_f1 = 0.0
best_thresh = 0.5

for t in np.arange(0.1, 0.9, 0.01):
    preds_t = (probs > t).astype(int)
    f1 = f1_score(labels_np, preds_t, average='macro', zero_division=0)
    if f1 > best_f1:
        best_f1 = f1
        best_thresh = t

print(f"‚úÖ Best Threshold: {best_thresh:.2f} ‚Üí F1: {best_f1:.4f}")

y_pred = (probs > best_thresh).astype(int)

print("\nüîç Classification Report:")
print(classification_report(labels_np, y_pred, target_names=disease_columns, zero_division=0))

hl = hamming_loss(labels_np, y_pred)
print(f"üîª Hamming Loss: {hl:.4f}")


‚úÖ Best Threshold: 0.57 ‚Üí F1: 0.3516

üîç Classification Report:
                  precision    recall  f1-score   support

     Atelectasis       0.12      0.33      0.18      9934
    Cardiomegaly       0.33      0.71      0.45      8062
   Consolidation       0.12      0.44      0.19      4432
           Edema       0.49      0.77      0.60     15602
Pleural Effusion       0.74      0.73      0.73     25776
       Pneumonia       0.03      0.50      0.06      1792
    Pneumothorax       0.15      0.58      0.24      5797

       micro avg       0.29      0.65      0.40     71395
       macro avg       0.28      0.58      0.35     71395
    weighted avg       0.45      0.65      0.51     71395
     samples avg       0.33      0.45      0.35     71395

üîª Hamming Loss: 0.2946
