In [None]:
# One-cell script: Proto-Attention U-Net few-shot with HybridLoss (CE + Dice + Align),
# episodic training, tqdm.notebook, cosine LR, per-epoch loss prints, visualizations,
# end-of-training plots and per-class metrics. Images: .jpg, Masks: .bmp

import os, math, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import kornia  # for GPU-accelerated CLAHE

# --------------------- CLAHE Transform ---------------------
class Clahe(nn.Module):
    """
    Contrast-Limited Adaptive Histogram Equalization (CLAHE) using Kornia.
    """
    def __init__(self, clip_limit: float = 40.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = float(clip_limit)
        self.grid_size = grid_size

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        # img: (C, H, W), values in [0,1]
        batch = img.unsqueeze(0)
        # correct argument name: grid_size instead of tile_grid_size
        eq = kornia.enhance.equalize_clahe(
            batch,
            clip_limit=self.clip_limit,
            grid_size=self.grid_size
        )
        return eq.squeeze(0)

    def __repr__(self):
        return (f"{self.__class__.__name__}(clip_limit={self.clip_limit}, "
                f"grid_size={self.grid_size})")

# --------------------- Dataset ---------------------
class GlaucomaDataset(Dataset):
    def __init__(self, img_dir, mask_dir, img_size=128,
                 clahe_clip=40.0, clahe_grid=(8,8)):
        self.img_dir, self.mask_dir, self.img_size = img_dir, mask_dir, img_size
        self.samples = sorted(f for f in os.listdir(img_dir) if f.lower().endswith('.jpg'))

        base_transforms = [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            Clahe(clip_limit=clahe_clip, grid_size=clahe_grid),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ]
        self.sup_tf = transforms.Compose(base_transforms)
        self.qry_tf = transforms.Compose(base_transforms)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fn = self.samples[idx]
        img = Image.open(os.path.join(self.img_dir, fn)).convert('RGB')
        mask_fn = fn.rsplit('.', 1)[0] + '.bmp'
        mask = Image.open(os.path.join(self.mask_dir, mask_fn)).convert('L')
        mask = mask.resize((self.img_size, self.img_size), Image.NEAREST)
        m = np.array(mask)
        lbl = np.zeros_like(m, dtype=int)
        if m.max() > 2:
            lbl[m == 128] = 1
            lbl[m == 255] = 2
        else:
            lbl = m
        return img, torch.from_numpy(lbl).long()

# -------------------- Model --------------------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.block(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super().__init__()
        self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, 1), nn.BatchNorm2d(F_int))
        self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, 1), nn.BatchNorm2d(F_int))
        self.psi = nn.Sequential(nn.Conv2d(F_int, 1, 1), nn.BatchNorm2d(1), nn.Sigmoid())
    def forward(self, x, g):
        g1, x1 = self.W_g(g), self.W_x(x)
        return x * self.psi(F.relu(g1 + x1, inplace=True))

class UpBlock(nn.Module):
    def __init__(self, up_in, skip_ch, out_ch):
        super().__init__()
        self.up   = nn.ConvTranspose2d(up_in, out_ch, 2, stride=2)
        self.att  = AttentionGate(out_ch, skip_ch, skip_ch // 2)
        self.conv = ConvBlock(out_ch + skip_ch, out_ch)
    def forward(self, x, skip):
        xu = self.up(x)
        sa = self.att(skip, xu)
        return self.conv(torch.cat([xu, sa], dim=1))

class ProtoAttentionUNet(nn.Module):
    def __init__(self, in_ch=3, num_classes=3, feat_dim=256):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.pool = nn.MaxPool2d(2)
        self.proj     = nn.Conv2d(256, feat_dim, 1)
        self.sim_conv = nn.Conv2d(num_classes, feat_dim, 1)
        self.up3  = UpBlock(feat_dim, 128, 128)
        self.up2  = UpBlock(128, 64, 64)
        self.final = nn.Conv2d(64, num_classes, 1)

    def forward(self, s_imgs, s_masks, q_imgs):
        # Support encoding & prototype computation
        s1 = self.enc1(s_imgs)
        s2 = self.enc2(self.pool(s1))
        s3 = self.enc3(self.pool(s2))
        s_proj = self.proj(s3)
        m_ds = F.interpolate(s_masks.unsqueeze(1).float(),
                             size=s_proj.shape[2:], mode='nearest').squeeze(1).long()
        protos = []
        for c in range(self.final.out_channels):
            m = (m_ds == c).unsqueeze(1).float()
            p = (s_proj * m).sum((0,2,3)) / (m.sum() + 1e-6)
            protos.append(p)
        prototypes = F.normalize(torch.stack(protos), dim=1)

        # Query encoding & similarity map
        q1 = self.enc1(q_imgs)
        q2 = self.enc2(self.pool(q1))
        q3 = self.enc3(self.pool(q2))
        q_proj = self.proj(q3)
        B, D, h, w = q_proj.shape
        q_flat = q_proj.view(B, D, -1).permute(0, 2, 1)
        qn = F.normalize(q_flat, dim=2)
        sim = torch.matmul(qn, prototypes.t())
        sim_map = sim.permute(0, 2, 1).view(B, self.final.out_channels, h, w)
        sim_feat = self.sim_conv(sim_map)

        # Decode with attention
        x3 = self.up3(sim_feat, q2)
        x2 = self.up2(x3, q1)
        return self.final(x2), prototypes

# --------------------- Losses ---------------------
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, preds, targets):
        C = preds.shape[1]
        t = F.one_hot(targets, C).permute(0,3,1,2).float()
        p = F.softmax(preds, 1)
        inter = (p * t).sum((2,3))
        union = p.sum((2,3)) + t.sum((2,3))
        dice = (2*inter + self.eps) / (union + self.eps)
        return 1 - dice.mean()

class PrototypeAlignmentLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, prototypes):
        sim = F.cosine_similarity(prototypes.unsqueeze(1),
                                  prototypes.unsqueeze(0), dim=-1)
        return F.mse_loss(sim, torch.eye(sim.size(0), device=sim.device))

class HybridLoss(nn.Module):
    def __init__(self, class_weights=[0.1, 0.43, 0.47], align_w=0.3):
        super().__init__()
        w = torch.tensor(class_weights)
        self.register_buffer('ce_weight', w)
        self.ce    = nn.CrossEntropyLoss(weight=self.ce_weight)
        self.dice  = DiceLoss()
        self.align = PrototypeAlignmentLoss()
        self.align_w = align_w

    def forward(self, preds, targets, prototypes):
        l_ce = self.ce(preds, targets)
        l_d  = self.dice(preds, targets)
        l_a  = self.align(prototypes)
        loss = l_ce + l_d + self.align_w * l_a
        return loss, (l_ce.item(), l_d.item(), l_a.item())

# ---------------- Episodic Trainer ----------------
class EpisodicTrainer:
    def __init__(self, img_dir, mask_dir,
                 img_size=128, num_support=3, num_query=5,
                 episodes_per_epoch=200, epochs=30,
                 lr=1e-3, wd=1e-5, device=None):
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.ds     = GlaucomaDataset(img_dir, mask_dir, img_size)
        self.ns, self.nq      = num_support, num_query
        self.ep_per, self.epochs = episodes_per_epoch, epochs

        self.model = ProtoAttentionUNet().to(self.device)
        self.opt   = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=wd)
        self.sched = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt,
                                                                 T_max=epochs,
                                                                 eta_min=1e-5)
        self.crit  = HybridLoss().to(self.device)

        self.h_loss, self.h_score = [], []

    def train(self):
        for ep in range(1, self.epochs+1):
            self.model.train()
            tot_ce = tot_d = tot_a = tot_l = tot_sc = 0.0

            for _ in tqdm(range(self.ep_per), desc=f"Epoch {ep}", leave=False):
                idxs = np.arange(len(self.ds))
                sel  = np.random.choice(idxs, self.ns + self.nq, replace=False)
                sup = [self.ds[i] for i in sel[:self.ns]]
                qry = [self.ds[i] for i in sel[self.ns:]]

                s_x = torch.stack([self.ds.sup_tf(x[0]) for x in sup]).to(self.device)
                s_m = torch.stack([x[1] for x in sup]).to(self.device)
                q_x = torch.stack([self.ds.qry_tf(x[0]) for x in qry]).to(self.device)
                q_m = torch.stack([x[1] for x in qry]).to(self.device)

                preds, protos = self.model(s_x, s_m, q_x)
                loss, (l_ce, l_d, l_a) = self.crit(preds, q_m, protos)

                self.opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.opt.step()

                tot_ce += l_ce
                tot_d  += l_d
                tot_a  += l_a
                tot_l  += loss.item()
                tot_sc += (1 - DiceLoss()(preds, q_m).item())

            self.sched.step()

            avg_ce = tot_ce / self.ep_per
            avg_d  = tot_d  / self.ep_per
            avg_a  = tot_a  / self.ep_per
            avg_l  = tot_l  / self.ep_per
            avg_sc = tot_sc / self.ep_per

            self.h_loss.append(avg_l)
            self.h_score.append(avg_sc)

            print(f"Epoch {ep}/{self.epochs} | CE={avg_ce:.4f} | DiceLoss={avg_d:.4f} | "
                  f"Align={avg_a:.4f} | Total={avg_l:.4f} | DiceScore={avg_sc:.4f}")

            if ep % 2 == 0:
                self._visualize(q_x[0], q_m[0], preds[0])

        self._plot_curves()
        self._eval_metrics()

    def _visualize(self, img, gt, logit):
        pred = logit.argmax(0).cpu().numpy()
        im   = img.cpu().permute(1,2,0).numpy() * [0.229,0.224,0.225] + [0.485,0.456,0.406]
        fig, ax = plt.subplots(1,3,figsize=(12,4))
        ax[0].imshow(im);            ax[0].set_title("Query"); ax[0].axis('off')
        ax[1].imshow(gt.cpu(),cmap='gray'); ax[1].set_title("GT");    ax[1].axis('off')
        ax[2].imshow(pred,cmap='gray');     ax[2].set_title("Pred");  ax[2].axis('off')
        plt.show()

    def _plot_curves(self):
        epochs = range(1, self.epochs+1)
        plt.figure(figsize=(12,4))
        plt.subplot(1,2,1); plt.plot(epochs, self.h_loss,'-o');       plt.title("Loss");        plt.xlabel("Epoch")
        plt.subplot(1,2,2); plt.plot(epochs, self.h_score,'-o');      plt.title("Dice Score");  plt.xlabel("Epoch")
        plt.show()

    def _eval_metrics(self, eval_eps=100):
        C = self.model.final.out_channels
        tp = np.zeros(C); fp = np.zeros(C); fn = np.zeros(C)
        self.model.eval()
        with torch.no_grad():
            for _ in range(eval_eps):
                idxs = np.arange(len(self.ds))
                sel  = np.random.choice(idxs, self.ns+1, replace=False)
                sup = [self.ds[i] for i in sel[:self.ns]]
                qry = self.ds[sel[-1]]
                s_x = torch.stack([self.ds.sup_tf(x[0]) for x in sup]).to(self.device)
                s_m = torch.stack([x[1] for x in sup]).to(self.device)
                q_img, q_gt = qry
                q_x = self.ds.qry_tf(q_img).unsqueeze(0).to(self.device)

                logits, _ = self.model(s_x, s_m, q_x)
                pred = logits.argmax(1).squeeze(0).cpu().numpy()
                gt   = q_gt.numpy()
                for c in range(C):
                    tp[c] += ((pred==c)&(gt==c)).sum()
                    fp[c] += ((pred==c)&(gt!=c)).sum()
                    fn[c] += ((pred!=c)&(gt==c)).sum()

        print("Per-class Dice & IoU:")
        for c in range(C):
            dice = 2*tp[c]/(2*tp[c]+fp[c]+fn[c]+1e-6)
            iou  = tp[c]/(tp[c]+fp[c]+fn[c]+1e-6)
            print(f" Class {c}: Dice={dice:.4f}, IoU={iou:.4f}")

    def save(self, path='model_final.pth'):
        torch.save(self.model.state_dict(), path)
        print("Saved:", path)

