In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

# Configuration based on paper recommendations
SEED = 60
IMG_SIZE = 256
BATCH_SIZE = 8
LATENT_DIM = 256

# Modified Configuration
LAMBDA_GP = 5.0  # Reduced from 10
LAMBDA_PIXEL = 3.0  # Reduced from 5.0
LAMBDA_PEARSON = 5.0  # Increased from 2.0
LR = 0.0001  # Reduced from 0.0002
N_CRITIC = 3  # Reduced from 5
BASE_CHANNELS = 128  # Increased capacity
BETAS = (0.5, 0.9)  # More conservative momentum
EPOCHS = 1000

# Set seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Enhanced Dataset Class
class MedicalDataset(Dataset):
    def __init__(self, root_dir, img_size=256):
        self.image_paths = []
        self.labels = []
        self.img_size = img_size
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.Lambda(lambda x: 2 * x - 1)  # Normalize to [-1, 1]
        ])

        # Load data with binary labels (0/1)
        for label in ['0', '1']:
            folder = os.path.join(root_dir, label)
            if os.path.exists(folder):
                for fname in os.listdir(folder):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(folder, fname))
                        self.labels.append(int(label))

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img.astype(np.float32) / 255.0  # Normalize to [0, 1]
        img = torch.tensor(img).unsqueeze(0)  # Add channel dimension
        img = self.transform(img)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

    def __len__(self):
        return len(self.image_paths)

# Residual Block for Generator (StyleGAN-inspired)
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            spectral_norm(nn.Conv2d(channels, channels, 3, padding=1)),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(channels, channels, 3, padding=1)),
            nn.BatchNorm2d(channels)
        )
        
    def forward(self, x):
        return F.leaky_relu(x + self.conv(x), 0.2)

# Enhanced Generator with Residual Connections
class MedicalGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.init_size = IMG_SIZE // 32
        self.fc = nn.Sequential(
            spectral_norm(nn.Linear(LATENT_DIM + 1, BASE_CHANNELS * 32 * self.init_size ** 2)),
            nn.BatchNorm1d(BASE_CHANNELS * 32 * self.init_size ** 2),
            nn.LeakyReLU(0.2)
        )
        
        self.blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*32, BASE_CHANNELS*16, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*16),
            nn.LeakyReLU(0.2),
            ResidualBlock(BASE_CHANNELS*16),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*16, BASE_CHANNELS*8, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*8),
            nn.LeakyReLU(0.2),
            ResidualBlock(BASE_CHANNELS*8),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*8, BASE_CHANNELS*4, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*4),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*4, BASE_CHANNELS*2, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*2),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*2, BASE_CHANNELS, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS, 1, 3, padding=1)),
            nn.Tanh()
        )

    def forward(self, z, labels):
        labels = labels.view(-1, 1)
        z = torch.cat([z, labels], 1)
        x = self.fc(z).view(-1, BASE_CHANNELS*32, self.init_size, self.init_size)
        return self.blocks(x)

# Enhanced Discriminator with Spectral Norm
class MedicalDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(2, BASE_CHANNELS, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS, BASE_CHANNELS*2, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*2),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*2, BASE_CHANNELS*4, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*4),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*4, BASE_CHANNELS*8, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*8),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*8, BASE_CHANNELS*16, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*16),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*16, 1, 4, 1, 0))
        )

    def forward(self, img, labels):
        labels = labels.view(-1, 1, 1, 1).expand(-1, 1, IMG_SIZE, IMG_SIZE)
        x = torch.cat([img, labels], 1)
        return self.model(x).view(-1)

# Pearson Correlation Loss Function
def pearson_correlation_loss(x, y):
    """
    Computes Pearson correlation coefficient loss between two batches of images
    """
    # Flatten spatial dimensions
    x_flat = x.view(x.size(0), -1)
    y_flat = y.view(y.size(0), -1)
    
    # Compute means
    x_mean = torch.mean(x_flat, dim=1, keepdim=True)
    y_mean = torch.mean(y_flat, dim=1, keepdim=True)
    
    # Compute covariance and variances
    covariance = torch.mean((x_flat - x_mean) * (y_flat - y_mean), dim=1)
    x_var = torch.var(x_flat, dim=1)
    y_var = torch.var(y_flat, dim=1)
    
    # Compute Pearson correlation
    pearson = covariance / (torch.sqrt(x_var) * torch.sqrt(y_var) + 1e-8)
    
    # Return loss as 1 - average correlation
    return 1 - torch.mean(pearson)

