In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Pix

In [None]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_batchnorm=True, use_leaky=True):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1)]
        if use_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2) if use_leaky else nn.ReLU())
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, final=False):
        super().__init__()
        layers = [nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)]
        if not final:
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU())
            if dropout:
                layers.append(nn.Dropout(0.5))
        else:
            layers.append(nn.Tanh())  # Output range [-1, 1]
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class UNetGeneratorOptimized(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        # Encoder
        self.enc1 = ConvBlock(in_channels, features, use_batchnorm=False)     # 64
        self.enc2 = ConvBlock(features, features * 2)                          # 128
        self.enc3 = ConvBlock(features * 2, features * 4)                      # 256
        self.enc4 = ConvBlock(features * 4, features * 8)                      # 512
        self.enc5 = ConvBlock(features * 8, features * 8)                      # 512
        self.enc6 = ConvBlock(features * 8, features * 8)                      # 512
        self.enc7 = ConvBlock(features * 8, features * 8)                      # 512
        self.enc8 = ConvBlock(features * 8, features * 8, use_batchnorm=False) # bottleneck

        # Decoder
        self.dec1 = UpConvBlock(features * 8, features * 8, dropout=True)     # 512
        self.dec2 = UpConvBlock(features * 8 * 2, features * 8, dropout=True)
        self.dec3 = UpConvBlock(features * 8 * 2, features * 8, dropout=True)
        self.dec4 = UpConvBlock(features * 8 * 2, features * 8)
        self.dec5 = UpConvBlock(features * 8 * 2, features * 4)
        self.dec6 = UpConvBlock(features * 4 * 2, features * 2)
        self.dec7 = UpConvBlock(features * 2 * 2, features)
        self.dec8 = UpConvBlock(features * 2, out_channels, final=True)

    def forward(self, x):
        # Encoder path
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)
        e8 = self.enc8(e7)

        # Decoder path + skip connections
        d1 = self.dec1(e8)
        d2 = self.dec2(torch.cat([d1, e7], dim=1))
        d3 = self.dec3(torch.cat([d2, e6], dim=1))
        d4 = self.dec4(torch.cat([d3, e5], dim=1))
        d5 = self.dec5(torch.cat([d4, e4], dim=1))
        d6 = self.dec6(torch.cat([d5, e3], dim=1))
        d7 = self.dec7(torch.cat([d6, e2], dim=1))
        d8 = self.dec8(torch.cat([d7, e1], dim=1))

        return d8


PairedImageDataset with DataAugmentation


