In [1]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))  
])



class CTMRIDataset(Dataset):
    def __init__(self, root_ct, root_mri, transform=None):
        self.root_ct = root_ct
        self.root_mri = root_mri
        self.transform = transform

        self.ct_files = sorted([f for f in os.listdir(root_ct) if f.endswith('.png')])
        self.mri_files = sorted([f for f in os.listdir(root_mri) if f.endswith('.jpg')])

        # Ensure both domains have the same length
        self.length = min(len(self.ct_files), len(self.mri_files))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        ct_path = os.path.join(self.root_ct, self.ct_files[idx])
        mri_path = os.path.join(self.root_mri, self.mri_files[idx])

        ct_img = Image.open(ct_path).convert("L")
        mri_img = Image.open(mri_path).convert("L")

        if self.transform:
            ct_img = self.transform(ct_img)
            mri_img = self.transform(mri_img)

        return ct_img, mri_img


train_dataset = CTMRIDataset(
    root_ct="/kaggle/input/ct-to-mri-cgan/Dataset/images/trainA",
    root_mri="/kaggle/input/ct-to-mri-cgan/Dataset/images/trainB",
    transform=transform
)

test_dataset = CTMRIDataset(
    root_ct="/kaggle/input/ct-to-mri-cgan/Dataset/images/testA",
    root_mri="/kaggle/input/ct-to-mri-cgan/Dataset/images/testB",
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [None]:
import os
import itertools
from PIL import Image
import torch
import torch.nn as nn


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 256
batch_size = 1
num_epochs = 20
lambda_cycle = 10
lambda_identity = 5


class ResidualBlock(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(features, features, 3),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(features, features, 3),
            nn.InstanceNorm2d(features)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residuals=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        ]
        for _ in range(num_residuals):
            model += [ResidualBlock(256)]
        model += [
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, 7),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(in_channels, 64, normalize=False),
            *block(64, 128),
            *block(128, 256),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(256, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

G_AB = Generator().to(device)
G_AB.load_state_dict(torch.load("/kaggle/working/G_AB.pth"))
G_BA = Generator().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(train_loader):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # --------------------------------------
        # 1. Train Generators G_AB and G_BA
        # --------------------------------------
        optimizer_G.zero_grad()

        # Identity loss
        same_B = G_AB(real_B)
        loss_identity_B = criterion_identity(same_B, real_B) * lambda_identity

        same_A = G_BA(real_A)
        loss_identity_A = criterion_identity(same_A, real_A) * lambda_identity

        # GAN loss
        fake_B = G_AB(real_A)
        pred_fake_B = D_B(fake_B)
        valid = torch.ones_like(pred_fake_B).to(device)
        loss_GAN_AB = criterion_GAN(pred_fake_B, valid)

        fake_A = G_BA(real_B)
        pred_fake_A = D_A(fake_A)
        loss_GAN_BA = criterion_GAN(pred_fake_A, valid)

        # Cycle loss
        recovered_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recovered_A, real_A) * lambda_cycle

        recovered_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recovered_B, real_B) * lambda_cycle

        # Total generator loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_AB + loss_GAN_BA + loss_cycle_A + loss_cycle_B
        loss_G.backward()
        optimizer_G.step()

        # --------------------------------------
        # 2. Train Discriminator A
        # --------------------------------------
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, valid)

        fake_A_detach = fake_A.detach()
        pred_fake = D_A(fake_A_detach)
        fake = torch.zeros_like(pred_fake).to(device)
        loss_D_fake = criterion_GAN(pred_fake, fake)

        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # --------------------------------------
        # 3. Train Discriminator B
        # --------------------------------------
        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, valid)

        fake_B_detach = fake_B.detach()
        pred_fake = D_B(fake_B_detach)
        loss_D_fake = criterion_GAN(pred_fake, fake)

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        # --------------------------
        # Logging
        # --------------------------
    print(
            f"Epoch [{epoch+1}/{num_epochs}] "
            f"Loss_G: {loss_G.item():.4f}, "
            f"Loss_D_A: {loss_D_A.item():.4f}, Loss_D_B: {loss_D_B.item():.4f}"
    )
    torch.save(G_AB.state_dict(), "/kaggle/working/G_AB.pth")
    torch.save(G_BA.state_dict(), "/kaggle/working/G_BA.pth")

  G_AB.load_state_dict(torch.load("/kaggle/working/G_AB.pth"))


Epoch [1/20] Loss_G: 1.7722, Loss_D_A: 0.1984, Loss_D_B: 0.2038
Epoch [2/20] Loss_G: 1.8269, Loss_D_A: 0.2880, Loss_D_B: 0.2458
Epoch [3/20] Loss_G: 1.4436, Loss_D_A: 0.2227, Loss_D_B: 0.2477
Epoch [4/20] Loss_G: 1.3493, Loss_D_A: 0.1614, Loss_D_B: 0.2137


In [None]:
torch.save(G_AB.state_dict(), "/kaggle/working/G_AB.pth")
torch.save(G_BA.state_dict(), "/kaggle/working/G_BA.pth")


In [None]:
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import mean_squared_error, mean_absolute_error

G_A2B = Generator().to(device)
G_A2B.load_state_dict(torch.load("/kaggle/working/G_AB.pth"))
G_A2B.to("cuda")
G_A2B.eval()
transform = transforms.Compose([
    transforms.Resize((256, 256)),  
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
])

test_ct_path = "/kaggle/input/ct-to-mri-cgan/Dataset/images/testA"

psnr_list, ssim_list, mse_list, mae_list = [], [], [], []


def evaluate_metrics(real, generated):
    real_np = real.squeeze().cpu().numpy().transpose(1, 2, 0)
    generated_np = generated.squeeze().cpu().numpy().transpose(1, 2, 0)

    real_gray = np.mean(real_np, axis=2)
    gen_gray = np.mean(generated_np, axis=2)

    psnr_value = psnr(real_gray, gen_gray, data_range=1)
    ssim_value = ssim(real_gray, gen_gray, data_range=1)
    mse_value = mean_squared_error(real_gray.flatten(), gen_gray.flatten())
    mae_value = mean_absolute_error(real_gray.flatten(), gen_gray.flatten())

    return psnr_value, ssim_value, mse_value, mae_value


for filename in os.listdir(test_ct_path):
    if filename.endswith(".jpg") or filename.endswith(".png"):
        path = os.path.join(test_ct_path, filename)
        image = Image.open(path).convert("RGB") 
        image_tensor = transform(image).unsqueeze(0).to("cuda")

 
        with torch.no_grad():
            generated = G_A2B(image_tensor)
            
        psnr_val, ssim_val, mse_val, mae_val = evaluate_metrics(image_tensor, generated)
        psnr_list.append(psnr_val)
        ssim_list.append(ssim_val)
        mse_list.append(mse_val)
        mae_list.append(mae_val)


print(f"Average PSNR : {np.mean(psnr_list):.6f}")
print(f"Average SSIM : {np.mean(ssim_list):.6f}")
print(f"Average MSE  : {np.mean(mse_list):.6f}")
print(f"Average MAE  : {np.mean(mae_list):.6f}")
