In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
import torch.nn as nn
import torch.nn.functional as F
import itertools
import matplotlib.pyplot as plt
from PIL import Image
import os

BATCH_SIZE = 5

lr = 0.0001
beta1 = 0.5
beta2 = 0.996
n_epoches = 120
decay_epoch = 40
display_epoch = 20

MONET_IMAGES_PATH = "/kaggle/input/augmented-alzheimer-mri-dataset/AugmentedAlzheimerDataset/MildDemented"
TEST_IMAGES_PATH = "/kaggle/input/augmented-alzheimer-mri-dataset/AugmentedAlzheimerDataset/ModerateDemented"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor


# Define Transforms with Fixed Resize
transforms_dataset = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize all images to 256x256
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Custom Dataset
class ImageDataset(Dataset):
    def __init__(self, directory_x, directory_y, test=False, transforms=None):
        self.transforms = transforms
        
        if test:
            self.monet_images_X = [os.path.join(directory_x, name) for name in sorted(os.listdir(directory_x))[250:]]
            self.test_images_Y = [os.path.join(directory_y, name) for name in sorted(os.listdir(directory_y))[250:301]]
        else:
            self.monet_images_X = [os.path.join(directory_x, name) for name in sorted(os.listdir(directory_x))[:250]]
            self.test_images_Y = [os.path.join(directory_y, name) for name in sorted(os.listdir(directory_y))[:250]]
        
    def __len__(self):
        return len(self.monet_images_X)
    
    def __getitem__(self, index):
        x_img = Image.open(self.monet_images_X[index]).convert("RGB")
        y_img = Image.open(self.test_images_Y[index]).convert("RGB")

        if self.transforms is not None:
            x_img = self.transforms(x_img)
            y_img = self.transforms(y_img)
            
        return x_img, y_img

# Custom Collate Function to Handle Variable-Sized Images
def collate_fn(batch):
    max_height = max(img.shape[1] for img, _ in batch)
    max_width = max(img.shape[2] for img, _ in batch)
    padded_batch = []
    for img_X, img_Y in batch:
        pad_X = torch.nn.functional.pad(img_X, (0, max_width - img_X.shape[2], 0, max_height - img_X.shape[1]))
        pad_Y = torch.nn.functional.pad(img_Y, (0, max_width - img_Y.shape[2], 0, max_height - img_Y.shape[1]))
        padded_batch.append((pad_X, pad_Y))
    real_X, real_Y = zip(*padded_batch)
    return torch.stack(real_X), torch.stack(real_Y)

train_loader = DataLoader(
    ImageDataset(directory_x=MONET_IMAGES_PATH, directory_y=TEST_IMAGES_PATH, test=False, transforms=transforms_dataset),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=3,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    ImageDataset(directory_x=MONET_IMAGES_PATH, directory_y=TEST_IMAGES_PATH, test=True, transforms=transforms_dataset),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=3,
    collate_fn=collate_fn
)

# Function to display images after each epoch
def sample_images(real_X, real_Y):
    G_XY.eval()
    G_YX.eval()

    real_X = real_X.to(device).type(Tensor)
    fake_Y = G_XY(real_X).detach()

    real_Y = real_Y.to(device).type(Tensor)
    fake_X = G_YX(real_Y).detach()

    ncols = real_X.size(0)
    real_X_grid = make_grid(real_X, nrow=ncols, normalize=True)
    fake_Y_grid = make_grid(fake_Y, nrow=ncols, normalize=True)
    real_Y_grid = make_grid(real_Y, nrow=ncols, normalize=True)
    fake_X_grid = make_grid(fake_X, nrow=ncols, normalize=True)

    fig, axs = plt.subplots(2, 2, figsize=(8, 8))  

    axs[0, 0].imshow(real_X_grid.permute(1, 2, 0).cpu())
    axs[0, 0].set_title("Real Images from Domain X")
    axs[0, 0].axis('off')

    axs[0, 1].imshow(fake_Y_grid.permute(1, 2, 0).cpu())
    axs[0, 1].set_title("Generated Images to Domain Y")
    axs[0, 1].axis('off')

    axs[1, 0].imshow(real_Y_grid.permute(1, 2, 0).cpu())
    axs[1, 0].set_title("Real Images from Domain Y")
    axs[1, 0].axis('off')

    axs[1, 1].imshow(fake_X_grid.permute(1, 2, 0).cpu())
    axs[1, 1].set_title("Generated Images to Domain X")
    axs[1, 1].axis('off')

    plt.tight_layout()
    plt.show()

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# Instantiate Generators
G_XY = Generator().to(device)  # Maps X → Y
G_YX = Generator().to(device)  # Maps Y → X

