In [None]:
import torch, torchaudio, torchvision.transforms as transforms, matplotlib.pyplot as plt, torch.nn as nn, torch.optim as optim, numpy as np, os
from torchvision.models import vgg16, VGG16_Weights
from torch.utils.data import DataLoader, TensorDataset
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import sys
sys.path.append("../")
from ad_utils import *

print(torch.cuda.device_count())
cuda0 = torch.device("cuda:0")
cuda1 = torch.device("cuda:1")
device = cuda1
print(torch.cuda.get_device_name(device) if torch.cuda.is_available() else "No GPU available")
data = np.load("../../hvcm/RFQ.npy", allow_pickle=True)
label = np.load("../../hvcm/RFQ_labels.npy", allow_pickle=True)
label = label[:, 1]  # Assuming the second column is the label
label = (label == "Fault").astype(int)  # Convert to binary labels
print(data.shape, label.shape)

scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)

normal_data = data[label == 0]
faulty_data = data[label == 1]

normal_label = label[label == 0]
faulty_label = label[label == 1]

X_train, X_test, y_train, y_test = train_test_split(normal_data, normal_label, test_size=0.2, random_state=42)


In [None]:
# ===============================
# TRUE MAD-GAN IMPLEMENTATION
# ===============================

import torch.nn.functional as F

class MADGANEncoder(nn.Module):
    """Encoder for reconstruction-based anomaly detection"""
    def __init__(self, num_features=14, hidden_dim=128, latent_dim=64, seq_len=4500):
        super(MADGANEncoder, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.seq_len = seq_len
        
        # LSTM for temporal modeling (key MAD-GAN component)
        self.lstm = nn.LSTM(num_features, hidden_dim, num_layers=2, 
                           batch_first=True, dropout=0.2, bidirectional=True)
        
        # Attention mechanism for important feature selection
        self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=8, 
                                             dropout=0.1, batch_first=True)
        
        # Latent space projection
        self.to_latent = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, latent_dim)
        )
        
    def forward(self, x):
        """Encode multivariate time series to latent space"""
        batch_size = x.size(0)
        
        # LSTM encoding
        lstm_out, (h_n, c_n) = self.lstm(x)
        
        # Self-attention for important features
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        
        # Global average pooling over time dimension
        pooled = torch.mean(attn_out, dim=1)  # [batch, hidden_dim*2]
        
        # Project to latent space
        latent = self.to_latent(pooled)
        
        return latent

class MADGANDecoder(nn.Module):
    """Decoder for reconstruction"""
    def __init__(self, latent_dim=64, hidden_dim=128, num_features=14, seq_len=4500):
        super(MADGANDecoder, self).__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.num_features = num_features
        self.seq_len = seq_len
        
        # Expand latent to initial hidden state
        self.latent_to_hidden = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # LSTM decoder
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, 
                           batch_first=True, dropout=0.2)
        
        # Output projection
        self.to_output = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, num_features),
            nn.Tanh()
        )
        
    def forward(self, latent):
        """Decode latent representation back to time series"""
        batch_size = latent.size(0)
        
        # Expand latent to hidden representation
        hidden = self.latent_to_hidden(latent)  # [batch, hidden_dim]
        
        # Repeat for sequence length
        hidden_seq = hidden.unsqueeze(1).repeat(1, self.seq_len, 1)  # [batch, seq_len, hidden_dim]
        
        # LSTM decoding
        lstm_out, _ = self.lstm(hidden_seq)
        
        # Output projection
        output = self.to_output(lstm_out)  # [batch, seq_len, num_features]
        
        return output

class MADGANGenerator(nn.Module):
    """Generator component of MAD-GAN"""
    def __init__(self, noise_dim=100, hidden_dim=128, num_features=14, seq_len=4500):
        super(MADGANGenerator, self).__init__()
        self.noise_dim = noise_dim
        self.seq_len = seq_len
        
        # Use decoder architecture for generation
        self.decoder = MADGANDecoder(noise_dim, hidden_dim, num_features, seq_len)
        
    def forward(self, z):
        return self.decoder(z)

class MADGANDiscriminator(nn.Module):
    """Discriminator with temporal modeling"""
    def __init__(self, num_features=14, hidden_dim=128, seq_len=4500):
        super(MADGANDiscriminator, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        
        # LSTM for temporal analysis
        self.lstm = nn.LSTM(num_features, hidden_dim, num_layers=2,
                           batch_first=True, dropout=0.2, bidirectional=True)
        
        # Attention mechanism
        self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=4,
                                             dropout=0.1, batch_first=True)
        
        # Classification layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1)
        )
        
    def forward(self, x):
        """Discriminate between real and fake sequences"""
        # LSTM processing
        lstm_out, _ = self.lstm(x)
        
        # Self-attention
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        
        # Global average pooling
        pooled = torch.mean(attn_out, dim=1)
        
        # Classification
        output = self.classifier(pooled)
        
        return output

