In [1]:
# Cell 1: install required packages
!pip install -q sentence-transformers transformers pillow tqdm pandas scikit-learn


In [13]:
# Cell 2: imports & config
import os
import json
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sentence_transformers import SentenceTransformer

from sklearn.metrics import roc_auc_score

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# ---------------------------
# EDIT THESE PATHS TO MATCH YOUR FILES IF NEEDED
# ---------------------------
IMG_DIR = "/content/images"  # full/global images
ROI_DIR = "/content/roi_images"  # if you have ROI crops (optional)
GLOBAL_CSV = "/content/lesion_metadata_with_global_labels.csv"   # should contain columns: image_id,global_text
LOCAL_JSON = "/content/local_labels_final.csv"   # mapping image_id -> {"roi_imgs":[...],"roi_texts":[...]} (optional)
TEACHER_GLOBAL_NPZ = "/content/biomedclip_global_embeddings.npz"  # teacher fused global embeddings (fused_embeds, image_ids)
TEACHER_LOCAL_NPY = "/content/biomedclip_local_embeddings.npz"  # object-array of per-image arrays of local teacher embeddings (optional)
TEST72_CSV = "/content/eval_sample.csv"  # test split (72). If not present notebook will create it.
OUT_DIR = "/content/output"
os.makedirs(OUT_DIR, exist_ok=True)

print("Paths set. Make sure the files exist at these locations.")


Device: cuda
Paths set. Make sure the files exist at these locations.


In [14]:
# Cell 3: check that critical files exist (warn but don't crash)
def warn_path(p):
    if not os.path.exists(p):
        print("MISSING:", p)
    else:
        print("FOUND:", p)

warn_path(IMG_DIR)
warn_path(GLOBAL_CSV)
warn_path(TEACHER_GLOBAL_NPZ)
# local files optional
warn_path(LOCAL_JSON)
warn_path(TEACHER_LOCAL_NPY)
warn_path(TEST72_CSV)

print("If any required file is missing, upload it to the path above or edit the path variables in Cell 2.")


FOUND: /content/images
FOUND: /content/lesion_metadata_with_global_labels.csv
FOUND: /content/biomedclip_global_embeddings.npz
FOUND: /content/local_labels_final.csv
FOUND: /content/biomedclip_local_embeddings.npz
FOUND: /content/eval_sample.csv
If any required file is missing, upload it to the path above or edit the path variables in Cell 2.


In [15]:
import numpy as np
np.load("biomedclip_global_embeddings.npz").files


['image_embeds', 'text_embeds', 'image_ids']

In [16]:

np.load("biomedclip_local_embeddings.npz").files

['image_embeds', 'text_embeds', 'image_ids']

In [18]:
# Cell 1
from sentence_transformers import SentenceTransformer
import torch, numpy as np, pandas as pd, os
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

student = SentenceTransformer("sentence-transformers/clip-ViT-B-32", device=device)
print("Student model loaded. Embedding dim:", student.get_sentence_embedding_dimension())


Device: cuda
Student model loaded. Embedding dim: None


In [44]:
# df1=pd.read_csv("/content/eval_sample.csv")
# df1["image_id"]+=".jpg"
# df1.to_csv("/content/eval_sample.csv")

In [100]:
# Cell 2
def encode_images_student(student_model, pil_images, batch_size=16):
    if len(pil_images) == 0:
        return torch.zeros((0, student_model.get_sentence_embedding_dimension()), device=device)
    emb = student_model.encode(pil_images, batch_size=batch_size, convert_to_tensor=True, device=device)
    return emb  # already on device

def encode_texts_student(student_model, texts, batch_size=64):
    if len(texts) == 0:
        return torch.zeros((0, student_model.get_sentence_embedding_dimension()), device=device)
    emb = student_model.encode(texts, batch_size=batch_size, convert_to_tensor=True, device=device)
    return emb


In [101]:
# Cell 3 - set your path variables (these were already set earlier in your environment)
# If you already have these variables defined, this will overwrite - adjust paths if needed.
IMG_DIR = "/content/images"
GLOBAL_CSV = "/content/lesion_metadata_with_global_labels.csv"
TEST72_CSV = "/content/eval_sample.csv"
OUT_DIR = "/content/output"
os.makedirs(OUT_DIR, exist_ok=True)

