In [1]:
# ddcgan_fusion_fixed.py
# Full updated codebase (single-file). Replace your old script with this or copy into a notebook cell.

# %%
# Cell 1: Import libraries & set seed
import os
import random
import time
from pathlib import Path

import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import kornia

# reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

print("Libraries loaded. PyTorch:", torch.__version__, "CUDA:", torch.cuda.is_available())

# %%
# Cell 2: Models (Encoder/Decoder/Generator/Discriminators)
print("Defining models...")

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride, kernel=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel//2),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, x):
        return self.model(x)

class Encoder(nn.Module):
    def __init__(self, in_channels=2, base_channels=48):
        super().__init__()
        self.conv1 = Block(in_channels, base_channels, 1)      # preserve size
        self.conv2 = Block(base_channels, base_channels, 2)   # downsample
        self.conv3 = Block(base_channels, base_channels, 1)
        self.conv4 = Block(base_channels, base_channels, 2)   # downsample
        self.conv5 = Block(base_channels, base_channels, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x  # shape: (B, C, H/4, W/4)

class Decoder(nn.Module):
    def __init__(self, in_channels=48, out_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(64, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(32, 16, 4, 2, 1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(16, out_channels, 3, 1, 1),
            nn.Tanh()
        )
    def forward(self, x):
        return self.model(x)

class Generator(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, base_channels=48):
        super().__init__()
        self.encoder = Encoder(in_channels=in_channels, base_channels=base_channels)
        self.decoder = Decoder(in_channels=base_channels, out_channels=out_channels)
    def forward(self, x):
        feat = self.encoder(x)
        out = self.decoder(feat)
        return out, feat  # return both fused image and encoder features

class Discriminator(nn.Module):
    # image-level discriminator
    def __init__(self, in_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1)
        )
    def forward(self, x):
        return self.model(x)

class FeatureDiscriminator(nn.Module):
    # discriminator working on encoder feature maps
    def __init__(self, feat_channels=48):
        super().__init__()
        # small conv net to judge feature maps
        self.model = nn.Sequential(
            nn.Conv2d(feat_channels, feat_channels, 3, 1, 1),
            nn.BatchNorm2d(feat_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(feat_channels, 1)
        )
    def forward(self, feat):
        return self.model(feat)

print("Model definitions done.")

# %%
# Cell 3: Dataset
print("Defining CT-MRI Dataset class...")

class CTMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None, img_size=(256,256)):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.img_size = img_size
        self.image_pairs = []

        ct_dir = self.root_dir / 'CT'
        mri_dir = self.root_dir / 'MRI'

        if ct_dir.exists() and mri_dir.exists():
            ct_files = sorted(list(ct_dir.glob('*.png')) + list(ct_dir.glob('*.jpg')) + list(ct_dir.glob('*.jpeg')))
            for ct in ct_files:
                mri = mri_dir / ct.name
                if mri.exists():
                    self.image_pairs.append((str(ct), str(mri)))
                else:
                    # try same-stem with different extension if needed
                    stems = list(mri_dir.glob(ct.stem + '.*'))
                    if stems:
                        self.image_pairs.append((str(ct), str(stems[0])))
        else:
            print("Warning: CT or MRI directory missing:", ct_dir, mri_dir)

        print(f"Dataset found {len(self.image_pairs)} pairs under {root_dir}")

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        ct_path, mri_path = self.image_pairs[idx]
        try:
            ct_img = Image.open(ct_path).convert('L')  # grayscale
            mri_img = Image.open(mri_path).convert('L')

            ct_img = ct_img.resize(self.img_size, Image.Resampling.BILINEAR)
            mri_img = mri_img.resize(self.img_size, Image.Resampling.BILINEAR)

            if self.transform:
                ct_t = self.transform(ct_img)
                mri_t = self.transform(mri_img)
            else:
                ct_t = transforms.ToTensor()(ct_img)
                mri_t = transforms.ToTensor()(mri_img)
                ct_t = ct_t * 2.0 - 1.0
                mri_t = mri_t * 2.0 - 1.0

            return ct_t, mri_t
        except Exception as e:
            print("Error loading pair:", ct_path, mri_path, e)
            dummy = torch.rand(1, self.img_size[0], self.img_size[1]) * 2 - 1
            return dummy, dummy

# %%
# Cell 4: Utilities and Trainer (major updates)
print("Defining trainer and utilities...")

# Paths and directories
RESULTS_DIR = "results/ddcgan_fusion_fixed"
SAMPLES_DIR = os.path.join(RESULTS_DIR, "samples")
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")
CHECKPOINTS_DIR = "checkpoints/intermediate/ddcgan_fusion_fixed"
FINAL_MODELS_DIR = "checkpoints/final/ddcgan_fusion_fixed"
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(FINAL_MODELS_DIR, exist_ok=True)

def weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if getattr(m, "bias", None) is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def gradient_loss(gen_img, ref_img):
    # Expect images in [-1,1], shape (B,1,H,W)
    gen_grad = kornia.filters.sobel(gen_img)
    ref_grad = kornia.filters.sobel(ref_img)
    return torch.nn.functional.l1_loss(gen_grad, ref_grad)

class DDcGANTrainer:
    def __init__(self, dataset_path, batch_size=8, lr=2e-4, img_size=(256,256), device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.img_size = img_size

        print("Initializing trainer...", "device=", self.device)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1,1]
        ])

        self.dataset = CTMRIDataset(dataset_path, transform=transform, img_size=img_size)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

        # Models
        self.generator = Generator(in_channels=2, out_channels=1).to(self.device)
        self.disc_img = Discriminator(in_channels=1).to(self.device)       # image-level discriminator
        self.disc_feat = FeatureDiscriminator(feat_channels=48).to(self.device)  # feature-level

        self.generator.apply(weights_init)
        self.disc_img.apply(weights_init)
        self.disc_feat.apply(weights_init)

        # Opts
        self.g_optimizer = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        self.d_img_optimizer = optim.Adam(self.disc_img.parameters(), lr=lr, betas=(0.5, 0.999))
        self.d_feat_optimizer = optim.Adam(self.disc_feat.parameters(), lr=lr, betas=(0.5, 0.999))

        # Loss functions
        self.adversarial_loss = nn.MSELoss()   # used for discriminator outputs
        self.l1_loss = nn.L1Loss()
        # use kornia SSIM
        self.ssim_fn = kornia.losses.SSIMLoss(window_size=11, reduction='mean')

        # weights (tuned to encourage both modalities)
        self.w_recon = 5.0
        self.w_grad = 5.0
        self.w_ssim = 2.0
        self.w_feat_adv = 1.0

        # history
        self.g_losses = []
        self.d_losses = []
        self.epoch_times = []

    def train_epoch(self, epoch):
        self.generator.train()
        self.disc_img.train()
        self.disc_feat.train()

        start = time.time()
        pbar = tqdm(self.dataloader, desc=f"Epoch {epoch}")
        avg_g_loss = 0.0
        avg_d_loss = 0.0
        n = 0

        for i, (ct_imgs, mri_imgs) in enumerate(pbar):
            n += 1
            ct_imgs = ct_imgs.to(self.device)   # (B,1,H,W)
            mri_imgs = mri_imgs.to(self.device)

            # build generator input by concatenating CT and MRI (channel dim)
            input_imgs = torch.cat([ct_imgs, mri_imgs], dim=1)  # (B,2,H,W)

            batch_size = ct_imgs.size(0)
            real_label = torch.ones(batch_size, 1, device=self.device)
            fake_label = torch.zeros(batch_size, 1, device=self.device)

            # ---------------------
            # Train image discriminator (real from CT/MRI mix, fake from generator)
            # ---------------------
            self.d_img_optimizer.zero_grad()

            # real fused samples: randomly choose CT or MRI — forces discriminator to accept both modality styles
            # this makes discriminator less biased to one modality
            if random.random() < 0.5:
                real_fused = ct_imgs
            else:
                real_fused = mri_imgs

            # generate fake fused images (detach -> for discriminator)
            with torch.no_grad():
                fake_fused, _ = self.generator(input_imgs)

            real_pred = self.disc_img(real_fused)
            fake_pred = self.disc_img(fake_fused.detach())

            d_real_loss = self.adversarial_loss(real_pred, real_label)
            d_fake_loss = self.adversarial_loss(fake_pred, fake_label)
            d_img_loss = 0.5 * (d_real_loss + d_fake_loss)
            d_img_loss.backward()
            self.d_img_optimizer.step()

            # ---------------------
            # Train feature discriminator (on encoder feature maps)
            # ---------------------
            self.d_feat_optimizer.zero_grad()
            # get encoder features for real (we compute encoder features for CT and MRI separately)
            with torch.no_grad():
                # encoder expects 2-channel input, create a "pseudo" paired input:
                ct_pair = torch.cat([ct_imgs, ct_imgs], dim=1)   # (B,2,H,W)
                mri_pair = torch.cat([mri_imgs, mri_imgs], dim=1)
                ct_feat = self.generator.encoder(ct_pair)   # (B,C,H',W')
                mri_feat = self.generator.encoder(mri_pair)
            # sample features: pick CT or MRI features randomly as "real" features
            real_feat = ct_feat if random.random() < 0.5 else mri_feat
            # fake features are features from generator on mixed input
            with torch.no_grad():
                _, fake_feat = self.generator(input_imgs)
            real_f_pred = self.disc_feat(real_feat)
            fake_f_pred = self.disc_feat(fake_feat.detach())
            d_feat_real_loss = self.adversarial_loss(real_f_pred, real_label)
            d_feat_fake_loss = self.adversarial_loss(fake_f_pred, fake_label)
            d_feat_loss = 0.5 * (d_feat_real_loss + d_feat_fake_loss)
            d_feat_loss.backward()
            self.d_feat_optimizer.step()

            d_total = d_img_loss + d_feat_loss

            # ---------------------
            # Train generator (adversarial + reconstruction + gradient + ssim + feat adv)
            # ---------------------
            self.g_optimizer.zero_grad()
            fused_imgs, fused_feat = self.generator(input_imgs)

            # adversarial (image-level)
            pred_img = self.disc_img(fused_imgs)
            g_adv_img = self.adversarial_loss(pred_img, real_label)

            # adversarial (feature-level) - try to fool the feature discriminator
            pred_feat = self.disc_feat(fused_feat)
            g_adv_feat = self.adversarial_loss(pred_feat, real_label)

            # reconstruction L1 to both CT and MRI (equal weight)
            rec_loss = self.l1_loss(fused_imgs, ct_imgs) + self.l1_loss(fused_imgs, mri_imgs)

            # gradient preservation for both modalities
            grad_loss = gradient_loss(fused_imgs, ct_imgs) + gradient_loss(fused_imgs, mri_imgs)

            # SSIM preservation for both
            ssim_loss = self.ssim_fn(fused_imgs, ct_imgs) + self.ssim_fn(fused_imgs, mri_imgs)

            g_loss = g_adv_img + self.w_feat_adv * g_adv_feat + \
                     self.w_recon * rec_loss + \
                     self.w_grad * grad_loss + \
                     self.w_ssim * ssim_loss

            g_loss.backward()
            self.g_optimizer.step()

            avg_g_loss += g_loss.item()
            avg_d_loss += d_total.item()

            pbar.set_postfix({
                'G': f'{g_loss.item():.4f}',
                'D': f'{d_total.item():.4f}',
                'Rec': f'{rec_loss.item():.4f}',
                'Grad': f'{grad_loss.item():.4f}',
                'SSIM': f'{ssim_loss.item():.4f}'
            })

            # save sample occasionally
            if i % 200 == 0:
                self.save_sample_images(ct_imgs, mri_imgs, fused_imgs, epoch, i)

        avg_g_loss = avg_g_loss / max(1, n)
        avg_d_loss = avg_d_loss / max(1, n)
        elapsed = time.time() - start
        self.g_losses.append(avg_g_loss)
        self.d_losses.append(avg_d_loss)
        self.epoch_times.append(elapsed)

        print(f"\nEpoch {epoch} done — G_loss: {avg_g_loss:.4f}, D_loss: {avg_d_loss:.4f}, time: {elapsed:.1f}s")
        return avg_g_loss, avg_d_loss

    def save_sample_images(self, ct_imgs, mri_imgs, fused_imgs, epoch, batch_idx):
        os.makedirs(SAMPLES_DIR, exist_ok=True)
        ct_sample = ct_imgs[0].cpu()
        mri_sample = mri_imgs[0].cpu()
        fused_sample = fused_imgs[0].cpu()

        # denormalize
        ct_np = ((ct_sample + 1.0) / 2.0).squeeze().numpy()
        mri_np = ((mri_sample + 1.0) / 2.0).squeeze().numpy()
        fused_np = ((fused_sample + 1.0) / 2.0).squeeze().numpy()

        fig, axes = plt.subplots(1,3, figsize=(12,4))
        axes[0].imshow(ct_np, cmap='gray'); axes[0].set_title('CT'); axes[0].axis('off')
        axes[1].imshow(mri_np, cmap='gray'); axes[1].set_title('MRI'); axes[1].axis('off')
        axes[2].imshow(fused_np, cmap='gray'); axes[2].set_title('Fused'); axes[2].axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(SAMPLES_DIR, f'epoch{epoch}_batch{batch_idx}.png'), dpi=150, bbox_inches='tight')
        plt.close()

    def save_model(self, epoch, path=CHECKPOINTS_DIR):
        os.makedirs(path, exist_ok=True)
        ckpt = {
            'epoch': epoch,
            'generator_state_dict': self.generator.state_dict(),
            'disc_img_state_dict': self.disc_img.state_dict(),
            'disc_feat_state_dict': self.disc_feat.state_dict(),
            'g_optimizer': self.g_optimizer.state_dict(),
            'd_img_optimizer': self.d_img_optimizer.state_dict(),
            'd_feat_optimizer': self.d_feat_optimizer.state_dict(),
            'g_losses': self.g_losses,
            'd_losses': self.d_losses,
            'epoch_times': self.epoch_times
        }
        torch.save(ckpt, os.path.join(path, f'ddcgan_fixed_epoch_{epoch}.pth'))
        print("Saved checkpoint at epoch", epoch)

    def load_model(self, ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=self.device)
        self.generator.load_state_dict(ckpt['generator_state_dict'])
        self.disc_img.load_state_dict(ckpt['disc_img_state_dict'])
        self.disc_feat.load_state_dict(ckpt['disc_feat_state_dict'])
        self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        self.d_img_optimizer.load_state_dict(ckpt['d_img_optimizer'])
        self.d_feat_optimizer.load_state_dict(ckpt['d_feat_optimizer'])
        self.g_losses = ckpt.get('g_losses', [])
        self.d_losses = ckpt.get('d_losses', [])
        print("Loaded checkpoint:", ckpt_path)
        return ckpt.get('epoch', 0)

    def train(self, num_epochs=50, save_interval=5):
        print("Starting training:", num_epochs, "epochs, dataset size:", len(self.dataset))
        for ep in range(1, num_epochs+1):
            g_loss, d_loss = self.train_epoch(ep)
            if ep % save_interval == 0:
                self.save_model(ep)
        # final save
        self.save_model(num_epochs, path=FINAL_MODELS_DIR)
        self.plot_training_metrics()

    def plot_training_metrics(self):
        os.makedirs(PLOTS_DIR, exist_ok=True)
        epochs = range(1, len(self.g_losses)+1)
        plt.figure(figsize=(10,4))
        plt.plot(epochs, self.g_losses, label='G_loss')
        plt.plot(epochs, self.d_losses, label='D_loss')
        plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Training losses')
        plt.savefig(os.path.join(PLOTS_DIR, 'losses.png'), dpi=150, bbox_inches='tight')
        plt.close()

