In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import sys
sys.path.append("../")
from ad_utils import *
from sklearn.model_selection import train_test_split


cuda0 = torch.device("cuda:0")
cuda1 = torch.device("cuda:1")
device = cuda1
print(torch.cuda.get_device_name(device) if torch.cuda.is_available() else "No GPU available")

data = np.load("../../hvcm/RFQ.npy", allow_pickle=True)
label = np.load("../../hvcm/RFQ_labels.npy", allow_pickle=True)
label = label[:, 1]  # Assuming the second column is the label
label = (label == "Fault").astype(int)  # Convert to binary labels
print(data.shape, label.shape)

scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)

normal_data = data[label == 0]
faulty_data = data[label == 1]

normal_label = label[label == 0]
faulty_label = label[label == 1]

X_train_normal, X_test_normal, y_train_normal, y_test_normal = train_test_split(normal_data, normal_label, test_size=0.2, random_state=42, shuffle=True)
X_train_faulty, X_test_faulty, y_train_faulty, y_test_faulty = train_test_split(faulty_data, faulty_label, test_size=0.2, random_state=42, shuffle=True)

# Time GAN

In [None]:
import torch
import torch.nn as nn

class Embedder(nn.Module):
    """Enhanced Embedding network optimized for anomaly detection time series."""
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Multi-scale feature extraction for better temporal patterns
        self.conv1d = nn.Conv1d(input_dim, hidden_dim // 2, kernel_size=3, padding=1)
        
        # Bidirectional LSTM for better temporal understanding
        self.lstm = nn.LSTM(
            input_size=hidden_dim // 2, 
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Attention mechanism for important feature selection
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2, 
            num_heads=4, 
            dropout=0.1,
            batch_first=True
        )
        
        # Projection layer to desired hidden dimension
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, X):
        """Forward pass for embedding with attention mechanism."""
        batch_size, seq_len, input_dim = X.shape
        
        # Conv1D for local pattern extraction
        X_conv = self.conv1d(X.transpose(1, 2)).transpose(1, 2)
        
        # LSTM for temporal dynamics
        H_lstm, _ = self.lstm(X_conv)
        
        # Self-attention for important pattern focus
        H_att, _ = self.attention(H_lstm, H_lstm, H_lstm)
        
        # Project to final embedding
        H = self.projection(H_att)
        
        return H

class Recovery(nn.Module):
    """Enhanced Recovery network with residual connections for better reconstruction."""
    def __init__(self, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        # Bidirectional LSTM for reconstruction
        self.lstm = nn.LSTM(
            input_size=hidden_dim, 
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Multi-layer reconstruction with residual connections
        self.recovery_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.1)
            ),
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.1)
            )
        ])
        
        # Final reconstruction layer
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Tanh()  # Bounded output for stability
        )
        
    def forward(self, H):
        """Forward pass for recovery with residual connections."""
        # LSTM processing
        H_lstm, _ = self.lstm(H)
        
        # Progressive reconstruction with residuals
        x = H_lstm
        for layer in self.recovery_layers:
            residual = x
            x = layer(x)
            # Add residual connection where dimensions match
            if x.shape[-1] == residual.shape[-1]:
                x = x + residual * 0.1  # Scaled residual
        
        # Final output
        X_tilde = self.output_layer(x)
        return X_tilde

class Generator(nn.Module):
    """Enhanced Generator with noise injection and temporal consistency."""
    def __init__(self, z_dim, hidden_dim, num_layers):
        super().__init__()
        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Initial noise transformation
        self.noise_transform = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(0.2)
        )
        
        # Bidirectional LSTM for better generation
        self.lstm = nn.LSTM(
            input_size=hidden_dim, 
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Temporal consistency layers
        self.temporal_conv = nn.Conv1d(
            hidden_dim * 2, hidden_dim, 
            kernel_size=3, padding=1
        )
        
        # Final generation layer with progressive refinement
        self.generation_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, Z):
        """Forward pass for generator with temporal consistency."""
        # Transform noise
        Z_transformed = self.noise_transform(Z)
        
        # LSTM generation
        H_lstm, _ = self.lstm(Z_transformed)
        
        # Temporal consistency via convolution
        H_conv = self.temporal_conv(H_lstm.transpose(1, 2)).transpose(1, 2)
        
        # Final generation
        H_hat = self.generation_layers(H_conv)
        
        return H_hat