class TrueMADGAN(nn.Module):
    """Complete MAD-GAN system for anomaly detection"""
    def __init__(self, num_features=14, hidden_dim=128, latent_dim=64, 
                 noise_dim=100, seq_len=4500):
        super(TrueMADGAN, self).__init__()
        
        # Core components
        self.encoder = MADGANEncoder(num_features, hidden_dim, latent_dim, seq_len)
        self.decoder = MADGANDecoder(latent_dim, hidden_dim, num_features, seq_len)
        self.generator = MADGANGenerator(noise_dim, hidden_dim, num_features, seq_len)
        self.discriminator = MADGANDiscriminator(num_features, hidden_dim, seq_len)
        
        self.latent_dim = latent_dim
        self.noise_dim = noise_dim
        
    def encode_decode(self, x):
        """Encode and decode for reconstruction"""
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed, latent
    
    def generate(self, batch_size, device):
        """Generate fake samples"""
        z = torch.randn(batch_size, self.noise_dim, device=device)
        fake = self.generator(z)
        return fake
    
    def compute_anomaly_score(self, x, lambda_rec=0.1):
        """
        Compute MAD-GAN anomaly score
        
        Args:
            x: Input sequences [batch, seq_len, features]
            lambda_rec: Weight for reconstruction loss
        
        Returns:
            anomaly_scores: Combined anomaly scores
            disc_scores: Discrimination scores
            rec_errors: Reconstruction errors
        """
        self.eval()
        with torch.no_grad():
            # Reconstruct input
            reconstructed, _ = self.encode_decode(x)
            
            # Discrimination score (probability of being fake)
            disc_logits = self.discriminator(x)
            disc_probs = torch.sigmoid(disc_logits).squeeze()
            
            # Reconstruction error
            rec_error = F.mse_loss(reconstructed, x, reduction='none')
            rec_error = rec_error.mean(dim=(1, 2))  # Average over time and features
            
            # Normalize reconstruction error
            rec_error_norm = (rec_error - rec_error.min()) / (rec_error.max() - rec_error.min() + 1e-8)
            
            # Combined anomaly score (higher = more anomalous)
            # High discrimination probability + high reconstruction error = anomaly
            anomaly_scores = (1 - disc_probs) + lambda_rec * rec_error_norm
            
        return anomaly_scores, disc_probs, rec_error
    
    def get_model_parameters(self):
        """Get parameter counts for each component"""
        return {
            'encoder': sum(p.numel() for p in self.encoder.parameters()),
            'decoder': sum(p.numel() for p in self.decoder.parameters()),
            'generator': sum(p.numel() for p in self.generator.parameters()),
            'discriminator': sum(p.numel() for p in self.discriminator.parameters()),
            'total': sum(p.numel() for p in self.parameters())
        }

# Initialize TRUE MAD-GAN
print("="*60)
print("INITIALIZING TRUE MAD-GAN")
print("="*60)

# Parameters
num_features = X_train.shape[2]  # 14
seq_len = X_train.shape[1]       # 4500
hidden_dim = 128
latent_dim = 64
noise_dim = 100

# Create true MAD-GAN
true_madgan = TrueMADGAN(
    num_features=num_features,
    hidden_dim=hidden_dim,
    latent_dim=latent_dim,
    noise_dim=noise_dim,
    seq_len=seq_len
).to(device)

# Print model information
params = true_madgan.get_model_parameters()
print(f"Model Components:")
print(f"  Encoder parameters: {params['encoder']:,}")
print(f"  Decoder parameters: {params['decoder']:,}")
print(f"  Generator parameters: {params['generator']:,}")
print(f"  Discriminator parameters: {params['discriminator']:,}")
print(f"  Total parameters: {params['total']:,}")

print(f"\nData Information:")
print(f"  Training samples: {len(X_train)}")
print(f"  Sequence length: {seq_len}")
print(f"  Features: {num_features}")
print(f"  Device: {device}")

print("="*60)
print("TRUE MAD-GAN READY FOR TRAINING")
print("="*60)

In [None]:
# ===============================
# TRUE MAD-GAN TRAINING FUNCTION
# ===============================