# Complete GAN Training System
class StableGAN:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize networks
        self.netG = MedicalGenerator().to(self.device)
        self.netD = MedicalDiscriminator().to(self.device)
        
        # Optimizers
        self.optimG = optim.Adam(self.netG.parameters(), lr=LR, betas=BETAS)
        self.optimD = optim.Adam(self.netD.parameters(), lr=LR, betas=BETAS)
        
        # Data loading
        self.dataset = MedicalDataset("medical_data")
        self.loader = DataLoader(self.dataset, BATCH_SIZE, shuffle=True, num_workers=4)
        
        # Tracking
        self.metrics = {'d_loss': [], 'g_loss': [], 'pearson': []}
        self.fixed_z = torch.randn(2, LATENT_DIM, device=self.device)  # For consistent sampling

    def compute_gradient_penalty(self, real_samples, fake_samples, labels):
        """Calculates the gradient penalty loss for WGAN-GP"""
        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        d_interpolates = self.netD(interpolates, labels)
        
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True,
        )[0]
        
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def train_epoch(self, epoch):
        self.netG.train()
        self.netD.train()
        
        for batch_idx, (real_imgs, labels) in enumerate(self.loader):
            real_imgs = real_imgs.to(self.device)
            labels = labels.to(self.device)
            batch_size = real_imgs.size(0)
            
            # --- Discriminator Update ---
            self.optimD.zero_grad()
            
            # Generate fake images
            z = torch.randn(batch_size, LATENT_DIM, device=self.device)
            fake_imgs = self.netG(z, labels)
            
            # Compute discriminator losses
            d_real = self.netD(real_imgs, labels)
            d_fake = self.netD(fake_imgs.detach(), labels)
            
            # Gradient penalty
            gp = self.compute_gradient_penalty(real_imgs, fake_imgs, labels)
            
            # WGAN-GP loss
            d_loss = -torch.mean(d_real) + torch.mean(d_fake) + LAMBDA_GP * gp
            d_loss.backward()
            self.optimD.step()
            
            # --- Generator Update ---
            if batch_idx % N_CRITIC == 0:
                self.optimG.zero_grad()
                
                # Generate new fake images
                z = torch.randn(batch_size, LATENT_DIM, device=self.device)
                fake_imgs = self.netG(z, labels)
                
                # Adversarial loss
                g_adv = -torch.mean(self.netD(fake_imgs, labels))
                
                # Pixel-wise L1 loss
                px_loss = F.l1_loss(fake_imgs, real_imgs)
                
                # Pearson correlation loss
                pearson_loss = pearson_correlation_loss(fake_imgs, real_imgs)
                
                # Total generator loss
                g_total = g_adv + LAMBDA_PIXEL*px_loss + LAMBDA_PEARSON*pearson_loss
                g_total.backward()
                self.optimG.step()
                
                # Update metrics
                with torch.no_grad():
                    self.metrics['g_loss'].append(g_total.item())
                    self.metrics['pearson'].append(1 - pearson_loss.item())  # Store actual correlation
                    
            # Progress reporting
            if batch_idx % 25 == 0:
                print(f"Epoch [{epoch}/{EPOCHS}] Batch [{batch_idx}/{len(self.loader)}] "
                      f"D_loss: {d_loss.item():.4f} G_loss: {g_total.item():.4f} "
                      f"Pearson: {1 - pearson_loss.item():.4f}")

    def save_samples(self, epoch):
        """Saves generated samples for both labels"""
        self.netG.eval()
        os.makedirs("samples", exist_ok=True)
        
        with torch.no_grad():
            for label in [0, 1]:
                # Generate images for current label
                labels = torch.full((2,), label, device=self.device, dtype=torch.float32)
                gen_imgs = self.netG(self.fixed_z, labels)
                
                # Process and save images
                for i in range(2):
                    img = gen_imgs[i].cpu().squeeze().numpy()
                    img = (img + 1) * 127.5  # Convert from [-1,1] to [0,255]
                    img = img.astype(np.uint8)
                    cv2.imwrite(f"samples/epoch_{epoch}_label_{label}_sample_{i}.png", img)
        
        self.netG.train()

    def train(self):
        """Main training loop"""
        for epoch in range(1, EPOCHS + 1):
            self.train_epoch(epoch)
            
            # Save checkpoints
            
            # Save samples
            if epoch % 200 == 0:
                self.save_samples(epoch)
                print(f"Saved samples for epoch {epoch}")
                
        # Final save
        torch.save({
            'generator': self.netG.state_dict(),
            'discriminator': self.netD.state_dict(),
            'metrics': self.metrics
        }, "gan_final.pth")