class Supervisor(nn.Module):
    """Enhanced Supervisor with temporal prediction for anomaly detection."""
    def __init__(self, hidden_dim, num_layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Multi-step prediction LSTM
        self.lstm = nn.LSTM(
            input_size=hidden_dim, 
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Temporal prediction with attention
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=2,
            dropout=0.1,
            batch_first=True
        )
        
        # Prediction refinement
        self.prediction_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, H):
        """Forward pass for supervisor with temporal prediction."""
        # LSTM for sequence modeling
        H_lstm, _ = self.lstm(H)
        
        # Attention for temporal dependencies
        H_att, _ = self.temporal_attention(H_lstm, H_lstm, H_lstm)
        
        # Final prediction
        H_hat_supervise = self.prediction_head(H_att)
        
        return H_hat_supervise

class Discriminator(nn.Module):
    """Enhanced Discriminator with multi-scale analysis for anomaly detection."""
    def __init__(self, hidden_dim, num_layers):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Multi-scale temporal analysis
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=k, padding=k//2)
            for k in [3, 5, 7]  # Different temporal scales
        ])
        
        # Main LSTM discriminator
        self.lstm = nn.LSTM(
            input_size=hidden_dim * 3,  # Concatenated multi-scale features
            hidden_size=hidden_dim,
            num_layers=num_layers-1 if num_layers > 1 else 1,
            batch_first=True,
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Feature attention for important pattern focus
        self.feature_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=2,
            dropout=0.1,
            batch_first=True
        )
        
        # Progressive discrimination
        self.discriminator_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 4, 1)
        )
        
    def forward(self, H):
        """Forward pass for discriminator with multi-scale analysis."""
        # Multi-scale convolution analysis
        H_transpose = H.transpose(1, 2)  # For conv1d
        multi_scale_features = []
        
        for conv in self.conv_layers:
            conv_out = torch.relu(conv(H_transpose))
            multi_scale_features.append(conv_out)
        
        # Concatenate multi-scale features
        H_multi = torch.cat(multi_scale_features, dim=1).transpose(1, 2)
        
        # LSTM processing
        H_lstm, _ = self.lstm(H_multi)
        
        # Attention mechanism
        H_att, _ = self.feature_attention(H_lstm, H_lstm, H_lstm)
        
        # Final discrimination
        Y = self.discriminator_head(H_att)
        
        return Y

In [None]:
# Add this chunking function after your imports
def chunk_sequences(data, chunk_size=100, overlap=10):
    """
    Split long sequences into smaller chunks
    
    Args:
        data: shape [n_samples, seq_len, features] = (690, 4500, 14)
        chunk_size: size of each chunk
        overlap: overlap between chunks
    
    Returns:
        chunked_data: shape [n_chunks, chunk_size, features]
    """
    n_samples, seq_len, n_features = data.shape
    chunks = []
    
    for sample in data:
        # Create chunks with overlap
        for start in range(0, seq_len - chunk_size + 1, chunk_size - overlap):
            end = start + chunk_size
            if end <= seq_len:
                chunks.append(sample[start:end])
    
    return np.array(chunks)

# Enhanced chunking for better anomaly detection
def chunk_sequences_enhanced(data, chunk_size=150, overlap=20):
    """
    Enhanced chunking with better overlap for anomaly detection
    
    Args:
        data: shape [n_samples, seq_len, features]
        chunk_size: larger chunks for better context
        overlap: more overlap for continuity
    
    Returns:
        chunked_data: shape [n_chunks, chunk_size, features]
    """
    n_samples, seq_len, n_features = data.shape
    chunks = []
    
    for sample in data:
        # Create chunks with strategic overlap
        for start in range(0, seq_len - chunk_size + 1, chunk_size - overlap):
            end = start + chunk_size
            if end <= seq_len:
                chunks.append(sample[start:end])
    
    return np.array(chunks)

# Update your loss functions to be more stable
def embedding_loss(X, X_tilde):
    """
    Robust reconstruction loss using relative error
    """
    # Use relative L1 loss to handle large values
    return torch.mean(torch.abs(X - X_tilde) / (torch.abs(X) + 1e-6))