# %%
# Cell 5: Config & run (change DATASET_PATH as appropriate)
if __name__ == "__main__":
    DATASET_PATH = "../Dataset/train"   # <-- set to your dataset location
    BATCH_SIZE = 16
    LR = 2e-4
    NUM_EPOCHS = 3
    IMG_SIZE = (256,256)
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    if not os.path.exists(DATASET_PATH):
        print("Dataset path not found:", DATASET_PATH)
        print("Please set DATASET_PATH to your dataset folder containing CT/ and MRI/ subfolders.")
    else:
        trainer = DDcGANTrainer(
            dataset_path=DATASET_PATH,
            batch_size=BATCH_SIZE,
            lr=LR,
            img_size=IMG_SIZE,
            device=DEVICE
        )
        trainer.train(NUM_EPOCHS, save_interval=10)

    # demonstration of inference on a random pair (even if untrained)
    def demo_inference(checkpoint=None):
        device = torch.device(DEVICE if torch.cuda.is_available() else 'cpu')
        gen = Generator(in_channels=2).to(device)
        if checkpoint and os.path.exists(checkpoint):
            ck = torch.load(checkpoint, map_location=device)
            gen.load_state_dict(ck['generator_state_dict'])
            print("Loaded generator from", checkpoint)
        gen.eval()

        # pick a random pair from dataset or dummy
        if os.path.exists(DATASET_PATH):
            ds = CTMRIDataset(DATASET_PATH, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5],[0.5])
            ]), img_size=IMG_SIZE)
            if len(ds) > 0:
                ct, mri = ds[random.randint(0, len(ds)-1)]
                ct = ct.unsqueeze(0).to(device)
                mri = mri.unsqueeze(0).to(device)
            else:
                ct = torch.rand(1,1,IMG_SIZE[0],IMG_SIZE[1]).to(device)*2-1
                mri = torch.rand(1,1,IMG_SIZE[0],IMG_SIZE[1]).to(device)*2-1
        else:
            ct = torch.rand(1,1,IMG_SIZE[0],IMG_SIZE[1]).to(device)*2-1
            mri = torch.rand(1,1,IMG_SIZE[0],IMG_SIZE[1]).to(device)*2-1

        with torch.no_grad():
            fused, _ = gen(torch.cat([ct, mri], dim=1))
        # denorm
        def denorm(x):
            x = x.detach().cpu().squeeze()
            x = (x + 1.0) / 2.0
            x = x.clamp(0,1).numpy()
            return x
        ct_np = denorm(ct)
        mri_np = denorm(mri)
        fused_np = denorm(fused)

        fig, ax = plt.subplots(1,3, figsize=(12,4))
        ax[0].imshow(ct_np, cmap='gray'); ax[0].set_title('CT'); ax[0].axis('off')
        ax[1].imshow(mri_np, cmap='gray'); ax[1].set_title('MRI'); ax[1].axis('off')
        ax[2].imshow(fused_np, cmap='gray'); ax[2].set_title('Fused'); ax[2].axis('off')
        plt.tight_layout()
        demo_path = os.path.join(RESULTS_DIR, "demo_inference.png")
        plt.savefig(demo_path, dpi=150, bbox_inches='tight')
        plt.show()
        print("Saved demo:", demo_path)

    # run demo (uncomment to run)
    # demo_inference()


  return torch._C._cuda_getDeviceCount() > 0


