# ü©∫ Postdoctoral Technical Challenge ‚Äî PneumoniaMNIST
**AlfaisalX: Cognitive Robotics and Autonomous Agents**  
**MedX Research Unit, Alfaisal University, Riyadh, Saudi Arabia**

---

This notebook is **100% self-contained** ‚Äî all code is defined inline, no repository clone or local imports needed.
Run cells top to bottom on Colab free tier (T4 GPU recommended).

| Task | Method | Key Tech |
|------|---------|----------|
| **Task 1** | CNN Classification + full evaluation | EfficientNet-B0, Focal Loss, AdamW |
| **Task 2** | Medical Report Generation | MedGemma-4B-IT VLM |
| **Task 3** | Semantic Image Retrieval | BioMedCLIP + FAISS |

## ‚öôÔ∏è 0. Install & Setup

In [None]:
!pip install -q medmnist timm open-clip-torch faiss-cpu seaborn scikit-learn
!pip install -q transformers accelerate
print('‚úì Dependencies installed')

In [None]:
import os, sys, json, time, warnings, copy
import numpy as np
import torch
import torch.nn as nn
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
warnings.filterwarnings('ignore')

torch.manual_seed(42); np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB')

for d in ['outputs/task1','outputs/task2','outputs/task3']:
    os.makedirs(d, exist_ok=True)

CLASS_NAMES = ['Normal', 'Pneumonia']
print('‚úì Setup complete')

---
## üìã Task 1: CNN Classification

### 1.1 Dataset

In [None]:
from medmnist import PneumoniaMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image as PILImage

IMAGE_SIZE, BATCH_SIZE = 224, 32

# Medical X-ray augmentation rules:
#  ‚úì Horizontal flip  ‚Äî mirror anatomy is valid
#  ‚úì Small rotation   ‚Äî patient positioning variation
#  ‚úó Vertical flip    ‚Äî NEVER: produces anatomically invalid images
train_tfm = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])
eval_tfm = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])
vis_tfm = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

train_ds = PneumoniaMNIST(split='train', transform=train_tfm, download=True, as_rgb=True)
val_ds   = PneumoniaMNIST(split='val',   transform=eval_tfm,  download=True, as_rgb=True)
test_ds  = PneumoniaMNIST(split='test',  transform=eval_tfm,  download=True, as_rgb=True)
vis_ds   = PneumoniaMNIST(split='test',  transform=vis_tfm,   download=True, as_rgb=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

train_labels = [train_ds[i][1].item() for i in range(len(train_ds))]
n0, n1 = train_labels.count(0), train_labels.count(1)
print(f'Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}')
print(f'Class ‚Äî Normal: {n0} ({100*n0/len(train_labels):.1f}%) | Pneumonia: {n1} ({100*n1/len(train_labels):.1f}%)')

In [None]:
# Dataset visualization
test_labels_list = [test_ds[i][1].item() for i in range(len(test_ds))]
fig, axes = plt.subplots(2, 8, figsize=(18, 5))
fig.suptitle('PneumoniaMNIST Samples', fontsize=14, fontweight='bold')
for cls_idx, cls_name in enumerate(CLASS_NAMES):
    samples = [i for i, l in enumerate(test_labels_list) if l == cls_idx][:8]
    for j, idx in enumerate(samples):
        img, _ = vis_ds[idx]
        axes[cls_idx, j].imshow(img.permute(1,2,0).numpy()[:,:,0], cmap='gray')
        axes[cls_idx, j].set_title(cls_name, fontsize=8,
            color='steelblue' if cls_idx==0 else 'tomato')
        axes[cls_idx, j].axis('off')
plt.tight_layout()
plt.savefig('outputs/task1/dataset_samples.png', dpi=130, bbox_inches='tight')
plt.show()

### 1.2 Model: EfficientNet-B0 + Focal Loss
- **EfficientNet-B0**: 5.3M params, compound scaling (depth/width/resolution), ImageNet pretrained
- **Focal Loss** `Œ≥=2.0`: Down-weights easy examples, crucial for ~74% pneumonia imbalance
- **AdamW + cosine warmup**: Weight decay prevents overfitting on small dataset

In [None]:
import timm

class FocalLoss(nn.Module):
    """FL = -Œ±(1-p_t)^Œ≥ log(p_t). Handles class imbalance better than CE."""
    def __init__(self, gamma=2.0, alpha=1.0):
        super().__init__()
        self.gamma, self.alpha = gamma, alpha
        self.ce = nn.CrossEntropyLoss(reduction='none')
    def forward(self, inputs, targets):
        ce = self.ce(inputs, targets)
        pt = torch.exp(-ce)
        return (self.alpha * (1-pt)**self.gamma * ce).mean()

model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=2, in_chans=3).to(device)
total = sum(p.numel() for p in model.parameters())
print(f'EfficientNet-B0 | Parameters: {total:,}')