def train_true_madgan(madgan_model, train_data, device, epochs=100, batch_size=16,
                     lr_enc=0.0001, lr_dec=0.0001, lr_gen=0.0002, lr_disc=0.0001,
                     lambda_rec=0.1, lambda_enc=0.1):
    """
    Train the true MAD-GAN with all components
    
    Args:
        madgan_model: TrueMADGAN instance
        train_data: Training data (normal samples only)
        lambda_rec: Weight for reconstruction loss
        lambda_enc: Weight for encoder adversarial loss
    """
    print("🚀 Starting TRUE MAD-GAN Training...")
    
    # Separate optimizers for each component
    optimizer_enc = optim.Adam(madgan_model.encoder.parameters(), lr=lr_enc, betas=(0.5, 0.999))
    optimizer_dec = optim.Adam(madgan_model.decoder.parameters(), lr=lr_dec, betas=(0.5, 0.999))
    optimizer_gen = optim.Adam(madgan_model.generator.parameters(), lr=lr_gen, betas=(0.5, 0.999))
    optimizer_disc = optim.Adam(madgan_model.discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.999))
    
    # Loss functions
    adversarial_loss = nn.BCEWithLogitsLoss()
    reconstruction_loss = nn.MSELoss()
    
    # Data loader
    dataset = TensorDataset(torch.tensor(train_data, dtype=torch.float32))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    # Training history
    history = {
        'disc_loss': [], 'gen_loss': [], 'enc_loss': [], 'dec_loss': [], 'rec_loss': []
    }
    
    print(f"Training Configuration:")
    print(f"  Epochs: {epochs}, Batch size: {batch_size}")
    print(f"  Learning rates - Enc: {lr_enc}, Dec: {lr_dec}, Gen: {lr_gen}, Disc: {lr_disc}")
    print(f"  Loss weights - Reconstruction: {lambda_rec}, Encoder: {lambda_enc}")
    print(f"  Total batches per epoch: {len(dataloader)}")
    
    for epoch in range(epochs):
        epoch_losses = {'disc': [], 'gen': [], 'enc': [], 'dec': [], 'rec': []}
        
        madgan_model.train()
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            real_data = real_data.to(device)
            current_batch_size = real_data.size(0)
            
            # Labels for adversarial training
            real_labels = torch.ones(current_batch_size, 1, device=device) * 0.9  # Label smoothing
            fake_labels = torch.zeros(current_batch_size, 1, device=device) + 0.1
            
            # =====================================
            # 1. Train Discriminator
            # =====================================
            optimizer_disc.zero_grad()
            
            # Real data
            real_pred = madgan_model.discriminator(real_data)
            d_real_loss = adversarial_loss(real_pred, real_labels)
            
            # Fake data from generator
            fake_data = madgan_model.generate(current_batch_size, device)
            fake_pred = madgan_model.discriminator(fake_data.detach())
            d_fake_loss = adversarial_loss(fake_pred, fake_labels)
            
            # Reconstructed data
            reconstructed, _ = madgan_model.encode_decode(real_data)
            recon_pred = madgan_model.discriminator(reconstructed.detach())
            d_recon_loss = adversarial_loss(recon_pred, fake_labels)
            
            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss + d_recon_loss) / 3
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(madgan_model.discriminator.parameters(), 1.0)
            optimizer_disc.step()
            
            epoch_losses['disc'].append(d_loss.item())
            
            # =====================================
            # 2. Train Generator
            # =====================================
            if batch_idx % 2 == 0:  # Train generator every 2 batches
                optimizer_gen.zero_grad()
                
                fake_data = madgan_model.generate(current_batch_size, device)
                fake_pred = madgan_model.discriminator(fake_data)
                g_loss = adversarial_loss(fake_pred, real_labels)
                
                g_loss.backward()
                torch.nn.utils.clip_grad_norm_(madgan_model.generator.parameters(), 1.0)
                optimizer_gen.step()
                
                epoch_losses['gen'].append(g_loss.item())
            
            # =====================================
            # 3. Train Encoder + Decoder (Reconstruction)
            # =====================================
            if batch_idx % 2 == 0:  # Train reconstruction every 2 batches
                optimizer_enc.zero_grad()
                optimizer_dec.zero_grad()
                
                # Reconstruction
                reconstructed, latent = madgan_model.encode_decode(real_data)
                
                # Reconstruction loss
                rec_loss = reconstruction_loss(reconstructed, real_data)
                
                # Encoder adversarial loss (fool discriminator with reconstructions)
                recon_pred = madgan_model.discriminator(reconstructed)
                enc_adv_loss = adversarial_loss(recon_pred, real_labels)
                
                # Combined encoder/decoder loss
                total_rec_loss = lambda_rec * rec_loss + lambda_enc * enc_adv_loss
                total_rec_loss.backward()
                
                torch.nn.utils.clip_grad_norm_(madgan_model.encoder.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(madgan_model.decoder.parameters(), 1.0)
                optimizer_enc.step()
                optimizer_dec.step()
                
                epoch_losses['enc'].append(enc_adv_loss.item())
                epoch_losses['dec'].append(total_rec_loss.item())
                epoch_losses['rec'].append(rec_loss.item())
            
            # Memory cleanup
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()
        
        # Calculate epoch averages
        avg_losses = {k: np.mean(v) if v else 0 for k, v in epoch_losses.items()}
        
        # Store history
        for k, v in avg_losses.items():
            history[f'{k}_loss'].append(v)
        
        # Print progress
        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1:3d}/{epochs} | "
                  f"D: {avg_losses['disc']:.4f} | "
                  f"G: {avg_losses['gen']:.4f} | "
                  f"Rec: {avg_losses['rec']:.4f} | "
                  f"E: {avg_losses['enc']:.4f}")
        
        # Memory cleanup
        torch.cuda.empty_cache()
    
    print("✅ TRUE MAD-GAN Training Completed!")
    return madgan_model, history

In [None]:
# ===============================
# MEMORY-OPTIMIZED MAD-GAN
# ===============================

import torch.utils.checkpoint as checkpoint
import gc

