In [None]:
CGAN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
import os
import csv
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR

# Configuration
checkpoint_dir = '/home/jovyan/FluxGAN/plots/checkpoint'
loss_log_file = '/home/jovyan/FluxGAN/plots/loss_log.csv'
checkpoint_interval = 1000
num_epochs = 30001
batch_size = 512
noise_dim = 100

# Setup CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True  # Enable cudnn auto-tuner

# Setup directories
os.makedirs(checkpoint_dir, exist_ok=True)
if not os.path.exists(loss_log_file):
    with open(loss_log_file, 'w') as f:
        f.write('Epoch,D Loss,G Loss\n')

# Load and preprocess data
data = pd.read_csv('./flux_burnup_dataset.csv')
X = data[['Enrichment (%)', 'Flux', 'Burnup']].values
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# Save dataset stats for inference
data_min = scaler.data_min_
data_max = scaler.data_max_

# Models for CGAN
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim + 3, 256),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(128),
            nn.Linear(128, 3),
            nn.Tanh()  # Output activation function for continuous data
        )
        self.apply(self.init_weights)

    def forward(self, z, conditions):
        x = torch.cat([z, conditions], dim=1)
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(6, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.LayerNorm(128),
            nn.Linear(128, 1)
        )
        self.apply(self.init_weights)

    def forward(self, x, conditions):
        x = torch.cat([x, conditions], dim=1)
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Initialize models and move to GPU
generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers with adjusted learning rates
optimizer_G = optim.Adam(generator.parameters(), lr=0.002, betas=(0.5, 0.999))  # Increase learning rate for Generator
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999))  # Decrease learning rate for Discriminator

# Learning Rate Scheduler
scheduler_G = StepLR(optimizer_G, step_size=1000, gamma=0.5)
scheduler_D = StepLR(optimizer_D, step_size=1000, gamma=0.5)

# Loss function
adversarial_loss = nn.BCEWithLogitsLoss()

# Mixed Precision Scaler
scaler = GradScaler()

# Checkpoint functions
def save_checkpoint(epoch):
    path = os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar')
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'data_min': data_min,
        'data_max': data_max
    }, path)
    print(f"[Checkpoint] Saved at epoch {epoch}")

def load_checkpoint():
    files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.tar')]
    if not files:
        print("[Checkpoint] No checkpoint found. Starting fresh.")
        return 0
    
    latest = max(files, key=lambda f: int(f.split('_')[1].split('.')[0]))
    path = os.path.join(checkpoint_dir, latest)
    
    try:
        checkpoint = torch.load(path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        print(f"[Checkpoint] Loaded from epoch {checkpoint['epoch']}")
        return checkpoint['epoch'] + 1
    except Exception as e:
        print(f"[Checkpoint] Error loading: {str(e)}. Starting fresh.")
        return 0

# Load checkpoint if exists
start_epoch = load_checkpoint()

# Convert data to CUDA tensors
dataset = TensorDataset(torch.tensor(X_scaled, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for real_data in dataloader:
        real_data = real_data[0].to(device, non_blocking=True)
        current_batch_size = real_data.size(0)
        
        # Train Discriminator
        optimizer_D.zero_grad(set_to_none=True)
        
        # Real data with label smoothing
        real_labels = torch.full((current_batch_size, 1), 0.9, device=device)  # Smoothing real labels
        real_output = discriminator(real_data, real_data)
        d_loss_real = adversarial_loss(real_output, real_labels)
        
        # Fake data with label smoothing
        z = torch.randn(current_batch_size, noise_dim, device=device)
        fake_data = generator(z, real_data)
        fake_labels = torch.full((current_batch_size, 1), 0.1, device=device)  # Smoothing fake labels
        fake_output = discriminator(fake_data.detach(), real_data)
        d_loss_fake = adversarial_loss(fake_output, fake_labels)
        
        # Total discriminator loss
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator less frequently
        if epoch % 1 == 0:  # Update Generator every step
            optimizer_G.zero_grad(set_to_none=True)
            gen_labels = torch.ones(current_batch_size, 1, device=device)
            g_output = discriminator(fake_data.detach(), real_data)  # Detach fake data
            g_loss = adversarial_loss(g_output, gen_labels)
            g_loss.backward()
            optimizer_G.step()

    # Step the learning rate schedulers
    scheduler_G.step()
    scheduler_D.step()
    
    # Logging and checkpointing
    if epoch % 1 == 0:
        print(f"Epoch [{epoch}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    
    with open(loss_log_file, 'a') as f:
        f.write(f'{epoch},{d_loss.item()},{g_loss.item()}\n')
    
    if epoch % checkpoint_interval == 0 and epoch > 0:
        save_checkpoint(epoch)

# Final save
save_checkpoint(num_epochs - 1)


In [None]:
above is the best

In [None]:
tried to improve but no major

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
import os
import csv
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR

# Configuration
checkpoint_dir = '/home/jovyan/FluxGAN/plots/checkpoint'
loss_log_file = '/home/jovyan/FluxGAN/plots/loss_log.csv'
checkpoint_interval = 10
num_epochs = 1001
batch_size = 1024
noise_dim = 100

# Setup CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True  # Enable cudnn auto-tuner

# Setup directories
os.makedirs(checkpoint_dir, exist_ok=True)
if not os.path.exists(loss_log_file):
    with open(loss_log_file, 'w') as f:
        f.write('Epoch,D Loss,G Loss\n')

# Load and preprocess data
data = pd.read_csv('./flux_burnup_dataset.csv')
X = data[['Enrichment (%)', 'Flux', 'Burnup']].values
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# Save dataset stats for inference
data_min = scaler.data_min_
data_max = scaler.data_max_

# Models for CGAN
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim + 3, 256),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(128),
            nn.Linear(128, 3),
            nn.Tanh()  # Output activation function for continuous data
        )
        self.apply(self.init_weights)

    def forward(self, z, conditions):
        x = torch.cat([z, conditions], dim=1)
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(6, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.LayerNorm(128),
            nn.Linear(128, 1)
        )
        self.apply(self.init_weights)

    def forward(self, x, conditions):
        x = torch.cat([x, conditions], dim=1)
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Initialize models and move to GPU
generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers with adjusted learning rates
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Increase learning rate for Generator
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999))  # Decrease learning rate for Discriminator

