# Importing libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from typing import List, Callable, Optional, Dict
from itertools import cycle
from dataclasses import dataclass
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import shutil
import pandas as pd

# Bayesian UNet

## Bayesian UNet Parts

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class EncoderBloc(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.3):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout2d(p=dropout_prob)

    def forward(self, x):
        x_conv = self.conv(x)
        x_down = self.pool(x_conv)
        x_down = self.dropout(x_down)
        return x_down, x_conv

class DecoderBloc(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)


## Bayesian UNet Implementation

In [None]:
class BayesianUNet(nn.Module):
    def __init__(self, in_channels, num_classes, dropout_prob):
        super().__init__()

        # Encoder (Down sampling)
        self.encoder_bloc_1 = EncoderBloc(in_channels, 64, dropout_prob)
        self.encoder_bloc_2 = EncoderBloc(64, 128, dropout_prob)
        self.encoder_bloc_3 = EncoderBloc(128, 256, dropout_prob)
        self.encoder_bloc_4 = EncoderBloc(256, 512, dropout_prob)

        # Bottleneck
        self.bottle_neck = DoubleConv(512, 1024)

        # Decoder (Up sampling)
        self.decoder_bloc_1 = DecoderBloc(1024, 512)
        self.decoder_bloc_2 = DecoderBloc(512, 256)
        self.decoder_bloc_3 = DecoderBloc(256, 128)
        self.decoder_bloc_4 = DecoderBloc(128, 64)

        # Final 1x1 conv
        self.out = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x2, x1 = self.encoder_bloc_1(x)
        x3, x2 = self.encoder_bloc_2(x2)
        x4, x3 = self.encoder_bloc_3(x3)
        x5, x4 = self.encoder_bloc_4(x4)

        # Bottleneck

        bn = self.bottle_neck(x5)

        # Decoder
        u1 = self.decoder_bloc_1(bn, x4)
        u2 = self.decoder_bloc_2(u1, x3)
        u3 = self.decoder_bloc_3(u2, x2)
        u4 = self.decoder_bloc_4(u3, x1)

        # Output
        return self.out(u4)

## Testing Bayesian UNet

In [None]:
model = BayesianUNet(3, 5, 0.5)
x = torch.randn(1, 3, 256, 256)
pred = model(x)
print(pred.shape)

# Augmentation

In [None]:
def add_gaussian_noise(
    tensor: torch.Tensor,
    mean: float = 0.0,
    std: float = 1.0
) -> torch.Tensor:
    noise = torch.randn_like(tensor) * std + mean
    return (tensor + noise).clamp(0.0, 1.0)


def apply_geometric(x, aug_dict, flip_axis, rot_axis0, rot_axis1):
    if aug_dict["flip"]:
        x = torch.flip(x, [flip_axis])
    if aug_dict["rot"]:
        x = torch.rot90(x, aug_dict["rot"], dims=(rot_axis0,rot_axis1))
    return x

def apply_photometric(x, aug_dict):
    if aug_dict["jitter"] > 0:
        jitter_tf = T.ColorJitter(
            brightness=aug_dict["jitter"],
            contrast=aug_dict["jitter"],
            saturation=aug_dict["jitter"],
            hue=aug_dict["jitter"]
        )
        x = jitter_tf(x)
    if aug_dict["blur"]   > 1:
        blur_tf = T.GaussianBlur(
            kernel_size=(aug_dict["blur"], aug_dict["blur"]),
            sigma=(0.1, float(aug_dict["blur"]))
        )

        x = blur_tf(x)
    x = add_gaussian_noise(x, aug_dict["noise_mean"], aug_dict["noise_std"])
    return x

def augment_image(
    x: torch.Tensor
) -> (torch.Tensor, dict):

    flip = bool(random.randint(0, 1))
    rot = random.randint(0, 3)

    aug_geometric_dict = {
        "flip": flip,
        "rot": rot
    }

    aug_photometric_dict = {
        "jitter": 0.3,
        "blur": 3,
        "noise_mean": 0.0,
        "noise_std": 0.1
    }

    x_aug = apply_geometric(x, aug_geometric_dict, flip_axis = 2, rot_axis0 = 1, rot_axis1 = 2)
    x_aug = apply_photometric(x_aug, aug_photometric_dict)

    return x_aug, aug_geometric_dict

def augment_mask(
    x: torch.tensor,
    aug_dict: dict
):
    x_aug = apply_geometric(x, aug_dict, flip_axis=1, rot_axis0=0, rot_axis1=1)

    return x_aug

def reverse_augmentations(
    x: torch.Tensor,
    aug_dict: dict,
    flip_axis: int = 1,
    rot_axis0: int = 0,
    rot_axis1: int = 1
):
    # 1) undo rotation
    k = aug_dict["rot"]
    if k:
        inv_k = (4 - k) % 4
        x = torch.rot90(x, k=inv_k, dims=(rot_axis0, rot_axis1))

    # 2) undo flip
    if aug_dict["flip"]:
        x = torch.flip(x, dims=[flip_axis])
    return x

# Data Loading

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import random
from PIL import Image, ImageOps
import numpy as np

Classes = {
    (51, 221, 255): 0,  # ICM
    (250, 50, 83): 1,   # TE
    (61, 245, 61): 2,   # ZP
    (255, 245, 61): 3,  # BL
    (0, 0, 0): 4,       # background
}