def supervised_loss(H, H_hat_supervise):
    """
    Supervised loss for the supervisor network - with safety check
    """
    if H.size(1) > 1:
        return torch.mean(torch.abs(H[:, 1:, :] - H_hat_supervise[:, :-1, :]))
    return torch.tensor(0.0, device=H.device)

def discriminator_loss(Y_real, Y_fake):
    """
    Discriminator loss using BCE with logits for stability
    """
    criterion = nn.BCEWithLogitsLoss()
    real_loss = criterion(Y_real, torch.ones_like(Y_real))
    fake_loss = criterion(Y_fake, torch.zeros_like(Y_fake))
    return real_loss + fake_loss

def generator_loss(Y_fake, H, H_hat_supervise, X, X_hat, lambda_sup=1.0, lambda_recon=0.01):
    """
    Generator loss with MUCH lower reconstruction weight for raw data
    """
    criterion = nn.BCEWithLogitsLoss()
    
    # Adversarial loss
    loss_adv = criterion(Y_fake, torch.ones_like(Y_fake))
    
    # Supervised loss
    loss_sup = supervised_loss(H, H_hat_supervise)
    
    # Relative reconstruction loss (VERY low weight for raw data)
    loss_recon = torch.mean(torch.abs(X - X_hat) / (torch.abs(X) + 1e-6))
    
    # CRITICAL: Much lower reconstruction weight for raw data
    total_loss = loss_adv + lambda_sup * loss_sup + lambda_recon * loss_recon
    return total_loss

# Enhanced loss functions for anomaly detection
def embedding_loss_enhanced(X, X_tilde):
    """
    Multi-objective embedding loss for anomaly detection
    """
    # Reconstruction loss (L1 + L2 combination)
    l1_loss = torch.mean(torch.abs(X - X_tilde))
    l2_loss = torch.mean((X - X_tilde) ** 2)
    
    # Frequency domain loss for temporal patterns
    X_fft = torch.fft.fft(X, dim=1)
    X_tilde_fft = torch.fft.fft(X_tilde, dim=1)
    freq_loss = torch.mean(torch.abs(X_fft - X_tilde_fft))
    
    # Feature correlation preservation
    X_corr = torch.corrcoef(X.reshape(-1, X.shape[-1]).T)
    X_tilde_corr = torch.corrcoef(X_tilde.reshape(-1, X_tilde.shape[-1]).T)
    corr_loss = torch.mean((X_corr - X_tilde_corr) ** 2)
    
    # Combined loss with weights optimized for anomaly detection
    total_loss = 0.4 * l1_loss + 0.3 * l2_loss + 0.2 * freq_loss + 0.1 * corr_loss
    return total_loss

def supervised_loss_enhanced(H, H_hat_supervise):
    """
    Enhanced supervised loss with temporal consistency
    """
    if H.size(1) <= 1:
        return torch.tensor(0.0, device=H.device)
    
    # Standard supervised loss
    base_loss = torch.mean(torch.abs(H[:, 1:, :] - H_hat_supervise[:, :-1, :]))
    
    # Temporal smoothness constraint
    H_diff = H[:, 1:, :] - H[:, :-1, :]
    H_hat_diff = H_hat_supervise[:, 1:, :] - H_hat_supervise[:, :-1, :]
    smooth_loss = torch.mean(torch.abs(H_diff - H_hat_diff))
    
    return base_loss + 0.1 * smooth_loss

def discriminator_loss_enhanced(Y_real, Y_fake):
    """
    Enhanced discriminator loss with gradient penalty
    """
    # Least squares loss for more stable training
    real_loss = torch.mean((Y_real - 1) ** 2)
    fake_loss = torch.mean(Y_fake ** 2)
    
    return (real_loss + fake_loss) / 2