class MemoryEfficientMADGANEncoder(nn.Module):
    """Memory-optimized encoder with chunked processing"""
    def __init__(self, num_features=14, hidden_dim=64, latent_dim=32, seq_len=4500):
        super(MemoryEfficientMADGANEncoder, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.seq_len = seq_len
        self.chunk_size = 500  # Process in smaller chunks
        
        # Smaller LSTM
        self.lstm = nn.LSTM(num_features, hidden_dim, num_layers=1, 
                           batch_first=True, dropout=0.1)
        
        # Simpler attention
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, 
                                             dropout=0.1, batch_first=True)
        
        # Compact latent projection
        self.to_latent = nn.Sequential(
            nn.Linear(hidden_dim, latent_dim),
            nn.Tanh()
        )
        
    def forward(self, x):
        """Memory-efficient forward pass with chunking"""
        batch_size = x.size(0)
        
        # Process in chunks to save memory
        chunk_outputs = []
        
        for i in range(0, self.seq_len, self.chunk_size):
            end_idx = min(i + self.chunk_size, self.seq_len)
            x_chunk = x[:, i:end_idx, :]
            
            # Use gradient checkpointing to save memory
            def chunk_forward(chunk):
                lstm_out, _ = self.lstm(chunk)
                attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
                return torch.mean(attn_out, dim=1)
            
            chunk_output = checkpoint.checkpoint(chunk_forward, x_chunk, use_reentrant=False)
            chunk_outputs.append(chunk_output)
            
            # Clear intermediate tensors
            del x_chunk
        
        # Combine chunks
        combined = torch.stack(chunk_outputs, dim=1).mean(dim=1)
        latent = self.to_latent(combined)
        
        # Clear memory
        del chunk_outputs, combined
        
        return latent

class MemoryEfficientMADGANDecoder(nn.Module):
    """Memory-optimized decoder"""
    def __init__(self, latent_dim=32, hidden_dim=64, num_features=14, seq_len=4500):
        super(MemoryEfficientMADGANDecoder, self).__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.num_features = num_features
        self.seq_len = seq_len
        self.chunk_size = 500
        
        # Compact expansion
        self.latent_to_hidden = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # Smaller LSTM
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=1, 
                           batch_first=True, dropout=0.1)
        
        # Simple output projection
        self.to_output = nn.Sequential(
            nn.Linear(hidden_dim, num_features),
            nn.Tanh()
        )
        
    def forward(self, latent):
        """Memory-efficient chunked decoding"""
        batch_size = latent.size(0)
        hidden = self.latent_to_hidden(latent)
        
        # Generate in chunks
        outputs = []
        
        for i in range(0, self.seq_len, self.chunk_size):
            end_idx = min(i + self.chunk_size, self.seq_len)
            chunk_len = end_idx - i
            
            # Create hidden sequence for this chunk
            hidden_chunk = hidden.unsqueeze(1).repeat(1, chunk_len, 1)
            
            # Process chunk
            def decode_chunk(h_chunk):
                lstm_out, _ = self.lstm(h_chunk)
                return self.to_output(lstm_out)
            
            output_chunk = checkpoint.checkpoint(decode_chunk, hidden_chunk, use_reentrant=False)
            outputs.append(output_chunk)
            
            # Clear memory
            del hidden_chunk
        
        result = torch.cat(outputs, dim=1)
        del outputs
        
        return result

class MemoryEfficientMADGANDiscriminator(nn.Module):
    """Memory-optimized discriminator"""
    def __init__(self, num_features=14, hidden_dim=64, seq_len=4500):
        super(MemoryEfficientMADGANDiscriminator, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.chunk_size = 500
        
        # Smaller LSTM
        self.lstm = nn.LSTM(num_features, hidden_dim, num_layers=1,
                           batch_first=True, dropout=0.1)
        
        # Simple classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1)
        )
        
    def forward(self, x):
        """Memory-efficient chunked discrimination"""
        seq_len = x.size(1)
        chunk_features = []
        
        for i in range(0, seq_len, self.chunk_size):
            end_idx = min(i + self.chunk_size, seq_len)
            x_chunk = x[:, i:end_idx, :]
            
            def process_chunk(chunk):
                lstm_out, _ = self.lstm(chunk)
                return torch.mean(lstm_out, dim=1)
            
            chunk_feature = checkpoint.checkpoint(process_chunk, x_chunk, use_reentrant=False)
            chunk_features.append(chunk_feature)
            
            del x_chunk
        
        # Combine features
        combined = torch.stack(chunk_features, dim=1).mean(dim=1)
        output = self.classifier(combined)
        
        del chunk_features, combined
        
        return output

