In [2]:
import os
import time
import random
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import h5py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import OneCycleLR

from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchvision.models import vgg16, VGG16_Weights

In [3]:
# ────────────────────────────────────────────────────────────────────────────────
# CONFIGURATION
# ────────────────────────────────────────────────────────────────────────────────
class Config:
    DATA_PATH    = "/Users/imamahasan/MyData/Code/lightcyclegan_ld_v3"  # <-- your root
    SAVE_DIR     = "./checkpoints"
    BATCH_SIZE   = 4
    EPOCHS       = 100
    LR           = 2e-4
    WEIGHT_DECAY = 1e-4
    PATIENCE     = 10

    # Loss weights
    L_ADV   = 1.0
    L_PERP  = 1.0
    L_SSIM  = 5.0
    L_CYCLE = 10.0
    L_EDGE  = 0.2
    L_FFL   = 0.5

    # Dimensions (original LoDoPaB sizes)
    SINO_SHAPE = (1000, 513)
    IMG_SHAPE  = (362, 362)

    DEVICE      = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    USE_AMP     = False if DEVICE.type == "mps" else True
    MAX_LR      = LR
    NUM_WORKERS = 0

os.makedirs(Config.SAVE_DIR, exist_ok=True)

In [4]:
# ────────────────────────────────────────────────────────────────────────────────
# DATASET
# ────────────────────────────────────────────────────────────────────────────────
class LoDoPaBSinogramDataset(Dataset):
    def __init__(self, sino_dir, gt_dir):
        self.sino_files = sorted(
            os.path.join(sino_dir, f)
            for f in os.listdir(sino_dir) if f.endswith(".hdf5")
        )
        self.gt_files = sorted(
            os.path.join(gt_dir, f)
            for f in os.listdir(gt_dir) if f.endswith(".hdf5")
        )
        assert self.sino_files, f"No sinogram files in {sino_dir}"
        assert self.gt_files,   f"No GT files in {gt_dir}"
        assert len(self.sino_files) == len(self.gt_files), "File count mismatch"

        self.indices = []
        for idx, (sf, gf) in enumerate(zip(self.sino_files, self.gt_files)):
            with h5py.File(sf, 'r') as fs, h5py.File(gf, 'r') as fg:
                n = min(len(fs['data']), len(fg['data']))
            assert n > 0, f"No slices in {sf}"
            self.indices += [(idx, i) for i in range(n)]
        print(f"Loaded {len(self.indices)} slices from {len(self.sino_files)} files")

        self.tf_sino = T.Compose([
            T.Resize(Config.SINO_SHAPE),
            T.ToTensor()
        ])
        self.tf_img = T.Compose([
            T.Resize(Config.IMG_SHAPE),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        file_idx, slice_idx = self.indices[idx]
        sf, gf = self.sino_files[file_idx], self.gt_files[file_idx]
        with h5py.File(sf,'r') as fs:
            sino = fs['data'][slice_idx].astype(np.float32)
        with h5py.File(gf,'r') as fg:
            img  = fg['data'][slice_idx].astype(np.float32)

        # normalize
        sino = (sino - sino.min())/(sino.max()-sino.min()+1e-8)
        img  = (img  - img.min())/(img.max()-img.min()+1e-8)

        # ToTensor on a grayscale PIL gives shape [1,H,W]
        sino_t = self.tf_sino(Image.fromarray((sino*255).astype(np.uint8)))
        img_t  = self.tf_img (Image.fromarray((img *255).astype(np.uint8)))

        # **do not** unsqueeze here!
        return sino_t, img_t

In [5]:
# ────────────────────────────────────────────────────────────────────────────────
# MODEL COMPONENTS
# ────────────────────────────────────────────────────────────────────────────────
class GhostModule(nn.Module):
    def __init__(self, inp, oup, ratio=2, primary_kernel=1, dw_kernels=(3,5)):
        super().__init__()
        self.init_c  = math.ceil(oup/ratio)
        self.cheap_c = self.init_c*(ratio-1)
        self.primary = nn.Conv2d(inp, self.init_c, primary_kernel,
                                 padding=primary_kernel//2, bias=False)
        self.dw_convs = nn.ModuleList([
            nn.Conv2d(self.init_c, self.init_c, k,
                      padding=k//2, groups=self.init_c, bias=False)
            for k in dw_kernels
        ])
        self.bn = nn.BatchNorm2d(oup)
    def forward(self, x):
        x1 = self.primary(x)
        x2 = torch.cat([dw(x1) for dw in self.dw_convs], dim=1)[:, :self.cheap_c]
        out = torch.cat([x1, x2], dim=1)[:, :self.bn.num_features]
        return self.bn(out)

class CondConv(nn.Module):
    def __init__(self, inp, oup, exp=2):
        super().__init__()
        self.experts = nn.ModuleList(
            nn.Conv2d(inp, oup, 3, padding=1, bias=False) for _ in range(exp)
        )
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(inp, exp, 1),
            nn.Softmax(dim=1)
        )
    def forward(self, x):
        w = self.attn(x)
        return sum(w[:, i:i+1] * conv(x) for i, conv in enumerate(self.experts))

class ECABlock(nn.Module):
    def __init__(self, c, gamma=2, b=1):
        super().__init__()
        t = int(abs((math.log2(c) + b)/gamma))
        k = t if t%2 else t+1
        self.conv1d = nn.Conv1d(1,1,k,padding=k//2,bias=False)
    def forward(self, x):
        y = x.mean(dim=(2,3)).unsqueeze(1)
        y = self.conv1d(y)
        w = y.sigmoid().squeeze(1).unsqueeze(-1).unsqueeze(-1)
        return x * w

class MixStyle(nn.Module):
    def __init__(self, p=0.5, alpha=0.1):
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
    def forward(self, x):
        if not self.training or random.random() > self.p:
            return x
        B,C,H,W = x.size()
        mu = x.mean([2,3], keepdim=True)
        sig= x.var([2,3], keepdim=True).sqrt().add(1e-6)
        xn = (x - mu)/sig
        perm = torch.randperm(B)
        mu2,sig2 = mu[perm], sig[perm]
        l = self.beta.sample((B,1,1,1)).to(x.device)
        mu_m  = mu*l + mu2*(1-l)
        sig_m = sig*l + sig2*(1-l)
        return xn*sig_m + mu_m

class SinogramEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,32,3,padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32,64,3,padding=1,stride=2), nn.ReLU(inplace=True),
            nn.Conv2d(64,128,3,padding=1,stride=2), nn.ReLU(inplace=True),
            nn.Conv2d(128,64,3,padding=1),    nn.ReLU(inplace=True),
            nn.Conv2d(64,32,3,padding=1),     nn.ReLU(inplace=True),
            nn.Conv2d(32,1,3,padding=1),
            # Upsample back to target image size
            nn.Upsample(size=Config.IMG_SHAPE, mode='bilinear', align_corners=False)
        )
    def forward(self, x):
        return self.net(x)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.ms     = MixStyle()
        self.init   = GhostModule(1,16)
        # self.d1,a1  = CondConv(16,32,exp=2), ECABlock(32)
        # self.d2,a2  = CondConv(32,64,exp=3), ECABlock(64)
        self.d1, self.a1 = CondConv(16,32,exp=2), ECABlock(32)
        self.d2, self.a2 = CondConv(32,64,exp=3), ECABlock(64)
        self.pool   = nn.AvgPool2d(2)
        self.res    = nn.Sequential(*[GhostModule(64,64) for _ in range(2)])
        self.bridge = ECABlock(64)
        # <<-- FIXED: use keyword args for Upsample
        self.u1     = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            GhostModule(64,64)
        )
        self.u2     = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            GhostModule(64,32)
        )
        self.final  = nn.Conv2d(32,1,3,padding=1)
        self.output_resize = nn.Upsample(size=Config.IMG_SHAPE, mode='bilinear', align_corners=False)
        self.act    = nn.ReLU(inplace=True)
    def _crop_to_match(self, tensor1, tensor2):
        """Crop tensor1 to match the spatial dimensions of tensor2."""
        _, _, h1, w1 = tensor1.size()
        _, _, h2, w2 = tensor2.size()
        h = min(h1, h2)
        w = min(w1, w2)
        return tensor1[:, :, :h, :w]

    def forward(self, x):
        x   = self.ms(x)
        e0  = self.act(self.init(x))
        e1  = self.act(self.a1(self.d1(e0))); e1p = self.pool(e1)
        e2  = self.act(self.a2(self.d2(e1p))); e2p = self.pool(e2)
        r   = self.res(e2p); b = self.bridge(r)
        u1_out = self.u1(b)
        e2_cropped = self._crop_to_match(e2, u1_out)
        d1  = self.act(u1_out + e2_cropped)
        u2_out = self.u2(d1)
        e1_cropped = self._crop_to_match(e1, u2_out)
        d2  = self.act(u2_out + e1_cropped)
        out = torch.tanh(self.final(d2))
        return self.output_resize(out)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            GhostModule(1,8),  nn.AvgPool2d(2),
            GhostModule(8,16), nn.AvgPool2d(2),
            GhostModule(16,32), nn.Conv2d(32,1,3,padding=1)
        )
    def forward(self,x): return self.net(x)


