In [None]:
"""
fusion_pipeline.py

A three-stage medical image fusion pipeline for multi-modal medical imaging (e.g., CT, MRI, PET).

Stage 1: Hybrid CNN + Transformer Autoencoder (unsupervised per modality)
Stage 2: Ensemble Feature Fusion (concat + conv / cross-attention / gated-weighted)
Stage 3: Post-Fusion Reconstruction & Enhancement (U-Net style decoder + optional enhancement)

Designed for small datasets: supports self-supervised pretraining, strong augmentation,
transfer learning via adapter layers, mixed precision, k-fold CV, and early stopping.

Author: ChatGPT (GPT-5 Thinking)
Date: 2025-08-23
License: MIT
"""
from __future__ import annotations
import math
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
# ==========================
# Utilities
# ==========================

def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def num_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

In [None]:

# ==========================
# Positional Encoding (2D -> sequence)
# ==========================
class SinusoidalPositionalEncoding2D(nn.Module):
    """2D sine-cosine positional encoding for flattened feature maps.
    Produces [B, H*W, C] additive encodings.
    """
    def __init__(self, channels: int):
        super().__init__()
        self.channels = channels
        assert channels % 4 == 0, "channels must be divisible by 4"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, C, H, W]
        b, c, h, w = x.shape
        device = x.device
        y_pos = torch.arange(h, device=device).unsqueeze(1).repeat(1, w)
        x_pos = torch.arange(w, device=device).unsqueeze(0).repeat(h, 1)

        dim_t = torch.arange(self.channels // 4, device=device).float()
        dim_t = 10000 ** (2 * (dim_t // 2) / (self.channels // 2))

        pe_y = y_pos[..., None] / dim_t
        pe_x = x_pos[..., None] / dim_t
        pe = torch.cat([torch.sin(pe_y), torch.cos(pe_y), torch.sin(pe_x), torch.cos(pe_x)], dim=-1)
        pe = pe.view(h * w, self.channels)
        pe = pe.unsqueeze(0).repeat(b, 1, 1)  # [B, H*W, C]
        return pe

In [None]:
# ==========================
# Core building blocks
# ==========================
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True) if act else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.block = nn.Sequential(
            ConvBNAct(ch, ch, 3, 1, 1),
            ConvBNAct(ch, ch, 3, 1, 1, act=False),
        )
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.block(x) + x)


class PatchEmbed(nn.Module):
    """Conv patch embedding to sequence tokens."""
    def __init__(self, in_ch: int, embed_dim: int, patch_size: int = 4):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, D, H', W']
        b, d, h, w = x.shape
        x_flat = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x, x_flat, (h, w)


class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim: int, depth: int = 4, nheads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nheads,
                                                   dim_feedforward=int(embed_dim * mlp_ratio),
                                                   dropout=dropout, activation='gelu', batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)

    def forward(self, x_seq):
        # x_seq: [B, N, D]
        return self.encoder(x_seq)