class MemoryOptimizedMADGAN(nn.Module):
    """Memory-optimized complete MAD-GAN system"""
    def __init__(self, num_features=14, hidden_dim=64, latent_dim=32, 
                 noise_dim=64, seq_len=4500):
        super(MemoryOptimizedMADGAN, self).__init__()
        
        # Smaller components
        self.encoder = MemoryEfficientMADGANEncoder(num_features, hidden_dim, latent_dim, seq_len)
        self.decoder = MemoryEfficientMADGANDecoder(latent_dim, hidden_dim, num_features, seq_len)
        self.discriminator = MemoryEfficientMADGANDiscriminator(num_features, hidden_dim, seq_len)
        
        # Generator uses decoder
        self.generator_projection = nn.Sequential(
            nn.Linear(noise_dim, latent_dim),
            nn.Tanh()
        )
        
        self.latent_dim = latent_dim
        self.noise_dim = noise_dim
        
    def encode_decode(self, x):
        """Encode and decode with memory management"""
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed, latent
    
    def generate(self, batch_size, device):
        """Generate with memory efficiency"""
        z = torch.randn(batch_size, self.noise_dim, device=device)
        latent = self.generator_projection(z)
        fake = self.decoder(latent)
        del z, latent
        return fake
    
    def compute_anomaly_score(self, x, lambda_rec=0.1):
        """Memory-efficient anomaly scoring"""
        self.eval()
        with torch.no_grad():
            # Process in smaller batches to save memory
            batch_size = x.size(0)
            if batch_size > 4:  # Split large batches
                mid = batch_size // 2
                scores1, disc1, rec1 = self.compute_anomaly_score(x[:mid], lambda_rec)
                scores2, disc2, rec2 = self.compute_anomaly_score(x[mid:], lambda_rec)
                
                scores = torch.cat([scores1, scores2])
                disc_scores = torch.cat([disc1, disc2])
                rec_errors = torch.cat([rec1, rec2])
                
                return scores, disc_scores, rec_errors
            
            # Reconstruct
            reconstructed, _ = self.encode_decode(x)
            
            # Discriminate
            disc_logits = self.discriminator(x)
            disc_probs = torch.sigmoid(disc_logits).squeeze()
            
            # Reconstruction error
            rec_error = F.mse_loss(reconstructed, x, reduction='none')
            rec_error = rec_error.mean(dim=(1, 2))
            
            # Combined score
            anomaly_scores = (1 - disc_probs) + lambda_rec * rec_error
            
            # Clean up
            del reconstructed, disc_logits
            
        return anomaly_scores, disc_probs, rec_error
    
    def get_model_parameters(self):
        """Get parameter counts"""
        return {
            'encoder': sum(p.numel() for p in self.encoder.parameters()),
            'decoder': sum(p.numel() for p in self.decoder.parameters()),
            'discriminator': sum(p.numel() for p in self.discriminator.parameters()),
            'generator_proj': sum(p.numel() for p in self.generator_projection.parameters()),
            'total': sum(p.numel() for p in self.parameters())
        }