if __name__ == '__main__':
    IMG_DIR = '/kaggle/input/refuge2/REFUGE2/train/images'
    MSK_DIR = '/kaggle/input/refuge2/REFUGE2/train/mask'
    trainer = EpisodicTrainer(
        IMG_DIR, MSK_DIR,
        img_size=128,
        num_support=3,
        num_query=5,
        episodes_per_epoch=50,
        epochs=30,
        lr=1e-3,
        wd=1e-5
    )
    trainer.train()
    trainer.save()


In [None]:
import os, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms
import kornia

# re‐use your Clahe module definition from training
class Clahe(torch.nn.Module):
    def __init__(self, clip_limit: float = 40.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = float(clip_limit)
        self.grid_size = grid_size

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        batch = img.unsqueeze(0)
        eq = kornia.enhance.equalize_clahe(
            batch,
            clip_limit=self.clip_limit,
            grid_size=self.grid_size
        )
        return eq.squeeze(0)

# --- Paths & params ---
IMG_DIR     = '/kaggle/input/refuge2/REFUGE2/test/images'
MSK_DIR     = '/kaggle/input/refuge2/REFUGE2/test/mask'
MODEL_PATH  = '/kaggle/working/model_final.pth'
IMG_SIZE    = 128
NUM_SUPPORT = 3
NUM_EPISODES= 50
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Deterministic test‐time transform (must match train) ---
tf_test = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    Clahe(clip_limit=40.0, grid_size=(8,8)),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# --- Mask loader (REFUGE encoding: 0=BG,128=OD,255=OC) ---
def load_mask(path):
    m = np.array(Image.open(path)
                   .convert('L')
                   .resize((IMG_SIZE, IMG_SIZE), Image.NEAREST))
    lbl = np.zeros_like(m, dtype=int)
    lbl[m == 128] = 1
    lbl[m == 255] = 2
    return lbl

# --- Visualization helper ---
def color_mask(mask):
    h, w = mask.shape
    col = np.zeros((h, w, 3), float)
    col[mask == 0] = [1,1,0]   # BG→yellow
    col[mask == 1] = [0,1,0]   # OD→green
    col[mask == 2] = [1,0,0]   # OC→red
    return col

# --- Gather files ---
img_files = sorted([
    os.path.join(IMG_DIR, f) for f in os.listdir(IMG_DIR)
    if f.lower().endswith(('.jpg','.png','.tif'))
])
mask_map = {
    os.path.splitext(f)[0]: os.path.join(MSK_DIR, f)
    for f in os.listdir(MSK_DIR)
    if f.lower().endswith(('.bmp','.png','.jpg'))
}

# --- Load model ---
model = ProtoAttentionUNet(in_ch=3, num_classes=3, feat_dim=256).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Metric functions ---
def dice_scores(pred, gt, eps=1e-6):
    scores = []
    for c in range(3):
        p = (pred == c); g = (gt == c)
        inter = (p & g).sum()
        union = p.sum() + g.sum()
        scores.append((2*inter + eps) / (union + eps))
    return np.array(scores)

# --- Confusion matrix & accumulator ---
conf_mat = np.zeros((3,3), int)
all_dices = []

# --- Few‐shot episodes ---
for _ in tqdm(range(NUM_EPISODES), desc='Meta-Test REFUGE2'):
    # sample support + query
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT + 1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    # support batch
    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask = load_mask(mask_map[key])
        s_imgs.append(tf_test(img))
        s_msks.append(torch.from_numpy(mask).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    # query
    q_pil = Image.open(img_files[qry_idx]).convert('RGB')
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_test(q_pil).unsqueeze(0).to(DEVICE)

    # inference
    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
    pred = logits.argmax(1).squeeze(0).cpu().numpy()

    # record metrics
    all_dices.append(dice_scores(pred, q_gt))
    for i in range(3):
        for j in range(3):
            conf_mat[i, j] += int(((q_gt == i) & (pred == j)).sum())

# --- Results ---
all_dices = np.stack(all_dices)
mean_d = all_dices.mean(axis=0)
print(f"Mean per-class Dice: BG={mean_d[0]:.4f}, OD={mean_d[1]:.4f}, OC={mean_d[2]:.4f}")
print(f"Overall Mean Dice: {mean_d.mean():.4f}\n")
print("Confusion Matrix (rows=GT, cols=Pred):\n", conf_mat)

# --- Plot confusion matrix ---
plt.figure(figsize=(5,5))
plt.imshow(conf_mat, cmap='Blues')
plt.title("Confusion Matrix")
plt.xlabel("Predicted"); plt.ylabel("Ground Truth")
for i in range(3):
    for j in range(3):
        plt.text(j, i, conf_mat[i, j], ha='center', va='center', color='red')
plt.xticks([0,1,2], ['BG','OD','OC'])
plt.yticks([0,1,2], ['BG','OD','OC'])
plt.show()

# --- Visualize a few episodes ---
for _ in range(5):
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT + 1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask = load_mask(mask_map[key])
        s_imgs.append(tf_test(img))
        s_msks.append(torch.from_numpy(mask).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    q_pil = Image.open(img_files[qry_idx]).convert('RGB').resize((IMG_SIZE,IMG_SIZE))
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_test(q_pil).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
    pred = logits.argmax(1).squeeze(0).cpu().numpy()

    fig, ax = plt.subplots(1, 3, figsize=(12,4))
    ax[0].imshow(q_pil);           ax[0].set_title("Query Image"); ax[0].axis('off')
    ax[1].imshow(color_mask(q_gt)); ax[1].set_title("GT Mask");     ax[1].axis('off')
    ax[2].imshow(color_mask(pred)); ax[2].set_title("Pred Mask");   ax[2].axis('off')
    plt.show()


In [None]:
import os, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms
import kornia  # GPU‐accelerated image ops

# re‐use your Clahe module
class Clahe(torch.nn.Module):
    def __init__(self, clip_limit: float = 40.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = float(clip_limit)
        self.grid_size = grid_size

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        batch = img.unsqueeze(0)
        eq = kornia.enhance.equalize_clahe(
            batch,
            clip_limit=self.clip_limit,
            grid_size=self.grid_size
        )
        return eq.squeeze(0)

# --- Color mask for visualization ---
def color_mask(mask):
    h,w = mask.shape
    col = np.zeros((h,w,3),dtype=float)
    col[mask==0] = [1,1,0]   # BG  → yellow
    col[mask==1] = [0,1,0]   # OD  → green
    col[mask==2] = [1,0,0]   # OC  → red
    return col

# --- Paths & params ---
ORIGA_IMG_DIR  = '/kaggle/input/glaucoma-datasets/REFUGE/train/Images'
ORIGA_MSK_DIR  = '/kaggle/input/glaucoma-datasets/REFUGE/train/Masks'
MODEL_PATH     = '/kaggle/working/model_final.pth'
IMG_SIZE       = 128
NUM_SUPPORT    = 3
NUM_EPISODES   = 50
DEVICE         = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Transforms (with CLAHE) ---
tf_img = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    Clahe(clip_limit=40.0, grid_size=(8,8)),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- Load mask (0,1,2 encoding) ---
def load_mask(path):
    m = Image.open(path).convert('L')
    m = m.resize((IMG_SIZE,IMG_SIZE), Image.NEAREST)
    return np.array(m, dtype=int)

# --- Gather file lists ---
img_files = sorted([
    os.path.join(ORIGA_IMG_DIR,f)
    for f in os.listdir(ORIGA_IMG_DIR) if f.lower().endswith('.jpg')
])
mask_map = {
    os.path.splitext(f)[0]: os.path.join(ORIGA_MSK_DIR,f)
    for f in os.listdir(ORIGA_MSK_DIR) if f.lower().endswith('.png')
}

# --- Load model ---
model = ProtoAttentionUNet(in_ch=3, num_classes=3, feat_dim=256).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Dice metric ---
def dice_score(pred, gt, eps=1e-6):
    scores = []
    for c in range(3):
        p = (pred==c)
        g = (gt  ==c)
        inter = (p & g).sum()
        union = p.sum() + g.sum()
        scores.append((2*inter+eps)/(union+eps))
    return np.array(scores)

# --- Confusion matrix accumulator ---
conf_mat = np.zeros((3,3), dtype=int)
all_scores = []

# --- Meta-test episodes ---
for _ in tqdm(range(NUM_EPISODES), desc='Meta-Test'):
    # sample support + query
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT+1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    # build support tensors
    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img_pil = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask_arr = load_mask(mask_map[key])
        s_imgs.append(tf_img(img_pil))
        s_msks.append(torch.from_numpy(mask_arr).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    # load query
    q_pil = Image.open(img_files[qry_idx]).convert('RGB')
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_img(q_pil).unsqueeze(0).to(DEVICE)

    # inference + channel‐swap fix
    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
        logits = logits[:, [2,1,0], :, :]
        pred = logits.argmax(1).squeeze(0).cpu().numpy()

    # record dice
    sc = dice_score(pred, q_gt)
    all_scores.append(sc)

    # update confusion
    for i in range(3):
        for j in range(3):
            conf_mat[i,j] += int(((q_gt==i) & (pred==j)).sum())

# --- Report results ---
all_scores = np.stack(all_scores)
mean_dice = all_scores.mean(axis=0)
print("Mean per-class Dice:")
print(f"  BG: {mean_dice[0]:.4f}, OD: {mean_dice[1]:.4f}, OC: {mean_dice[2]:.4f}")
print(f"Overall Mean Dice: {mean_dice.mean():.4f}\n")

print("Confusion Matrix (rows=GT, cols=Pred):")
print(conf_mat)

# --- Plot confusion matrix ---
plt.figure(figsize=(5,5))
plt.imshow(conf_mat, cmap='Blues')
plt.title("Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("Ground Truth")
for i in range(3):
    for j in range(3):
        plt.text(j, i, conf_mat[i,j], ha='center', va='center', color='red')
plt.xticks([0,1,2], ['BG','OD','OC'])
plt.yticks([0,1,2], ['BG','OD','OC'])
plt.show()

# --- Show a few example episodes ---
for _ in range(5):
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT+1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img_pil = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask_arr = load_mask(mask_map[key])
        s_imgs.append(tf_img(img_pil))
        s_msks.append(torch.from_numpy(mask_arr).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    q_pil = Image.open(img_files[qry_idx]).convert('RGB').resize((IMG_SIZE,IMG_SIZE))
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_img(q_pil).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
        logits = logits[:, [2,1,0], :, :]
        pred = logits.argmax(1).squeeze(0).cpu().numpy()

    fig, ax = plt.subplots(1,3,figsize=(12,4))
    ax[0].imshow(q_pil);                   ax[0].set_title("Query");   ax[0].axis('off')
    ax[1].imshow(color_mask(q_gt));        ax[1].set_title("GT Mask"); ax[1].axis('off')
    ax[2].imshow(color_mask(pred));        ax[2].set_title("Pred Mask");ax[2].axis('off')
    plt.show()


In [None]:
import os, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms
import kornia  # GPU‐accelerated image ops

# re‐use your Clahe module
class Clahe(torch.nn.Module):
    def __init__(self, clip_limit: float = 40.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = float(clip_limit)
        self.grid_size = grid_size

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        batch = img.unsqueeze(0)
        eq = kornia.enhance.equalize_clahe(
            batch,
            clip_limit=self.clip_limit,
            grid_size=self.grid_size
        )
        return eq.squeeze(0)

# --- Color mask for visualization ---
def color_mask(mask):
    h,w = mask.shape
    col = np.zeros((h,w,3),dtype=float)
    col[mask==0] = [1,1,0]   # BG  → yellow
    col[mask==1] = [0,1,0]   # OD  → green
    col[mask==2] = [1,0,0]   # OC  → red
    return col

# --- Paths & params ---
ORIGA_IMG_DIR  = '/kaggle/input/glaucoma-datasets/ORIGA/Images'
ORIGA_MSK_DIR  = '/kaggle/input/glaucoma-datasets/ORIGA/Masks'
MODEL_PATH     = '/kaggle/working/model_final.pth'
IMG_SIZE       = 128
NUM_SUPPORT    = 3
NUM_EPISODES   = 100
DEVICE         = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Transforms (with CLAHE) ---
tf_img = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    Clahe(clip_limit=40.0, grid_size=(8,8)),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- Load mask (0,1,2 encoding) ---
def load_mask(path):
    m = Image.open(path).convert('L')
    m = m.resize((IMG_SIZE,IMG_SIZE), Image.NEAREST)
    return np.array(m, dtype=int)

# --- Gather file lists ---
img_files = sorted([
    os.path.join(ORIGA_IMG_DIR,f)
    for f in os.listdir(ORIGA_IMG_DIR) if f.lower().endswith('.jpg')
])
mask_map = {
    os.path.splitext(f)[0]: os.path.join(ORIGA_MSK_DIR,f)
    for f in os.listdir(ORIGA_MSK_DIR) if f.lower().endswith('.png')
}

# --- Load model ---
model = ProtoAttentionUNet(in_ch=3, num_classes=3, feat_dim=256).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Dice metric ---
def dice_score(pred, gt, eps=1e-6):
    scores = []
    for c in range(3):
        p = (pred==c)
        g = (gt  ==c)
        inter = (p & g).sum()
        union = p.sum() + g.sum()
        scores.append((2*inter+eps)/(union+eps))
    return np.array(scores)

# --- Confusion matrix accumulator ---
conf_mat = np.zeros((3,3), dtype=int)
all_scores = []

# --- Meta-test episodes ---
for _ in tqdm(range(NUM_EPISODES), desc='Meta-Test'):
    # sample support + query
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT+1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    # build support tensors
    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img_pil = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask_arr = load_mask(mask_map[key])
        s_imgs.append(tf_img(img_pil))
        s_msks.append(torch.from_numpy(mask_arr).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    # load query
    q_pil = Image.open(img_files[qry_idx]).convert('RGB')
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_img(q_pil).unsqueeze(0).to(DEVICE)

    # inference + channel‐swap fix
    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
        logits = logits[:, [2,1,0], :, :]
        pred = logits.argmax(1).squeeze(0).cpu().numpy()

    # record dice
    sc = dice_score(pred, q_gt)
    all_scores.append(sc)

    # update confusion
    for i in range(3):
        for j in range(3):
            conf_mat[i,j] += int(((q_gt==i) & (pred==j)).sum())

# --- Report results ---
all_scores = np.stack(all_scores)
mean_dice = all_scores.mean(axis=0)
print("Mean per-class Dice:")
print(f"  BG: {mean_dice[0]:.4f}, OD: {mean_dice[1]:.4f}, OC: {mean_dice[2]:.4f}")
print(f"Overall Mean Dice: {mean_dice.mean():.4f}\n")

print("Confusion Matrix (rows=GT, cols=Pred):")
print(conf_mat)

# --- Plot confusion matrix ---
plt.figure(figsize=(5,5))
plt.imshow(conf_mat, cmap='Blues')
plt.title("Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("Ground Truth")
for i in range(3):
    for j in range(3):
        plt.text(j, i, conf_mat[i,j], ha='center', va='center', color='red')
plt.xticks([0,1,2], ['BG','OD','OC'])
plt.yticks([0,1,2], ['BG','OD','OC'])
plt.show()

# --- Show a few example episodes ---
for _ in range(5):
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT+1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img_pil = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask_arr = load_mask(mask_map[key])
        s_imgs.append(tf_img(img_pil))
        s_msks.append(torch.from_numpy(mask_arr).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    q_pil = Image.open(img_files[qry_idx]).convert('RGB').resize((IMG_SIZE,IMG_SIZE))
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_img(q_pil).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
        logits = logits[:, [2,1,0], :, :]
        pred = logits.argmax(1).squeeze(0).cpu().numpy()

    fig, ax = plt.subplots(1,3,figsize=(12,4))
    ax[0].imshow(q_pil);                   ax[0].set_title("Query");   ax[0].axis('off')
    ax[1].imshow(color_mask(q_gt));        ax[1].set_title("GT Mask"); ax[1].axis('off')
    ax[2].imshow(color_mask(pred));        ax[2].set_title("Pred Mask");ax[2].axis('off')
    plt.show()


In [None]:
import os, random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torchvision import transforms
import kornia  # GPU‐accelerated image ops

# re‐use your Clahe module
class Clahe(torch.nn.Module):
    def __init__(self, clip_limit: float = 40.0, grid_size=(8, 8)):
        super().__init__()
        self.clip_limit = float(clip_limit)
        self.grid_size = grid_size

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        batch = img.unsqueeze(0)
        eq = kornia.enhance.equalize_clahe(
            batch,
            clip_limit=self.clip_limit,
            grid_size=self.grid_size
        )
        return eq.squeeze(0)

# --- Color mask for visualization ---
def color_mask(mask):
    h,w = mask.shape
    col = np.zeros((h,w,3),dtype=float)
    col[mask==0] = [1,1,0]   # BG  → yellow
    col[mask==1] = [0,1,0]   # OD  → green
    col[mask==2] = [1,0,0]   # OC  → red
    return col

# --- Paths & params ---
ORIGA_IMG_DIR  = '/kaggle/input/glaucoma-datasets/G1020/Images'
ORIGA_MSK_DIR  = '/kaggle/input/glaucoma-datasets/G1020/Masks'
MODEL_PATH     = '/kaggle/working/model_final.pth'
IMG_SIZE       = 128
NUM_SUPPORT    = 3
NUM_EPISODES   = 100
DEVICE         = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Transforms (with CLAHE) ---
tf_img = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    Clahe(clip_limit=40.0, grid_size=(8,8)),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# --- Load mask (0,1,2 encoding) ---
def load_mask(path):
    m = Image.open(path).convert('L')
    m = m.resize((IMG_SIZE,IMG_SIZE), Image.NEAREST)
    return np.array(m, dtype=int)

# --- Gather file lists ---
img_files = sorted([
    os.path.join(ORIGA_IMG_DIR,f)
    for f in os.listdir(ORIGA_IMG_DIR) if f.lower().endswith('.jpg')
])
mask_map = {
    os.path.splitext(f)[0]: os.path.join(ORIGA_MSK_DIR,f)
    for f in os.listdir(ORIGA_MSK_DIR) if f.lower().endswith('.png')
}

# --- Load model ---
model = ProtoAttentionUNet(in_ch=3, num_classes=3, feat_dim=256).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Dice metric ---
def dice_score(pred, gt, eps=1e-6):
    scores = []
    for c in range(3):
        p = (pred==c)
        g = (gt  ==c)
        inter = (p & g).sum()
        union = p.sum() + g.sum()
        scores.append((2*inter+eps)/(union+eps))
    return np.array(scores)

# --- Confusion matrix accumulator ---
conf_mat = np.zeros((3,3), dtype=int)
all_scores = []

# --- Meta-test episodes ---
for _ in tqdm(range(NUM_EPISODES), desc='Meta-Test'):
    # sample support + query
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT+1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    # build support tensors
    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img_pil = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask_arr = load_mask(mask_map[key])
        s_imgs.append(tf_img(img_pil))
        s_msks.append(torch.from_numpy(mask_arr).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    # load query
    q_pil = Image.open(img_files[qry_idx]).convert('RGB')
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_img(q_pil).unsqueeze(0).to(DEVICE)

    # inference + channel‐swap fix
    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
        logits = logits[:, [2,1,0], :, :]
        pred = logits.argmax(1).squeeze(0).cpu().numpy()

    # record dice
    sc = dice_score(pred, q_gt)
    all_scores.append(sc)

    # update confusion
    for i in range(3):
        for j in range(3):
            conf_mat[i,j] += int(((q_gt==i) & (pred==j)).sum())

# --- Report results ---
all_scores = np.stack(all_scores)
mean_dice = all_scores.mean(axis=0)
print("Mean per-class Dice:")
print(f"  BG: {mean_dice[0]:.4f}, OD: {mean_dice[1]:.4f}, OC: {mean_dice[2]:.4f}")
print(f"Overall Mean Dice: {mean_dice.mean():.4f}\n")

print("Confusion Matrix (rows=GT, cols=Pred):")
print(conf_mat)

# --- Plot confusion matrix ---
plt.figure(figsize=(5,5))
plt.imshow(conf_mat, cmap='Blues')
plt.title("Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("Ground Truth")
for i in range(3):
    for j in range(3):
        plt.text(j, i, conf_mat[i,j], ha='center', va='center', color='red')
plt.xticks([0,1,2], ['BG','OD','OC'])
plt.yticks([0,1,2], ['BG','OD','OC'])
plt.show()

# --- Show a few example episodes ---
for _ in range(5):
    idxs = random.sample(range(len(img_files)), NUM_SUPPORT+1)
    sup_idxs, qry_idx = idxs[:NUM_SUPPORT], idxs[-1]

    s_imgs, s_msks = [], []
    for i in sup_idxs:
        img_pil = Image.open(img_files[i]).convert('RGB')
        key = os.path.splitext(os.path.basename(img_files[i]))[0]
        mask_arr = load_mask(mask_map[key])
        s_imgs.append(tf_img(img_pil))
        s_msks.append(torch.from_numpy(mask_arr).long())
    s_imgs = torch.stack(s_imgs).to(DEVICE)
    s_msks = torch.stack(s_msks).to(DEVICE)

    q_pil = Image.open(img_files[qry_idx]).convert('RGB').resize((IMG_SIZE,IMG_SIZE))
    key = os.path.splitext(os.path.basename(img_files[qry_idx]))[0]
    q_gt = load_mask(mask_map[key])
    q_tensor = tf_img(q_pil).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        logits, _ = model(s_imgs, s_msks, q_tensor)
        logits = logits[:, [2,1,0], :, :]
        pred = logits.argmax(1).squeeze(0).cpu().numpy()

    fig, ax = plt.subplots(1,3,figsize=(12,4))
    ax[0].imshow(q_pil);                   ax[0].set_title("Query");   ax[0].axis('off')
    ax[1].imshow(color_mask(q_gt));        ax[1].set_title("GT Mask"); ax[1].axis('off')
    ax[2].imshow(color_mask(pred));        ax[2].set_title("Pred Mask");ax[2].axis('off')
    plt.show()