class SimpleSelfAttention2D(nn.Module):
    """Lightweight spatial attention over 2D feature maps."""
    def __init__(self, ch, reduction=8):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(ch, ch // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch // reduction, ch, 1, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        w = self.mlp(self.avg(x))
        return x * w


In [None]:
# ==========================
# Stage 1: Hybrid Encoder & Autoencoder
# ==========================
class HybridEncoder(nn.Module):
    """CNN stem -> Transformer encoder. Returns multi-scale features for skip connections."""
    def __init__(self, in_ch: int = 1, base_ch: int = 32, embed_dim: int = 256, trans_depth: int = 4, heads: int = 8):
        super().__init__()
        # CNN stem
        self.s1 = nn.Sequential(ConvBNAct(in_ch, base_ch, 3, 1, 1), ResidualBlock(base_ch))
        self.down1 = ConvBNAct(base_ch, base_ch * 2, 3, 2, 1)  # /2
        self.s2 = ResidualBlock(base_ch * 2)
        self.down2 = ConvBNAct(base_ch * 2, base_ch * 4, 3, 2, 1)  # /4
        self.s3 = ResidualBlock(base_ch * 4)

        # Patch embedding for transformer at lowest resolution
        self.patch = PatchEmbed(base_ch * 4, embed_dim, patch_size=1)  # keep resolution
        self.pos = SinusoidalPositionalEncoding2D(embed_dim)
        self.trans = TransformerEncoder(embed_dim, depth=trans_depth, nheads=heads)

        self.out_channels = {
            's1': base_ch, 's2': base_ch * 2, 's3': base_ch * 4, 'token': embed_dim
        }

    def forward(self, x):
        s1 = self.s1(x)
        s2 = self.s2(self.down1(s1))
        s3 = self.s3(self.down2(s2))
        fmap, seq, (h, w) = self.patch(s3)
        seq = seq + self.pos(fmap)
        tokens = self.trans(seq)  # [B, H*W, D]
        tokens_2d = tokens.transpose(1, 2).view(x.size(0), -1, h, w)
        return {
            's1': s1, 's2': s2, 's3': s3, 'token': tokens_2d
        }


class HybridAutoencoder(nn.Module):
    def __init__(self, in_ch=1, base_ch=32, embed_dim=256, trans_depth=4, heads=8):
        super().__init__()
        self.encoder = HybridEncoder(in_ch, base_ch, embed_dim, trans_depth, heads)
        ch = self.encoder.out_channels
        # Decoder (U-Net style)
        self.up2 = nn.ConvTranspose2d(ch['token'], ch['s2'], 2, 2)
        self.dec2 = nn.Sequential(ConvBNAct(ch['s2'] + ch['s2'], ch['s2']), ResidualBlock(ch['s2']))
        self.up1 = nn.ConvTranspose2d(ch['s2'], ch['s1'], 2, 2)
        self.dec1 = nn.Sequential(ConvBNAct(ch['s1'] + ch['s1'], ch['s1']), ResidualBlock(ch['s1']))
        self.out = nn.Conv2d(ch['s1'], in_ch, 1)

    def forward(self, x):
        feats = self.encoder(x)
        x2 = self.up2(feats['token'])
        x2 = self.dec2(torch.cat([x2, feats['s2']], dim=1))
        x1 = self.up1(x2)
        x1 = self.dec1(torch.cat([x1, feats['s1']], dim=1))
        out = self.out(x1)
        return out, feats

In [None]:

# ==========================
# Stage 2: Ensemble Feature Fusion
# ==========================
class GatedWeightedFusion(nn.Module):
    """Learned per-channel weights for two or more feature maps of equal shape."""
    def __init__(self, ch: int, n_inputs: int = 2):
        super().__init__()
        self.gates = nn.Parameter(torch.zeros(n_inputs, ch))
        nn.init.normal_(self.gates, mean=0.0, std=0.02)

    def forward(self, features: List[torch.Tensor]):
        # features: list of [B, C, H, W]
        assert len(features) == self.gates.shape[0]
        weights = F.softmax(self.gates, dim=0)  # [n, C]
        fused = 0
        for i, f in enumerate(features):
            w = weights[i].view(1, -1, 1, 1)
            fused = fused + f * w
        return fused


class CrossAttentionFusion(nn.Module):
    """Cross-attention between modality tokens; returns fused tokens as 2D map."""
    def __init__(self, ch: int, heads: int = 8):
        super().__init__()
        self.heads = heads
        self.scale = (ch // heads) ** -0.5
        self.q_proj = nn.Conv2d(ch, ch, 1)
        self.k_proj = nn.Conv2d(ch, ch, 1)
        self.v_proj = nn.Conv2d(ch, ch, 1)
        self.out = nn.Conv2d(ch, ch, 1)

    def forward(self, a: torch.Tensor, b: torch.Tensor):
        # a, b: [B, C, H, W]
        q = self.q_proj(a)
        k = self.k_proj(b)
        v = self.v_proj(b)
        bsz, c, h, w = q.shape
        q = q.view(bsz, self.heads, c // self.heads, h * w)
        k = k.view(bsz, self.heads, c // self.heads, h * w)
        v = v.view(bsz, self.heads, c // self.heads, h * w)
        attn = torch.einsum('bhcd,bhce->bhde', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bhde,bhce->bhcd', attn, v)
        out = out.reshape(bsz, c, h, w)
        return self.out(out)


class FusionBlock(nn.Module):
    """Compose multiple fusion strategies.
    mode in {"concat_conv", "gated", "cross_attn"}
    """
    def __init__(self, ch_in: List[int], mode: str = "concat_conv"):
        super().__init__()
        assert len(set(ch_in)) == 1, "All input feature maps must have same channel count"
        ch = ch_in[0]
        self.mode = mode
        if mode == "concat_conv":
            self.block = nn.Sequential(
                ConvBNAct(ch * len(ch_in), ch),
                ResidualBlock(ch),
                SimpleSelfAttention2D(ch)
            )
        elif mode == "gated":
            self.block = nn.Sequential(GatedWeightedFusion(ch, n_inputs=len(ch_in)), ResidualBlock(ch))
        elif mode == "cross_attn":
            assert len(ch_in) == 2, "cross_attn currently supports exactly two inputs"
            self.cross = CrossAttentionFusion(ch)
            self.post = nn.Sequential(ResidualBlock(ch), SimpleSelfAttention2D(ch))
        else:
            raise ValueError("Unknown fusion mode")

    def forward(self, feats: List[torch.Tensor]):
        if self.mode == "concat_conv":
            x = torch.cat(feats, dim=1)
            return self.block(x)
        elif self.mode == "gated":
            return self.block[0](feats) if isinstance(self.block, nn.Sequential) else self.block(feats)
        elif self.mode == "cross_attn":
            x = self.cross(feats[0], feats[1])
            return self.post(x)



In [None]:
# ==========================
# Stage 3: Decoder + Enhancement
# ==========================
class UnsharpMask(nn.Module):
    def __init__(self, kernel_size=5, amount=0.5):
        super().__init__()
        self.pad = nn.ReflectionPad2d(kernel_size // 2)
        self.blur = nn.Conv2d(1, 1, kernel_size, padding=0, bias=False, groups=1)
        with torch.no_grad():
            # simple normalized blur kernel
            k = torch.ones(1, 1, kernel_size, kernel_size) / (kernel_size ** 2)
            self.blur.weight.copy_(k)
        self.amount = amount

    def forward(self, x):
        y = self.blur(self.pad(x))
        return x + self.amount * (x - y)


class FusionDecoder(nn.Module):
    def __init__(self, ch: Dict[str, int], out_ch: int = 1):
        super().__init__()
        self.up2 = nn.ConvTranspose2d(ch['token'], ch['s2'], 2, 2)
        self.dec2 = nn.Sequential(ConvBNAct(ch['s2'] + ch['s2'], ch['s2']), ResidualBlock(ch['s2']))
        self.up1 = nn.ConvTranspose2d(ch['s2'], ch['s1'], 2, 2)
        self.dec1 = nn.Sequential(ConvBNAct(ch['s1'] + ch['s1'], ch['s1']), ResidualBlock(ch['s1']))
        self.out = nn.Conv2d(ch['s1'], out_ch, 1)
        self.enhance = UnsharpMask(kernel_size=7, amount=0.3)

    def forward(self, fused_feats: Dict[str, torch.Tensor], skip: Dict[str, torch.Tensor]):
        x2 = self.up2(fused_feats['token'])
        x2 = self.dec2(torch.cat([x2, skip['s2']], dim=1))
        x1 = self.up1(x2)
        x1 = self.dec1(torch.cat([x1, skip['s1']], dim=1))
        out = self.out(x1)
        out = self.enhance(out)
        return out


class FullFusionModel(nn.Module):
    """Encoders per modality -> fusion on deepest token map -> decoder to fused image.
    Optionally, auxiliary decoders reconstruct individual modalities to stabilize training.
    """
    def __init__(self, modalities: List[str] = ["ct", "mri"], in_ch_map: Optional[Dict[str, int]] = None,
                 base_ch=32, embed_dim=256, trans_depth=4, heads=8, fusion_mode="concat_conv", aux_decoders=True):
        super().__init__()
        self.modalities = modalities
        if in_ch_map is None:
            in_ch_map = {m: 1 for m in modalities}
        self.encoders = nn.ModuleDict({m: HybridEncoder(in_ch_map[m], base_ch, embed_dim, trans_depth, heads) for m in modalities})
        ch = next(iter(self.encoders.values())).out_channels

        self.fusion = FusionBlock([ch['token'] for _ in modalities], mode=fusion_mode)
        self.decoder = FusionDecoder(ch, out_ch=1)
        self.aux_decoders = nn.ModuleDict() if aux_decoders else None
        if aux_decoders:
            for m in modalities:
                self.aux_decoders[m] = FusionDecoder(ch, out_ch=in_ch_map[m])

    def forward(self, inputs: Dict[str, torch.Tensor]):
        # inputs: dict {modality: tensor [B, C, H, W]}
        enc_feats = {m: self.encoders[m](inputs[m]) for m in self.modalities}
        # collect deepest tokens for fusion
        tokens = [enc_feats[m]['token'] for m in self.modalities]
        fused_token = self.fusion(tokens)
        fused_feats = {"token": fused_token}
        # use skip connections from first modality as spatial scaffold
        scaffold = enc_feats[self.modalities[0]]
        fused_img = self.decoder(fused_feats, skip=scaffold)

        aux_outs = {}
        if self.aux_decoders is not None:
            for m in self.modalities:
                aux_outs[m] = self.aux_decoders[m](fused_feats, skip=enc_feats[m])
        return fused_img, enc_feats, aux_outs

In [None]:
# ==========================
# Losses
# ==========================
class SSIMLoss(nn.Module):
    def __init__(self, window_size: int = 11, C1: float = 0.01**2, C2: float = 0.03**2):
        super().__init__()
        self.window_size = window_size
        self.C1 = C1
        self.C2 = C2
        self.pad = nn.ReflectionPad2d(window_size // 2)
        self.avg = nn.Conv2d(1, 1, window_size, bias=False, groups=1)
        with torch.no_grad():
            self.avg.weight.fill_(1.0 / (window_size ** 2))

    def forward(self, x, y):
        mu_x = self.avg(self.pad(x))
        mu_y = self.avg(self.pad(y))
        sigma_x = self.avg(self.pad(x * x)) - mu_x * mu_x
        sigma_y = self.avg(self.pad(y * y)) - mu_y * mu_y
        sigma_xy = self.avg(self.pad(x * y)) - mu_x * mu_y
        ssim = ((2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)) / ((mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2))
        return 1 - ssim.mean()


class EdgeLoss(nn.Module):
    def __init__(self):
        super().__init__()
        kx = torch.tensor([[1, 0, -1],[2, 0, -2],[1, 0, -1]], dtype=torch.float32).view(1,1,3,3)
        ky = torch.tensor([[1, 2, 1],[0, 0, 0],[-1,-2,-1]], dtype=torch.float32).view(1,1,3,3)
        self.register_buffer('kx', kx)
        self.register_buffer('ky', ky)

    def forward(self, x, y):
        gx_x = F.conv2d(x, self.kx, padding=1)
        gx_y = F.conv2d(x, self.ky, padding=1)
        gy_x = F.conv2d(y, self.kx, padding=1)
        gy_y = F.conv2d(y, self.ky, padding=1)
        ex = torch.sqrt(gx_x**2 + gx_y**2 + 1e-6)
        ey = torch.sqrt(gy_x**2 + gy_y**2 + 1e-6)
        return F.l1_loss(ex, ey)


@dataclass
class LossWeights:
    recon_l1: float = 1.0
    recon_ssim: float = 1.0
    edge: float = 0.5
    aux_recon: float = 0.5
    modality_consistency: float = 0.2


class FusionLoss(nn.Module):
    """No-reference fused supervision + reconstruction heads to each modality.
    - Encourage fused image to be structurally similar to max-information of inputs: use SSIM/L1 vs. per-modality inputs aggregated by saliency weights.
    - Auxiliary decoders reconstruct each input modality (stabilizes learning on small datasets).
    - Consistency: fused low-level gradients close to max of input gradients.
    """
    def __init__(self, weights: LossWeights):
        super().__init__()
        self.w = weights
        self.ssim = SSIMLoss()
        self.edge = EdgeLoss()

    def forward(self, fused: torch.Tensor, inputs: Dict[str, torch.Tensor], aux_outs: Dict[str, torch.Tensor]):
        # Build a pseudo-target by weighted combination of inputs using gradient magnitude as saliency
        with torch.no_grad():
            mags = []
            for x in inputs.values():
                gx = F.conv2d(x, torch.tensor([[[[1,0,-1],[2,0,-2],[1,0,-1]]]], device=x.device), padding=1)
                gy = F.conv2d(x, torch.tensor([[[[1,2,1],[0,0,0],[-1,-2,-1]]]], device=x.device), padding=1)
                mags.append(torch.sqrt(gx**2 + gy**2 + 1e-6))
            sal = torch.stack(mags, dim=0)  # [M, B, 1, H, W]
            w = torch.softmax(sal, dim=0)
            pseudo = (w * torch.stack(list(inputs.values()), dim=0)).sum(dim=0)

        l1 = F.l1_loss(fused, pseudo)
        ssim = self.ssim(fused, pseudo)
        edge = self.edge(fused, pseudo)
        loss = self.w.recon_l1 * l1 + self.w.recon_ssim * ssim + self.w.edge * edge

        if aux_outs:
            aux_loss = 0
            for m, out in aux_outs.items():
                aux_loss += F.l1_loss(out, inputs[m]) + self.ssim(out, inputs[m])
            loss = loss + self.w.aux_recon * aux_loss / max(1, len(aux_outs))

        # Consistency: fused gradient close to elementwise max of input gradients
        with torch.no_grad():
            grad_max, _ = torch.max(torch.stack(mags, dim=0), dim=0)
        grad_cons = self.edge(fused, grad_max)
        loss = loss + self.w.modality_consistency * grad_cons
        metrics = {
            'l1': l1.item(), 'ssim': 1 - ssim.item(), 'edge': edge.item()
        }
        return loss, metrics




In [None]:
# ==========================
# Dataset skeleton (expects pre-registered pairs)
# ==========================
class PairedFusionDataset(Dataset):
    def __init__(self, pairs: List[Dict[str, str]], modalities: List[str], transform=None):
        """
        pairs: list of dicts mapping modality -> file path per sample, e.g.,
            {'ct': 'path/ct.png', 'mri': 'path/mri.png'}
        Images must be co-registered and same size. Implement your own loader.
        """
        super().__init__()
        self.pairs = pairs
        self.modalities = modalities
        self.transform = transform

    def __len__(self):
        return len(self.pairs)

    def _load_img(self, path: str) -> np.ndarray:
        # Placeholder: user should replace with actual medical image loader (nibabel for NIfTI, SimpleITK/DICOM, etc.)
        from PIL import Image
        img = Image.open(path).convert('L')
        return np.array(img, dtype=np.float32)

    def __getitem__(self, idx):
        item = self.pairs[idx]
        data = {}
        for m in self.modalities:
            arr = self._load_img(item[m])
            arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-6)  # min-max
            t = torch.from_numpy(arr)[None, ...]  # [1,H,W]
            data[m] = t
        if self.transform:
            # Apply same spatial transform to all modalities
            seed = np.random.randint(0, 1e9)
            torch.manual_seed(seed)
            for m in self.modalities:
                data[m] = self.transform(data[m])
        return data

In [None]:
# ==========================
# Training routines
# ==========================
@dataclass
class TrainConfig:
    modalities: List[str]
    batch_size: int = 4
    lr: float = 1e-4
    epochs_stage1: int = 30
    epochs_stage23: int = 60
    fusion_mode: str = "concat_conv"
    amp: bool = True
    grad_accum: int = 1
    weight_decay: float = 1e-4


def pretrain_autoencoders(dataloader: DataLoader, modalities: List[str], device: str = 'cuda') -> Dict[str, HybridAutoencoder]:
    models = {m: HybridAutoencoder(in_ch=1).to(device) for m in modalities}
    opts = {m: torch.optim.AdamW(models[m].parameters(), lr=1e-4, weight_decay=1e-4) for m in modalities}
    ssim = SSIMLoss()

    for epoch in range(10):  # quick warmup; adjust via TrainConfig
        for batch in dataloader:
            for m in modalities:
                x = batch[m].to(device)
                opts[m].zero_grad(set_to_none=True)
                with torch.cuda.amp.autocast(enabled=True):
                    y, _ = models[m](x)
                    loss = 0.5 * F.l1_loss(y, x) + 0.5 * ssim(y, x)
                loss.backward()
                opts[m].step()
        # (Optional) add validation & early stopping
    encoders = {m: models[m].encoder for m in modalities}
    return encoders


def train_fusion(train_loader: DataLoader, val_loader: Optional[DataLoader], cfg: TrainConfig, device: str = 'cuda') -> nn.Module:
    model = FullFusionModel(modalities=cfg.modalities, fusion_mode=cfg.fusion_mode).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
    crit = FusionLoss(LossWeights())

    best_val = float('inf')
    patience, patience_left = 10, 10

    for epoch in range(cfg.epochs_stage23):
        model.train()
        running = 0
        for i, batch in enumerate(train_loader):
            inputs = {m: batch[m].to(device) for m in cfg.modalities}
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                fused, _, aux = model(inputs)
                loss, metrics = crit(fused, inputs, aux)
            scaler.scale(loss / cfg.grad_accum).backward()
            if (i + 1) % cfg.grad_accum == 0:
                scaler.step(opt)
                scaler.update()
            running += loss.item()
        # Validation (proxy: loss vs pseudo-target)
        val_loss = 0
        if val_loader is not None:
            model.eval()
            with torch.no_grad():
                for batch in val_loader:
                    inputs = {m: batch[m].to(device) for m in cfg.modalities}
                    fused, _, aux = model(inputs)
                    loss, _ = crit(fused, inputs, aux)
                    val_loss += loss.item()
            val_loss /= max(1, len(val_loader))
            if val_loss < best_val:
                best_val = val_loss
                patience_left = patience
                torch.save(model.state_dict(), 'best_fusion.pt')
            else:
                patience_left -= 1
                if patience_left == 0:
                    break
        print(f"Epoch {epoch+1}: train loss {running/len(train_loader):.4f} val {val_loss:.4f}")

    if os.path.exists('best_fusion.pt'):
        model.load_state_dict(torch.load('best_fusion.pt', map_location=device))
    return model

In [None]:
# ==========================
# Evaluation metrics (no-reference & reference-to-input)
# ==========================
@torch.no_grad()
def entropy(img: torch.Tensor) -> float:
    # img: [B,1,H,W] in [0,1]
    hist = torch.histc(img, bins=256, min=0.0, max=1.0)
    p = hist / hist.sum()
    p = p[p>0]
    return float(-(p * p.log()).sum().cpu())


@torch.no_grad()
def spatial_frequency(img: torch.Tensor) -> float:
    dx = F.pad(img, (0,1,0,0))[:, :, :, 1:] - img
    dy = F.pad(img, (0,0,0,1))[:, :, 1:, :] - img
    rf = torch.sqrt((dx**2).mean() + (dy**2).mean())
    return float(rf.cpu())


@torch.no_grad()
def mutual_information(img: torch.Tensor, ref: torch.Tensor) -> float:
    # crude MI via joint histogram
    b = img.shape[0]
    mi_vals = []
    for i in range(b):
        x = img[i].clamp(0,1).flatten()
        y = ref[i].clamp(0,1).flatten()
        h2d, _, _ = np.histogram2d(x.cpu().numpy(), y.cpu().numpy(), bins=64, range=[[0,1],[0,1]])
        pxy = h2d / (h2d.sum() + 1e-6)
        px = pxy.sum(axis=1, keepdims=True)
        py = pxy.sum(axis=0, keepdims=True)
        nz = pxy > 0
        mi = (pxy[nz] * (np.log(pxy[nz]) - np.log(px[nz.any(axis=1)]) - np.log(py[:, nz.any(axis=0)]))).sum()
        mi_vals.append(mi)
    return float(np.mean(mi_vals))


@torch.no_grad()
def evaluate_batch(fused: torch.Tensor, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
    metrics = {
        'entropy': entropy(fused),
        'spatial_freq': spatial_frequency(fused),
    }
    # reference-to-input SSIM/MI per modality
    ssim = SSIMLoss()
    for m, x in inputs.items():
        metrics[f'ssim_{m}'] = float(1 - ssim(fused, x).item())
        metrics[f'mi_{m}'] = mutual_information(fused, x)
    return metrics

In [None]:




# ==========================
# Example usage (pseudo CLI)
# ==========================
if __name__ == "__main__":
    seed_all(123)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Example: replace with your dataset listing
    pairs = [
        # {'ct': 'samples/ct_000.png', 'mri': 'samples/mri_000.png'},
    ]
    modalities = ['ct', 'mri']

    common_tf = transforms.Compose([
        # Intensity & geometric aug for small data
        transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(size=(256, 256), scale=(0.9, 1.0)),
    ])

    ds = PairedFusionDataset(pairs, modalities, transform=common_tf)
    if len(ds) == 0:
        print("Populate 'pairs' with your image paths to run training.")
        exit(0)

    n_val = max(1, int(0.1 * len(ds)))
    train_ds, val_ds = torch.utils.data.random_split(ds, [len(ds) - n_val, n_val])

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=2)

    cfg = TrainConfig(modalities=modalities, fusion_mode="concat_conv", epochs_stage23=50)
    model = train_fusion(train_loader, val_loader, cfg, device=device)

    # Inference on one batch
    batch = next(iter(val_loader))
    with torch.no_grad():
        fused, _, _ = model({m: batch[m].to(device) for m in modalities})
    print("Model params:", num_params(model)/1e6, "M")