# Define optimizers for Generators and Discriminators
optimizer_G = torch.optim.Adam(
    itertools.chain(G_XY.parameters(), G_YX.parameters()), lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_X = torch.optim.Adam(G_XY.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = torch.optim.Adam(G_YX.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Learning rate schedulers
lr_scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.5)
lr_scheduler_D_X = torch.optim.lr_scheduler.StepLR(optimizer_D_X, step_size=10, gamma=0.5)
lr_scheduler_D_Y = torch.optim.lr_scheduler.StepLR(optimizer_D_Y, step_size=10, gamma=0.5)

def compute_loss(fake_Y, fake_X, real_X, real_Y):
    return torch.mean((fake_Y - real_Y) ** 2) + torch.mean((fake_X - real_X) ** 2)

# Loss functions
adversarial_loss = nn.MSELoss()  # Adversarial loss (L2 loss)
cycle_loss = nn.L1Loss()  # Cycle consistency loss (L1 loss)

# Compute the discriminator loss
def compute_discriminator_loss(D, real_images, fake_images):
    """Computes the discriminator loss for real and fake images."""
    real_preds = D(real_images)
    fake_preds = D(fake_images.detach())  # Detach to avoid backpropagating through G
    real_loss = adversarial_loss(real_preds, torch.ones_like(real_preds))  # Real images should be classified as 1
    fake_loss = adversarial_loss(fake_preds, torch.zeros_like(fake_preds))  # Fake images should be classified as 0
    return (real_loss + fake_loss) / 2  # Average loss

# Define the Discriminators for domains X and Y
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize discriminators
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)

# Define optimizers
optimizer_D_X = torch.optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = torch.optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Define learning rate schedulers
lr_scheduler_D_X = torch.optim.lr_scheduler.StepLR(optimizer_D_X, step_size=10, gamma=0.5)
lr_scheduler_D_Y = torch.optim.lr_scheduler.StepLR(optimizer_D_Y, step_size=10, gamma=0.5)



# Training loop
# Training loop
for epoch in range(n_epoches):
    for i, (real_X, real_Y) in enumerate(train_loader):
        # Move tensors to the same device as the model
        real_X, real_Y = real_X.to(device), real_Y.to(device)

        optimizer_G.zero_grad()

        fake_Y = G_XY(real_X)
        fake_X = G_YX(real_Y)

        loss_G = compute_loss(fake_Y, fake_X, real_X, real_Y)
        loss_G.backward()
        optimizer_G.step()

        optimizer_D_X.zero_grad()
        optimizer_D_Y.zero_grad()

        loss_D_X = compute_discriminator_loss(D_X, real_X, fake_X)
        loss_D_Y = compute_discriminator_loss(D_Y, real_Y, fake_Y)

        loss_D_X.backward()
        optimizer_D_X.step()

        loss_D_Y.backward()
        optimizer_D_Y.step()

    lr_scheduler_G.step()
    lr_scheduler_D_X.step()
    lr_scheduler_D_Y.step()

    print(f'Epoch {epoch + 1}/{n_epoches}, Generator Loss: {loss_G.item()}, Discriminator Loss: {(loss_D_X + loss_D_Y).item()}')

    # Test step - move test tensors to device before passing to model
    test_real_X, test_real_Y = next(iter(test_loader))
    test_real_X, test_real_Y = test_real_X.to(device), test_real_Y.to(device)
    sample_images(test_real_X, test_real_Y)
