In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import sklearn


class PatchEmbedding(nn.Module):
    """
    1) Split the input (B, 1, H, W) into non-overlapping patches of size (patch_size × patch_size).
    2) Flatten each patch to a vector of length (patch_size^2).
    3) Project that to an embedding of dimension embed_dim via a learnable linear layer.

    Here we implement 2) and 3) as a single Conv2d with stride=patch_size, kernel_size=patch_size.
    """
    def __init__(self, in_ch=1, embed_dim=256, patch_size=16, img_size=128):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size  # number of patches along H (and W)
        self.num_patches = self.grid_size * self.grid_size

        # Conv2d: (in_ch)→(embed_dim), kernel=patch_size, stride=patch_size
        self.proj = nn.Conv2d(
            in_channels=in_ch,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        # After this conv, an input of shape (B, in_ch, H, W) becomes (B, embed_dim, H/ps, W/ps).
        # We’ll flatten that spatial grid to (num_patches, embed_dim) later.

    def forward(self, x):
        """
        x: (B, 1, H, W)
        returns: (B, num_patches, embed_dim)
        """
        B = x.shape[0]
        x = self.proj(x)              # (B, embed_dim, grid_size, grid_size)
        x = x.flatten(2)              # (B, embed_dim, grid_size*grid_size)
        x = x.transpose(1, 2)         # (B, num_patches, embed_dim)
        return x


class TransformerEncoderBlock(nn.Module):
    """One standard Transformer encoder block with multi-head self-attention and MLP."""
    def __init__(self, embed_dim=256, num_heads=8, mlp_ratio=4.0, qkv_bias=True, drop=0.1, attn_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=qkv_bias,
            dropout=attn_drop,
            batch_first=True
        )
        self.drop1 = nn.Dropout(drop)

        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden, embed_dim),
            nn.Dropout(drop),
        )

    def forward(self, x):
        """
        x: (B, N, embed_dim), where N = num_patches
        returns: (B, N, embed_dim)
        """
        # Self-attention expects (B, N, E) with batch_first=True
        x_norm = self.norm1(x)
        # MultiheadAttention returns (B, N, embed_dim)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + self.drop1(attn_out)

        x_norm2 = self.norm2(x)
        mlp_out = self.mlp(x_norm2)
        x = x + mlp_out
        return x