def generator_loss_enhanced(Y_fake, H, H_hat_supervise, X, X_hat, 
                          lambda_sup=2.0, lambda_recon=0.1, lambda_diversity=0.05):
    """
    Enhanced generator loss optimized for anomaly detection
    """
    # Adversarial loss (least squares)
    loss_adv = torch.mean((Y_fake - 1) ** 2)
    
    # Enhanced supervised loss
    loss_sup = supervised_loss_enhanced(H, H_hat_supervise)
    
    # Enhanced reconstruction loss
    loss_recon = embedding_loss_enhanced(X, X_hat)
    
    # Diversity loss to prevent mode collapse
    batch_size = Y_fake.shape[0]
    if batch_size > 1:
        # Encourage diversity in generated samples
        H_hat_flat = H_hat_supervise.reshape(batch_size, -1)
        pairwise_dist = torch.pdist(H_hat_flat, p=2)
        diversity_loss = torch.exp(-pairwise_dist.mean())
    else:
        diversity_loss = torch.tensor(0.0, device=Y_fake.device)
    
    # Combined loss with optimized weights for anomaly detection
    total_loss = (loss_adv + 
                 lambda_sup * loss_sup + 
                 lambda_recon * loss_recon + 
                 lambda_diversity * diversity_loss)
    
    return total_loss, {
        'adv': loss_adv.item(),
        'sup': loss_sup.item(),
        'recon': loss_recon.item(),
        'div': diversity_loss.item()
    }

# Quality assessment for generated samples
def assess_generation_quality(real_data, synthetic_data):
    """
    Assess quality of generated samples for anomaly detection
    """
    real_mean = np.mean(real_data, axis=(0, 1))
    synth_mean = np.mean(synthetic_data, axis=(0, 1))
    
    real_std = np.std(real_data, axis=(0, 1))
    synth_std = np.std(synthetic_data, axis=(0, 1))
    
    # Statistical similarity
    mean_diff = np.mean(np.abs(real_mean - synth_mean))
    std_diff = np.mean(np.abs(real_std - synth_std))
    
    # Temporal correlation preservation
    real_corr = np.corrcoef(real_data.reshape(-1, real_data.shape[-1]).T)
    synth_corr = np.corrcoef(synthetic_data.reshape(-1, synthetic_data.shape[-1]).T)
    corr_diff = np.mean(np.abs(real_corr - synth_corr))
    
    return {
        'mean_difference': mean_diff,
        'std_difference': std_diff,
        'correlation_difference': corr_diff,
        'quality_score': 1.0 / (1.0 + mean_diff + std_diff + corr_diff)
    }