# Learning Rate Scheduler
scheduler_G = StepLR(optimizer_G, step_size=1000, gamma=0.5)
scheduler_D = StepLR(optimizer_D, step_size=1000, gamma=0.5)

# Loss function
adversarial_loss = nn.BCEWithLogitsLoss()

# Mixed Precision Scaler
scaler = GradScaler()

# Checkpoint functions
def save_checkpoint(epoch):
    path = os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar')
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'data_min': data_min,
        'data_max': data_max
    }, path)
    print(f"[Checkpoint] Saved at epoch {epoch}")

def load_checkpoint():
    files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.tar')]
    if not files:
        print("[Checkpoint] No checkpoint found. Starting fresh.")
        return 0
    
    latest = max(files, key=lambda f: int(f.split('_')[1].split('.')[0]))
    path = os.path.join(checkpoint_dir, latest)
    
    try:
        checkpoint = torch.load(path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        print(f"[Checkpoint] Loaded from epoch {checkpoint['epoch']}")
        return checkpoint['epoch'] + 1
    except Exception as e:
        print(f"[Checkpoint] Error loading: {str(e)}. Starting fresh.")
        return 0

# Load checkpoint if exists
start_epoch = load_checkpoint()

# Convert data to CUDA tensors
dataset = TensorDataset(torch.tensor(X_scaled, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for real_data in dataloader:
        real_data = real_data[0].to(device, non_blocking=True)
        current_batch_size = real_data.size(0)
        
        # Train Discriminator
        optimizer_D.zero_grad(set_to_none=True)
        
        # Real data with label smoothing
        real_labels = torch.full((current_batch_size, 1), 0.9, device=device)  # Smoothing real labels
        real_output = discriminator(real_data, real_data)
        d_loss_real = adversarial_loss(real_output, real_labels)
        
        # Fake data with label smoothing
        z = torch.randn(current_batch_size, noise_dim, device=device)
        fake_data = generator(z, real_data)
        fake_labels = torch.full((current_batch_size, 1), 0.1, device=device)  # Smoothing fake labels
        fake_output = discriminator(fake_data.detach(), real_data)
        d_loss_fake = adversarial_loss(fake_output, fake_labels)
        
        # Total discriminator loss
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator less frequently (after every 2 discriminator updates)
        if epoch % 2 == 0:  # Update Generator every 2 discriminator steps
            optimizer_G.zero_grad(set_to_none=True)
            gen_labels = torch.ones(current_batch_size, 1, device=device)
            g_output = discriminator(fake_data.detach(), real_data)  # Detach fake data
            g_loss = adversarial_loss(g_output, gen_labels)
            g_loss.backward()
            optimizer_G.step()

    # Step the learning rate schedulers
    scheduler_G.step()
    scheduler_D.step()
    
    # Logging and checkpointing
    if epoch % 1 == 0:
        print(f"Epoch [{epoch}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    
    with open(loss_log_file, 'a') as f:
        f.write(f'{epoch},{d_loss.item()},{g_loss.item()}\n')
    
    if epoch % checkpoint_interval == 0 and epoch > 0:
        save_checkpoint(epoch)

# Final save
save_checkpoint(num_epochs - 1)


In [None]:
ON WORK

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
import os
import csv
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR
import torch.nn.utils.spectral_norm as spectral_norm

# Configuration
checkpoint_dir = '/home/jovyan/FluxGAN/plots/checkpoint'
loss_log_file = '/home/jovyan/FluxGAN/plots/loss_log.csv'
checkpoint_interval = 1000
num_epochs = 30001
batch_size = 512
noise_dim = 100

# Setup CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True  # Enable cudnn auto-tuner

# Setup directories
os.makedirs(checkpoint_dir, exist_ok=True)
if not os.path.exists(loss_log_file):
    with open(loss_log_file, 'w') as f:
        f.write('Epoch,D Loss,G Loss\n')

# Load and preprocess data
data = pd.read_csv('./flux_burnup_dataset.csv')
X = data[['Enrichment (%)', 'Flux', 'Burnup']].values
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# Save dataset stats for inference
data_min = scaler.data_min_
data_max = scaler.data_max_

# Generator Model
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.fc1 = nn.Linear(noise_dim + 3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 256)
        self.fc4 = nn.Linear(256, 3)
        
        # Initialize weights
        self.apply(self.init_weights)

    def forward(self, z, conditions):
        x = torch.cat([z, conditions], dim=1)
        x = nn.LeakyReLU(0.2)(self.fc1(x))
        residual = x
        x = nn.LeakyReLU(0.2)(self.fc2(x))
        x = nn.LeakyReLU(0.2)(self.fc3(x))
        x += residual
        x = self.fc4(x)
        return torch.tanh(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            spectral_norm(nn.Linear(6, 256)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.LayerNorm(256),
            spectral_norm(nn.Linear(256, 128)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.LayerNorm(128),
            spectral_norm(nn.Linear(128, 1))
        )
        self.apply(self.init_weights)

    def forward(self, x, conditions):
        x = torch.cat([x, conditions], dim=1)
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Initialize models
generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.0002)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.00005)

# Learning Rate Schedulers
scheduler_G = StepLR(optimizer_G, step_size=1000, gamma=0.5)
scheduler_D = StepLR(optimizer_D, step_size=1000, gamma=0.5)

# Loss functions for WGAN
def wgan_discriminator_loss(real_output, fake_output):
    return torch.mean(fake_output) - torch.mean(real_output)

def wgan_generator_loss(fake_output):
    return -torch.mean(fake_output)

# Gradient Penalty for WGAN-GP
def calc_gradient_penalty(discriminator, real_data, fake_data, conditions, batch_size, device):
    epsilon = torch.rand(batch_size, 1, device=device)
    interpolated_data = epsilon * real_data + (1 - epsilon) * fake_data
    interpolated_data.requires_grad_(True)

    d_interpolated = discriminator(interpolated_data, conditions)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated_data,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Checkpoint functions
def save_checkpoint(epoch):
    path = os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar')
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'data_min': data_min,
        'data_max': data_max
    }, path)
    print(f"[Checkpoint] Saved at epoch {epoch}")

def load_checkpoint():
    files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.tar')]
    if not files:
        print("[Checkpoint] No checkpoint found. Starting fresh.")
        return 0
    
    latest = max(files, key=lambda f: int(f.split('_')[1].split('.')[0]))
    path = os.path.join(checkpoint_dir, latest)
    
    try:
        checkpoint = torch.load(path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        print(f"[Checkpoint] Loaded from epoch {checkpoint['epoch']}")
        return checkpoint['epoch'] + 1
    except Exception as e:
        print(f"[Checkpoint] Error loading: {str(e)}. Starting fresh.")
        return 0

# Load checkpoint if exists
start_epoch = load_checkpoint()

# Convert data to CUDA tensors
dataset = TensorDataset(torch.tensor(X_scaled, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for real_data in dataloader:
        real_data = real_data[0].to(device, non_blocking=True)
        current_batch_size = real_data.size(0)
        
        # Generate noise and fake data
        noise = torch.randn(current_batch_size, noise_dim, device=device)
        fake_data = generator(noise, real_data)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        
        # Real data
        real_output = discriminator(real_data, real_data)
        # Fake data (detached)
        fake_output = discriminator(fake_data.detach(), real_data)
        
        # WGAN loss
        d_loss = wgan_discriminator_loss(real_output, fake_output)
        
        # Gradient penalty
        grad_penalty = calc_gradient_penalty(
            discriminator, 
            real_data, 
            fake_data.detach(), 
            real_data, 
            current_batch_size, 
            device
        )
        d_loss += 10 * grad_penalty
        
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        
        # Generate new fake data for generator training
        fake_data = generator(noise, real_data)
        g_output = discriminator(fake_data, real_data)
        g_loss = wgan_generator_loss(g_output)
        
        g_loss.backward()
        optimizer_G.step()

    # Update learning rates
    scheduler_G.step()
    scheduler_D.step()
    
    # Logging
    if epoch % 1 == 0:
        print(f"Epoch [{epoch}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    
    with open(loss_log_file, 'a') as f:
        f.write(f'{epoch},{d_loss.item()},{g_loss.item()}\n')
    
    # Checkpointing
    if epoch % checkpoint_interval == 0 and epoch > 0:
        save_checkpoint(epoch)

# Final save
save_checkpoint(num_epochs - 1)

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
import os
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR

# Configuration
checkpoint_dir = '/home/jovyan/FluxGAN/plots/checkpoint'
loss_log_file = '/home/jovyan/FluxGAN/plots/loss_log.csv'
checkpoint_interval = 1000
num_epochs = 30001
batch_size = 512
noise_dim = 100
label_flip_rate = 0.05   # 5% label flipping

# Setup CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# Setup directories
os.makedirs(checkpoint_dir, exist_ok=True)
if not os.path.exists(loss_log_file):
    with open(loss_log_file, 'w') as f:
        f.write('Epoch,D Loss,G Loss,GenMean,GenStd\n')

# Load and preprocess data
data = pd.read_csv('./flux_burnup_dataset.csv')
X = data[['Enrichment (%)', 'Flux', 'Burnup']].values
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# Save dataset stats for inference
data_min = scaler.data_min_
data_max = scaler.data_max_

# Models for Vanilla GAN
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super().__init__()
        self.net = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(noise_dim, 256)),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(256),
            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.LeakyReLU(0.2),
            nn.LayerNorm(128),
            nn.utils.spectral_norm(nn.Linear(128, 3)),
            nn.Tanh()
        )
        self.apply(self.init_weights)

    def forward(self, z):
        return self.net(z)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(3, 256)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.LayerNorm(256),
            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.LayerNorm(128),
            nn.utils.spectral_norm(nn.Linear(128, 1))
        )
        self.apply(self.init_weights)

    def forward(self, x):
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Initialize models and move to GPU
generator = Generator(noise_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers with weight decay and adjusted learning rates
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-5)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.00005, betas=(0.5, 0.999), weight_decay=1e-5)

# Learning Rate Scheduler (gentler decay)
scheduler_G = StepLR(optimizer_G, step_size=1000, gamma=0.8)
scheduler_D = StepLR(optimizer_D, step_size=1000, gamma=0.8)

# Loss function
adversarial_loss = nn.BCEWithLogitsLoss()

# Mixed Precision Scaler
scaler = GradScaler()

# Checkpoint functions
def save_checkpoint(epoch):
    path = os.path.join(checkpoint_dir, f'checkpoint_{epoch}.tar')
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'data_min': data_min,
        'data_max': data_max
    }, path)
    print(f"[Checkpoint] Saved at epoch {epoch}")