TARGET_SIZE = (256, 256)  # (W, H)


# Label Encoding
def mask_encoding(arr):
    h, w, _ = arr.shape
    class_mask = np.zeros((h, w), dtype=np.uint8)
    for rgb, idx in Classes.items():
        class_mask[np.all(arr == rgb, axis=-1)] = idx
    return class_mask


class BlastocystDataset(Dataset):
    def __init__(self, image_dir, mask_dir, seed=None, augment=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.seed = seed
        self.augment = augment

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)

        core = os.path.splitext(img_name)[0]
        mask_path = os.path.join(self.mask_dir, core + ".png")


        img = Image.open(img_path).convert("RGB")
        img = ImageOps.pad(img, TARGET_SIZE, method=Image.BILINEAR, color=(0, 0, 0))
        img = np.array(img, np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)

        mask_rgb = Image.open(mask_path).convert("RGB")
        mask_rgb = ImageOps.pad(mask_rgb, TARGET_SIZE, method=Image.NEAREST, color=(0, 0, 0))
        mask_arr = np.array(mask_rgb)
        mask = mask_encoding(mask_arr)  # [H,W] uint8
        mask = torch.from_numpy(mask).long()

        if self.augment:
            img, aug_dict = augment_image(img)
            mask = augment_mask(mask, aug_dict)

        return img, mask, img_name

class UnlabeledBlastocystDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_filenames = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, img_name)

        img = Image.open(img_path).convert("RGB")
        img = ImageOps.pad(img, TARGET_SIZE, method=Image.BILINEAR, color=(0, 0, 0))
        img = np.array(img, np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)

        return img, img_name