def clear_gpu_memory():
    """Clear GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

# Initialize memory-optimized MAD-GAN
print("="*60)
print("INITIALIZING MEMORY-OPTIMIZED MAD-GAN")
print("="*60)

# Clear memory first
clear_gpu_memory()

# Smaller parameters for memory efficiency
num_features = X_train.shape[2]  # 14
seq_len = X_train.shape[1]       # 4500
hidden_dim = 64      # Reduced from 128
latent_dim = 32      # Reduced from 64
noise_dim = 64       # Reduced from 100

# Create memory-optimized MAD-GAN
memory_madgan = MemoryOptimizedMADGAN(
    num_features=num_features,
    hidden_dim=hidden_dim,
    latent_dim=latent_dim,
    noise_dim=noise_dim,
    seq_len=seq_len
).to(device)

# Print model information
params = memory_madgan.get_model_parameters()
print(f"Memory-Optimized Model Components:")
print(f"  Encoder parameters: {params['encoder']:,}")
print(f"  Decoder parameters: {params['decoder']:,}")
print(f"  Discriminator parameters: {params['discriminator']:,}")
print(f"  Generator projection parameters: {params['generator_proj']:,}")
print(f"  Total parameters: {params['total']:,}")

# Compare with original
print(f"\nParameter Reduction:")
original_total = sum(p.numel() for p in true_madgan.parameters())
reduction = (1 - params['total'] / original_total) * 100
print(f"  Original model: {original_total:,} parameters")
print(f"  Optimized model: {params['total']:,} parameters")
print(f"  Reduction: {reduction:.1f}%")

print(f"\nMemory Usage Check:")
if torch.cuda.is_available():
    memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
    memory_reserved = torch.cuda.memory_reserved(device) / 1024**3
    print(f"  Allocated: {memory_allocated:.2f} GB")
    print(f"  Reserved: {memory_reserved:.2f} GB")

print("="*60)
print("MEMORY-OPTIMIZED MAD-GAN READY")
print("="*60)

In [None]:
# ===============================
# MEMORY-EFFICIENT TRAINING FUNCTION
# ===============================

def train_memory_efficient_madgan(madgan_model, train_data, device, 
                                 epochs=50, batch_size=4, accumulation_steps=4,
                                 lr_enc=0.0001, lr_dec=0.0001, lr_gen=0.0002, lr_disc=0.0001,
                                 lambda_rec=0.1, lambda_enc=0.1):
    """
    Memory-efficient training with gradient accumulation and frequent cleanup
    
    Args:
        accumulation_steps: Accumulate gradients over multiple mini-batches
    """
    print("🚀 Starting Memory-Efficient MAD-GAN Training...")
    
    # Use gradient accumulation to simulate larger batch sizes
    effective_batch_size = batch_size * accumulation_steps
    
    # Optimizers with lower memory usage
    optimizer_enc = optim.Adam(madgan_model.encoder.parameters(), lr=lr_enc, betas=(0.5, 0.999))
    optimizer_dec = optim.Adam(madgan_model.decoder.parameters(), lr=lr_dec, betas=(0.5, 0.999))
    optimizer_gen = optim.Adam(madgan_model.generator_projection.parameters(), lr=lr_gen, betas=(0.5, 0.999))
    optimizer_disc = optim.Adam(madgan_model.discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.999))
    
    # Loss functions
    adversarial_loss = nn.BCEWithLogitsLoss()
    reconstruction_loss = nn.MSELoss()
    
    # Data loader with small batch size
    dataset = TensorDataset(torch.tensor(train_data, dtype=torch.float32))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    # Training history
    history = {
        'disc_loss': [], 'gen_loss': [], 'enc_loss': [], 'dec_loss': [], 'rec_loss': []
    }
    
    print(f"Training Configuration:")
    print(f"  Epochs: {epochs}")
    print(f"  Mini-batch size: {batch_size}")
    print(f"  Accumulation steps: {accumulation_steps}")
    print(f"  Effective batch size: {effective_batch_size}")
    print(f"  Learning rates - Enc: {lr_enc}, Dec: {lr_dec}, Gen: {lr_gen}, Disc: {lr_disc}")
    print(f"  Total batches per epoch: {len(dataloader)}")
    
    for epoch in range(epochs):
        epoch_losses = {'disc': [], 'gen': [], 'enc': [], 'dec': [], 'rec': []}
        
        # Zero gradients at start of epoch
        optimizer_disc.zero_grad()
        optimizer_gen.zero_grad()
        optimizer_enc.zero_grad()
        optimizer_dec.zero_grad()
        
        madgan_model.train()
        
        for batch_idx, (real_data,) in enumerate(dataloader):
            real_data = real_data.to(device)
            current_batch_size = real_data.size(0)
            
            # Labels
            real_labels = torch.ones(current_batch_size, 1, device=device) * 0.9
            fake_labels = torch.zeros(current_batch_size, 1, device=device) + 0.1
            
            # =====================================
            # Train Discriminator (with accumulation)
            # =====================================
            
            # Real data
            real_pred = madgan_model.discriminator(real_data)
            d_real_loss = adversarial_loss(real_pred, real_labels) / accumulation_steps
            d_real_loss.backward()
            
            # Fake data from generator
            with torch.no_grad():
                fake_data = madgan_model.generate(current_batch_size, device)
            fake_pred = madgan_model.discriminator(fake_data)
            d_fake_loss = adversarial_loss(fake_pred, fake_labels) / accumulation_steps
            d_fake_loss.backward()
            
            # Reconstructed data
            with torch.no_grad():
                reconstructed, _ = madgan_model.encode_decode(real_data)
            recon_pred = madgan_model.discriminator(reconstructed)
            d_recon_loss = adversarial_loss(recon_pred, fake_labels) / accumulation_steps
            d_recon_loss.backward()
            
            # Step discriminator every accumulation_steps
            if (batch_idx + 1) % accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(madgan_model.discriminator.parameters(), 1.0)
                optimizer_disc.step()
                optimizer_disc.zero_grad()
            
            d_loss_total = (d_real_loss + d_fake_loss + d_recon_loss) * accumulation_steps
            epoch_losses['disc'].append(d_loss_total.item())
            
            # Clear intermediate tensors
            del real_pred, fake_data, fake_pred, reconstructed, recon_pred
            
            # =====================================
            # Train Generator (less frequently)
            # =====================================
            if batch_idx % (accumulation_steps * 2) == 0:
                fake_data = madgan_model.generate(current_batch_size, device)
                fake_pred = madgan_model.discriminator(fake_data)
                g_loss = adversarial_loss(fake_pred, real_labels) / accumulation_steps
                g_loss.backward()
                
                if (batch_idx + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(madgan_model.generator_projection.parameters(), 1.0)
                    optimizer_gen.step()
                    optimizer_gen.zero_grad()
                
                epoch_losses['gen'].append(g_loss.item() * accumulation_steps)
                del fake_data, fake_pred
            
            # =====================================
            # Train Encoder + Decoder (Reconstruction)
            # =====================================
            if batch_idx % accumulation_steps == 0:
                # Reconstruction
                reconstructed, latent = madgan_model.encode_decode(real_data)
                
                # Reconstruction loss
                rec_loss = reconstruction_loss(reconstructed, real_data) / accumulation_steps
                
                # Encoder adversarial loss
                recon_pred = madgan_model.discriminator(reconstructed)
                enc_adv_loss = adversarial_loss(recon_pred, real_labels) / accumulation_steps
                
                # Combined loss
                total_rec_loss = lambda_rec * rec_loss + lambda_enc * enc_adv_loss
                total_rec_loss.backward()
                
                if (batch_idx + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(madgan_model.encoder.parameters(), 1.0)
                    torch.nn.utils.clip_grad_norm_(madgan_model.decoder.parameters(), 1.0)
                    optimizer_enc.step()
                    optimizer_dec.step()
                    optimizer_enc.zero_grad()
                    optimizer_dec.zero_grad()
                
                epoch_losses['enc'].append(enc_adv_loss.item() * accumulation_steps)
                epoch_losses['dec'].append(total_rec_loss.item() * accumulation_steps)
                epoch_losses['rec'].append(rec_loss.item() * accumulation_steps)
                
                del reconstructed, latent, recon_pred
            
            # Aggressive memory cleanup
            if batch_idx % 5 == 0:
                clear_gpu_memory()
            
            # Memory monitoring
            if batch_idx % 20 == 0 and torch.cuda.is_available():
                memory_used = torch.cuda.memory_allocated(device) / 1024**3
                if memory_used > 6.0:  # Warning if using more than 6GB
                    print(f"    WARNING: High memory usage: {memory_used:.2f}GB at batch {batch_idx}")
                    clear_gpu_memory()
        
        # Calculate epoch averages
        avg_losses = {k: np.mean(v) if v else 0 for k, v in epoch_losses.items()}
        
        # Store history
        for k, v in avg_losses.items():
            history[f'{k}_loss'].append(v)
        
        # Print progress
        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1:3d}/{epochs} | "
                  f"D: {avg_losses['disc']:.4f} | "
                  f"G: {avg_losses['gen']:.4f} | "
                  f"Rec: {avg_losses['rec']:.4f} | "
                  f"E: {avg_losses['enc']:.4f}")
            
            # Memory status
            if torch.cuda.is_available():
                memory_used = torch.cuda.memory_allocated(device) / 1024**3
                print(f"         Memory: {memory_used:.2f}GB")
        
        # Aggressive cleanup at end of epoch
        clear_gpu_memory()
    
    print("✅ Memory-Efficient Training Completed!")
    return madgan_model, history

# Test memory usage before training
def test_memory_usage():
    """Test memory usage with current model"""
    print("🔍 Testing Memory Usage...")
    
    # Test with small batch
    test_batch_size = 2
    test_data = torch.tensor(X_train[:test_batch_size], dtype=torch.float32).to(device)
    
    # Clear memory first
    clear_gpu_memory()
    
    # Monitor initial memory
    initial_memory = torch.cuda.memory_allocated(device) / 1024**3
    print(f"Initial memory: {initial_memory:.2f}GB")
    
    # Test forward pass
    memory_madgan.eval()
    with torch.no_grad():
        # Test encoding
        latent = memory_madgan.encoder(test_data)
        memory_after_encode = torch.cuda.memory_allocated(device) / 1024**3
        print(f"After encoding: {memory_after_encode:.2f}GB")
        
        # Test decoding
        reconstructed = memory_madgan.decoder(latent)
        memory_after_decode = torch.cuda.memory_allocated(device) / 1024**3
        print(f"After decoding: {memory_after_decode:.2f}GB")
        
        # Test discrimination
        disc_output = memory_madgan.discriminator(test_data)
        memory_after_disc = torch.cuda.memory_allocated(device) / 1024**3
        print(f"After discrimination: {memory_after_disc:.2f}GB")
        
        # Test generation
        fake_data = memory_madgan.generate(test_batch_size, device)
        memory_after_gen = torch.cuda.memory_allocated(device) / 1024**3
        print(f"After generation: {memory_after_gen:.2f}GB")
        
        # Test anomaly scoring
        anomaly_scores, _, _ = memory_madgan.compute_anomaly_score(test_data)
        memory_after_scoring = torch.cuda.memory_allocated(device) / 1024**3
        print(f"After anomaly scoring: {memory_after_scoring:.2f}GB")
    
    clear_gpu_memory()
    final_memory = torch.cuda.memory_allocated(device) / 1024**3
    print(f"After cleanup: {final_memory:.2f}GB")
    
    print("✅ Memory test completed successfully!")
    
    return memory_after_scoring < 8.0  # Return True if under 8GB

# Run memory test
memory_ok = test_memory_usage()

In [None]:
# ===============================
# EXECUTE MEMORY-EFFICIENT TRAINING
# ===============================

if memory_ok:
    print("✅ Memory test passed. Starting training...")
    
    # Clear all memory before training
    clear_gpu_memory()
    
    # Train with very conservative settings
    try:
        trained_memory_madgan, memory_training_history = train_memory_efficient_madgan(
            madgan_model=memory_madgan,
            train_data=X_train,
            device=device,
            epochs=200,              # Reduced epochs
            batch_size=2,           # Very small batch size
            accumulation_steps=8,   # Effective batch size = 16
            lr_enc=0.0001,
            lr_dec=0.0001,
            lr_gen=0.0002,
            lr_disc=0.0001,
            lambda_rec=0.1,
            lambda_enc=0.1
        )
        
        print("🎉 Training completed successfully!")
        
        # Generate synthetic data with memory management
        print("\n🔄 Generating synthetic data...")
        
        trained_memory_madgan.eval()
        generated_samples = []
        generation_batch_size = len(X_train)  # Very small for generation
        
        num_samples_to_generate = len(X_train)  # Limit generation
        
        with torch.no_grad():
            for i in range(0, num_samples_to_generate, generation_batch_size):
                current_batch = min(generation_batch_size, num_samples_to_generate - i)
                fake_data = trained_memory_madgan.generate(current_batch, device)
                generated_samples.append(fake_data.cpu().numpy())
                
                # Clear GPU memory after each batch
                del fake_data
                if i % 20 == 0:
                    clear_gpu_memory()
        
        # Combine generated samples
        memory_generated_data = np.concatenate(generated_samples, axis=0)
        
        print(f"✅ Generated {len(memory_generated_data)} samples")
        print(f"Generated data shape: {memory_generated_data.shape}")
        
        # Plot training history with memory-efficient plotting
        plt.figure(figsize=(12, 8))
        
        plt.subplot(2, 3, 1)
        if memory_training_history['disc_loss']:
            plt.plot(memory_training_history['disc_loss'], label='Discriminator', alpha=0.8)
        plt.title('Discriminator Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 3, 2)
        if memory_training_history['gen_loss']:
            plt.plot(memory_training_history['gen_loss'], label='Generator', color='orange', alpha=0.8)
        plt.title('Generator Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 3, 3)
        if memory_training_history['rec_loss']:
            plt.plot(memory_training_history['rec_loss'], label='Reconstruction', color='green', alpha=0.8)
        plt.title('Reconstruction Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 3, 4)
        if memory_training_history['enc_loss']:
            plt.plot(memory_training_history['enc_loss'], label='Encoder Adv', color='red', alpha=0.8)
        plt.title('Encoder Adversarial Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 3, 5)
        if memory_training_history['dec_loss']:
            plt.plot(memory_training_history['dec_loss'], label='Decoder Total', color='purple', alpha=0.8)
        plt.title('Decoder Total Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 3, 6)
        # Plot key losses only to avoid clutter
        if memory_training_history['disc_loss']:
            plt.plot(memory_training_history['disc_loss'], label='Discriminator', alpha=0.7)
        if memory_training_history['rec_loss']:
            plt.plot(memory_training_history['rec_loss'], label='Reconstruction', alpha=0.7)
        plt.title('Key Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print final statistics
        print("\n📊 Final Training Statistics:")
        if memory_training_history['disc_loss']:
            print(f"   Final Discriminator Loss: {memory_training_history['disc_loss'][-1]:.4f}")
        if memory_training_history['gen_loss']:
            print(f"   Final Generator Loss: {memory_training_history['gen_loss'][-1]:.4f}")
        if memory_training_history['rec_loss']:
            print(f"   Final Reconstruction Loss: {memory_training_history['rec_loss'][-1]:.4f}")
        
        # Final memory check
        final_memory = torch.cuda.memory_allocated(device) / 1024**3
        print(f"   Final GPU Memory Usage: {final_memory:.2f}GB")
        
        # Test anomaly detection on a few samples
        print("\n🔍 Testing Anomaly Detection...")
        test_samples = torch.tensor(X_train[:5], dtype=torch.float32).to(device)
        anomaly_scores, disc_scores, rec_errors = trained_memory_madgan.compute_anomaly_score(test_samples)
        
        print("Normal samples anomaly scores:")
        for i, (ascore, dscore, rerror) in enumerate(zip(anomaly_scores, disc_scores, rec_errors)):
            print(f"   Sample {i+1}: Anomaly={ascore:.4f}, Disc={dscore:.4f}, RecErr={rerror:.4f}")
        
        clear_gpu_memory()
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("❌ Still running out of memory. Trying even smaller settings...")
            clear_gpu_memory()
            
            # Ultra-conservative fallback
            try:
                trained_memory_madgan, memory_training_history = train_memory_efficient_madgan(
                    madgan_model=memory_madgan,
                    train_data=X_train,  # Use only subset of data
                    device=device,
                    epochs=200,
                    batch_size=1,           # Minimum batch size
                    accumulation_steps=16,  # Larger accumulation
                    lr_enc=0.0001,
                    lr_dec=0.0001,
                    lr_gen=0.0002,
                    lr_disc=0.0001,
                    lambda_rec=0.1,
                    lambda_enc=0.1
                )
                print("✅ Ultra-conservative training completed!")
                
            except Exception as e2:
                print(f"❌ Training failed even with minimal settings: {e2}")
                print("💡 Suggestions:")
                print("   - Use CPU training (slower but will work)")
                print("   - Reduce sequence length (truncate time series)")
                print("   - Use even smaller model dimensions")
        else:
            print(f"❌ Training failed with error: {e}")
            
else:
    print("❌ Memory test failed. GPU memory insufficient.")
    print("💡 Recommendations:")
    print("   1. Reduce sequence length (currently 4500 timesteps)")
    print("   2. Use CPU training instead of GPU")
    print("   3. Process data in smaller chunks")
    print("   4. Use simpler model architecture")

# Final cleanup
clear_gpu_memory()
print("\n🧹 Final memory cleanup completed.")


In [None]:
# ===============================
# FID SCORE EVALUATION
# ===============================

# Test the simplified FID calculation
print("Testing simplified FID calculation...")

# Use smaller subsets for testing
test_real = X_train[:100]  # Use 100 samples for testing
test_generated = memory_generated_data[:100]

print(f"Test real data shape: {test_real.shape}")
print(f"Test generated data shape: {test_generated.shape}")

# Calculate FID score
fid_score = calculate_fid_score(
    real_data=test_real,
    fake_data=test_generated,
    device=device,
    sample_rate=1000,
)

if fid_score is not None:
    print(f"\n🎉 SUCCESS! FID Score: {fid_score:.4f}")
    
    # Interpret the score
    if fid_score < 10:
        quality = "Excellent"
    elif fid_score < 25:
        quality = "Good"
    elif fid_score < 50:
        quality = "Fair"
    elif fid_score < 100:
        quality = "Poor"
    else:
        quality = "Very Poor"
    
    print(f"Quality Assessment: {quality}")
else:
    print("❌ FID calculation failed. Please check the error messages above.")

In [None]:
run_comprehensive_cross_validation_experiment(X_train, faulty_data, device, memory_generated_data, epochs=200, batch_size=32)