In [None]:
# Enhanced training function optimized for anomaly detection
def train_timegan_enhanced(data, seq_len, batch_size, model_params, train_params):
    """
    Enhanced TimeGAN training specifically optimized for anomaly detection
    """
    # Enhanced chunking with better parameters
    chunk_size = seq_len
    print(f"Enhanced chunking sequences into size {chunk_size}...")
    chunked_data = chunk_sequences_enhanced(data, chunk_size=chunk_size, overlap=30)
    print(f"Created {len(chunked_data)} enhanced chunks from {len(data)} original sequences")
    
    # Model parameters
    input_dim = model_params['input_dim']
    hidden_dim = model_params['hidden_dim']
    num_layers = model_params['num_layers']
    z_dim = model_params['z_dim']
    
    # Training parameters
    epochs = train_params['epochs']
    learning_rate = train_params['learning_rate']
    
    # Create dataset and loader
    data_tensor = torch.tensor(chunked_data, dtype=torch.float32)
    dataset = TensorDataset(data_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    # Initialize enhanced models
    embedder = Embedder(input_dim, hidden_dim, num_layers).to(device)
    recovery = Recovery(hidden_dim, input_dim, num_layers).to(device)
    generator = Generator(z_dim, hidden_dim, num_layers).to(device)
    supervisor = Supervisor(hidden_dim, num_layers).to(device)
    discriminator = Discriminator(hidden_dim, num_layers).to(device)
    
    # Enhanced weight initialization
    def enhanced_weights_init(m):
        if isinstance(m, nn.LSTM):
            for name, param in m.named_parameters():
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(param.data)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(param.data)
                elif 'bias' in name:
                    nn.init.constant_(param.data, 0)
                    # Set forget gate bias to 1
                    n = param.size(0)
                    param.data[n//4:n//2].fill_(1.)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
    
    # Apply enhanced initialization
    for model in [embedder, recovery, generator, supervisor, discriminator]:
        model.apply(enhanced_weights_init)
    
    # Enhanced optimizers with different learning rates
    e_optimizer = optim.AdamW(
        list(embedder.parameters()) + list(recovery.parameters()), 
        lr=learning_rate, betas=(0.5, 0.999), weight_decay=1e-5
    )
    g_optimizer = optim.AdamW(
        list(generator.parameters()) + list(supervisor.parameters()), 
        lr=learning_rate, betas=(0.5, 0.999), weight_decay=1e-5
    )
    d_optimizer = optim.AdamW(
        discriminator.parameters(), 
        lr=learning_rate * 0.1, betas=(0.5, 0.999), weight_decay=1e-5
    )
    
    # Learning rate schedulers
    e_scheduler = optim.lr_scheduler.CosineAnnealingLR(e_optimizer, T_max=epochs)
    g_scheduler = optim.lr_scheduler.CosineAnnealingLR(g_optimizer, T_max=epochs)
    d_scheduler = optim.lr_scheduler.CosineAnnealingLR(d_optimizer, T_max=epochs)
    
    print('Starting Enhanced TimeGAN Training for Anomaly Detection...')
    print(f'Model Parameters: Hidden={hidden_dim}, Layers={num_layers}, Z_dim={z_dim}')
    
    # Training history
    history = {
        'embedding_loss': [],
        'generator_loss': [],
        'discriminator_loss': [],
        'quality_scores': []
    }
    
    for epoch in range(epochs):
        epoch_e_loss = 0
        epoch_g_loss = 0
        epoch_d_loss = 0
        epoch_losses_detail = {'adv': 0, 'sup': 0, 'recon': 0, 'div': 0}
        
        for batch_idx, (X_mb,) in enumerate(dataloader):
            X_mb = X_mb.to(device)
            batch_size_actual = X_mb.shape[0]
            
            # Phase 1: Enhanced Embedding Training (every iteration)
            embedder.train()
            recovery.train()
            
            H = embedder(X_mb)
            X_tilde = recovery(H)
            
            e_loss = embedding_loss_enhanced(X_mb, X_tilde)
            
            e_optimizer.zero_grad()
            e_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                list(embedder.parameters()) + list(recovery.parameters()), 
                max_norm=1.0
            )
            e_optimizer.step()
            
            epoch_e_loss += e_loss.item()
            
            # Phase 2: Enhanced Generator and Discriminator Training
            if batch_idx % 3 == 0:  # Train G and D every 3 iterations for stability
                
                # Generator training with enhanced loss
                generator.train()
                supervisor.train()
                
                Z_mb = torch.randn(batch_size_actual, seq_len, z_dim).to(device)
                H_hat = generator(Z_mb)
                H_hat_supervise = supervisor(H_hat)
                X_hat = recovery(H_hat)
                
                # Get embeddings from real data
                with torch.no_grad():
                    H_real = embedder(X_mb)
                
                # Discriminator outputs
                Y_fake = discriminator(H_hat)
                
                # Enhanced generator loss
                g_loss, loss_details = generator_loss_enhanced(
                    Y_fake, H_real, H_hat_supervise, X_mb, X_hat
                )
                
                g_optimizer.zero_grad()
                g_loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(generator.parameters()) + list(supervisor.parameters()), 
                    max_norm=1.0
                )
                g_optimizer.step()
                
                epoch_g_loss += g_loss.item()
                for key in loss_details:
                    epoch_losses_detail[key] += loss_details[key]
                
                # Enhanced Discriminator training
                discriminator.train()
                
                # Generate fresh samples for discriminator
                Z_mb_d = torch.randn(batch_size_actual, seq_len, z_dim).to(device)
                with torch.no_grad():
                    H_hat_d = generator(Z_mb_d)
                    H_real_d = embedder(X_mb)
                
                Y_fake_d = discriminator(H_hat_d.detach())
                Y_real_d = discriminator(H_real_d)
                
                d_loss = discriminator_loss_enhanced(Y_real_d, Y_fake_d)
                
                d_optimizer.zero_grad()
                d_loss.backward()
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                d_optimizer.step()
                
                epoch_d_loss += d_loss.item()
        
        # Update learning rates
        e_scheduler.step()
        g_scheduler.step()
        d_scheduler.step()
        
        # Calculate epoch averages
        num_batches = len(dataloader)
        g_d_batches = num_batches // 3 if num_batches > 3 else max(1, num_batches)
        
        avg_e_loss = epoch_e_loss / num_batches
        avg_g_loss = epoch_g_loss / g_d_batches
        avg_d_loss = epoch_d_loss / g_d_batches
        
        # Store history
        history['embedding_loss'].append(avg_e_loss)
        history['generator_loss'].append(avg_g_loss)
        history['discriminator_loss'].append(avg_d_loss)
        
        # Quality assessment every 10 epochs
        if epoch % 10 == 0:
            with torch.no_grad():
                # Generate sample for quality assessment
                Z_sample = torch.randn(min(100, len(chunked_data)), seq_len, z_dim).to(device)
                generator.eval()
                supervisor.eval()
                recovery.eval()
                
                H_sample = generator(Z_sample)
                H_sample = supervisor(H_sample)
                X_sample = recovery(H_sample)
                
                sample_real = chunked_data[:min(100, len(chunked_data))]
                sample_synth = X_sample.cpu().numpy()
                
                quality = assess_generation_quality(sample_real, sample_synth)
                history['quality_scores'].append(quality['quality_score'])
        
        # Enhanced progress reporting
        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f'Epoch {epoch+1}/{epochs}:')
            print(f'  Embedding Loss: {avg_e_loss:.4f}')
            print(f'  Generator Loss: {avg_g_loss:.4f} (Adv: {epoch_losses_detail["adv"]/g_d_batches:.3f}, '
                  f'Sup: {epoch_losses_detail["sup"]/g_d_batches:.3f}, '
                  f'Recon: {epoch_losses_detail["recon"]/g_d_batches:.3f})')
            print(f'  Discriminator Loss: {avg_d_loss:.4f}')
            
            # Training stability indicators
            if len(history['generator_loss']) > 10:
                g_stability = np.std(history['generator_loss'][-10:])
                d_stability = np.std(history['discriminator_loss'][-10:])
                
                if g_stability < 0.1 and d_stability < 0.1:
                    print(f'  ✅ Training highly stable')
                elif g_stability < 0.5 and d_stability < 0.5:
                    print(f'  🔄 Training moderately stable')
                else:
                    print(f'  ⚠️  Training showing variation')
            
            if len(history['quality_scores']) > 0:
                print(f'  Quality Score: {history["quality_scores"][-1]:.4f}')
    
    print('Enhanced TimeGAN training completed!')
    
    return {
        'embedder': embedder,
        'recovery': recovery,
        'generator': generator,
        'supervisor': supervisor,
        'discriminator': discriminator,
        'chunk_size': chunk_size,
        'original_seq_len': data.shape[1],
        'history': history
    }