In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class PairedImageDataset(Dataset):
    def __init__(self, input_dir, output_dir, augment=False):
        self.input_paths = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(('jpg','png'))])
        self.output_paths = sorted([os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(('jpg','png'))])

        base_transforms = [
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)  # [0,1] → [-1,1]
        ]

        if augment:
            self.transform = transforms.Compose([
                transforms.Resize((1024, 1024)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(degrees=10),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
            ])
        else:
            self.transform = transforms.Compose(base_transforms)

    def __len__(self):
        return len(self.input_paths)

    def __getitem__(self, idx):
        x = Image.open(self.input_paths[idx]).convert("RGB")
        y = Image.open(self.output_paths[idx]).convert("RGB")
        return self.transform(x), self.transform(y)

class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, features=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(features, features*2, 4, 2, 1), nn.BatchNorm2d(features*2), nn.LeakyReLU(0.2),
            nn.Conv2d(features*2, features*4, 4, 2, 1), nn.BatchNorm2d(features*4), nn.LeakyReLU(0.2),
            nn.Conv2d(features*4, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        return self.model(torch.cat([x, y], dim=1))

In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
import torchvision.utils as vutils
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import csv
import torch.multiprocessing
torch.multiprocessing.set_start_method('spawn', force=True)
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import matplotlib.pyplot as plt

if __name__ == "__main__":
    # === Create output directories ===
    os.makedirs("/content/drive/MyDrive/GANPathology/output_fake", exist_ok=True)
    os.makedirs("/content/drive/MyDrive/GANPathology/checkpoints", exist_ok=True)

    # === Data paths ===
    input_train = "/content/drive/MyDrive/GANPathology/dataset/Input/training"
    output_train = "/content/drive/MyDrive/GANPathology/dataset/output/training"
    input_val = "/content/drive/MyDrive/GANPathology/dataset/Input/validation"
    output_val = "/content/drive/MyDrive/GANPathology/dataset/output/validation"

    # === Load datasets ===
    train_dataset = PairedImageDataset(input_train, output_train, augment=True)
    val_dataset = PairedImageDataset(input_val, output_val, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

    # === Model, loss, optimizers ===
    G = UNetGeneratorOptimized().cuda()
    D = PatchDiscriminator().cuda()

    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()

    optimizer_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

    # === Loss tracking ===
    #loss_log_path = "/rodata/dlmpfl/m300305/loss_log.csv"
    #with open(loss_log_path, mode='w', newline='') as f:
        #writer = csv.writer(f)
        #writer.writerow(["Epoch", "Train_G_L1", "Train_G_GAN", "Train_D", "Val_L1"])

    num_epochs = 500
    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        G.train()
        D.train()

        total_G_L1 = 0.0
        total_G_GAN = 0.0
        total_D = 0.0

        for i, (x, y) in enumerate(train_loader):
            x, y = x.cuda(), y.cuda()
            fake_y = G(x)

            # Train Discriminator
            D_real = D(x, y)
            D_fake = D(x, fake_y.detach())
            loss_D = 0.5 * (
                criterion_GAN(D_real, torch.ones_like(D_real)) +
                criterion_GAN(D_fake, torch.zeros_like(D_fake))
            )
            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            # Train Generator
            D_fake = D(x, fake_y)
            loss_G_GAN = criterion_GAN(D_fake, torch.ones_like(D_fake))
            loss_G_L1 = criterion_L1(fake_y, y)
            loss_G = loss_G_GAN + 100 * loss_G_L1

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            total_G_GAN += loss_G_GAN.item()
            total_G_L1 += loss_G_L1.item()
            total_D += loss_D.item()

            if i % 10 == 0:
                print(f"[Train] Epoch [{epoch}/{num_epochs}] Step [{i}/{len(train_loader)}] "
                    f"Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

        # Validation
        G.eval()
        total_val_L1 = 0.0
        with torch.no_grad():
            for j, (x_val, y_val) in enumerate(val_loader):
                x_val, y_val = x_val.cuda(), y_val.cuda()
                fake_val = G(x_val)
                val_L1 = criterion_L1(fake_val, y_val).item()
                total_val_L1 += val_L1

                if j == 0:
                    fake_img = (fake_val[0].cpu() + 1) / 2
                    input_img = (x_val[0].cpu() + 1) / 2
                    target_img = (y_val[0].cpu() + 1) / 2

                    comparison = torch.stack([input_img, fake_img, target_img])
                    grid_img = vutils.make_grid(comparison, nrow=3)

                    comparison_image = to_pil_image(grid_img)
                    plt.figure(figsize=(12, 4))
                    plt.imshow(comparison_image)
                    plt.title(f"Epoch {epoch}: Input | Generated | Target")
                    plt.axis('off')
                    plt.show()

        avg_val_L1 = total_val_L1 / len(val_loader)
        avg_G_L1 = total_G_L1 / len(train_loader)
        avg_G_GAN = total_G_GAN / len(train_loader)
        avg_D = total_D / len(train_loader)

        print(f"[Validation] Epoch [{epoch}/{num_epochs}] Avg L1 Loss: {avg_val_L1:.4f}")

        # Log to CSV
        #with open(loss_log_path, mode='a', newline='') as f:
            #writer = csv.writer(f)
            #writer.writerow([epoch, avg_G_L1, avg_G_GAN, avg_D, avg_val_L1])

        # Save per-epoch models
        torch.save(G.state_dict(), f"/content/drive/MyDrive/GANPathology/checkpoints/generator_epoch_{epoch}.pth")
        torch.save(D.state_dict(), f"/content/drive/MyDrive/GANPathology/checkpoints/discriminator_epoch_{epoch}.pth")

        # Save best models
        if avg_val_L1 < best_val_loss:
            best_val_loss = avg_val_L1
            torch.save(G.state_dict(), f"/content/drive/MyDrive/GANPathology/checkpoints/best_generator_epoch_{epoch}.pth")
            torch.save(D.state_dict(), f"/content/drive/MyDrive/GANPathology/checkpoints/best_discriminator_epoch_{epoch}.pth")
            print(f"✅ Saved best models at epoch {epoch} with val L1 loss: {avg_val_L1:.4f}")

    print("✅ Training complete.")