Libraries loaded. PyTorch: 2.5.1 CUDA: False
Defining models...
Model definitions done.
Defining CT-MRI Dataset class...
Defining trainer and utilities...
Initializing trainer... device= cpu
Dataset found 3208 pairs under ../Dataset/train
Starting training: 3 epochs, dataset size: 3208


Epoch 1:   0%|          | 0/201 [00:06<?, ?it/s, G=10.3268, D=1.1427, Rec=1.0953, Grad=0.1309, SSIM=0.9966]


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [1]:
# ddcgan_fusion_improved.py
# Improved DDcGAN fusion code (CT + MRI) with balanced losses and feature-level discriminator
# Created: 2025-08-31

import os
import random
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import kornia

# ---------------------------
# Utilities
# ---------------------------

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed()


# ---------------------------
# Model blocks
# ---------------------------
class Block(nn.Module):
    def __init__(self, in_channels, filter_size, strides, kernel=3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, filter_size, kernel, strides, kernel // 2),
            nn.BatchNorm2d(filter_size),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.model(x)


class Encoder(nn.Module):
    def __init__(self, in_channels=2, out_channels=48, constant_feature_map=48):
        super().__init__()
        self.model = nn.Sequential(
            Block(in_channels, constant_feature_map, 1),
            Block(constant_feature_map, constant_feature_map, 2),
            Block(constant_feature_map, constant_feature_map, 1),
            Block(constant_feature_map, constant_feature_map, 2),
            Block(constant_feature_map, out_channels, 1),
        )

    def forward(self, x):
        return self.model(x)


class Decoder(nn.Module):
    def __init__(self, in_channels=48, out_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, out_channels, 3, 1, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.model(x)


class Generator(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, encoder_constant_features=48):
        super().__init__()
        self.encoder = Encoder(in_channels, encoder_constant_features, encoder_constant_features)
        self.decoder = Decoder(encoder_constant_features, out_channels)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded


class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1),
        )

    def forward(self, x):
        return self.model(x)