# Load or create test72
if not os.path.exists(TEST72_CSV):
    df_all = pd.read_csv(GLOBAL_CSV)
    np.random.seed(42)
    idx = np.random.choice(len(df_all), 72, replace=False)
    df_test72 = df_all.iloc[idx].reset_index(drop=True)
    df_test72.to_csv(TEST72_CSV, index=False)
    print("Created test72:", TEST72_CSV)
else:
    df_test72 = pd.read_csv(TEST72_CSV)
    print("Loaded existing test72:", TEST72_CSV)

# Encode baseline embeddings for test72
test_imgs = []
test_texts = []
for _, r in df_test72.iterrows():
    p = os.path.join(IMG_DIR, str(r['image_id']))
    test_imgs.append(Image.open(p).convert("RGB"))
    test_texts.append(str(r['global_label']))

print("Encoding test72 images and texts...")
img_embs = encode_images_student(student, test_imgs, batch_size=16)  # tensor
txt_embs = encode_texts_student(student, test_texts, batch_size=64)

img_np = img_embs.cpu().numpy()
txt_np = txt_embs.cpu().numpy()
sim = img_np @ txt_np.T

np.savez(os.path.join(OUT_DIR, "baseline_test72_student.npz"),
         image_embeds=img_np, text_embeds=txt_np, sim=sim,
         image_ids=df_test72['image_id'].values, texts=df_test72['global_label'].values)
print("Saved baseline artifacts to", os.path.join(OUT_DIR, "baseline_test72_student.npz"))

# quick unsupervised checks
pos = np.diag(sim)
neg = sim[~np.eye(len(sim), dtype=bool)]
print("Avg pos sim:", pos.mean(), "Avg neg sim:", neg.mean())

def recall_at_k(sim_matrix, k):
    ranks = np.argsort(-sim_matrix, axis=1)
    hits = sum(i in ranks[i, :k] for i in range(sim_matrix.shape[0]))
    return hits / sim_matrix.shape[0]

for k in (1,5,10):
    print(f"Recall@{k}:", recall_at_k(sim, k))


Loaded existing test72: /content/eval_sample.csv
Encoding test72 images and texts...
Saved baseline artifacts to /content/output/baseline_test72_student.npz
Avg pos sim: 26.11878 Avg neg sim: 26.099163
Recall@1: 0.0
Recall@5: 0.041666666666666664
Recall@10: 0.1527777777777778


In [102]:
# Cell 4
TEACHER_GLOBAL_NPZ = "/content/biomedclip_global_embeddings.npz"   # adjust if different
TEACHER_LOCAL_NPZ  = "/content/biomedclip_local_embeddings.npz"    # adjust if different (you said you have both NPZ)

tg = np.load(TEACHER_GLOBAL_NPZ, allow_pickle=True)
teacher_global_embeds = tg['image_embeds']  # shape (N, D)
teacher_global_text_embeds = tg.get('text_embeds', None)
teacher_global_ids = [str(x) for x in tg['image_ids']]

print("Teacher global loaded:", teacher_global_embeds.shape)

# Local teacher NPZ
if os.path.exists(TEACHER_LOCAL_NPZ):
    tl = np.load(TEACHER_LOCAL_NPZ, allow_pickle=True)
    teacher_local_embeds = tl['image_embeds']   # shape (N_local, D) OR (N, D) depending on how saved
    teacher_local_text_embeds = tl.get('text_embeds', None)
    teacher_local_ids = [str(x) for x in tl['image_ids']]
    print("Teacher local loaded:", teacher_local_embeds.shape)
else:
    teacher_local_embeds = None
    teacher_local_text_embeds = None
    teacher_local_ids = None
    print("No teacher local NPZ found at", TEACHER_LOCAL_NPZ)


Teacher global loaded: (472, 512)
Teacher local loaded: (472, 512)


In [103]:
# ===============================
# CELL: NORMALIZE IDS + DEBUG + SAFE MERGE
# ===============================

import numpy as np
import pandas as pd

print("\n=== STARTING ID NORMALIZATION ===")

# -------------------------
# 1. CLEAN TRAIN DF
# -------------------------
df_all['image_id'] = df_all['image_id'].astype(str)
# Fill NaN labels safely
df_all['global_labels'] = df_all['global_label'].astype(str).fillna("")

# Remove .jpg/.png
df_all['image_id_norm'] = (
    df_all['image_id']
    .str.replace(".jpg", "", regex=False)
    .str.replace(".jpeg", "", regex=False)
    .str.replace(".png", "", regex=False)
)

