In [11]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import LambdaLR
import matplotlib.pyplot as plt

## Dataset Setup

In [2]:
root_dir = "/ix1/qgu/ngl18/VSCCset1VirtualStaining/"
source_files = sorted(list(glob.iglob(f"{root_dir}/**/he/**.npy")))
target_files = sorted(list(glob.iglob(f"{root_dir}/**/ihc/**.npy")))

In [3]:
from data import VirtualStainingDataset
joint_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Affine(scale=(0.98, 1.02), translate_percent=0.05, rotate=(-5, 5), p=0.7),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=5, p=0.2),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.7),
    A.RandomGamma(p=0.3),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
], additional_targets={'target': 'image'}) 

virtual_staining_dataset = VirtualStainingDataset(source_files, target_files, transform=joint_transform)
virtual_staining_dataloader = DataLoader(virtual_staining_dataset, batch_size=64, shuffle=True)

### Model Setup

In [4]:
from models import Generator, PatchDiscriminator

# Setup for device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = Generator(3).to(device)
D = PatchDiscriminator(3).to(device)

In [5]:
from perceptual import VGGPerceptualLoss
from pytorch_msssim import ssim
import torch.nn.functional as F

# Model Training Parameters

# Loss function
adversarial_loss = nn.MSELoss()  # Adversarial loss (mean squared error)
pixelwise_loss = nn.L1Loss()  # Pixel-wise loss (L1 loss)
percep_loss_fn = VGGPerceptualLoss().to(device)

# Learning Rates
lr_G = 0.0002 
lr_D = 0.0001

# Optimizers
optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

total_steps = 150000
lambda_stop = int(0.3 * total_steps)
start_decay = 75000


# === Scheduler: Linear decay after 25k steps to 50k ===
def linear_decay_lambda(step, start=start_decay, end=total_steps):
    if step < start:
        return 1.0
    else:
        return max(0.0, 1.0 - (step - start) / (end - start))

# === Training Params ===
log_interval = 5000
curr_step = 0

scheduler_G = LambdaLR(optimizer_G, lr_lambda=linear_decay_lambda)



In [8]:
def get_lambda_ssim(curr_step, start=1, end=5.0, total_steps=50000):
    progress = min(curr_step / total_steps, 1.0)
    return start + progress * (end - start)

def get_lambda_percep(curr_step, start=1, end=3.0, total_steps=50000):
    progress = min(curr_step / total_steps, 1.0)
    return start + progress * (end - start)

def decay_lambda_l1(curr_step, start=10, end=2, total_steps=50000):
    progress = min(curr_step / total_steps, 1.0)
    return start - progress * (start - end)

### Model Training

In [None]:
from itertools import cycle

cyclic_loader = cycle(virtual_staining_dataloader)

for real_images, target_images in cyclic_loader:
    if curr_step >= total_steps:
        break

    # === Move to device ===
    real_images = real_images.to(device)
    target_images = target_images.to(device)

    # === Discriminator ===
    optimizer_D.zero_grad()
    fake_images = G(real_images)

    real_preds = D(target_images)
    fake_preds = D(fake_images.detach())
    real_loss = adversarial_loss(real_preds, torch.ones_like(real_preds))
    fake_loss = adversarial_loss(fake_preds, torch.zeros_like(fake_preds))
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    optimizer_D.step()

    # === Generator ===
    optimizer_G.zero_grad()
    fake_preds = D(fake_images)
    g_adv = adversarial_loss(fake_preds, torch.ones_like(fake_preds))
    g_l1 = pixelwise_loss(fake_images, target_images)
    g_ssim = 1 - ssim(fake_images, target_images, data_range=2.0, size_average=True)
    g_percep = percep_loss_fn(fake_images, target_images)
    g_loss = g_adv + decay_lambda_l1(curr_step, total_steps=lambda_stop) * g_l1 + get_lambda_ssim(curr_step, total_steps=lambda_stop) * g_ssim + get_lambda_percep(curr_step, total_steps=lambda_stop) * g_percep
    g_loss.backward()
    optimizer_G.step()
    scheduler_G.step()

    # === Logging & Saving ===
    if curr_step % log_interval == 0 or (curr_step + 1 >= total_steps):
        lr = scheduler_G.get_last_lr()[0]
        print(f"[Step {curr_step}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}, LR: {lr:.6f}")

        # === Visualization ===
        D.eval()
        with torch.no_grad():
            he_sample = real_images[:4]
            ihc_sample = target_images[:4]
            ihc_pred = G(he_sample)
            def to_img(x): return (x * 0.5 + 0.5).clamp(0, 1)

            fig, axs = plt.subplots(3, 4, figsize=(12, 9))
            for j in range(4):
                axs[0, j].imshow(to_img(he_sample[j].permute(1, 2, 0)).cpu())
                axs[0, j].set_title("H&E")
                axs[1, j].imshow(to_img(ihc_sample[j].permute(1, 2, 0)).cpu())
                axs[1, j].set_title(f"Real IHC")
                axs[2, j].imshow(to_img(ihc_pred[j].permute(1, 2, 0)).cpu())
                axs[2, j].set_title(f"Virtual IHC")
                for row in axs: row[j].axis('off')
            plt.tight_layout()
            plt.savefig(f"/ix1/qgu/ngl18/VirtualStaining/outputs/vscc/step_{curr_step}.png")
            plt.close()

        # === Save Checkpoint ===
        torch.save({
            "step": curr_step,
            "generator": G.state_dict(),
            "discriminator": D.state_dict(),
            "optimizer_G": optimizer_G.state_dict(),
            "optimizer_D": optimizer_D.state_dict(),
            "scheduler_G": scheduler_G.state_dict(),
        }, f"/ix1/qgu/ngl18/VirtualStaining/ckpts/vscc/step_{curr_step}.pth")

    curr_step += 1



[Step 0] D Loss: 0.4288, G Loss: 9.0413, LR: 0.000200