def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_loaders_active(
        labeled_img_dir,
        labeled_mask_dir,
        unlabeled_img_dir,
        test_img_dir,
        test_mask_dir,
        batch_size,
        seed=None,
        augment=False,
        generator=None,
        num_workers=4,
        pin_memory=True,
):
    # 1. Labeled Dataset (with masks)
    labeled_ds = BlastocystDataset(
        image_dir=labeled_img_dir,
        mask_dir=labeled_mask_dir,
        seed=seed,
        augment=augment
    )

    # 2. Unlabeled Dataset (images only)
    unlabeled_ds = UnlabeledBlastocystDataset(
        image_dir=unlabeled_img_dir
    )

    # 3. Test Dataset (with masks)
    test_ds = BlastocystDataset(
        image_dir=test_img_dir,
        mask_dir=test_mask_dir,
        seed=seed,
        augment=False
    )

    # Create loaders
    labeled_loader = DataLoader(
        labeled_ds,
        batch_size=batch_size,
        shuffle=True,
        generator=generator,
        worker_init_fn=seed_worker,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True  # Helps with batch normalization
    )

    unlabeled_loader = DataLoader(
        unlabeled_ds,
        batch_size=batch_size,
        shuffle=False,  # Important for sample tracking
        worker_init_fn=seed_worker,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        worker_init_fn=seed_worker,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    return labeled_loader, unlabeled_loader, test_loader

## Visualizing Image and Mask after augmentation 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# --- build a discrete colormap directly from your Classes dict ---
sorted_classes = sorted(Classes.items(), key=lambda x: x[1])  # sort by class ID
colors = [np.array(rgb)/255.0 for rgb, _ in sorted_classes]
cmap = ListedColormap(colors)
norm = BoundaryNorm(boundaries=np.arange(-0.5, len(colors)+0.5, 1), ncolors=len(colors))

# --- create dataset (set augment=True if you want augmentations applied) ---
labeled_img_dir  = "/kaggle/input/dataset/images"
labeled_mask_dir = "/kaggle/input/dataset/masks"

ds = BlastocystDataset(
    image_dir=labeled_img_dir,
    mask_dir=labeled_mask_dir,
    augment=True
)

# --- get one sample ---
img, mask, name = ds[0]   # tensors
print("Sample:", name, "| Image:", tuple(img.shape), "| Mask:", tuple(mask.shape))

# --- convert to numpy for plotting ---
img_np  = img.permute(1, 2, 0).cpu().numpy()
mask_np = mask.cpu().numpy()

# --- plot ---
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(img_np)
plt.title("Image")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(mask_np, cmap=cmap, norm=norm, interpolation="nearest")
plt.title("Mask")
plt.axis("off")

plt.show()

# TAAL Core

## Jensen–Shannon Divergence

In [None]:
def tensor_entropy(probs: torch.Tensor, dim: int, eps: float=1e-8) -> torch.Tensor:
    return -torch.sum(probs * torch.log(probs + eps), dim=dim)

def entropy_of_average_batch(
    prob_maps: torch.Tensor
) -> torch.Tensor:
    # prob_maps: [K, B, C, H, W]
    M = prob_maps.mean(dim=0)         # [B, C, H, W]
    return tensor_entropy(M, dim=1)          # [B, H, W]

def average_entropy_batch(
    prob_maps: torch.Tensor
) -> torch.Tensor:
    # prob_maps: [K, B, C, H, W]
    H_each = tensor_entropy(prob_maps, dim=2)  # sum over C → [K, B, H, W]
    return H_each.mean(dim=0)           # average over K → [B, H, W]

def JSD_batch(
    logits_list: List[torch.Tensor],
    alpha: float = 0.5
) -> torch.Tensor:
    """
    logits_list: list of K tensors, each [B, C, H, W]
    returns: [B] JSD scalar per image in the batch
    """
    # 1) Stack into [K, B, C, H, W]
    prob_maps = torch.stack([
        F.softmax(logits, dim=1)  # softmax over C
        for logits in logits_list
    ], dim=0)

    # 2) Compute entropies [B, H, W]
    H_M   = entropy_of_average_batch(prob_maps)
    H_avg = average_entropy_batch   (prob_maps)

    # 3) JSD map [B, H, W]
    jsd_map = alpha * H_M - (1.0 - alpha) * H_avg

    # 4) Reduce spatially to get one scalar per batch element
    return jsd_map.mean(dim=[1,2])   # → [B]

## Augment Predict Reverse

In [None]:
def augment_predict_reverse(
    model: torch.nn.Module,
    images: torch.Tensor,        # [B, C_in, H, W] (can be on CPU or GPU)
    K: int = 3,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    for_training: bool = True     # True → train consistency (grads ON), False → eval/AL scoring (grads OFF)
) -> torch.Tensor:
    B, C_in, H, W = images.shape
    model = model.to(device)

    # TAAL behavior:
    #  - training consistency: train mode, gradients enabled
    #  - AL/TTA scoring: eval mode, no gradients
    model.train() if for_training else model.eval()

    all_rev_logits = []
    ctx = torch.set_grad_enabled(for_training)
    with ctx:
        for _ in range(K):
            batch_aug, batch_params = [], []

            # Do augmentations on CPU (safer for ColorJitter/Blur), then move once to device
            for img in images:
                img_cpu = img.detach().to("cpu")                 # ensure CPU for torchvision augs
                img_aug, params = augment_image(img_cpu)         # your function
                batch_aug.append(img_aug)
                batch_params.append(params)

            batch_aug = torch.stack(batch_aug, dim=0).to(device, non_blocking=True)  # [B,C_in,H,W]
            logits = model(batch_aug)                                                # [B,C_out,H,W]

            # reverse per-sample geometry on logits
            rev_logits = []
            for b in range(B):
                rev = reverse_augmentations(
                    logits[b],
                    batch_params[b],
                    flip_axis=2,   # W in CHW
                    rot_axis0=1,   # H
                    rot_axis1=2    # W
                )  # [C_out,H,W]
                rev_logits.append(rev)
            rev_logits = torch.stack(rev_logits, dim=0)  # [B,C_out,H,W]
            all_rev_logits.append(rev_logits)

    # stack & permute → [B,K,C_out,H,W]
    all_rev_logits = torch.stack(all_rev_logits, dim=0)             # [K,B,C_out,H,W]
    all_rev_logits = all_rev_logits.permute(1,0,2,3,4).contiguous() # [B,K,C_out,H,W]
    return all_rev_logits


## JSD Consistency Loss

In [None]:
def jsd_consistency_batch(
    model: torch.nn.Module,
    imgs_U: torch.Tensor,   # [B,C,H,W]
    K: int = 3,
    alpha: float = 0.5,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> torch.Tensor:
    """
    Train-time JSD consistency loss (scalar).
    Uses augment_predict_reverse with gradients ON and model.train().
    """
    # Keep inputs on CPU to avoid GPU->CPU copies inside augment_predict_reverse
    if imgs_U.is_cuda:
        imgs_U = imgs_U.detach().cpu()

    logits_tta = augment_predict_reverse(
        model, imgs_U, K=K, device=device, for_training=True
    )  # [B,K,C,H,W]

    logits_list = [logits_tta[:, k] for k in range(K)]  # K × [B,C,H,W]
    loss_per_img = JSD_batch(logits_list, alpha=alpha)  # [B]
    return loss_per_img.mean()

## JSD Score TTA 

In [None]:
@torch.no_grad()
def jsd_score_tta(
    model: torch.nn.Module,
    loader,                      # DataLoader (unlabeled_loader)
    K: int = 3,
    alpha: float = 0.5,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> dict:
    """
    TAAL acquisition scoring (eval mode, no grad).
    Accepts a DataLoader that yields (imgs, names).
    Returns: dict {filename: score}
    """
    model.eval()
    score_dict = {}

    for imgs, names in loader:         # imgs [B,C,H,W], names list of strings
        if imgs.is_cuda:
            imgs = imgs.detach().cpu()

        logits_tta = augment_predict_reverse(
            model, imgs, K=K, device=device, for_training=False
        )  # [B,K,C,H,W]

        logits_list = [logits_tta[:, k] for k in range(K)]  # K × [B,C,H,W]
        scores = JSD_batch(logits_list, alpha=alpha)        # [B]

        for n, s in zip(names, scores.tolist()):
            score_dict[n] = s

    return score_dict

# Training Core

## Loss

In [None]:
class TverskyLossNoBG(nn.Module):
    def __init__(self, alpha, beta, smooth=1e-6, bg_idx=4):
        super().__init__()
        self.alpha, self.beta = alpha, beta
        self.smooth = smooth
        self.bg_idx = bg_idx

    def forward(self, logits, target):
        B,C,H,W = logits.shape
        probs   = F.softmax(logits, dim=1)
        oh      = F.one_hot(target.clamp(0,C-1), C).permute(0,3,1,2).float()
        dims    = (0,2,3)
        TP = (probs * oh).sum(dims)
        FP = (probs * (1 - oh)).sum(dims)
        FN = ((1 - probs) * oh).sum(dims)
        # exclude background
        mask = torch.ones(C, dtype=torch.bool, device=logits.device)
        mask[self.bg_idx] = False
        TP,FP,FN = TP[mask], FP[mask], FN[mask]
        TI = (TP + self.smooth) / (TP + self.alpha*FP + self.beta*FN + self.smooth)
        return (1 - TI).mean()

## Metrics

In [None]:
def pixel_accuracy(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    # preds and targets are of shape (N, H, W)
    correct = (preds == targets).sum() # Shape: ()
    total   = preds.numel()
    return correct.float() / total # Shape: ()

def macro_dice(preds: torch.Tensor,
               targets: torch.Tensor,
               num_classes: int,
               eps: float = 1e-6) -> torch.Tensor:
    # preds and targets are of shape (N, H, W)
    # One-hot encode to (N, H, W, C)
    p_oh = F.one_hot(preds,   num_classes).permute(0,3,1,2).float() # Shape: (N, H, W, C)
    t_oh = F.one_hot(targets, num_classes).permute(0,3,1,2).float() # Shape: (N, H, W, C)

    # Sum over batch+spatial dims → (C,)
    dims = (0, 2, 3)
    inter = (p_oh * t_oh).sum(dims) # Shape: (C,)
    union = p_oh.sum(dims) + t_oh.sum(dims) # Shape: (C,)

    dice_per_class = (2 * inter + eps) / (union + eps) # Shape: (C,)
    return dice_per_class.mean() # Shape: ()

## Sigmoid Ramup

In [None]:
def sigmoid_rampup(current: float, rampup_length: int, max_value: float = 1.0) -> float:
    """Smoothly ramps a value from 0 to `max_value` using a Gaussian curve."""
    if rampup_length <= 0:
        return max_value
    current = np.clip(current, 0.0, rampup_length)
    phase = 1.0 - current / rampup_length
    return max_value * np.exp(-5.0 * phase * phase)

## Train Eval

In [None]:
@dataclass(frozen=True)
class SSLConfig:
    loss: Callable
    K: int = 3
    alpha: float = 0.5
    use: bool = True


def train_one_epoch(
    labeled_loader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_sup_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    device: str = "cuda",
    num_classes: int = 5,
    scaler: Optional[GradScaler] = None,
    scheduler=None,
    step_scheduler_on_batch: bool = False,
    # ---- SSL (optional) ------------------------------------
    unlabeled_loader=None,                  # pass only if using SSL
    ssl: Optional[SSLConfig] = None,        # config for the SSL loss
    ssl_weight: float = 0.0,                # λ (ramp per epoch outside)
) -> Dict[str, float]:
    device = torch.device(device)
    scaler = scaler or GradScaler()
    model.train()

    use_ssl = (
        ssl is not None
        and ssl.use
        and ssl_weight > 0.0
        and (unlabeled_loader is not None)
        and callable(ssl.loss)
    )

    ul_iter = cycle(unlabeled_loader) if use_ssl else None

    tot_loss = tot_acc = tot_dice = 0.0
    n_batches = 0

    pbar = tqdm(labeled_loader, desc="Train", leave=False)
    for imgs_L, masks_L, *_ in pbar:
        imgs_L  = imgs_L.float().to(device, non_blocking=True)
        masks_L = masks_L.long().to(device,  non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        with autocast():
            # supervised branch
            logits_L = model(imgs_L)
            loss_sup = loss_sup_fn(logits_L, masks_L)

            # optional SSL branch
            loss_uns = torch.zeros((), device=device)
            if use_ssl:
                imgs_U, *_ = next(ul_iter)
                imgs_U = imgs_U.float().to(device, non_blocking=True)
                loss_uns = ssl.loss(model, imgs_U, K=ssl.K, alpha=ssl.alpha, device=device)

            loss = loss_sup + (ssl_weight * loss_uns)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if scheduler and step_scheduler_on_batch:
            scheduler.step()

        # metrics on the labeled stream
        with torch.no_grad():
            preds      = logits_L.argmax(1)
            batch_acc  = pixel_accuracy(preds, masks_L)
            batch_dice = macro_dice(preds, masks_L, num_classes)

        n_batches += 1
        tot_loss  += loss.item()
        tot_acc   += batch_acc.item()
        tot_dice  += batch_dice.item()

        pbar.set_postfix({
            "sup":  loss_sup.item(),
            "uns":  loss_uns.item() if use_ssl else 0.0,
            "λ":    ssl_weight if use_ssl else 0.0,
            "loss": tot_loss / n_batches,
            "dice": tot_dice / n_batches,
            "lr":   optimizer.param_groups[0]["lr"],
        })
    pbar.close()

    if n_batches == 0:
        raise ValueError("Labeled loader is empty – nothing to train on.")

    # Step epoch-based schedulers here (not ReduceLROnPlateau)
    if scheduler and not step_scheduler_on_batch and not isinstance(
        scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
    ):
        scheduler.step()

    return {
        "loss": tot_loss / n_batches,
        "acc%": (tot_acc / n_batches) * 100.0,
        "dice": tot_dice / n_batches,
    }


@torch.no_grad()
def evaluate_loader(loader, model, device="cuda", num_classes=5):
    model.eval()  # inference mode
    device = torch.device(device)

    tot_acc, tot_dice, n_batches = 0.0, 0.0, 0

    pbar = tqdm(loader, desc="Eval", leave=False)
    for imgs, masks, *_ in pbar:
        imgs = imgs.float().to(device, non_blocking=True)
        masks = masks.long().to(device, non_blocking=True)

        logits = model(imgs)  # (N, C, H, W)
        preds = logits.argmax(1)  # (N, H, W)

        batch_acc = pixel_accuracy(preds, masks)
        batch_dice = macro_dice(preds, masks, num_classes)

        # accumulate
        tot_acc += batch_acc.item()
        tot_dice += batch_dice.item()
        n_batches += 1

        pbar.set_postfix({"acc%": tot_acc / n_batches,
                          "dice": tot_dice / n_batches})
    pbar.close()

    if n_batches == 0:
        raise ValueError("Loader is empty – nothing to evaluate.")

    return tot_acc / n_batches, tot_dice / n_batches

# Active Learning

## Acquisition Functions

In [None]:
def random_score(model, imgs, **kwargs):
    return torch.rand(imgs.size(0), device=imgs.device) # Shape: (B,)

def entropy(model, imgs, T=8, num_classes=5):
    model.train() # Keep dropout ON for stochasticity
    probs_sum = torch.zeros(
        imgs.size(0), num_classes, *imgs.shape[2:], device=imgs.device
    ) # Shape: (B, C, H, W)

    for _ in range(T):
        with torch.amp.autocast('cuda'):
            logits = model(imgs) # Shape: (B, C, H, W)
            probs  = F.softmax(logits, 1) # Shape: (B, C, H, W)
        probs_sum += probs

    probs_mean = probs_sum / T # Shape: (B, C, H, W)
    ent = -(probs_mean * probs_mean.log()).sum(dim=1) # Shape: (B, H, W)
    return ent.sum(dim=(1, 2)) # Shape: (B,)


def BALD(model, imgs, T=8, num_classes=4):
    model.train()
    probs_sum = torch.zeros(
        imgs.size(0), num_classes, *imgs.shape[2:], device=imgs.device
    ) # Shape: (B, C, H, W)
    entropies_sum = torch.zeros(
        imgs.size(0), *imgs.shape[2:], device=imgs.device
    ) # Shape: (B, H, W)

    for _ in range(T):
        with torch.no_grad():
            with torch.amp.autocast('cuda'):
                logits = model(imgs) # Shape: (B, C, H, W)
                probs = F.softmax(logits, dim=1) # Shape: (B, C, H, W)

        # Accumulate probabilities for calculating predictive entropy
        probs_sum += probs

        # Calculate entropy of the current prediction and accumulate it
        # H[y|x, Θ_t] = -Σ_c (p_c * log(p_c))
        entropy_t = -(probs * torch.log(probs + 1e-12)).sum(dim=1) # Shape: (B, H, W)
        entropies_sum += entropy_t

    # 1. Calculate Predictive Entropy: H[y|x]
    probs_mean = probs_sum / T # Shape: (B, C, H, W)
    predictive_entropy = -(probs_mean * torch.log(probs_mean + 1e-12)).sum(dim=1) # Shape: (B, H, W)

    # 2. Calculate Expected Entropy: E[H[y|x, Θ]]
    expected_entropy = entropies_sum / T # Shape: (B, H, W)

    # 3. Compute BALD score for each pixel
    # I(y; Θ|x) = H[y|x] - E[H[y|x, Θ]]
    bald_map = predictive_entropy - expected_entropy # Shape: (B, H, W)

    return bald_map.mean(dim=(1, 2)) # Shape: (B,)



def committee_kl_divergence(model, imgs, T=8, num_classes=4):
    B, _, H, W = imgs.shape
    device     = imgs.device

    # 1) Monte Carlo posterior under dropout
    model.train()
    all_probs = torch.zeros(T, B, num_classes, H, W, device=device) # Shape: (T, B, C, H, W)
    with torch.no_grad(), torch.amp.autocast('cuda'):
        for i in range(T):
            logits      = model(imgs) # Shape: (B, C, H, W)
            all_probs[i] = F.softmax(logits, dim=1) # Shape: (B, C, H, W)
    posterior = all_probs.mean(dim=0) # Shape: (B, C, H, W)

    # 2) Deterministic “standard” prediction
    model.eval()   # Turn OFF dropout here
    with torch.no_grad(), torch.amp.autocast('cuda'):
        logits        = model(imgs) # Shape: (B, C, H, W)
        standard_probs = F.softmax(logits, dim=1) # Shape: (B, C, H, W)

    # 3) Compute per-pixel KL(standard || posterior)
    eps      = 1e-12
    P        = torch.clamp(standard_probs,   min=eps)
    Q        = torch.clamp(posterior,        min=eps)
    kl_map   = P * (P.log() - Q.log()) # Shape: (B, C, H, W)
    kl_pixel = kl_map.sum(dim=1) # Shape: (B, H, W)
    kl_score = kl_pixel.mean(dim=(1, 2)) # Shape: (B,)

    return kl_score

def committee_js_divergence(model, imgs, T=8, num_classes=4):
    B, _, H, W = imgs.shape
    device = imgs.device

    # 1) Monte Carlo posterior Q
    model.train()  # Keep dropout on
    all_probs = torch.zeros(T, B, num_classes, H, W, device=device) # Shape: (T, B, C, H, W)
    with torch.no_grad(), torch.amp.autocast('cuda'):
        for i in range(T):
            logits = model(imgs) # Shape: (B, C, H, W)
            all_probs[i] = F.softmax(logits, dim=1) # Shape: (B, C, H, W)
    Q = all_probs.mean(dim=0) # Shape: (B, C, H, W)

    # 2) Deterministic standard prediction p
    model.eval()  # Turn dropout off
    with torch.no_grad(), torch.amp.autocast('cuda'):
        logits = model(imgs) # Shape: (B, C, H, W)
        p = F.softmax(logits, dim=1) # Shape: (B, C, H, W)

    # 3) Build mixture M and clamp for numerical stability
    eps      = 1e-12
    p = torch.clamp(p, min=eps)
    Q = torch.clamp(Q, min=eps)
    M = torch.clamp(0.5 * (p + Q), min=eps)

    # 4) Compute ½ KL(p‖M) + ½ KL(Q‖M) per pixel
    kl_p_m = p * (p.log() - M.log()) # Shape: (B, C, H, W)
    kl_q_m = Q * (Q.log() - M.log()) # Shape: (B, C, H, W)
    js_map   = 0.5 * (kl_p_m + kl_q_m).sum(dim=1) # Shape: (B, H, W)
    js_score = js_map.mean(dim=(1, 2)) # Shape: (B,)

    return js_score

def taal_unweighted_score(model, U, device, K=3):
    return jsd_score_tta(model, U, K=K, alpha=0.5, device=device)  # [B]

def taal_weighted_score(model, U, device, K=3):
    return jsd_score_tta(model, U, K=K, alpha=0.75, device=device) # [B]

## Active Learning Utilis

In [None]:
def create_active_learning_pools(
        BASE_DIR,
        label_split_ratio=0.1,
        test_split_ratio=0.2,
        shuffle=True
):
    # Create directories
    dirs = {
        'labeled_img': os.path.join(BASE_DIR, "Labeled_pool", 'labeled_images'),
        'labeled_mask': os.path.join(BASE_DIR, "Labeled_pool", 'labeled_masks'),
        'unlabeled_img': os.path.join(BASE_DIR, "Unlabeled_pool", 'unlabeled_images'),
        'unlabeled_mask': os.path.join(BASE_DIR, "Unlabeled_pool", 'unlabeled_masks'),
        'test_img': os.path.join(BASE_DIR, "Test", 'test_images'),
        'test_mask': os.path.join(BASE_DIR, "Test", 'test_masks')
    }

    dirs["labeled_img_dir"] = dirs["labeled_img"]
    dirs["labeled_mask_dir"] = dirs["labeled_mask"]
    dirs["unlabeled_img_dir"] = dirs["unlabeled_img"]
    dirs["test_img_dir"] = dirs["test_img"]
    dirs["test_mask_dir"] = dirs["test_mask"]

    for path in dirs.values():
        os.makedirs(path, exist_ok=True)

    # Get image list
    img_dir = os.path.join(BASE_DIR, 'images')
    images = sorted([f for f in os.listdir(img_dir) if f.lower().endswith("bmp")])

    if shuffle:
        random.shuffle(images)

    # Split images
    n_test = int(len(images) * test_split_ratio)
    n_labeled = int(len(images) * label_split_ratio)

    test_split = images[:n_test]
    labeled_split = images[n_test:n_test + n_labeled]
    unlabeled_split = images[n_test + n_labeled:]

    def copy_files(file_list, img_dest, mask_dest):

        for im in file_list:
            base_name = os.path.splitext(im)[0]

            # Copy image
            src_img = os.path.join(img_dir, im)
            dst_img = os.path.join(img_dest, im)
            shutil.copy(src_img, dst_img)

            # Copy mask
            mask_file = f"{base_name}.png"
            src_mask = os.path.join(BASE_DIR, 'masks', mask_file)
            dst_mask = os.path.join(mask_dest, mask_file)

            if os.path.exists(src_mask):
                shutil.copy(src_mask, dst_mask)
            else:
                print(f"Warning: Mask not found for {im} - {src_mask}")

    copy_files(test_split, dirs['test_img'], dirs['test_mask'])
    copy_files(labeled_split, dirs['labeled_img'], dirs['labeled_mask'])
    copy_files(unlabeled_split, dirs['unlabeled_img'], dirs['unlabeled_mask'])

    return dirs

def reset_data(base_dir):
    # Directories to remove
    dirs_to_remove = [
        os.path.join(base_dir, "Labeled_pool"),
        os.path.join(base_dir, "Unlabeled_pool"),
        os.path.join(base_dir, "Test")
    ]

    for dir_path in dirs_to_remove:
        if os.path.exists(dir_path):
            shutil.rmtree(dir_path)


def move_images_with_dict(
        base_dir: str,
        labeled_dir: str,
        unlabeled_dir: str,
        score_dict: dict,
        num_to_move: int = 2
):
    # Sort by descending uncertainty (most uncertain first)
    sorted_items = sorted(score_dict.items(), key=lambda x: x[1], reverse=True)

    moved = 0
    for im, score in sorted_items:
        if moved >= num_to_move:
            break

        # Clean filename and get base name
        im_clean = im.strip()
        base_name = os.path.splitext(im_clean)[0]

        # Image paths
        src_im = os.path.join(base_dir, unlabeled_dir, "unlabeled_images", im_clean)
        dst_im = os.path.join(base_dir, labeled_dir, "labeled_images", im_clean)

        # Mask paths
        mask_name = base_name + ".png"
        src_msk = os.path.join(base_dir, unlabeled_dir, "unlabeled_masks", mask_name)
        dst_msk = os.path.join(base_dir, labeled_dir, "labeled_masks", mask_name)

        # Verify image exists
        if not os.path.exists(src_im):
            print(f"[WARN] Image not found: {src_im}")
            continue

        # Move image
        shutil.copy(src_im, dst_im)
        os.remove(src_im)
        print(f"[MOVE] IMAGE {im_clean} (Uncertainty: {score:.4f})")

        # Move mask if exists
        if os.path.exists(src_msk):
            shutil.copy(src_msk, dst_msk)
            os.remove(src_msk)
            print(f"[MOVE]  MASK {mask_name}")
        else:
            print(f"[WARN] Mask not found: {src_msk}")

        moved += 1

    print(f"Moved {moved} most uncertain images from {unlabeled_dir} → {labeled_dir}.")

def score_unlabeled_pool(unlabeled_loader, model, score_fn, num_classes=5, device="cuda"):
    scores, fnames = [], []
    with torch.no_grad():
        for imgs, names in tqdm(unlabeled_loader, desc="Scoring", leave=False):
            imgs = imgs.to(device)
            s = score_fn(model, imgs, num_classes=num_classes)
            scores.extend(s.cpu().tolist())
            fnames.extend(names)
    return dict(zip(fnames, scores))

## Active Learning Loop

# Benchmarking

In [None]:
ACQ_FUNCS = {
    "random":          random_score,
    "entropy":         entropy,
    "bald":            BALD,
    "kl-divergence":   committee_kl_divergence,
    "js-divergence":   committee_js_divergence,
    "taal-unweighted": taal_unweighted_score,
    "taal":            taal_weighted_score
}


def active_learning_loop(
        BASE_DIR: str,
        LABEL_SPLIT_RATIO: float = .1,
        TEST_SPLIT_RATIO: float = .2,
        augment: bool = False,
        sample_size: int = 2,
        acquisition_type: str = "js-divergence",
        mc_runs: int = 8,
        dropout = 0.3,
        batch_size: int = 16,
        lr: float = 1e-3,
        seed: int | None = None,
        loop_iterations: int | None = None,  # set None to disable
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        # early-stopping inside each fine-tune
        patience: int = 15,
        min_delta: float = 1e-4,
        # SSL schedule for TAAL
        ssl_lambda_max: float = 1.0,
        ssl_ramp_epochs: int = 10,
):
    """Active learning loop that supports supervised (entropy/KL/JS/etc.) and TAAL SSL."""
    # ─────────────────── housekeeping ────────────────────────
    reset_data(BASE_DIR)

    g = torch.Generator()
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        g.manual_seed(seed)

    dirs = create_active_learning_pools(
        BASE_DIR, LABEL_SPLIT_RATIO, TEST_SPLIT_RATIO, shuffle=True
    )
    acq = acquisition_type.lower()
    scorer = ACQ_FUNCS[acq]
    ckpt_dir = os.path.join(BASE_DIR, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    # ─────────────────── model built once ────────────────────
    model = BayesianUNet(in_channels=3, num_classes=5, dropout_prob=dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    iteration = 0
    log: list[dict] = []
    total_train = (
        len(os.listdir(dirs["labeled_img"])) +  # currently labelled
        len(os.listdir(dirs["unlabeled_img"]))  # plus still un-labelled
    )
    train_on_full_data = False

    # TAAL methods use SSL; others are supervised
    SSL_METHODS = {"taal", "taal-unweighted"}
    
    alpha_map = {"taal": 0.75, "taal-unweighted": 0.5}

    # ─────────────────── big loop ────────────────────────────
    while True:
        if loop_iterations is not None and iteration >= loop_iterations:
            break

        n_unl = len(os.listdir(dirs["unlabeled_img"]))
        if n_unl == 0:
            if train_on_full_data:
                print("Finished Training on the whole dataset")
                break
            else:
                print("Un-labelled pool exhausted")
                train_on_full_data = True

        use_ssl = (acq in SSL_METHODS) and (n_unl > 0)
        iteration += 1
        print(f"\n── Active Learning Iteration: {iteration} | Unlabelled pool size: {n_unl}")

        # loaders
        L, U, T = get_loaders_active(
            dirs["labeled_img"], dirs["labeled_mask"],
            dirs["unlabeled_img"],
            dirs["test_img"], dirs["test_mask"],
            batch_size=batch_size,
            seed=seed,
            augment=augment,
            generator=g,
            num_workers=4, pin_memory=True
        )

        # ───── fine-tune with early stopping (no epoch cap) ───
        best_val, wait, epoch = -float("inf"), 0, 0

        # class weights for CE
        num_classes = 5
        class_counts = torch.zeros(num_classes, dtype=torch.float32, device=device)
        total_pixels = 0
        for imgs, masks, _ in L:
            flat = masks.view(-1).to(device)
            class_counts += torch.bincount(flat, minlength=num_classes).float()
            total_pixels += flat.numel()
        class_freqs = class_counts / total_pixels
        median_freq = torch.median(class_freqs)
        weights = (median_freq / class_freqs)
        weights = weights / weights.mean()

        # supervised loss (CE + Tversky)
        tversky_fn = TverskyLossNoBG(0.3, 0.7, bg_idx=4).to(device)
        ce_loss = nn.CrossEntropyLoss(weight=weights)
        def combined_loss(logits, targets):
            return ce_loss(logits, targets) + tversky_fn(logits, targets)

        # SSL config for TAAL, otherwise None
        ssl_cfg = None
        if use_ssl:
            ssl_cfg = SSLConfig(loss=jsd_consistency_batch,
                                K=3,
                                alpha=alpha_map[acq],
                                use=True)

        while True:
            epoch += 1

            # ramp λ only for SSL methods
            if use_ssl:
                lambda_t = ssl_lambda_max * sigmoid_rampup(epoch - 1, ssl_ramp_epochs)
            else:
                lambda_t = 0.0

            # unified train-one-epoch call (SSL on/off based on args)
            train_one_epoch(
                labeled_loader=L,
                model=model,
                optimizer=optimizer,
                loss_sup_fn=combined_loss,
                device=device,
                num_classes=5,
                unlabeled_loader=U if use_ssl and lambda_t > 0 else None,
                ssl=ssl_cfg,
                ssl_weight=lambda_t,
            )

            model.eval()
            with torch.no_grad():
                _, val_dice = evaluate_loader(T, model, device=device, num_classes=5)
            model.train()
            print(f"    Epoch {epoch:03d} | val Dice {val_dice:.4f}")

            if val_dice > best_val + min_delta:
                best_val, wait = val_dice, 0
                torch.save(model.state_dict(), os.path.join(ckpt_dir, "best_tmp.pt"))
            else:
                wait += 1
                if wait >= patience:
                    print(f"    Early-stop after {epoch} epochs")
                    break

        model.load_state_dict(torch.load(os.path.join(ckpt_dir, "best_tmp.pt"), map_location=device))

        # evaluate & log (on your T loader)
        _, test_dice = evaluate_loader(T, model, device=device, num_classes=5)
        curr_labeled = len(os.listdir(dirs["labeled_img"]))
        frac = curr_labeled / total_train
        log.append({"round": iteration, "fraction": frac, "dice_score": test_dice})
        print(f"[Active Learning iteration: {iteration}]")
        print(f"   Validation Dice = {test_dice:.4f}")

        # acquisition
        if not train_on_full_data:
            if use_ssl:  # TAAL: TTA-based scorer
                score_dict = scorer(model, U, device=device)
            else:       # supervised baselines (your existing scorer util)
                score_dict = score_unlabeled_pool(
                    U, model, scorer, num_classes=5, device=device
                )
            move_images_with_dict(
                BASE_DIR, "Labeled_pool", "Unlabeled_pool",
                score_dict, num_to_move=min(sample_size, n_unl)
            )

    return pd.DataFrame(log)

In [None]:
%%bash
 # ← edit this to match your dataset’s folder under /kaggle/input
 DATASET_PATH="/kaggle/input/dataset"   # e.g. /kaggle/input/embryo-images-and-masks

 # where you want to store a local copy
 WORKING_DATA="/kaggle/working/data"

 # create the directory
 mkdir -p "$WORKING_DATA"

 # copy images & masks
 cp -r "$DATASET_PATH/images" "$WORKING_DATA/"
 cp -r "$DATASET_PATH/masks"  "$WORKING_DATA/"

In [None]:
acquisition_funcs = ["taal"]

for acq in acquisition_funcs:
    print(f"Running acquisition={acq}")
    df = active_learning_loop(
        BASE_DIR="/kaggle/working/data",
        LABEL_SPLIT_RATIO=0.1,        
        TEST_SPLIT_RATIO=0.2,
        augment=True,
        sample_size=2,                 
        acquisition_type=acq,      
        mc_runs=8,
        dropout=0.3,
        batch_size=16,                  
        lr=1e-3,
        seed=1,
        loop_iterations=None,             
        device = "cuda" if torch.cuda.is_available() else "cpu",
        # early-stopping inside each fine-tune
        patience = 15,
        min_delta  = 1e-4,
        # SSL schedule for TAAL
        ssl_lambda_max = 1.0,
        ssl_ramp_epochs = 10,
    )
    df["method"] = acq
    output_path = f"/kaggle/working/{acq}_log.csv"
    df.to_csv(output_path, index=False)
    df.to_csv(output_path, index=False)