print("Sample normalized train_df IDs:", df_all['image_id_norm'].head().tolist())


# -------------------------
# 2. CLEAN TEACHER GLOBAL IDS
# -------------------------
teacher_ids_raw = list(teacher_global_ids)  # original list
teacher_ids_norm = []

valid_teacher_idx = []

for i, tid in enumerate(teacher_ids_raw):
    tid = str(tid)
    # Remove extensions if any
    tid = tid.replace(".jpg", "").replace(".jpeg", "").replace(".png", "")

    # Remove stray whitespace
    tid = tid.strip()

    # ISIC IDs should start with ISIC_
    if not tid.startswith("ISIC_"):
        continue

    # Many ISIC IDs have length 12 (e.g., ISIC_0027139)
    # Some NPZs truncate last 1â€“2 digits (e.g., ISIC_00288)
    # We DO NOT DROP THESE â€” instead we allow len >= 8
    if len(tid) >= 8:
        teacher_ids_norm.append(tid)
        valid_teacher_idx.append(i)

teacher_global_embeds = teacher_global_embeds[valid_teacher_idx]
if teacher_local_embeds is not None:
    teacher_local_embeds = teacher_local_embeds[valid_teacher_idx]

print("Teacher embeddings kept:", len(teacher_ids_norm))
print("Sample normalized teacher IDs:", teacher_ids_norm[:10])


# -------------------------
# 3. BUILD ID â†’ INDEX MAP
# -------------------------
id_to_idx = {tid: i for i, tid in enumerate(teacher_ids_norm)}

print("Mapping size:", len(id_to_idx))


# -------------------------
# 4. DEBUG: Check how many training IDs match teachers
# -------------------------
train_ids = df_all['image_id_norm'].tolist()

missing = [imgid for imgid in train_ids if imgid not in id_to_idx]

print("\n=== MATCHING REPORT ===")
print("Total train images:", len(train_ids))
print("Train images found in teacher NPZ:", len(train_ids) - len(missing))
print("Train images NOT found:", len(missing))

# Show first few missing for debugging
print("Missing examples:", missing[:10])

# We'll use this mapping inside Dataset later
TEACHER_ID_TO_IDX = id_to_idx
TEACHER_IDS_NORM = teacher_ids_norm



=== STARTING ID NORMALIZATION ===
Sample normalized train_df IDs: ['ISIC_0027828', 'ISIC_0029161', 'ISIC_0025819', 'ISIC_0027960', 'ISIC_0025140']
Teacher embeddings kept: 472
Sample normalized teacher IDs: ['ISIC_0027828', 'ISIC_0029161', 'ISIC_0025819', 'ISIC_0027960', 'ISIC_0025140', 'ISIC_0024635', 'ISIC_0025063', 'ISIC_0027957', 'ISIC_0025136', 'ISIC_0027652']
Mapping size: 472

=== MATCHING REPORT ===
Total train images: 472
Train images found in teacher NPZ: 472
Train images NOT found: 0
Missing examples: []


In [104]:
# Cell 5
LOCAL_CSV = "/content/local_labels_final.csv"  # local labels per image; columns: image_id, local_text, (optional) roi_fname
ROI_DIR = "/content/roi_images"          # if roi images exist; else we'll reuse global images as ROI placeholders

# load global CSV and exclude test72
df_all = pd.read_csv(GLOBAL_CSV)
test_ids = set(df_test72['image_id'].astype(str).tolist())
train_df = df_all[~df_all['image_id'].astype(str).isin(test_ids)].reset_index(drop=True)
print("Train rows:", len(train_df))

# load local CSV if exists
local_map = {}
if os.path.exists(LOCAL_CSV):
    df_local = pd.read_csv(LOCAL_CSV)
    # Expect each row to have image_id and local_text (and optional roi_fname)
    for _, r in df_local.iterrows():
        imgid = str(r['image_id'])
        text = str(r.get('local_text', ""))
        roi = r.get('roi_fname', None)
        local_map[imgid] = {"roi_imgs": [roi] if (roi and os.path.exists(os.path.join(ROI_DIR, roi))) else [], "roi_texts":[text]}
    print("Loaded local CSV entries:", len(local_map))
else:
    print("Local CSV not found at", LOCAL_CSV, "- local captions will be empty.")


Train rows: 400
Loaded local CSV entries: 472


In [105]:
# ===========================
# CELL 6 â€” DEFINITELY WORKING
# ===========================

from torch.utils.data import Dataset, DataLoader