# Enhanced generation function
def generate_timegan_samples_enhanced(model, n_samples, seq_len, z_dim):
    """
    Generate high-quality synthetic samples for anomaly detection
    """
    generator = model['generator']
    supervisor = model['supervisor']
    recovery = model['recovery']
    
    # Set models to evaluation mode
    generator.eval()
    supervisor.eval()
    recovery.eval()
    
    generated_samples = []
    batch_size = 64  # Generate in batches
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            current_batch_size = min(batch_size, n_samples - i)
            
            # Generate diverse noise
            Z = torch.randn(current_batch_size, seq_len, z_dim).to(device)
            
            # Generate synthetic data
            H_hat = generator(Z)
            H_hat = supervisor(H_hat)
            X_hat = recovery(H_hat)
            
            generated_samples.append(X_hat.cpu().numpy())
    
    return np.concatenate(generated_samples, axis=0)

def reconstruct_full_sequences_enhanced(chunks, original_length=4500, chunk_size=150, overlap=30):
    """
    Enhanced sequence reconstruction with better overlap handling
    """
    step_size = chunk_size - overlap
    chunks_needed = (original_length - overlap) // step_size + 1
    
    n_full_sequences = len(chunks) // chunks_needed
    full_sequences = []
    
    for i in range(n_full_sequences):
        start_idx = i * chunks_needed
        end_idx = start_idx + min(chunks_needed, len(chunks) - start_idx)
        sequence_chunks = chunks[start_idx:end_idx]
        
        # Enhanced reconstruction with smooth transitions
        reconstructed = np.zeros((original_length, sequence_chunks.shape[2]))
        weights = np.zeros(original_length)
        
        for j, chunk in enumerate(sequence_chunks):
            pos = j * step_size
            end_pos = min(pos + chunk_size, original_length)
            chunk_len = end_pos - pos
            
            if chunk_len > 0:
                # Weighted averaging for smooth transitions
                weight = np.ones(chunk_len)
                if j > 0:  # Not the first chunk
                    weight[:overlap] = np.linspace(0.5, 1.0, overlap)
                if j < len(sequence_chunks) - 1:  # Not the last chunk
                    weight[-overlap:] = np.linspace(1.0, 0.5, overlap)
                
                reconstructed[pos:end_pos] += chunk[:chunk_len] * weight[:, np.newaxis]
                weights[pos:end_pos] += weight
        
        # Normalize by weights
        weights[weights == 0] = 1  # Avoid division by zero
        reconstructed = reconstructed / weights[:, np.newaxis]
        
        full_sequences.append(reconstructed)
    
    return np.array(full_sequences)