class TransformerEncoder(nn.Module):
    """A stack of L TransformerEncoderBlock layers."""
    def __init__(self,
                 num_layers=6,
                 embed_dim=256,
                 num_heads=8,
                 mlp_ratio=4.0,
                 qkv_bias=True,
                 drop_rate=0.1,
                 attn_drop_rate=0.0):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        x: (B, num_patches, embed_dim)
        returns: (B, num_patches, embed_dim)
        """
        for blk in self.layers:
            x = blk(x)
        x = self.norm(x)
        return x


class SimpleSegTransformer(nn.Module):
    """
    A “from-scratch” transformer segmentation network:
      1) PatchEmbedding → tokens of shape (B, N, E)
      2) Add learnable Positional Embeddings (N, E)
      3) TransformerEncoder backbone (L layers)
      4) A simple decoder: reshape tokens back to a 2D grid, then apply a few
         2D transposed convolutions to upsample to the desired mask size.
    No pretrained weights—everything is initialized randomly.
    """
    def __init__(self,
                 img_size=128,
                 patch_size=16,
                 in_ch=1,
                 embed_dim=256,
                 depth=6,
                 num_heads=8,
                 mlp_ratio=4.0,
                 drop_rate=0.1,
                 attn_drop_rate=0.0,
                 decoder_channels=[256, 128, 64],
                 out_ch=1):
        """
        Args:
            img_size: input image height = width (assume square).
            patch_size: size of each patch (e.g. 16 → 8×8 patches for 128×128).
            in_ch: number of input channels (1 for grayscale).
            embed_dim: token dimension inside the Transformer.
            depth: number of Transformer layers.
            num_heads: number of attention heads.
            mlp_ratio: expansion factor for the MLP inside each Transformer block.
            drop_rate, attn_drop_rate: dropout rates.
            decoder_channels: list of channel sizes in the decoder.
                              Should correspond to successive upsampling stages.
            out_ch: number of output channels (1 for binary mask).
        """
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size       # e.g. 128//16 = 8 → 8×8 = 64 patches
        self.num_patches = self.grid_size * self.grid_size  # 64
        self.embed_dim = embed_dim

        # 1) Patch embedding
        self.patch_embed = PatchEmbedding(
            in_ch=in_ch,
            embed_dim=embed_dim,
            patch_size=patch_size,
            img_size=img_size
        )

        # 2) Position embeddings: one per patch
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # 3) Transformer encoder
        self.encoder = TransformerEncoder(
            num_layers=depth,
            embed_dim=embed_dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=True,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate
        )

        # 4) Decoder: reshape tokens → (B, embed_dim, grid_size, grid_size),
        # then do a series of upsampling via ConvTranspose2d blocks.
        dec_blocks = []
        curr_ch = embed_dim
        for out_ch_dec in decoder_channels:
            # Each block upsamples spatially by a factor of 2:
            dec_blocks.append(nn.Sequential(
                nn.ConvTranspose2d(curr_ch, out_ch_dec, kernel_size=2, stride=2),
                nn.BatchNorm2d(out_ch_dec),
                nn.ReLU(inplace=True),
            ))
            curr_ch = out_ch_dec
        self.decoder_blocks = nn.ModuleList(dec_blocks)

        # Final 1×1 conv to map to mask channel:
        self.head = nn.Conv2d(curr_ch, out_ch, kernel_size=1)

        # Initialize decoder weights
        self._init_weights()

    def _init_weights(self):
        # Initialize PatchEmbedding and head conv
        nn.init.kaiming_normal_(self.patch_embed.proj.weight, mode="fan_out", nonlinearity="relu")
        if self.patch_embed.proj.bias is not None:
            nn.init.zeros_(self.patch_embed.proj.bias)
        nn.init.zeros_(self.head.bias)
        nn.init.kaiming_normal_(self.head.weight, mode="fan_out", nonlinearity="relu")
        # Decoder blocks are ConvTranspose2d + BatchNorm + ReLU; BatchNorm is initialized by default.

    def forward(self, x):
        """
        x: (B, 1, img_size, img_size), e.g. (B, 1, 128, 128)
        returns: (B, out_ch, img_size, img_size) - a mask in [0,1] after sigmoid
        """
        B = x.shape[0]

        # 1) Patch embedding → (B, num_patches, embed_dim)
        x_tokens = self.patch_embed(x)

        # 2) Add position embeddings
        x_tokens = x_tokens + self.pos_embed   # broadcast pos_embed across batch

        # 3) Transformer encoder
        x_enc = self.encoder(x_tokens)         # (B, num_patches, embed_dim)

        # 4) Reshape tokens → (B, embed_dim, grid_size, grid_size)
        H = W = self.grid_size
        x_grid = x_enc.transpose(1, 2).reshape(B, self.embed_dim, H, W)

        # 5) Decoder: successive upsampling
        for dec in self.decoder_blocks:
            x_grid = dec(x_grid)
            # e.g. from (B, E, H, W)→(B, dec_ch1, 2H, 2W)→(B, dec_ch2, 4H, 4W) etc.

        # 6) Now x_grid should have spatial dims = img_size (if decoder_channels were chosen to match).
        #    If not exactly img_size, we can interpolate:
        x_out = F.interpolate(x_grid, size=(self.patch_size * self.grid_size, self.patch_size * self.grid_size),
                              mode="bilinear", align_corners=False)
        # Final 1×1 conv → (B, out_ch, img_size, img_size)
        mask_logits = self.head(x_out)

        # 7) Sigmoid for binary mask
        mask = torch.sigmoid(mask_logits)
        return mask


In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        """
        preds:   Tensor (B,1,H,W) after Sigmoid
        targets: Tensor (B,1,H,W) binary {0,1}
        """
        p_flat = preds.view(-1)
        t_flat = targets.view(-1)
        intersection = (p_flat * t_flat).sum()
        dice_coeff = (2. * intersection + self.smooth) / (p_flat.sum() + t_flat.sum() + self.smooth)
        return 1 - dice_coeff

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.9, gamma=2.0, eps=1e-6):
        super().__init__()
        self.alpha, self.gamma, self.eps = alpha, gamma, eps
        self.beta = 1 - alpha  # Ensure alpha + beta = 1

    def forward(self, preds, targets):
        preds = preds.view(-1)
        targets = targets.view(-1)
        TP = (preds * targets).sum()
        FP = (preds * (1 - targets)).sum()
        FN = ((1 - preds) * targets).sum()
        tversky = (TP + self.eps) / (TP + self.alpha*FN + self.beta*FP + self.eps)
        return torch.pow((1 - tversky), self.gamma)

class ComboLossTF(nn.Module):
    def __init__(self, bce_weight=0.33, dice_weight=0.33, focal_twersky_weight=0.33):
        super().__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss(smooth=1e-6)
        self.FW = FocalTverskyLoss (alpha = 0.99, gamma=3.1)
        self.bw, self.dw, self.fw = bce_weight, dice_weight, focal_twersky_weight

    def forward(self, preds, targets):
        # preds, targets both (B,1,H,W)
        l_bce = self.bce(preds, targets)
        l_dice = self.dice(preds, targets)
        l_focal_tversky = self.FW(preds, targets)
        return self.bw * l_bce + self.dw * l_dice + self.fw * l_focal_tversky

In [None]:
def sigzi(x, axis=None):
    """