### 1.3 Training

In [None]:
NUM_EPOCHS = 20
LR, WD = 1e-4, 1e-4
WARMUP = 3

criterion = FocalLoss(gamma=2.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

def lr_lambda(ep):
    if ep < WARMUP: return (ep+1)/WARMUP
    return 0.5*(1 + np.cos(np.pi*(ep-WARMUP)/max(NUM_EPOCHS-WARMUP,1)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def run_epoch(model, loader, opt=None):
    training = opt is not None
    model.train() if training else model.eval()
    tot_loss, correct, total = 0., 0, 0
    ctx = torch.enable_grad() if training else torch.no_grad()
    with ctx:
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            lbls = lbls.squeeze().long().to(device)
            if training: opt.zero_grad()
            out = model(imgs)
            loss = criterion(out, lbls)
            if training:
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
            tot_loss += loss.item()*imgs.size(0)
            correct += (out.argmax(1)==lbls).sum().item()
            total += lbls.size(0)
    return tot_loss/total, correct/total

history = {'train_loss':[],'val_loss':[],'train_acc':[],'val_acc':[]}
best_val_acc, best_epoch, best_state = 0., 0, None

for ep in range(1, NUM_EPOCHS+1):
    t0 = time.time()
    tr_l, tr_a = run_epoch(model, train_loader, optimizer)
    va_l, va_a = run_epoch(model, val_loader)
    scheduler.step()
    for k, v in zip(['train_loss','val_loss','train_acc','val_acc'],[tr_l,va_l,tr_a,va_a]):
        history[k].append(v)
    if va_a > best_val_acc:
        best_val_acc, best_epoch = va_a, ep
        best_state = copy.deepcopy(model.state_dict())
    print(f'Ep {ep:02d}/{NUM_EPOCHS} | Train {tr_l:.4f}/{tr_a:.4f} | '
          f'Val {va_l:.4f}/{va_a:.4f} | {time.time()-t0:.1f}s{" ‚òÖ" if va_a==best_val_acc else ""}')

model.load_state_dict(best_state)
torch.save(best_state, 'outputs/task1/best_model.pth')
with open('outputs/task1/history.json','w') as f:
    json.dump({**history,'best_epoch':best_epoch,'best_val_acc':best_val_acc}, f)
print(f'\n‚úì Best val acc: {best_val_acc:.4f} at epoch {best_epoch}')

### 1.4 Evaluation

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, roc_curve, classification_report
)

# Training curves
eps = range(1, NUM_EPOCHS+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(eps, history['train_loss'], label='Train', lw=2, color='#2196F3')
axes[0].plot(eps, history['val_loss'],   label='Val',   lw=2, color='#F44336')
axes[0].axvline(best_epoch, ls='--', color='gray', label=f'Best ep={best_epoch}')
axes[0].set(title='Loss', xlabel='Epoch', ylabel='Loss'); axes[0].legend(); axes[0].grid(alpha=0.3)
axes[1].plot(eps, [a*100 for a in history['train_acc']], label='Train', lw=2, color='#2196F3')
axes[1].plot(eps, [a*100 for a in history['val_acc']],   label='Val',   lw=2, color='#F44336')
axes[1].axvline(best_epoch, ls='--', color='gray', label=f'Best ep={best_epoch}')
axes[1].set(title='Accuracy', xlabel='Epoch', ylabel='Acc (%)'); axes[1].legend(); axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.savefig('outputs/task1/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

# Test predictions
model.eval()
y_true, y_pred, y_prob = [], [], []
with torch.no_grad():
    for imgs, lbls in test_loader:
        out = model(imgs.to(device))
        probs = torch.softmax(out,1)[:,1].cpu().numpy()
        preds = out.argmax(1).cpu().numpy()
        y_true.extend(lbls.squeeze().numpy())
        y_pred.extend(preds); y_prob.extend(probs)
y_true, y_pred, y_prob = map(np.array, [y_true, y_pred, y_prob])

metrics = {
    'accuracy':  float(accuracy_score(y_true, y_pred)),
    'precision': float(precision_score(y_true, y_pred)),
    'recall':    float(recall_score(y_true, y_pred)),
    'f1':        float(f1_score(y_true, y_pred)),
    'auc':       float(roc_auc_score(y_true, y_prob)),
}
with open('outputs/task1/test_metrics.json','w') as f:
    json.dump(metrics, f, indent=2)

print('=== TEST SET METRICS ===')
for k, v in metrics.items(): print(f'  {k:10s}: {v:.4f}')
print(); print(classification_report(y_true, y_pred, target_names=CLASS_NAMES))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            ax=axes[0], annot_kws={'size':14})
axes[0].set(xlabel='Predicted', ylabel='True', title='Confusion Matrix')

# ROC
fpr, tpr, _ = roc_curve(y_true, y_prob)
axes[1].plot(fpr, tpr, lw=2, color='#2196F3', label=f'AUC={metrics["auc"]:.4f}')
axes[1].plot([0,1],[0,1],'k--',lw=1.5); axes[1].fill_between(fpr,tpr,alpha=0.1,color='#2196F3')
axes[1].set(xlabel='FPR',ylabel='TPR',title='ROC Curve'); axes[1].legend(); axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('outputs/task1/confusion_roc.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Failure cases ‚Äî collect all vis_ds predictions
vis_loader = DataLoader(vis_ds, batch_size=64, shuffle=False)
all_imgs_list, vis_labels, vis_preds, vis_probs = [], [], [], []
model.eval()
with torch.no_grad():
    for imgs, lbls in vis_loader:
        out = model(imgs.to(device))
        prbs = torch.softmax(out,1)[:,1].cpu().numpy()
        prds = out.argmax(1).cpu().numpy()
        all_imgs_list.append(imgs)
        vis_labels.extend(lbls.squeeze().numpy())
        vis_preds.extend(prds); vis_probs.extend(prbs)
all_imgs_t   = torch.cat(all_imgs_list, 0)
vis_labels   = np.array(vis_labels)
vis_preds    = np.array(vis_preds)
vis_probs_np = np.array(vis_probs)

fail_idx = np.where(vis_labels != vis_preds)[0]
np.random.shuffle(fail_idx); sel = fail_idx[:16]

fig, axes = plt.subplots(4, 4, figsize=(12,12))
fig.suptitle(f'Failure Cases ‚Äî {len(fail_idx)}/{len(vis_labels)} misclassified', fontweight='bold')
for i, ax in enumerate(axes.flat):
    if i >= len(sel): ax.axis('off'); continue
    idx = sel[i]
    img = all_imgs_t[idx].permute(1,2,0).numpy().clip(0,1)[:,:,0]
    true_l = CLASS_NAMES[vis_labels[idx]]; pred_l = CLASS_NAMES[vis_preds[idx]]
    conf = vis_probs_np[idx] if vis_preds[idx]==1 else 1-vis_probs_np[idx]
    ax.imshow(img, cmap='gray')
    ax.set_title(f'True: {true_l}\nPred: {pred_l} ({conf:.2f})', fontsize=8, color='#C62828')
    ax.axis('off')
plt.tight_layout()
plt.savefig('outputs/task1/failure_cases.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'‚úì {len(fail_idx)} failure cases ({100*len(fail_idx)/len(vis_labels):.1f}% error rate)')

---
## üî¨ Task 2: Medical Report Generation (MedGemma VLM)

In [None]:
# ‚îÄ‚îÄ HuggingFace auth (required for MedGemma) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
# 1. Accept terms: https://huggingface.co/google/medgemma-4b-it
# 2. Paste your token below OR uncomment login()
HF_TOKEN = ''  # <-- paste your HF token here
if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN
    print('‚úì HF_TOKEN set')
else:
    print('‚ö† No token ‚Äî will use structured mock reports')

In [None]:
# ‚îÄ‚îÄ Three prompting strategies ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
PROMPTS = {
    'concise': (
        'You are a radiologist. In 2-3 sentences describe the key findings '
        'and state whether this chest X-ray is normal or shows pneumonia.'
    ),
    'structured': (
        'You are an expert radiologist. Provide a structured report:\n'
        '1. Lung Fields: Opacities, consolidation, infiltrates.\n'
        '2. Heart/Mediastinum: Size and contour.\n'
        '3. Impression: Normal or Pneumonia with brief justification.'
    ),
    'differential': (
        'As a radiologist, describe visible lung field features, '
        'give a differential diagnosis, then conclude with your '
        'primary diagnosis: normal chest or pneumonia.'
    ),
}

MOCK = {
    ('Normal','concise'):
        'The chest radiograph shows clear bilateral lung fields with no consolidation or infiltrate. '
        'Cardiac silhouette is normal. Impression: Normal chest radiograph.',
    ('Normal','structured'):
        '1. Lung Fields: Clear bilaterally; no opacities, consolidation, or infiltrates. '
        'Costophrenic angles sharp.\n2. Heart/Mediastinum: Normal size; no widening.\n'
        '3. Impression: Normal chest radiograph ‚Äî no acute cardiopulmonary findings.',
    ('Normal','differential'):
        'Lung fields appear clear. Mild bronchovascular prominence noted but within limits. '
        'Differential: (1) Normal, (2) Mild bronchitis. '
        'Primary Diagnosis: Normal chest ‚Äî no evidence of pneumonia.',
    ('Pneumonia','concise'):
        'Increased right lower lobe opacity with air bronchograms consistent with consolidation. '
        'Impression: Right lower lobe pneumonia.',
    ('Pneumonia','structured'):
        '1. Lung Fields: Right lower lobe consolidation with air bronchograms; '
        'mild left perihilar haziness.\n2. Heart/Mediastinum: Normal.\n'
        '3. Impression: Bilateral pneumonia, right > left. Clinical correlation recommended.',
    ('Pneumonia','differential'):
        'Right lower lobe focal consolidation with air bronchograms; subtle left lower haziness. '
        'Differential: (1) CAP, (2) Aspiration pneumonitis. '
        'Primary: Bacterial pneumonia ‚Äî bilateral, right predominant.',
}

# ‚îÄ‚îÄ Load MedGemma (best effort) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
vlm_model = vlm_processor = None
USE_REAL_VLM = False
try:
    from transformers import AutoProcessor, AutoModelForImageTextToText
    tok = os.environ.get('HF_TOKEN') or None
    vlm_processor = AutoProcessor.from_pretrained('google/medgemma-4b-it', token=tok)
    vlm_model = AutoModelForImageTextToText.from_pretrained(
        'google/medgemma-4b-it', token=tok,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map='auto' if torch.cuda.is_available() else None,
    )
    if not torch.cuda.is_available(): vlm_model = vlm_model.to(device)
    vlm_model.eval(); USE_REAL_VLM = True
    print('‚úì MedGemma-4B loaded!')
except Exception as e:
    print(f'MedGemma unavailable ({type(e).__name__}) ‚Üí using mock reports')
    print('  To enable: accept terms at https://huggingface.co/google/medgemma-4b-it')
    print('  and set HF_TOKEN above.')

def tensor_to_pil(t):
    img = t.permute(1,2,0).numpy().clip(0,1)[:,:,0]
    return PILImage.fromarray((img*255).astype(np.uint8)).convert('RGB')

def generate_report(pil_img, prompt, true_label, strategy):
    if USE_REAL_VLM:
        msgs = [{'role':'user','content':[{'type':'image','image':pil_img},{'type':'text','text':prompt}]}]
        inp = vlm_processor.apply_chat_template(msgs, add_generation_prompt=True,
                  tokenize=True, return_dict=True, return_tensors='pt').to(vlm_model.device)
        with torch.no_grad():
            out = vlm_model.generate(**inp, max_new_tokens=300, do_sample=False)
        n = inp['input_ids'].shape[1]
        return vlm_processor.decode(out[0][n:], skip_special_tokens=True).strip()
    return MOCK.get((true_label, strategy), f'[Mock] {true_label} ‚Äì {strategy} strategy.')

print('‚úì Report generation ready')

In [None]:
# Select 5 Normal + 5 Pneumonia images for report generation
tl_arr = np.array([test_ds[i][1].item() for i in range(len(test_ds))])
np.random.seed(42)
sel_idx = list(np.random.choice(np.where(tl_arr==0)[0], 5, replace=False)) + \
          list(np.random.choice(np.where(tl_arr==1)[0], 5, replace=False))

prompt_keys = list(PROMPTS.keys())
report_results = []

for i, idx in enumerate(sel_idx):
    img_t, lbl_t = vis_ds[idx]
    true_lbl = CLASS_NAMES[lbl_t.item()]
    strat    = prompt_keys[i % len(prompt_keys)]
    t0 = time.time()
    report = generate_report(tensor_to_pil(img_t), PROMPTS[strat], true_lbl, strat)
    report_results.append({
        'index': int(idx), 'true_label': true_lbl,
        'cnn_pred': CLASS_NAMES[vis_preds[idx]],
        'prompt_strategy': strat, 'report': report,
        'generation_time_s': round(time.time()-t0, 2)
    })
    print(f'[{i+1}/10] idx={idx} True={true_lbl} CNN={CLASS_NAMES[vis_preds[idx]]} Strategy={strat}')

with open('outputs/task2/generated_reports.json','w') as f:
    json.dump(report_results, f, indent=2)
print('\n‚úì Saved generated_reports.json')

In [None]:
# Visualize reports
n = len(report_results)
fig, axes = plt.subplots(n, 2, figsize=(16, 4.5*n))
fig.suptitle('Generated Radiology Reports ‚Äî MedGemma', fontsize=14, fontweight='bold')
for i, r in enumerate(report_results):
    ax_i, ax_t = axes[i]
    img_t, _ = vis_ds[r['index']]
    ax_i.imshow(img_t.permute(1,2,0).numpy().clip(0,1)[:,:,0], cmap='gray')
    match = r['true_label'] == r['cnn_pred']
    ax_i.set_title(f"True: {r['true_label']}  CNN: {r['cnn_pred']} {'‚úì' if match else '‚úó'}",
                   color='#2E7D32' if match else '#C62828', fontweight='bold')
    ax_i.axis('off')
    ax_t.axis('off')
    ax_t.text(0.02, 0.97, f"[{r['prompt_strategy']}]\n\n{r['report']}",
              transform=ax_t.transAxes, fontsize=8.5, va='top',
              bbox=dict(boxstyle='round,pad=0.5', facecolor='#FFFDE7', alpha=0.85))
plt.tight_layout()
plt.savefig('outputs/task2/reports_visualization.png', dpi=110, bbox_inches='tight')
plt.show()
print('‚úì Saved reports_visualization.png')

print(f'\nReal VLM: {"Yes (MedGemma-4B)" if USE_REAL_VLM else "Mock reports"}')
for r in report_results:
    print(f'\n‚îÄ‚îÄ {r["true_label"]} | CNN:{r["cnn_pred"]} | {r["prompt_strategy"]} ‚îÄ‚îÄ')
    print(r['report'])

---
## üîç Task 3: Semantic Image Retrieval (BioMedCLIP + FAISS)

In [None]:
import open_clip, faiss

# Load best available medical embedding model
emb_model = emb_preprocess = emb_tokenizer = None
EMB_TYPE = None

for model_id, pretrained, typ in [
    ('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', None, 'biomedclip'),
    ('ViT-B-32', 'openai', 'clip'),
]:
    try:
        emb_model, _, emb_preprocess = open_clip.create_model_and_transforms(
            model_id, pretrained=pretrained)
        emb_tokenizer = open_clip.get_tokenizer(model_id)
        EMB_TYPE = typ
        print(f'‚úì Loaded {typ}')
        break
    except Exception as e:
        print(f'{typ} failed: {e}')

if emb_model is None:
    print('Using CNN (EfficientNet) features as fallback ‚Äî image-only search')
    EMB_TYPE = 'cnn'
else:
    emb_model = emb_model.to(device).eval()

In [None]:
# Extract embeddings
@torch.no_grad()
def extract_embeddings_clip(ds, model, preprocess, bs=64):
    embeddings, labels = [], []
    for i in tqdm(range(0, len(ds), bs), desc='Embeddings'):
        batch_imgs, batch_lbls = [], []
        for j in range(i, min(i+bs, len(ds))):
            t, l = ds[j]
            np_img = t.permute(1,2,0).numpy().clip(0,1)[:,:,0]
            pil = PILImage.fromarray((np_img*255).astype(np.uint8)).convert('RGB')
            batch_imgs.append(preprocess(pil)); batch_lbls.append(l.item())
        bt = torch.stack(batch_imgs).to(device)
        e = model.encode_image(bt).float()
        e = e / e.norm(dim=-1, keepdim=True)
        embeddings.append(e.cpu().numpy()); labels.extend(batch_lbls)
    return np.vstack(embeddings), np.array(labels)

@torch.no_grad()
def extract_embeddings_cnn(ds, cnn, bs=64):
    import timm
    fe = timm.create_model('efficientnet_b0', pretrained=False, num_classes=0)
    sd = {k:v for k,v in torch.load('outputs/task1/best_model.pth', map_location=device).items()
          if 'classifier' not in k}
    fe.load_state_dict(sd, strict=False); fe = fe.to(device).eval()
    embeddings, labels = [], []
    loader = DataLoader(ds, batch_size=bs, shuffle=False)
    for imgs, lbls in tqdm(loader, desc='CNN embeddings'):
        e = fe(imgs.to(device)).float()
        e = e / (e.norm(dim=-1, keepdim=True) + 1e-8)
        embeddings.append(e.cpu().numpy()); labels.extend(lbls.squeeze().numpy())
    return np.vstack(embeddings), np.array(labels)

if EMB_TYPE in ('biomedclip', 'clip'):
    embeddings, emb_labels = extract_embeddings_clip(vis_ds, emb_model, emb_preprocess)
else:
    embeddings, emb_labels = extract_embeddings_cnn(vis_ds, model)

np.save('outputs/task3/embeddings.npy', embeddings)
np.save('outputs/task3/labels.npy', emb_labels)
with open('outputs/task3/embedding_info.json','w') as f:
    json.dump({'model_type': EMB_TYPE, 'embedding_dim': embeddings.shape[1],
               'num_samples': len(emb_labels)}, f, indent=2)
print(f'‚úì Embeddings: {embeddings.shape} | Model: {EMB_TYPE}')

In [None]:
# Build FAISS index (IndexFlatIP = exact cosine sim on L2-normalized vectors)
d = embeddings.shape[1]
faiss_index = faiss.IndexFlatIP(d)
faiss_index.add(embeddings.astype(np.float32))
faiss.write_index(faiss_index, 'outputs/task3/faiss_index.bin')
print(f'‚úì FAISS index | {faiss_index.ntotal} vectors, dim={d}')

In [None]:
# Image-to-image search demo
def img2img(query_idx, k=5):
    q = embeddings[query_idx].reshape(1,-1).astype(np.float32)
    scores, idxs = faiss_index.search(q, k+1)
    return [(int(i), float(s)) for i,s in zip(idxs[0], scores[0]) if i!=-1 and i!=query_idx][:k]

def show_retrieval(q_idx, retrieved, title='', save=None):
    k = len(retrieved)
    fig, axes = plt.subplots(1, k+1, figsize=(3*(k+1), 3.8))
    fig.suptitle(title, fontweight='bold')
    def draw(ax, idx, lbl_str, color, bcolor=None):
        img = vis_ds[idx][0].permute(1,2,0).numpy().clip(0,1)[:,:,0]
        ax.imshow(img, cmap='gray'); ax.set_title(lbl_str, color=color, fontsize=8.5); ax.axis('off')
        if bcolor:
            for sp in ax.spines.values(): sp.set_visible(True); sp.set_color(bcolor); sp.set_linewidth(3)
    draw(axes[0], q_idx, f'QUERY\n{CLASS_NAMES[emb_labels[q_idx]]}', 'blue', 'blue')
    for j, (ri, sc) in enumerate(retrieved):
        match = emb_labels[ri] == emb_labels[q_idx]
        draw(axes[j+1], ri, f'#{j+1} {CLASS_NAMES[emb_labels[ri]]}\n{sc:.3f} {"‚úì" if match else "‚úó"}',
             '#2E7D32' if match else '#C62828')
    plt.tight_layout()
    if save: plt.savefig(save, dpi=130, bbox_inches='tight')
    plt.show()

for q_idx, name in [(int(np.where(emb_labels==0)[0][5]), 'normal'),
                     (int(np.where(emb_labels==1)[0][5]), 'pneumonia')]:
    retrieved = img2img(q_idx, k=5)
    show_retrieval(q_idx, retrieved,
                   title=f'Image-to-Image: {CLASS_NAMES[emb_labels[q_idx]]} Query',
                   save=f'outputs/task3/retrieval_{name}.png')
    precision = np.mean([emb_labels[i]==emb_labels[q_idx] for i,_ in retrieved])
    print(f'Query={CLASS_NAMES[emb_labels[q_idx]]} | P@5={precision:.2f}')

In [None]:
# Text-to-image search
if emb_tokenizer is not None:
    def txt2img(query_text, k=5):
        with torch.no_grad():
            toks = emb_tokenizer([query_text]).to(device)
            te = emb_model.encode_text(toks).float()
            te = te / te.norm(dim=-1, keepdim=True)
        sc, ids = faiss_index.search(te.cpu().numpy().astype(np.float32), k)
        return [(int(i), float(s)) for i,s in zip(ids[0], sc[0]) if i!=-1]

    for q in ['bilateral lung consolidation pneumonia',
              'clear normal lung fields no abnormality']:
        res = txt2img(q, k=5)
        lbls = [CLASS_NAMES[emb_labels[i]] for i,_ in res]
        print(f'Text: "{q}"\n  ‚Üí Retrieved labels: {lbls}\n')
else:
    print('Text-to-image search unavailable (CNN fallback mode ‚Äî no text encoder)')

In [None]:
# Precision@k evaluation
K_VALS, N_QUERIES = [1, 5, 10], 200
np.random.seed(42)
q_idxs = np.random.choice(len(embeddings), N_QUERIES, replace=False)
pk_results = {k: [] for k in K_VALS}
max_k = max(K_VALS)
for qi in tqdm(q_idxs, desc='P@k eval'):
    q = embeddings[qi].reshape(1,-1).astype(np.float32)
    _, ids = faiss_index.search(q, max_k+1)
    retrieved = [int(i) for i in ids[0] if i!=-1 and i!=qi]
    for k in K_VALS:
        top = retrieved[:k]
        if top: pk_results[k].append(np.mean([emb_labels[i]==emb_labels[qi] for i in top]))

precision_at_k = {f'P@{k}': float(np.mean(pk_results[k])) for k in K_VALS}
baseline = float(max(np.mean(emb_labels==0), np.mean(emb_labels==1)))

with open('outputs/task3/precision_at_k.json','w') as f:
    json.dump(precision_at_k, f, indent=2)

fig, ax = plt.subplots(figsize=(7,4.5))
bars = ax.bar(list(precision_at_k.keys()), list(precision_at_k.values()),
              color='#1565C0', width=0.5, edgecolor='white', zorder=3)
ax.axhline(baseline, color='#EF5350', ls='--', lw=2, label=f'Baseline={baseline:.3f}')
ax.set(xlabel='k', ylabel='Precision@k', title=f'Retrieval Precision@k ‚Äî {EMB_TYPE.upper()}')
ax.set_ylim(0,1.05); ax.grid(axis='y',alpha=0.3,zorder=0); ax.legend()
for bar, v in zip(bars, precision_at_k.values()):
    ax.text(bar.get_x()+bar.get_width()/2, v+0.02, f'{v:.3f}', ha='center', fontweight='bold')
plt.tight_layout()
plt.savefig('outputs/task3/precision_at_k.png', dpi=150, bbox_inches='tight')
plt.show()

print('=== PRECISION@K ===')
for k, v in precision_at_k.items():
    print(f'  {k}: {v:.4f}  (+{v-baseline:.4f} vs baseline)')

---\n## üìä Final Summary

In [None]:
print('='*60)
print('  COMPLETE SYSTEM SUMMARY')
print('='*60)

print('\n‚îÄ‚îÄ Task 1: EfficientNet-B0 CNN ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ')
with open('outputs/task1/test_metrics.json') as f: m=json.load(f)
for k,v in m.items(): print(f'  {k:10s}: {v:.4f}')

print('\n‚îÄ‚îÄ Task 2: MedGemma Report Generation ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ')
with open('outputs/task2/generated_reports.json') as f: rpts=json.load(f)
print(f'  Reports: {len(rpts)} | Real VLM: {USE_REAL_VLM}')
print(f'  Strategies: {", ".join(set(r["prompt_strategy"] for r in rpts))}')

print('\n‚îÄ‚îÄ Task 3: BioMedCLIP + FAISS Retrieval ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ')
with open('outputs/task3/precision_at_k.json') as f: pk=json.load(f)
print(f'  Model: {EMB_TYPE}')
for k,v in pk.items(): print(f'  {k}: {v:.4f}')

print('\n‚úì All three tasks complete!')