if __name__ == "__main__":
    gan = StableGAN()
    gan.train()

AttributeError: Can't pickle local object 'MedicalDataset.__init__.<locals>.<lambda>'

In [1]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from time import time
from datetime import datetime
from sklearn.metrics import confusion_matrix
import json

# Configuration based on paper recommendations
SEED = 60
IMG_SIZE = 256
BATCH_SIZE = 8
LATENT_DIM = 256

# Modified Configuration
LAMBDA_GP = 5.0  # Reduced from 10
LAMBDA_PIXEL = 3.0  # Reduced from 5.0
LAMBDA_PEARSON = 5.0  # Increased from 2.0
LR = 0.0001  # Reduced from 0.0002
N_CRITIC = 3  # Reduced from 5
BASE_CHANNELS = 128  # Increased capacity
BETAS = (0.5, 0.9)  # More conservative momentum
EPOCHS = 1000

# Set seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Define normalize function outside of the transform to make it picklable
def normalize_image(x):
    return 2 * x - 1  # Normalize to [-1, 1]

# Enhanced Dataset Class
class MedicalDataset(Dataset):
    def __init__(self, root_dir, img_size=256):
        self.image_paths = []
        self.labels = []
        self.img_size = img_size
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.3),
            transforms.Lambda(normalize_image)  # Using named function instead of lambda
        ])

        # Load data with binary labels (0/1)
        for label in ['0', '1']:
            folder = os.path.join(root_dir, label)
            if os.path.exists(folder):
                for fname in os.listdir(folder):
                    if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(folder, fname))
                        self.labels.append(int(label))

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img.astype(np.float32) / 255.0  # Normalize to [0, 1]
        img = torch.tensor(img).unsqueeze(0)  # Add channel dimension
        img = self.transform(img)
        return img, torch.tensor(self.labels[idx], dtype=torch.float32)

    def __len__(self):
        return len(self.image_paths)

# Residual Block for Generator (StyleGAN-inspired)
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            spectral_norm(nn.Conv2d(channels, channels, 3, padding=1)),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(channels, channels, 3, padding=1)),
            nn.BatchNorm2d(channels)
        )
        
    def forward(self, x):
        return F.leaky_relu(x + self.conv(x), 0.2)

# Enhanced Generator with Residual Connections
class MedicalGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.init_size = IMG_SIZE // 32
        self.fc = nn.Sequential(
            spectral_norm(nn.Linear(LATENT_DIM + 1, BASE_CHANNELS * 32 * self.init_size ** 2)),
            nn.BatchNorm1d(BASE_CHANNELS * 32 * self.init_size ** 2),
            nn.LeakyReLU(0.2)
        )
        
        self.blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*32, BASE_CHANNELS*16, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*16),
            nn.LeakyReLU(0.2),
            ResidualBlock(BASE_CHANNELS*16),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*16, BASE_CHANNELS*8, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*8),
            nn.LeakyReLU(0.2),
            ResidualBlock(BASE_CHANNELS*8),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*8, BASE_CHANNELS*4, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*4),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*4, BASE_CHANNELS*2, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS*2),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*2, BASE_CHANNELS, 3, padding=1)),
            nn.BatchNorm2d(BASE_CHANNELS),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS, 1, 3, padding=1)),
            nn.Tanh()
        )

    def forward(self, z, labels):
        labels = labels.view(-1, 1)
        z = torch.cat([z, labels], 1)
        x = self.fc(z).view(-1, BASE_CHANNELS*32, self.init_size, self.init_size)
        return self.blocks(x)