# Train, and generate

In [None]:
# Enhanced model parameters optimized for anomaly detection
chunk_size = 150  # Larger chunks for better context
input_dim = data.shape[2]  # 14 features
hidden_dim = 64  # Increased for better representation
num_layers = 3   # More layers for complex patterns
z_dim = input_dim * 2  # Larger latent space
seq_len = chunk_size
batch_size = 32  # Optimized batch size

model_params = {
    'input_dim': input_dim,
    'hidden_dim': hidden_dim,
    'num_layers': num_layers,
    'z_dim': z_dim
}

train_params = {
    'epochs': 100,  # More epochs for convergence
    'learning_rate': 0.0005  # Optimized learning rate
}

print("Enhanced TimeGAN Configuration for Anomaly Detection:")
print(f"Chunk Size: {chunk_size} (better temporal context)")
print(f"Hidden Dimension: {hidden_dim} (enhanced representation)")
print(f"Number of Layers: {num_layers} (deeper networks)")
print(f"Latent Dimension: {z_dim} (richer noise space)")
print(f"Batch Size: {batch_size} (optimized for stability)")
print(f"Epochs: {train_params['epochs']} (sufficient convergence)")

# Train the Enhanced TimeGAN model
print("\n" + "="*60)
print("STARTING ENHANCED TIMEGAN TRAINING")
print("="*60)

trained_model = train_timegan_enhanced(X_train_normal, seq_len, batch_size, model_params, train_params)

# Generate enhanced synthetic data
print("\n" + "="*60)
print("GENERATING ENHANCED SYNTHETIC DATA")
print("="*60)

num_samples = len(X_train_normal)
n_full_sequences_desired = num_samples

# Calculate chunks needed per sequence with enhanced parameters
step_size = chunk_size - 30  # overlap = 30
chunks_per_sequence = (4500 - 30) // step_size + 1
n_synthetic_chunks = n_full_sequences_desired * chunks_per_sequence

print(f"Generating {n_full_sequences_desired} full sequences:")
print(f"Chunks per sequence: {chunks_per_sequence}")
print(f"Total chunks needed: {n_synthetic_chunks}")

# Generate enhanced synthetic chunks
synthetic_chunks = generate_timegan_samples_enhanced(
    trained_model, n_synthetic_chunks, seq_len, z_dim
)
print(f"Generated {synthetic_chunks.shape} enhanced synthetic chunks")

# Reconstruct full sequences with enhanced method
synthetic_full = reconstruct_full_sequences_enhanced(
    synthetic_chunks,
    original_length=4500,
    chunk_size=chunk_size,
    overlap=30
)

print(f"Reconstructed {synthetic_full.shape} full enhanced synthetic sequences")

# Quality Assessment
print("\n" + "="*60)
print("QUALITY ASSESSMENT")
print("="*60)