In [6]:

# ────────────────────────────────────────────────────────────────────────────────
# LOSSES
# ────────────────────────────────────────────────────────────────────────────────
def edge_loss(pred, tgt):
    k = torch.tensor([[1,2,1],[0,0,0],[-1,-2,-1]],
                     device=pred.device,dtype=pred.dtype).view(1,1,3,3)
    return F.l1_loss(F.conv2d(pred,k,padding=1),
                     F.conv2d(tgt,k,padding=1))

def focal_frequency_loss(pred,tgt,chi=1.0):
    dev = pred.device
    if dev.type=="mps":  # CPU fallback
        p,t = pred.float().cpu(), tgt.float().cpu()
        Yp = torch.fft.rfft2(p,norm='ortho'); Yt = torch.fft.rfft2(t,norm='ortho')
        wf = torch.log(torch.abs(Yt)**2+1e-8)**chi
        return F.l1_loss(Yp*wf,Yt*wf).to(dev)
    Yp = torch.fft.rfft2(pred,norm='ortho'); Yt = torch.fft.rfft2(tgt,norm='ortho')
    wf = torch.log(torch.abs(Yt)**2+1e-8)**chi
    return F.l1_loss(Yp*wf,Yt*wf)

In [None]:
# ────────────────────────────────────────────────────────────────────────────────
# TRAINING
# ────────────────────────────────────────────────────────────────────────────────
def train_and_evaluate():
    train_ds = LoDoPaBSinogramDataset(
        os.path.join(Config.DATA_PATH,"observation_train"),
        os.path.join(Config.DATA_PATH,"ground_truth_train")
    )
    val_ds = LoDoPaBSinogramDataset(
        os.path.join(Config.DATA_PATH,"observation_validation"),
        os.path.join(Config.DATA_PATH,"ground_truth_validation")
    )

    print(f"Device: {Config.DEVICE}  AMP: {Config.USE_AMP}")
    print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True,
                              num_workers=Config.NUM_WORKERS, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=Config.BATCH_SIZE, shuffle=False,
                              num_workers=Config.NUM_WORKERS, pin_memory=True)

    encoder = SinogramEncoder().to(Config.DEVICE)
    generator = Generator().to(Config.DEVICE)
    discriminator = Discriminator().to(Config.DEVICE)

    optG = optim.Adam(list(encoder.parameters())+list(generator.parameters()),
                      lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    optD = optim.Adam(discriminator.parameters(),
                      lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    scheduler = OneCycleLR(optG, max_lr=Config.MAX_LR,
                           total_steps=Config.EPOCHS*len(train_loader), pct_start=0.1)

    scaler   = GradScaler(enabled=Config.USE_AMP)
    ssim_fn  = StructuralSimilarityIndexMeasure(data_range=1.0).to(Config.DEVICE)
    lpips_fn = LPIPS().to(Config.DEVICE)
    vgg_feats= vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16].to(Config.DEVICE).eval()
    for p in vgg_feats.parameters(): p.requires_grad=False

    history, best_ssim, patience = {"lossD":[],"lossG":[],"ssim":[],"psnr":[],"lpips":[]}, 0.0, 0

    for epoch in range(1, Config.EPOCHS+1):
        t0 = time.time()
        encoder.train(); generator.train(); discriminator.train()
        sumD, sumG, sum_ssim, sum_psnr, sum_lpips = 0.0, 0.0, 0.0, 0.0, 0.0
        n_batches = 0

        for sino, gt in tqdm(train_loader, desc=f"Epoch {epoch}"):
            sino, gt = sino.to(Config.DEVICE), gt.to(Config.DEVICE)

            # D step
            with autocast(enabled=Config.USE_AMP):
                fake = generator(encoder(sino)).detach()
                rD, fD = discriminator(gt), discriminator(fake)
                lossD = 0.5 * (
                    F.binary_cross_entropy_with_logits(rD, torch.ones_like(rD)) +
                    F.binary_cross_entropy_with_logits(fD, torch.zeros_like(fD))
                )
            optD.zero_grad()
            if Config.USE_AMP:
                scaler.scale(lossD).backward()
                scaler.step(optD)
            else:
                lossD.backward()
                optD.step()

            # G step
            with autocast(enabled=Config.USE_AMP):
                out    = generator(encoder(sino))
                adv    = F.binary_cross_entropy_with_logits(discriminator(out),
                                                            torch.ones_like(rD)) * Config.L_ADV
                perp   = F.l1_loss(
                    vgg_feats(out.repeat(1,3,1,1)),
                    vgg_feats(gt.repeat(1,3,1,1))
                ) * Config.L_PERP
                ssim_l = (1 - ssim_fn(out, gt)) * Config.L_SSIM
                cycle  = F.l1_loss(generator(encoder(out)), gt) * Config.L_CYCLE
                edge_l = edge_loss(out, gt) * Config.L_EDGE
                ffl    = focal_frequency_loss(out, gt) * Config.L_FFL
                lossG  = adv + perp + ssim_l + cycle + edge_l + ffl
            optG.zero_grad()
            if Config.USE_AMP:
                scaler.scale(lossG).backward()
                scaler.step(optG)
                scaler.update()
            else:
                lossG.backward()
                optG.step()
            scheduler.step()

            # metrics
            with torch.no_grad():
                ssim_val = ssim_fn(out, gt).item()
                psnr_val = psnr(out, gt, data_range=1.0).item()
                lpips_val = lpips_fn(out.repeat(1,3,1,1), gt.repeat(1,3,1,1)).item()

            sumD += lossD.item()
            sumG += lossG.item()
            sum_ssim += ssim_val
            sum_psnr += psnr_val
            sum_lpips += lpips_val
            n_batches += 1

        avgD = sumD / n_batches
        avgG = sumG / n_batches
        avg_ssim = sum_ssim / n_batches
        avg_psnr = sum_psnr / n_batches
        avg_lpips = sum_lpips / n_batches
        epoch_time = time.time() - t0

        # early stopping
        if avg_ssim > best_ssim:
            best_ssim = avg_ssim
            patience = 0
            torch.save({
                'encoder': encoder.state_dict(),
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
            }, 'best_model.pth')
        else:
            patience += 1
            if patience >= Config.PATIENCE:
                print("Early stopping triggered")
                break

        # print summary
        print(f"Epoch {epoch:3d} | "
              f"D: {avgD:.4f}  G: {avgG:.4f}  "
              f"SSIM: {avg_ssim:.4f}  PSNR: {avg_psnr:.2f}  "
              f"LPIPS: {avg_lpips:.4f}  Time: {epoch_time:.2f}s")

        # record history
        history["lossD"].append(avgD)
        history["lossG"].append(avgG)
        history["ssim"].append(avg_ssim)
        history["psnr"].append(avg_psnr)
        history["lpips"].append(avg_lpips)

    # Final save & plot
    torch.save(encoder.state_dict(), os.path.join(Config.SAVE_DIR,"final_enc.pth"))
    torch.save(generator.state_dict(),os.path.join(Config.SAVE_DIR,"final_gen.pth"))
    torch.save(discriminator.state_dict(),os.path.join(Config.SAVE_DIR,"final_dis.pth"))

    e = list(range(1, len(history["lossD"])+1))
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(e, history["lossD"], label="D Loss")
    plt.plot(e, history["lossG"], label="G Loss")
    plt.legend(); plt.grid(True); plt.title("Train Losses")

    plt.subplot(1,2,2)
    plt.plot(e, history["ssim"], label="SSIM")
    plt.plot(e, history["psnr"], label="PSNR")
    plt.legend(); plt.grid(True); plt.title("Val Metrics")

    plt.tight_layout()
    plt.savefig(os.path.join(Config.SAVE_DIR,"metrics.png"))
    plt.show()

if __name__=="__main__":
    print("Using device:", Config.DEVICE)
    train_and_evaluate()


Using device: mps
Loaded 35820 slices from 280 files
Loaded 3522 slices from 28 files
Device: mps  AMP: False
Train samples: 35820, Val samples: 3522


  scaler   = GradScaler(enabled=Config.USE_AMP)
  self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
  with autocast(enabled=Config.USE_AMP):
  with autocast(enabled=Config.USE_AMP):
Epoch 1:   4%|▍         | 383/8955 [02:26<50:23,  2.84it/s] 