Compute the interquartile range (IQR) of x along the specified axis.
    Args:
        x: array-like, shape (P, H, W) or (H, W) or (N, C, H, W)
        axis: axis along which to compute the IQR.
              If None, computes over the flattened array.

    Returns: float, the IQR of x.

    """
    return 0.741 * (np.percentile(x, 75, axis=axis) - np.percentile(x, 25, axis=axis))

def split_stack(arr, nrows, ncols):
    """
    Split a stack of 2D panels into (nrows × ncols) tiles.
    arr: ndarray, shape (P, H, W)
    Returns: ndarray, shape (P * (H//nrows)*(W//ncols), nrows, ncols)
    """
    P, H, W = arr.shape
    pad_h = (-H) % nrows
    pad_w = (-W) % ncols
    if pad_h or pad_w:
        arr = np.pad(arr,
                     ((0, 0),
                      (0, pad_h),
                      (0, pad_w)),
                     mode='constant',
                     constant_values=0)
    H2, W2 = arr.shape[1], arr.shape[2]
    blocks = (arr
              .reshape(P,
                       H2 // nrows, nrows,
                       W2 // ncols, ncols)
              .swapaxes(2, 3))
    P2, Hb, Wb, nr, nc = blocks.shape
    out = blocks.reshape(P2 * Hb * Wb, nr, nc)
    return out

def build_datasets(npz_file, tile_size=128):
    """
    Load data from .npz, clip exactly as TF did, split into tiles, return PyTorch tensors.
      - Clips x to [-166.43, 169.96]
      - Splits each large image into (tile_size × tile_size) patches
      - Adds a channel dimension (→ shape (N, 1, tile_size, tile_size))
    """
    data = np.load(npz_file)
    x = data['x']  # shape (P, H, W)
    y = data['y']

    x = x/sigzi(x)  # normalize by interquartile range
    x = np.clip(x, -5, 5) # clip to [-5, 5]

    # Split into tiles (tile_size × tile_size)
    x_tiles = split_stack(x, tile_size, tile_size)  # (N_tiles, tile_size, tile_size)
    y_tiles = split_stack(y, tile_size, tile_size)

    # Convert to FloatTensor and add channel dimension
    x_tiles = torch.from_numpy(x_tiles).float().unsqueeze(1)  # (N, 1, tile_size, tile_size)
    y_tiles = torch.from_numpy(y_tiles).float().unsqueeze(1)  # (N, 1, tile_size, tile_size)

    return x_tiles, y_tiles

def reshape_masks(masks, new_size):
    """
    Resize binary masks (0/1) to `new_size`:
      - Uses bilinear interpolation (same as TF’s tf.image.resize with bilinear)
      - Applies torch.ceil(...) to recover {0,1} values exactly.
    Input:
      - masks: either a Tensor of shape (N, 1, H_orig, W_orig)
               or a numpy array of shape (N, H_orig, W_orig)
      - new_size: tuple (new_H, new_W)
    Returns:
      - Tensor of shape (N, 1, new_H, new_W), values in {0,1}
    """
    if isinstance(masks, np.ndarray):
        m = torch.from_numpy(masks).float().unsqueeze(1)  # → (N,1,H,W)
    else:
        m = masks  # assume already FloatTensor (N,1,H,W)
    m_resized = F.interpolate(m, size=new_size, mode='bilinear', align_corners=False)
    m_resized = torch.ceil(m_resized)
    return m_resized.clamp(0, 1)

def split_train_val(x_tiles, y_tiles, train_frac=0.8, seed=42):
    """
    Shuffle and split x_tiles, y_tiles into two TensorDatasets: train (80%) and val (20%).
    """
    n = x_tiles.shape[0]
    idx = torch.randperm(n, generator=torch.Generator().manual_seed(seed))
    split = int(train_frac * n)
    train_idx = idx[:split]
    val_idx   = idx[split:]
    x_tr, y_tr = x_tiles[train_idx], y_tiles[train_idx]
    x_val, y_val = x_tiles[val_idx], y_tiles[val_idx]
    return TensorDataset(x_tr, y_tr), TensorDataset(x_val, y_val)

import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler

def make_balanced_sampler(train_ds: TensorDataset):
    """
    Given train_ds.tensors = (x_tensor, y_tensor), where
      - x_tensor:  (N, C, H, W) float
      - y_tensor:  (N, 1, H, W) float or (N, H, W) float
    Returns a WeightedRandomSampler so that positives (mask sum>0)
    and negatives (mask sum==0) are drawn equally often.

    Usage:
        sampler = make_balanced_sampler(train_ds)
        loader  = DataLoader(train_ds, batch_size=32, sampler=sampler, num_workers=4)
    """
    xs, ys = train_ds.tensors

    # Ensure ys is shape (N,1,H,W)
    if ys.ndim == 3:
        ys = ys.unsqueeze(1)
    assert ys.ndim == 4 and ys.shape[1] == 1

    N = ys.shape[0]
    # 1) compute “class” for each example: 1 if ANY pixel>0, else 0
    with torch.no_grad():
        flat = ys.view(N, -1).sum(dim=1)   # shape (N,)
        is_pos = (flat > 0).long()         # 1 if positive, 0 if negative

    n_pos = int(is_pos.sum().item())
    n_neg = N - n_pos
    if n_pos == 0 or n_neg == 0:
        # no balancing possible
        print("Warning: no positives or no negatives in the training set; sampler will be unbalanced.")
        weights = torch.ones(N)
    else:
        # 2) assign weight to each index:
        #    weight = 1 / (count of examples in that class)
        #    so that drawing with replacement evens out pos vs neg
        weights = torch.empty(N, dtype=torch.double)
        weights[is_pos == 1] = 1.0 / n_pos
        weights[is_pos == 0] = 1.0 / n_neg

    # 3) create the WeightedRandomSampler:
    #    - we set num_samples = N (so each epoch is “N draws with replacement”)
    #    - replacement=True means we are allowed to pick the same index multiple times
    sampler = WeightedRandomSampler(
        weights=weights,
        num_samples=N,
        replacement=True,
    )
    return sampler


In [None]:
npz_file = "../DATA/train.npz"
x_tiles, y_tiles = build_datasets(npz_file, tile_size=256)
train_ds, val_ds = split_train_val(x_tiles, y_tiles, train_frac=0.8, seed=42)
del x_tiles, y_tiles  # free memory

In [None]:
import torch
from torch.utils.data import DataLoader
import sklearn.metrics

# 1) Instantiate the new model (no pretrained weights).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleSegTransformer(
    img_size=256,
    patch_size=64,
    in_ch=1,
    embed_dim=32,
    depth=3,            # 6 Transformer layers
    num_heads=4,
    mlp_ratio=4.0,
    drop_rate=0.5,
    attn_drop_rate=0.5,
    decoder_channels=[256, 128, 64],
    out_ch=1
).to(device)

# 2) Choose an optimizer + scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
# A common schedule for transformers from scratch is CosineAnnealingWarmRestarts:
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# 3) Loss function (your ComboLossTF from before):
criterion = ComboLossTF(bce_weight=0.25, dice_weight=0.25, focal_twersky_weight=0.5)

# 4) DataLoaders
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

# 5) Training loop
epochs = 100
for epoch in range(1, epochs+1):
    # ——— Training ———
    model.train()
    running_loss = 0.0
    tp = fp = fn = 0
    for batch_idx, (imgs, masks) in enumerate(train_loader):
        imgs = imgs.to(device)             # (B, 1, 128, 128)
        masks = masks.to(device)           # (B, 128, 128), dtype float or long

        # Forward pass → (B, 1, 128, 128)
        preds = model(imgs)

        # Make sure masks are shaped (B, 1, 128, 128):
        if masks.ndim == 3:
            masks = masks.unsqueeze(1)

        loss = criterion(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sched.step()

        running_loss += loss.item() * imgs.size(0)

        with torch.no_grad():
            pred_bin = (preds > 0.5).float()
            tp += (pred_bin * masks).sum().item()
            fp += (pred_bin * (1 - masks)).sum().item()
            fn += ((1 - pred_bin) * masks).sum().item()

        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1   = 2 * prec * rec / (prec + rec + 1e-8)
        print(f"\rEpoch {epoch:03d}  Batch {batch_idx+1:03d}/{len(train_loader):03d}  "
              f"Batch Loss: {loss.item():.4f}  | Train F1: {f1:.4f}  | Prec: {prec:.4f}  | Rec: {rec:.4f}", end="")

    train_loss = running_loss / len(train_loader.dataset)

    # ——— Validation ———
    model.eval()
    val_loss = 0.0
    tp = fp = fn = 0
    all_masks = []
    all_preds  = []
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            if masks.ndim == 3:
                masks = masks.unsqueeze(1)

            preds = model(imgs)
            loss = criterion(preds, masks)
            val_loss += loss.item() * imgs.size(0)

            pred_bin = (preds > 0.5).float()
            tp += (pred_bin * masks).sum().item()
            fp += (pred_bin * (1 - masks)).sum().item()
            fn += ((1 - pred_bin) * masks).sum().item()

            all_masks.append(masks.cpu())
            all_preds.append(preds.cpu())

    val_loss = val_loss / len(val_loader.dataset)
    prec = tp / (tp + fp + 1e-8)
    rec  = tp / (tp + fn + 1e-8)
    f1_val = 2 * prec * rec / (prec + rec + 1e-8)

    # Compute AUC on flattened probabilities
    #all_masks = torch.cat(all_masks, dim=0).numpy().ravel()
    #all_preds = torch.cat(all_preds, dim=0).numpy().ravel()
    #auc_val = sklearn.metrics.roc_auc_score(all_masks, all_preds)

    print(f"\rEpoch {epoch:03d}  "
          f"Train Loss: {train_loss:.4f}  | Val Loss: {val_loss:.4f}  "
          f"| Val F1: {f1_val:.4f}  | Val Prec: {prec:.4f}  | Val Rec: {rec:.4f} ") #| Val AUC: {auc_val:.4f}")