class FeatureDiscriminator(nn.Module):
    """Discriminator that operates on encoder feature maps (smaller spatial dims)."""

    def __init__(self, in_channels=48):
        super().__init__()
        # small conv net to judge feature realism
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, 1),
            nn.BatchNorm2d(in_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(in_channels, 1),
        )

    def forward(self, x):
        return self.model(x)


# ---------------------------
# Dataset
# ---------------------------
class CTMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None, img_size=(256, 256)):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.img_size = img_size
        self.image_pairs = []

        ct_dir = self.root_dir / "CT"
        mri_dir = self.root_dir / "MRI"

        if ct_dir.exists() and mri_dir.exists():
            ct_files = sorted(list(ct_dir.glob("*.png")))
            mri_files = sorted(list(mri_dir.glob("*.png")))
            for ct_file in ct_files:
                mri_file = mri_dir / ct_file.name
                if mri_file.exists():
                    self.image_pairs.append((str(ct_file), str(mri_file)))
                else:
                    # try other extensions if not found
                    for ext in (".jpg", ".jpeg", ".bmp", ".tif"):
                        alt = mri_dir / (ct_file.stem + ext)
                        if alt.exists():
                            self.image_pairs.append((str(ct_file), str(alt)))
                            break
        else:
            print(f"Warning: Dataset directories not found at {root_dir}. expected CT/ and MRI/")

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        ct_path, mri_path = self.image_pairs[idx]
        ct_img = Image.open(ct_path).convert("L")
        mri_img = Image.open(mri_path).convert("L")
        ct_img = ct_img.resize(self.img_size, Image.Resampling.BILINEAR)
        mri_img = mri_img.resize(self.img_size, Image.Resampling.BILINEAR)

        if self.transform:
            ct_tensor = self.transform(ct_img)
            mri_tensor = self.transform(mri_img)
        else:
            ct_tensor = transforms.ToTensor()(ct_img)
            mri_tensor = transforms.ToTensor()(mri_img)
            ct_tensor = ct_tensor * 2.0 - 1.0
            mri_tensor = mri_tensor * 2.0 - 1.0

        return ct_tensor, mri_tensor