# Enhanced Discriminator with Spectral Norm
class MedicalDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(2, BASE_CHANNELS, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS, BASE_CHANNELS*2, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*2),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*2, BASE_CHANNELS*4, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*4),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*4, BASE_CHANNELS*8, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*8),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*8, BASE_CHANNELS*16, 4, 2, 1)),
            nn.InstanceNorm2d(BASE_CHANNELS*16),
            nn.LeakyReLU(0.2),
            spectral_norm(nn.Conv2d(BASE_CHANNELS*16, 1, 4, 1, 0))
        )

    def forward(self, img, labels):
        labels = labels.view(-1, 1, 1, 1).expand(-1, 1, IMG_SIZE, IMG_SIZE)
        x = torch.cat([img, labels], 1)
        return self.model(x).view(-1)

# Classifier for evaluating generated images
class MedicalClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(256, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

# Structural Similarity Index (SSIM) for image quality assessment
def ssim(img1, img2, window_size=11, size_average=True):
    # Value range can be different from 0 to 255. So normalize images.
    if img1.max() > 128:
        img1 = img1 / 255.0
    if img2.max() > 128:
        img2 = img2 / 255.0
    
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    
    # Compute means
    mu1 = F.avg_pool2d(img1, window_size, stride=1, padding=window_size//2)
    mu2 = F.avg_pool2d(img2, window_size, stride=1, padding=window_size//2)
    
    # Compute squares of means
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    
    # Compute variances and covariance
    sigma1_sq = F.avg_pool2d(img1 * img1, window_size, stride=1, padding=window_size//2) - mu1_sq
    sigma2_sq = F.avg_pool2d(img2 * img2, window_size, stride=1, padding=window_size//2) - mu2_sq
    sigma12 = F.avg_pool2d(img1 * img2, window_size, stride=1, padding=window_size//2) - mu1_mu2
    
    # SSIM constants
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    
    # Compute SSIM
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

# Pearson Correlation Loss Function
def pearson_correlation_loss(x, y):
    """
    Computes Pearson correlation coefficient loss between two batches of images
    """
    # Flatten spatial dimensions
    x_flat = x.view(x.size(0), -1)
    y_flat = y.view(y.size(0), -1)
    
    # Compute means
    x_mean = torch.mean(x_flat, dim=1, keepdim=True)
    y_mean = torch.mean(y_flat, dim=1, keepdim=True)
    
    # Compute covariance and variances
    covariance = torch.mean((x_flat - x_mean) * (y_flat - y_mean), dim=1)
    x_var = torch.var(x_flat, dim=1)
    y_var = torch.var(y_flat, dim=1)
    
    # Compute Pearson correlation
    pearson = covariance / (torch.sqrt(x_var) * torch.sqrt(y_var) + 1e-8)
    
    # Return loss as 1 - average correlation
    return 1 - torch.mean(pearson)

# Compute Frechet Inception Distance (FID) approximation
def compute_fid_proxy(real_features, fake_features):
    """
    Simplified FID calculation since we don't have Inception network.
    This uses feature embeddings from our discriminator/classifier instead.
    """
    # Calculate mean and covariance for real and fake features
    mu1 = torch.mean(real_features, dim=0)
    sigma1 = torch.cov(real_features.T)
    
    mu2 = torch.mean(fake_features, dim=0)
    sigma2 = torch.cov(fake_features.T)
    
    # Calculate squared difference between means
    diff = mu1 - mu2
    mean_diff = torch.sum(diff * diff)
    
    # Calculate matrix sqrt - approximation for FID
    # This is a simplified version - true FID uses matrix sqrt
    covmean = torch.sqrt(sigma1 * sigma2 + 1e-8)
    
    # Calculate trace
    trace_term = torch.trace(sigma1 + sigma2 - 2 * covmean)
    
    # FID
    fid = mean_diff + trace_term
    return fid

# Complete GAN Training System
class StableGAN:
    def __init__(self, data_folder="medical_data"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Initialize networks
        self.netG = MedicalGenerator().to(self.device)
        self.netD = MedicalDiscriminator().to(self.device)
        
        # Add classifier for quality evaluation
        self.classifier = MedicalClassifier().to(self.device)
        self.classifier_optim = optim.Adam(self.classifier.parameters(), lr=LR*0.5)
        self.classifier_criterion = nn.BCELoss()
        
        # Optimizers
        self.optimG = optim.Adam(self.netG.parameters(), lr=LR, betas=BETAS)
        self.optimD = optim.Adam(self.netD.parameters(), lr=LR, betas=BETAS)
        
        # Learning rate schedulers
        self.schedulerG = optim.lr_scheduler.ReduceLROnPlateau(self.optimG, 'min', factor=0.5, patience=30, verbose=True)
        self.schedulerD = optim.lr_scheduler.ReduceLROnPlateau(self.optimD, 'min', factor=0.5, patience=30, verbose=True)
        
        # Data loading - use num_workers=0 to avoid multiprocessing issues with lambda
        self.dataset = MedicalDataset(data_folder, img_size=IMG_SIZE)
        self.loader = DataLoader(self.dataset, BATCH_SIZE, shuffle=True, num_workers=0)
        
        # Tracking
        self.metrics = {
            'd_loss': [], 'g_loss': [], 'pearson': [], 'ssim': [], 
            'fid_proxy': [], 'classifier_acc': [], 'train_time': [],
            'inception_score_proxy': [], 'psnr': []
        }
        
        # Create fixed noise for consistent sampling
        self.fixed_z = torch.randn(8, LATENT_DIM, device=self.device)  # Increased samples
        self.fixed_labels_0 = torch.zeros(4, device=self.device)
        self.fixed_labels_1 = torch.ones(4, device=self.device)
        self.fixed_labels = torch.cat([self.fixed_labels_0, self.fixed_labels_1])
        
        # Create run directory
        self.run_dir = f"gan_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        os.makedirs(f"{self.run_dir}/samples", exist_ok=True)
        os.makedirs(f"{self.run_dir}/checkpoints", exist_ok=True)

    def compute_gradient_penalty(self, real_samples, fake_samples, labels):
        """Calculates the gradient penalty loss for WGAN-GP"""
        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=self.device)
        interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
        d_interpolates = self.netD(interpolates, labels)
        
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True,
        )[0]
        
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def train_classifier(self, real_imgs, fake_imgs, labels, steps=1):
        """Train the classifier on both real and fake images"""
        for _ in range(steps):
            self.classifier_optim.zero_grad()
            
            # Create a batch with both real and fake images
            all_imgs = torch.cat([real_imgs, fake_imgs.detach()], 0)
            all_labels = torch.cat([labels, labels], 0)
            
            # Forward pass
            pred = self.classifier(all_imgs).squeeze()
            loss = self.classifier_criterion(pred, all_labels)
            
            # Backward pass
            loss.backward()
            self.classifier_optim.step()
            
            # Calculate accuracy
            pred_binary = (pred > 0.5).float()
            acc = (pred_binary == all_labels).float().mean()
            
        return acc.item()

    def compute_psnr(self, real_imgs, fake_imgs):
        """Compute Peak Signal-to-Noise Ratio"""
        mse = F.mse_loss(fake_imgs, real_imgs)
        if mse == 0:
            return 100
        max_pixel = 2.0  # Since images are normalized to [-1, 1]
        psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
        return psnr.item()

    def compute_inception_score_proxy(self, fake_imgs, eps=1e-8):
        """Compute a proxy for inception score using our classifier"""
        self.classifier.eval()
        with torch.no_grad():
            preds = self.classifier(fake_imgs).squeeze()
            # Calculate proxy for inception score
            kl_div = preds * torch.log(preds + eps) + (1 - preds) * torch.log(1 - preds + eps)
            inception_score = torch.exp(torch.mean(kl_div))
        self.classifier.train()
        return inception_score.item()

    def extract_features(self, imgs):
        """Extract features from the classifier for FID calculation"""
        # Use the second-to-last layer of the classifier as features
        x = self.classifier.model[0](imgs)
        x = self.classifier.model[1](x)
        x = self.classifier.model[2](x)
        x = self.classifier.model[3](x)
        x = self.classifier.model[4](x)
        x = self.classifier.model[5](x)
        x = self.classifier.model[6](x)
        x = self.classifier.model[7](x)
        x = self.classifier.model[8](x)
        x = self.classifier.model[9](x)
        x = self.classifier.model[10](x)
        x = self.classifier.model[11](x)
        x = self.classifier.model[12](x)
        x = self.classifier.model[13](x)
        # Return flattened features
        return x.view(x.size(0), -1)

    def train_epoch(self, epoch):
        self.netG.train()
        self.netD.train()
        self.classifier.train()
        
        epoch_metrics = {
            'd_loss': [], 'g_loss': [], 'pearson': [], 'ssim': [], 
            'psnr': [], 'classifier_acc': []
        }
        
        epoch_start_time = time()
        
        for batch_idx, (real_imgs, labels) in enumerate(self.loader):
            real_imgs = real_imgs.to(self.device)
            labels = labels.to(self.device)
            batch_size = real_imgs.size(0)
            
            # --- Discriminator Update ---
            self.optimD.zero_grad()
            
            # Generate fake images
            z = torch.randn(batch_size, LATENT_DIM, device=self.device)
            fake_imgs = self.netG(z, labels)
            
            # Compute discriminator losses
            d_real = self.netD(real_imgs, labels)
            d_fake = self.netD(fake_imgs.detach(), labels)
            
            # Gradient penalty
            gp = self.compute_gradient_penalty(real_imgs, fake_imgs, labels)
            
            # WGAN-GP loss
            d_loss = -torch.mean(d_real) + torch.mean(d_fake) + LAMBDA_GP * gp
            d_loss.backward()
            self.optimD.step()
            
            epoch_metrics['d_loss'].append(d_loss.item())
            
            # --- Generator Update ---
            if batch_idx % N_CRITIC == 0:
                self.optimG.zero_grad()
                
                # Generate new fake images
                z = torch.randn(batch_size, LATENT_DIM, device=self.device)
                fake_imgs = self.netG(z, labels)
                
                # Adversarial loss
                g_adv = -torch.mean(self.netD(fake_imgs, labels))
                
                # Pixel-wise L1 loss
                px_loss = F.l1_loss(fake_imgs, real_imgs)
                
                # Pearson correlation loss
                pearson_loss = pearson_correlation_loss(fake_imgs, real_imgs)
                
                # Total generator loss
                g_total = g_adv + LAMBDA_PIXEL*px_loss + LAMBDA_PEARSON*pearson_loss
                g_total.backward()
                self.optimG.step()
                
                # Train classifier and compute additional metrics
                with torch.no_grad():
                    # SSIM metric
                    ssim_val = ssim(fake_imgs, real_imgs).item()
                    
                    # PSNR metric
                    psnr_val = self.compute_psnr(real_imgs, fake_imgs)
                    
                    # Store metrics
                    epoch_metrics['g_loss'].append(g_total.item())
                    epoch_metrics['pearson'].append(1 - pearson_loss.item())
                    epoch_metrics['ssim'].append(ssim_val)
                    epoch_metrics['psnr'].append(psnr_val)
                
                # Train classifier on this batch
                classifier_acc = self.train_classifier(real_imgs, fake_imgs, labels)
                epoch_metrics['classifier_acc'].append(classifier_acc)
                
            # Progress reporting (less frequent)
            if batch_idx % 10 == 0:
    # Handle the conditional formatting properly
               if batch_idx % N_CRITIC == 0:
                g_loss_str = f"{g_total.item():.4f}"
               else:
                g_loss_str = "N/A"
         
               print(f"Epoch [{epoch}/{EPOCHS}] Batch [{batch_idx}/{len(self.loader)}] "
            f"D_loss: {d_loss.item():.4f} G_loss: {g_loss_str}")
        
        # Compute epoch-level metrics
        epoch_time = time() - epoch_start_time
        
        # Compute FID proxy and Inception Score proxy once per epoch on fixed samples
        with torch.no_grad():
            fake_samples = self.netG(self.fixed_z, self.fixed_labels)
            real_samples, real_labels = next(iter(self.loader))
            real_samples = real_samples.to(self.device)
            
            # Extract features for FID
            fake_features = self.extract_features(fake_samples)
            real_features = self.extract_features(real_samples)
            
            # Compute FID proxy
            fid_proxy = compute_fid_proxy(real_features, fake_features).item()
            
            # Compute Inception Score proxy
            is_proxy = self.compute_inception_score_proxy(fake_samples)
        
        # Update overall metrics
        self.metrics['d_loss'].append(np.mean(epoch_metrics['d_loss']))
        self.metrics['g_loss'].append(np.mean(epoch_metrics['g_loss']))
        self.metrics['pearson'].append(np.mean(epoch_metrics['pearson']))
        self.metrics['ssim'].append(np.mean(epoch_metrics['ssim']))
        self.metrics['classifier_acc'].append(np.mean(epoch_metrics['classifier_acc']))
        self.metrics['psnr'].append(np.mean(epoch_metrics['psnr']))
        self.metrics['fid_proxy'].append(fid_proxy)
        self.metrics['inception_score_proxy'].append(is_proxy)
        self.metrics['train_time'].append(epoch_time)
        
        # Update learning rate schedulers
        self.schedulerG.step(self.metrics['g_loss'][-1])
        self.schedulerD.step(self.metrics['d_loss'][-1])
        
        # Print epoch summary
        print(f"\nEpoch {epoch} completed in {epoch_time:.2f}s | "
              f"FID: {fid_proxy:.4f} | IS: {is_proxy:.4f} | "
              f"SSIM: {self.metrics['ssim'][-1]:.4f} | "
              f"PSNR: {self.metrics['psnr'][-1]:.4f} | "
              f"Acc: {self.metrics['classifier_acc'][-1]:.4f}\n")
        
        # Save metrics after each epoch
        with open(f"{self.run_dir}/metrics.json", 'w') as f:
            json.dump(self.metrics, f)
        
        # Plot training curves
        if epoch % 10 == 0:
            self.plot_metrics(epoch)

    def plot_metrics(self, epoch):
        """Plot and save training metrics"""
        fig, axs = plt.subplots(3, 2, figsize=(12, 15))
        
        # Plot losses
        axs[0, 0].plot(self.metrics['d_loss'], label='D Loss')
        axs[0, 0].plot(self.metrics['g_loss'], label='G Loss')
        axs[0, 0].set_title('Losses')
        axs[0, 0].legend()
        
        # Plot image quality metrics
        axs[0, 1].plot(self.metrics['pearson'], label='Pearson')
        axs[0, 1].plot(self.metrics['ssim'], label='SSIM')
        axs[0, 1].set_title('Image Quality')
        axs[0, 1].legend()
        
        # Plot FID and Inception Score
        axs[1, 0].plot(self.metrics['fid_proxy'], label='FID Proxy')
        axs[1, 0].set_title('FID Proxy (lower is better)')
        axs[1, 1].plot(self.metrics['inception_score_proxy'], label='IS Proxy')
        axs[1, 1].set_title('Inception Score Proxy')
        
        # Plot PSNR and classifier accuracy
        axs[2, 0].plot(self.metrics['psnr'], label='PSNR')
        axs[2, 0].set_title('PSNR (higher is better)')
        axs[2, 1].plot(self.metrics['classifier_acc'], label='Classifier Acc')
        axs[2, 1].set_title('Classifier Accuracy')
        
        plt.tight_layout()
        plt.savefig(f"{self.run_dir}/metrics_epoch_{epoch}.png")
        plt.close()

    def save_samples(self, epoch):
        """Saves generated samples for both labels"""
        self.netG.eval()
        
        with torch.no_grad():
            # Generate images using fixed noise
            gen_imgs = self.netG(self.fixed_z, self.fixed_labels)
            
            # Create a grid of images
            fig, axs = plt.subplots(2, 4, figsize=(16, 8))
            for i in range(8):
                row = i // 4
                col = i % 4
                img = gen_imgs[i].cpu().squeeze().numpy()
                img = (img + 1) * 0.5  # Convert from [-1,1] to [0,1]
                axs[row, col].imshow(img, cmap='gray')
                axs[row, col].set_title(f"Label: {int(self.fixed_labels[i].item())}")
                axs[row, col].axis('off')
            
            plt.tight_layout()
            plt.savefig(f"{self.run_dir}/samples/epoch_{epoch}_grid.png")
            plt.close()
            
            # Also save individual images
            for i, (img, label) in enumerate(zip(gen_imgs, self.fixed_labels)):
                img = img.cpu().squeeze().numpy()
                img = (img + 1) * 127.5  # Convert from [-1,1] to [0,255]
                img = img.astype(np.uint8)
                cv2.imwrite(f"{self.run_dir}/samples/epoch_{epoch}_label_{int(label.item())}_sample_{i}.png", img)
        
        self.netG.train()

    def save_checkpoint(self, epoch):
        """Save model checkpoint"""
        torch.save({
            'epoch': epoch,
            'generator': self.netG.state_dict(),
            'discriminator': self.netD.state_dict(),
            'classifier': self.classifier.state_dict(),
            'optimG': self.optimG.state_dict(),
            'optimD': self.optimD.state_dict(),
            'schedulerG': self.schedulerG.state_dict(),
            'schedulerD': self.schedulerD.state_dict(),
            'metrics': self.metrics
        }, f"{self.run_dir}/checkpoints/checkpoint_epoch_{epoch}.pth")
    def train(self):
    
     for epoch in range(1, EPOCHS + 1):
        self.train_epoch(epoch)
        
        # Save checkpoints
        
        
        # Save samples
        if epoch % 200 == 0:
            self.save_samples(epoch)
            print(f"Saved samples for epoch {epoch}")
        
    # Final save
     torch.save({
        'generator': self.netG.state_dict(),
        'discriminator': self.netD.state_dict(),
        'metrics': self.metrics
     }, "gan_final_1.pth")

if __name__ == "__main__":
    gan = StableGAN()
    gan.train()


Using device: cuda




Epoch [1/1000] Batch [0/25] D_loss: 1036.8507 G_loss: 6.6730
Epoch [1/1000] Batch [10/25] D_loss: 0.6456 G_loss: N/A
Epoch [1/1000] Batch [20/25] D_loss: 0.2230 G_loss: N/A

Epoch 1 completed in 35.48s | FID: 0.2469 | IS: 0.5011 | SSIM: 0.0328 | PSNR: 8.7906 | Acc: 0.5319

Epoch [2/1000] Batch [0/25] D_loss: -0.4374 G_loss: 4.7767
Epoch [2/1000] Batch [10/25] D_loss: 0.3301 G_loss: N/A
Epoch [2/1000] Batch [20/25] D_loss: -0.4230 G_loss: N/A

Epoch 2 completed in 38.42s | FID: 0.2470 | IS: 0.5012 | SSIM: 0.0518 | PSNR: 11.4593 | Acc: 0.6042

Epoch [3/1000] Batch [0/25] D_loss: -0.1164 G_loss: 5.0328
Epoch [3/1000] Batch [10/25] D_loss: 0.4800 G_loss: N/A
Epoch [3/1000] Batch [20/25] D_loss: -0.6459 G_loss: N/A

Epoch 3 completed in 39.87s | FID: 0.6796 | IS: 0.5033 | SSIM: 0.0523 | PSNR: 11.1743 | Acc: 0.6069

Epoch [4/1000] Batch [0/25] D_loss: -0.3743 G_loss: 5.0841
Epoch [4/1000] Batch [10/25] D_loss: -0.9457 G_loss: N/A
Epoch [4/1000] Batch [20/25] D_loss: -1.0102 G_loss: N/A

Epoc

In [3]:
gan.plot_metrics(epoch=1000)