EMB_DIM = student.get_sentence_embedding_dimension()

class TrainDataset(Dataset):
    def __init__(self, df, local_map, teacher_global_embeds, teacher_local_embeds, img_dir, roi_dir, max_rois=1):
        self.df = df.reset_index(drop=True)
        self.local_map = local_map
        self.teacher_global = teacher_global_embeds
        self.teacher_local = teacher_local_embeds
        self.img_dir = img_dir
        self.roi_dir = roi_dir
        self.max_rois = max_rois

        # Build mapping: teacher ID â†’ teacher index
        self.id_to_idx = {}
        for i, iid in enumerate(teacher_global_ids_norm):
            key = str(iid)
            self.id_to_idx[key] = i
            key = key.replace(".jpg", "")  # remove extension if present
            key = key.replace(".png", "")
            self.id_to_idx[key] = i


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Always use normalized id
        img_id_norm = str(row['image_id_norm'])

        t_idx = TEACHER_ID_TO_IDX.get(img_id_norm, None)



        # Load global image
        gpath = os.path.join(self.img_dir, img_id)
        gpil = Image.open(gpath).convert("RGB")
        gtext = str(row['global_label'])

        # Local (ROI)
        lm = self.local_map.get(img_id, {"roi_imgs": [], "roi_texts": []})
        roi_imgs = lm["roi_imgs"][:self.max_rois]
        roi_texts = lm["roi_texts"][:self.max_rois]

        roi_pils = []
        for r in roi_imgs:
            if r and os.path.exists(os.path.join(self.roi_dir, r)):
                roi_pils.append(Image.open(os.path.join(self.roi_dir, r)).convert("RGB"))
            else:
                roi_pils.append(gpil)

        # Pad to max_rois
        while len(roi_pils) < self.max_rois:
            roi_pils.append(gpil)
            roi_texts.append("")

        # Teacher embeddings
        img_id_norm = str(row['image_id_norm'])
        t_idx = TEACHER_ID_TO_IDX.get(img_id_norm, None)

        # If teacher not found, always return safe zeros
        if t_idx is None:
            tg = np.zeros((EMB_DIM,), dtype=np.float32)
            tl = np.zeros((self.max_rois, EMB_DIM), dtype=np.float32)
        else:
            tg = self.teacher_global[t_idx].astype(np.float32)

            if self.teacher_local is not None and len(self.teacher_local) == len(self.teacher_global):
                raw_local = self.teacher_local[t_idx].astype(np.float32)
                if raw_local.ndim == 1:
                    raw_local = np.expand_dims(raw_local, 0)
                tl = np.zeros((self.max_rois, EMB_DIM), dtype=np.float32)
                for i in range(min(self.max_rois, raw_local.shape[0])):
                    tl[i] = raw_local[i]
            else:
                tl = np.zeros((self.max_rois, EMB_DIM), dtype=np.float32)


        return {
            "global_img": gpil,
            "global_text": gtext,
            "roi_imgs": roi_pils,
            "roi_texts": roi_texts,
            "teacher_global": tg,
            "teacher_local": tl,
        }

# ðŸ”¥ THIS IS THE MOST IMPORTANT LINE
# num_workers MUST BE 0
train_dataset = TrainDataset(
    train_df, local_map,
    teacher_global_embeds,
    teacher_local_embeds,
    IMG_DIR, ROI_DIR,
    max_rois=1
)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,    # <-- THIS FIXES EVERYTHING
    pin_memory=True
)

print("Train loader ready:", len(train_loader), "batches")


Train loader ready: 50 batches


In [106]:
print("Loader object:", train_loader)
print("num_workers:", train_loader.num_workers)


Loader object: <torch.utils.data.dataloader.DataLoader object at 0x7ea11efdc830>
num_workers: 0


In [107]:
# Cell 7
import torch.nn.functional as F

def info_nce_loss(img_feats, txt_feats, temp=0.07):
    logits = (img_feats @ txt_feats.T) / temp
    labels = torch.arange(img_feats.size(0), device=img_feats.device)
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    return 0.5 * (loss_i + loss_t)

def roi_contrastive_loss(roi_img_feats, roi_txt_feats, temp=0.07):
    # roi_img_feats, roi_txt_feats: B x R x D
    B, R, D = roi_img_feats.shape
    loss = 0.0
    for r in range(R):
        loss += info_nce_loss(roi_img_feats[:, r, :], roi_txt_feats[:, r, :], temp=temp)
    return loss / max(1, R)