quality_metrics = assess_generation_quality(X_train_normal, synthetic_full)
print("Enhanced TimeGAN Quality Metrics:")
print(f"✓ Mean Difference: {quality_metrics['mean_difference']:.6f}")
print(f"✓ Std Difference: {quality_metrics['std_difference']:.6f}")
print(f"✓ Correlation Difference: {quality_metrics['correlation_difference']:.6f}")
print(f"✓ Overall Quality Score: {quality_metrics['quality_score']:.6f}")

# Plot training history
if 'history' in trained_model:
    history = trained_model['history']
    
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Training Losses
    plt.subplot(2, 3, 1)
    plt.plot(history['embedding_loss'], label='Embedding Loss', color='blue', alpha=0.7)
    plt.title('Embedding Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 3, 2)
    plt.plot(history['generator_loss'], label='Generator Loss', color='red', alpha=0.7)
    plt.title('Generator Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 3, 3)
    plt.plot(history['discriminator_loss'], label='Discriminator Loss', color='green', alpha=0.7)
    plt.title('Discriminator Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Quality Scores
    if len(history['quality_scores']) > 0:
        plt.subplot(2, 3, 4)
        plt.plot(range(0, len(history['embedding_loss']), 10)[:len(history['quality_scores'])], 
                history['quality_scores'], label='Quality Score', color='purple', alpha=0.7, marker='o')
        plt.title('Generation Quality Over Time')
        plt.xlabel('Epoch')
        plt.ylabel('Quality Score')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    # Plot 3: Data Comparison
    plt.subplot(2, 3, 5)
    # Compare first feature across time for first sample
    sample_idx = 0
    feature_idx = 0
    time_points = range(min(500, X_train_normal.shape[1]))  # First 500 time points
    
    plt.plot(time_points, X_train_normal[sample_idx, time_points, feature_idx], 
            label='Real Data', alpha=0.7, linewidth=2)
    plt.plot(time_points, synthetic_full[sample_idx, time_points, feature_idx], 
            label='Synthetic Data', alpha=0.7, linewidth=2)
    plt.title(f'Real vs Synthetic (Feature {feature_idx})')
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Statistical Comparison
    plt.subplot(2, 3, 6)
    real_means = np.mean(X_train_normal, axis=(0,1))
    synth_means = np.mean(synthetic_full, axis=(0,1))
    
    x_pos = np.arange(len(real_means))
    width = 0.35
    
    plt.bar(x_pos - width/2, real_means, width, label='Real Data', alpha=0.7)
    plt.bar(x_pos + width/2, synth_means, width, label='Synthetic Data', alpha=0.7)
    plt.title('Feature Means Comparison')
    plt.xlabel('Feature Index')
    plt.ylabel('Mean Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("\n✅ Enhanced TimeGAN training and generation completed successfully!")
print(f"📊 Generated {len(synthetic_full)} high-quality synthetic sequences")
print(f"🎯 Quality Score: {quality_metrics['quality_score']:.4f}")
print("🚀 Ready for enhanced anomaly detection pipeline!")

In [None]:
# ===============================
# FID SCORE EVALUATION
# ===============================

# Test the simplified FID calculation
print("Testing simplified FID calculation...")

# Use smaller subsets for testing
test_real = X_train_normal[:100]  # Use 100 samples for testing
test_generated = synthetic_full[:100]

print(f"Test real data shape: {test_real.shape}")
print(f"Test generated data shape: {test_generated.shape}")

# Calculate FID score
fid_score = calculate_fid_score(
    real_data=test_real,
    fake_data=test_generated,
    device=device,
    sample_rate=1000,
)

if fid_score is not None:
    print(f"\n🎉 SUCCESS! FID Score: {fid_score:.4f}")
    
    # Interpret the score
    if fid_score < 10:
        quality = "Excellent"
    elif fid_score < 25:
        quality = "Good"
    elif fid_score < 50:
        quality = "Fair"
    elif fid_score < 100:
        quality = "Poor"
    else:
        quality = "Very Poor"
    
    print(f"Quality Assessment: {quality}")
else:
    print("❌ FID calculation failed. Please check the error messages above.")

In [None]:
run_comprehensive_cross_validation_experiment(X_test_normal, X_test_faulty, device, synthetic_full, epochs=200, batch_size=32)