def load_checkpoint():
    files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.tar')]
    if not files:
        print("[Checkpoint] No checkpoint found. Starting fresh.")
        return 0

    latest = max(files, key=lambda f: int(f.split('_')[1].split('.')[0]))
    path = os.path.join(checkpoint_dir, latest)

    try:
        checkpoint = torch.load(path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        print(f"[Checkpoint] Loaded from epoch {checkpoint['epoch']}")
        return checkpoint['epoch'] + 1
    except Exception as e:
        print(f"[Checkpoint] Error loading: {str(e)}. Starting fresh.")
        return 0

# Load checkpoint if exists
start_epoch = load_checkpoint()

# Convert data to CUDA tensors
dataset = TensorDataset(torch.tensor(X_scaled, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for real_data in dataloader:
        real_data = real_data[0].to(device, non_blocking=True)
        current_batch_size = real_data.size(0)

        # ====== Train Discriminator ======
        optimizer_D.zero_grad(set_to_none=True)
        with autocast():
            real_labels = torch.full((current_batch_size, 1), 0.9, device=device)
            fake_labels = torch.zeros((current_batch_size, 1), device=device)
            n_flip = int(label_flip_rate * current_batch_size)
            if n_flip > 0:
                idx_flip = torch.randperm(current_batch_size)[:n_flip]
                real_labels[idx_flip] = 0
                fake_labels[idx_flip] = 1

            real_output = discriminator(real_data)
            d_loss_real = adversarial_loss(real_output, real_labels)

            z = torch.randn(current_batch_size, noise_dim, device=device)
            fake_data = generator(z)
            fake_output = discriminator(fake_data.detach())
            d_loss_fake = adversarial_loss(fake_output, fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2

        scaler.scale(d_loss).backward()
        scaler.step(optimizer_D)
        scaler.update()

        # ====== Train Generator (1:1 ratio) ======
        optimizer_G.zero_grad(set_to_none=True)
        with autocast():
            gen_labels = torch.ones(current_batch_size, 1, device=device)
            g_output = discriminator(fake_data)
            g_loss = adversarial_loss(g_output, gen_labels)
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()

    # Step the learning rate schedulers
    scheduler_G.step()
    scheduler_D.step()

    # Logging and checkpointing
    if epoch % 1 == 0:
        with torch.no_grad():
            z_log = torch.randn(batch_size, noise_dim, device=device)
            gen_samples = generator(z_log).cpu().numpy()
            gen_mean, gen_std = gen_samples.mean(), gen_samples.std()
        print(f"Epoch [{epoch}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | GenMean: {gen_mean:.3f} | GenStd: {gen_std:.3f}")

        with open(loss_log_file, 'a') as f:
            f.write(f'{epoch},{d_loss.item()},{g_loss.item()},{gen_mean},{gen_std}\n')

    if epoch % checkpoint_interval == 0 and epoch > 0:
        save_checkpoint(epoch)

# Final save
save_checkpoint(num_epochs - 1)


[Checkpoint] Error loading: Error(s) in loading state_dict for Generator:
	size mismatch for net.0.weight_orig: copying a param with shape torch.Size([256, 103]) from checkpoint, the shape in current model is torch.Size([256, 100]).
	size mismatch for net.0.weight_v: copying a param with shape torch.Size([103]) from checkpoint, the shape in current model is torch.Size([100]).. Starting fresh.
Epoch [0/30001] | D Loss: 0.8873 | G Loss: 0.6894 | GenMean: 0.124 | GenStd: 0.498
Epoch [1/30001] | D Loss: 0.6413 | G Loss: 0.9067 | GenMean: 0.154 | GenStd: 0.506
Epoch [2/30001] | D Loss: 0.6079 | G Loss: 1.0397 | GenMean: 0.258 | GenStd: 0.507
Epoch [3/30001] | D Loss: 0.6692 | G Loss: 0.9263 | GenMean: 0.405 | GenStd: 0.451
Epoch [4/30001] | D Loss: 0.7189 | G Loss: 0.7875 | GenMean: 0.509 | GenStd: 0.387
Epoch [5/30001] | D Loss: 0.7095 | G Loss: 0.8129 | GenMean: 0.563 | GenStd: 0.341
Epoch [6/30001] | D Loss: 0.7105 | G Loss: 0.8179 | GenMean: 0.599 | GenStd: 0.322
Epoch [7/30001] | D Los

KeyboardInterrupt: 