def negative_caption_loss(roi_img_feats, roi_txt_feats, temp=0.07):
    B, R, D = roi_img_feats.shape
    flat_txt = roi_txt_feats.view(B*R, D)
    loss = 0.0
    for r in range(R):
        logits = (roi_img_feats[:, r, :] @ flat_txt.T) / temp  # B x (B*R)
        labels = torch.arange(B, device=roi_img_feats.device) * R + r
        loss += F.cross_entropy(logits, labels)
    return loss / max(1, R)

def distill_mse_loss(student_feats, teacher_feats):
    return F.mse_loss(student_feats, teacher_feats)


In [108]:
# Cell 8
student.train()
optimizer = torch.optim.AdamW(student.parameters(), lr=2e-5)
alpha, beta, gamma, lam = 1.0, 1.0, 1.0, 0.4

print("Starting dry-run (10 batches max) to validate pipeline...")
max_batches = 10
for bidx, batch in enumerate(train_loader):
    if bidx >= max_batches:
        break

    # global encodings
    pil_imgs = list(batch['global_img'])
    g_img_emb = encode_images_student(student, pil_imgs)
    g_txt_emb = encode_texts_student(student, list(batch['global_labels']))

    # ROI encodings (max_rois=1 here so flattening simple)
    B = g_img_emb.size(0)
    R = len(batch['roi_imgs'][0])
    flat_rois = []
    for i in range(B):
        flat_rois.extend(batch['roi_imgs'][i])
    roi_img_flat = encode_images_student(student, flat_rois).view(B, R, -1)
    flat_roi_texts = []
    for i in range(B):
        flat_roi_texts.extend(batch['roi_texts'][i])
    roi_txt_flat = encode_texts_student(student, flat_roi_texts).view(B, R, -1)

    # teacher arrays
    tg = torch.from_numpy(np.vstack(batch['teacher_global'])).to(device).float()
    tl = torch.from_numpy(np.stack(batch['teacher_local'])).to(device).float()  # B x R x D

    # normalize
    g_img_emb = F.normalize(g_img_emb, dim=-1)
    g_txt_emb = F.normalize(g_txt_emb, dim=-1)
    roi_img_flat = F.normalize(roi_img_flat, dim=-1)
    roi_txt_flat = F.normalize(roi_txt_flat, dim=-1)

    # compute losses
    Lg = info_nce_loss(g_img_emb, g_txt_emb)
    Lr = roi_contrastive_loss(roi_img_flat, roi_txt_flat)
    Ln = negative_caption_loss(roi_img_flat, roi_txt_flat)
    LdG = distill_mse_loss(g_img_emb, tg.to(device))
    LdR = distill_mse_loss(roi_img_flat.view(B*R, -1), tl.view(B*R, -1).to(device))
    loss = alpha*Lg + beta*Lr + gamma*Ln + lam*(LdG + LdR)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Batch {bidx+1} loss: {loss.item():.4f}")

print("Dry-run done. If no errors, proceed to full training.")


Starting dry-run (10 batches max) to validate pipeline...


KeyError: 'image_id_norm'

In [89]:
del train_dataset
del train_loader
torch.cuda.empty_cache()


In [109]:
missing = []
for i in train_df['image_id'].astype(str):
    if i not in teacher_global_ids:
        missing.append(i)
len(missing), missing[:10]


(400,
 ['ISIC_0029161.jpg',
  'ISIC_0025819.jpg',
  'ISIC_0027960.jpg',
  'ISIC_0025140.jpg',
  'ISIC_0024635.jpg',
  'ISIC_0025063.jpg',
  'ISIC_0027957.jpg',
  'ISIC_0025136.jpg',
  'ISIC_0027218.jpg',
  'ISIC_0027139.jpg'])

In [110]:
teacher_global_ids[:20]


['ISIC_0027828',
 'ISIC_0029161',
 'ISIC_0025819',
 'ISIC_0027960',
 'ISIC_0025140',
 'ISIC_0024635',
 'ISIC_0025063',
 'ISIC_0027957',
 'ISIC_0025136',
 'ISIC_0027652',
 'ISIC_0027218',
 'ISIC_0026298',
 'ISIC_0027139',
 'ISIC_0027739',
 'ISIC_0027781',
 'ISIC_0024890',
 'ISIC_0025838',
 'ISIC_0025485',
 'ISIC_0025842',
 'ISIC_0028856']