# ---------------------------
# Training utilities
# ---------------------------
RESULTS_DIR = "results/ddcgan_fusion_improved"
SAMPLES_DIR = f"{RESULTS_DIR}/samples"
PLOTS_DIR = f"{RESULTS_DIR}/plots"
CHECKPOINTS_DIR = "checkpoints/ddcgan_fusion_improved"
FINAL_MODELS_DIR = "checkpoints/final/ddcgan_fusion_improved"


def weights_init(m):
    classname = m.__class__.__name__
    if "Conv" in classname:
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif "BatchNorm" in classname:
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0)


def gradient_loss(pred, target):
    # Sobel gradients via kornia
    pred_grad = kornia.filters.sobel(pred)
    target_grad = kornia.filters.sobel(target)
    return torch.nn.functional.l1_loss(pred_grad, target_grad)


# ---------------------------
# Trainer class
# ---------------------------
class DDcGANTrainer:
    def __init__(
        self,
        dataset_path,
        batch_size=8,
        lr=0.0002,
        img_size=(256, 256),
        device="cuda",
        feature_disc_weight=1.0,
    ):
        self.device = device
        self.batch_size = batch_size
        self.img_size = img_size
        self.lr = lr

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

        self.dataset = CTMRIDataset(dataset_path, transform=transform, img_size=img_size)
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )

        # Models
        self.generator = Generator(in_channels=2, out_channels=1).to(self.device)
        self.discriminator1 = Discriminator(in_channels=1).to(self.device)
        self.discriminator2 = Discriminator(in_channels=1).to(self.device)
        self.feature_disc = FeatureDiscriminator(in_channels=48).to(self.device)

        # Initialize
        self.generator.apply(weights_init)
        self.discriminator1.apply(weights_init)
        self.discriminator2.apply(weights_init)
        self.feature_disc.apply(weights_init)

        # Optimizers
        self.g_optimizer = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        self.d1_optimizer = optim.Adam(self.discriminator1.parameters(), lr=lr, betas=(0.5, 0.999))
        self.d2_optimizer = optim.Adam(self.discriminator2.parameters(), lr=lr, betas=(0.5, 0.999))
        self.fd_optimizer = optim.Adam(self.feature_disc.parameters(), lr=lr, betas=(0.5, 0.999))

        # Losses
        self.adversarial_loss = nn.MSELoss()
        self.reconstruction_loss = nn.L1Loss()
        self.ssim_loss_fn = kornia.losses.SSIMLoss(window_size=11, reduction="mean")

        # Weights (tune these)
        self.w_recon = 5.0
        self.w_grad = 5.0
        self.w_ssim = 2.0
        self.w_feature_adv = feature_disc_weight

        # tracking
        self.g_losses = []
        self.d_losses = []
        self.recon_losses = []

        print("Trainer initialized:")
        print(f" Dataset pairs: {len(self.dataset)}")
        print(f" Device: {self.device}")

    def ssim_loss(self, img1, img2):
        return self.ssim_loss_fn(img1, img2)

    def train_epoch(self, epoch):
        self.generator.train()
        self.discriminator1.train()
        self.discriminator2.train()
        self.feature_disc.train()

        epoch_g_loss = 0
        epoch_d_loss = 0
        epoch_recon_loss = 0

        pbar = tqdm(self.dataloader, desc=f"Epoch {epoch}")

        for i, (ct_imgs, mri_imgs) in enumerate(pbar):
            ct_imgs = ct_imgs.to(self.device)
            mri_imgs = mri_imgs.to(self.device)
            batch_size = ct_imgs.size(0)

            # Input for generator: concatenate CT and MRI
            input_imgs = torch.cat([ct_imgs, mri_imgs], dim=1)

            real_label = torch.ones(batch_size, 1).to(self.device)
            fake_label = torch.zeros(batch_size, 1).to(self.device)

            # ---------------------
            # Train Discriminators on image space
            # ---------------------
            with torch.no_grad():
                fake_fused, fake_feat = self.generator(input_imgs)

            # Choose real samples for image discriminators randomly from CT/MRI
            # This prevents the discriminator from always using CT as "real fused"
            if random.random() < 0.5:
                real_fused = ct_imgs
            else:
                real_fused = mri_imgs

            # D1 (global)
            self.d1_optimizer.zero_grad()
            real_pred1 = self.discriminator1(real_fused)
            fake_pred1 = self.discriminator1(fake_fused.detach())
            d1_loss = (self.adversarial_loss(real_pred1, real_label) + self.adversarial_loss(fake_pred1, fake_label)) / 2
            d1_loss.backward()
            self.d1_optimizer.step()

            # D2 (local)
            self.d2_optimizer.zero_grad()
            real_pred2 = self.discriminator2(real_fused)
            fake_pred2 = self.discriminator2(fake_fused.detach())
            d2_loss = (self.adversarial_loss(real_pred2, real_label) + self.adversarial_loss(fake_pred2, fake_label)) / 2
            d2_loss.backward()
            self.d2_optimizer.step()

            # Feature discriminator
            self.fd_optimizer.zero_grad()
            # For feature discriminator we use encoder outputs of "real" inputs
            # Build real features by encoding a mixed real input (either CT or MRI or concatenated)
            with torch.no_grad():
                # encode CT and MRI separately and also encode concatenated as real examples
                ct_encoded = self.generator.encoder(torch.cat([ct_imgs, ct_imgs], dim=1))
                mri_encoded = self.generator.encoder(torch.cat([mri_imgs, mri_imgs], dim=1))

            # Randomly choose real feature source
            if random.random() < 0.5:
                real_feat = ct_encoded
            else:
                real_feat = mri_encoded

            fake_feat_detach = fake_feat.detach()
            real_feat_pred = self.feature_disc(real_feat)
            fake_feat_pred = self.feature_disc(fake_feat_detach)
            fd_loss = (self.adversarial_loss(real_feat_pred, real_label) + self.adversarial_loss(fake_feat_pred, fake_label)) / 2
            fd_loss.backward()
            self.fd_optimizer.step()

            total_d_loss = d1_loss + d2_loss + fd_loss

            # ---------------------
            # Train Generator
            # ---------------------
            self.g_optimizer.zero_grad()
            fused_imgs, feat = self.generator(input_imgs)

            # Adversarial (image space)
            pred1 = self.discriminator1(fused_imgs)
            pred2 = self.discriminator2(fused_imgs)
            g_adv_loss = (self.adversarial_loss(pred1, real_label) + self.adversarial_loss(pred2, real_label)) / 2

            # Feature adversarial: fool feature discriminator
            feat_pred = self.feature_disc(feat)
            g_feat_adv_loss = self.adversarial_loss(feat_pred, real_label)

            # Reconstruction: with both modalities (L1)
            g_recon_loss = self.reconstruction_loss(fused_imgs, ct_imgs) + self.reconstruction_loss(fused_imgs, mri_imgs)

            # Gradient loss for both
            g_grad_loss = gradient_loss(fused_imgs, ct_imgs) + gradient_loss(fused_imgs, mri_imgs)

            # SSIM loss (both)
            g_ssim_loss = self.ssim_loss(fused_imgs, ct_imgs) + self.ssim_loss(fused_imgs, mri_imgs)

            # Total loss (balanced)
            g_loss = g_adv_loss + \
                     self.w_recon * g_recon_loss + \
                     self.w_grad * g_grad_loss + \
                     self.w_ssim * g_ssim_loss + \
                     self.w_feature_adv * g_feat_adv_loss

            g_loss.backward()
            self.g_optimizer.step()

            epoch_g_loss += g_loss.item()
            epoch_d_loss += total_d_loss.item()
            epoch_recon_loss += g_recon_loss.item()

            pbar.set_postfix({
                "G_Loss": f"{g_loss.item():.4f}",
                "D_Loss": f"{total_d_loss.item():.4f}",
                "Recon": f"{g_recon_loss.item():.4f}",
            })

            if i % 100 == 0:
                self.save_sample_images(ct_imgs, mri_imgs, fused_imgs, epoch, i)

        # epoch metrics
        avg_g_loss = epoch_g_loss / len(self.dataloader)
        avg_d_loss = epoch_d_loss / len(self.dataloader)
        avg_recon_loss = epoch_recon_loss / len(self.dataloader)

        self.g_losses.append(avg_g_loss)
        self.d_losses.append(avg_d_loss)
        self.recon_losses.append(avg_recon_loss)

        print(f"\nEpoch {epoch} summary: G {avg_g_loss:.4f}, D {avg_d_loss:.4f}, Recon {avg_recon_loss:.4f}")

        return avg_g_loss, avg_d_loss

    def save_sample_images(self, ct_imgs, mri_imgs, fused_imgs, epoch, batch_idx):
        os.makedirs(SAMPLES_DIR, exist_ok=True)
        ct_sample = ct_imgs[0].cpu()
        mri_sample = mri_imgs[0].cpu()
        fused_sample = fused_imgs[0].cpu()
        ct_np = ((ct_sample + 1) / 2).squeeze().numpy()
        mri_np = ((mri_sample + 1) / 2).squeeze().numpy()
        fused_np = ((fused_sample + 1) / 2).squeeze().numpy()

        fig, axes = plt.subplots(1, 3, figsize=(12, 4))
        axes[0].imshow(ct_np, cmap="gray")
        axes[0].set_title("CT")
        axes[0].axis("off")
        axes[1].imshow(mri_np, cmap="gray")
        axes[1].set_title("MRI")
        axes[1].axis("off")
        axes[2].imshow(fused_np, cmap="gray")
        axes[2].set_title("Fused")
        axes[2].axis("off")
        plt.tight_layout()
        plt.savefig(f"{SAMPLES_DIR}/epoch_{epoch}_batch_{batch_idx}.png", dpi=150, bbox_inches="tight")
        plt.close()

    def save_model(self, epoch, path=CHECKPOINTS_DIR):
        os.makedirs(path, exist_ok=True)
        checkpoint = {
            "epoch": epoch,
            "generator_state_dict": self.generator.state_dict(),
            "discriminator1_state_dict": self.discriminator1.state_dict(),
            "discriminator2_state_dict": self.discriminator2.state_dict(),
            "feature_disc_state_dict": self.feature_disc.state_dict(),
            "g_optimizer_state_dict": self.g_optimizer.state_dict(),
            "d1_optimizer_state_dict": self.d1_optimizer.state_dict(),
            "d2_optimizer_state_dict": self.d2_optimizer.state_dict(),
            "fd_optimizer_state_dict": self.fd_optimizer.state_dict(),
            "g_losses": self.g_losses,
            "d_losses": self.d_losses,
            "recon_losses": self.recon_losses,
        }
        torch.save(checkpoint, f"{path}/ddcgan_improved_epoch_{epoch}.pth")
        print(f"Saved checkpoint epoch {epoch} -> {path}")

    def load_model(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint["generator_state_dict"])\
            if "generator_state_dict" in checkpoint else None
        self.discriminator1.load_state_dict(checkpoint["discriminator1_state_dict"])\
            if "discriminator1_state_dict" in checkpoint else None
        self.discriminator2.load_state_dict(checkpoint["discriminator2_state_dict"])\
            if "discriminator2_state_dict" in checkpoint else None
        self.feature_disc.load_state_dict(checkpoint["feature_disc_state_dict"])\
            if "feature_disc_state_dict" in checkpoint else None
        self.g_losses = checkpoint.get("g_losses", [])
        self.d_losses = checkpoint.get("d_losses", [])
        self.recon_losses = checkpoint.get("recon_losses", [])
        return checkpoint.get("epoch", 0)

    def train(self, num_epochs=50, save_interval=5):
        print("Starting training...")
        start_time = time.time()
        for epoch in range(1, num_epochs + 1):
            g_loss, d_loss = self.train_epoch(epoch)
            print(f"Epoch [{epoch}/{num_epochs}] G: {g_loss:.4f} D: {d_loss:.4f}")
            if epoch % save_interval == 0:
                self.save_model(epoch)
        total_time = time.time() - start_time
        print(f"Training finished in {total_time:.2f}s")
        self.plot_training_metrics()

    def plot_training_metrics(self):
        os.makedirs(PLOTS_DIR, exist_ok=True)
        epochs = range(1, len(self.g_losses) + 1)
        plt.figure(figsize=(10, 4))
        plt.plot(epochs, self.g_losses, label="Generator Loss")
        plt.plot(epochs, self.d_losses, label="Discriminator Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.title("Losses")
        plt.savefig(f"{PLOTS_DIR}/losses.png", dpi=150, bbox_inches="tight")
        plt.close()


# ---------------------------
# Demo / Main
# ---------------------------
if __name__ == "__main__":
    # basic config - change these paths to suit your setup
    DATASET_PATH = "../Dataset/train"  # expects CT/ and MRI/ subfolders
    BATCH_SIZE = 8
    LR = 2e-4
    NUM_EPOCHS = 30
    IMG_SIZE = (256, 256)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    os.makedirs(RESULTS_DIR, exist_ok=True)
    os.makedirs(SAMPLES_DIR, exist_ok=True)
    os.makedirs(PLOTS_DIR, exist_ok=True)
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    os.makedirs(FINAL_MODELS_DIR, exist_ok=True)

    trainer = DDcGANTrainer(
        dataset_path=DATASET_PATH,
        batch_size=BATCH_SIZE,
        lr=LR,
        img_size=IMG_SIZE,
        device=device,
        feature_disc_weight=1.0,
    )

    # quick sanity check: minimal dataset
    if len(trainer.dataset) == 0:
        print("No image pairs found in dataset. Exiting.")
    else:
        trainer.train(num_epochs=NUM_EPOCHS, save_interval=5)
        trainer.save_model(NUM_EPOCHS, path=FINAL_MODELS_DIR)

    print("Done.")


  return torch._C._cuda_getDeviceCount() > 0


Trainer initialized:
 Dataset pairs: 3208
 Device: cpu
Starting training...


Epoch 1:   0%|          | 0/401 [00:03<?, ?it/s, G_Loss=12.9812, D_Loss=1.7494, Recon=1.6208]


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.