In [None]:
import torch, torchaudio, torchvision.transforms as transforms, matplotlib.pyplot as plt, torch.nn as nn, torch.optim as optim, numpy as np
from torchvision.models import vgg16, VGG16_Weights
from torch.utils.data import DataLoader, TensorDataset, Dataset
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import sys
sys.path.append("../")
from ad_utils import *
import warnings
warnings.filterwarnings('ignore')

cuda0 = torch.device("cuda:0")
cuda1 = torch.device("cuda:1")
device = cuda1
print(torch.cuda.get_device_name(device) if torch.cuda.is_available() else "No GPU available")

data = np.load("../../hvcm/RFQ.npy", allow_pickle=True)
label = np.load("../../hvcm/RFQ_labels.npy", allow_pickle=True)
label = label[:, 1]  # Assuming the second column is the label
label = (label == "Fault").astype(int)  # Convert to binary labels
print(data.shape, label.shape)

scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)

normal_data = data[label == 0]
faulty_data = data[label == 1]

normal_label = label[label == 0]
faulty_label = label[label == 1]

X_train_normal, X_test_normal, y_train_normal, y_test_normal = train_test_split(normal_data, normal_label, test_size=0.2, random_state=42, shuffle=True)
X_train_faulty, X_test_faulty, y_train_faulty, y_test_faulty = train_test_split(faulty_data, faulty_label, test_size=0.2, random_state=42, shuffle=True)

# Few-Shot GAN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Attention mechanism for Few-Shot learning
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, attention_dim=64):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv1d(in_channels, attention_dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(attention_dim, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights

# Enhanced Few-Shot Generator with Attention and Residual Connections
class EnhancedFewShotGenerator(nn.Module):
    def __init__(self, latent_dim=100, channels=14, seq_len=4500):
        super(EnhancedFewShotGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.seq_len = seq_len
        
        # Start with a smaller sequence length and upsample
        self.init_seq_len = seq_len // 32  # More aggressive downsampling
        
        # Enhanced initial projection with batch normalization
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 256 * self.init_seq_len),
            nn.BatchNorm1d(256 * self.init_seq_len),
            nn.ReLU(inplace=True)
        )
        
        # Progressive upsampling with attention and residual connections
        self.conv_blocks = nn.ModuleList([
            # Block 1: 256 -> 128 channels
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(128),
                nn.ReLU(inplace=True),
                AttentionBlock(128)
            ),
            # Block 2: 128 -> 64 channels  
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU(inplace=True),
                AttentionBlock(64)
            ),
            # Block 3: 64 -> 32 channels
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(32),
                nn.ReLU(inplace=True),
            ),
            # Block 4: 32 -> 16 channels
            nn.Sequential(
                nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(16),
                nn.ReLU(inplace=True),
            ),
            # Final block: 16 -> channels
            nn.Sequential(
                nn.ConvTranspose1d(16, channels, kernel_size=4, stride=2, padding=1),
                nn.Tanh()
            )
        ])

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 256, self.init_seq_len)
        
        # Progressive upsampling with skip connections
        for i, block in enumerate(self.conv_blocks):
            out = block(out)
        
        # Ensure exact sequence length through interpolation if needed
        if out.size(2) != self.seq_len:
            out = nn.functional.interpolate(out, size=self.seq_len, mode='linear', align_corners=False)
        
        return out  # Shape: (batch, 14, 4500)

# Enhanced Few-Shot Discriminator with Multi-Scale Features
class EnhancedFewShotDiscriminator(nn.Module):
    def __init__(self, channels=14, seq_len=4500):
        super(EnhancedFewShotDiscriminator, self).__init__()
        
        # Multi-scale feature extraction
        self.scale1_conv = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(channels, 32, kernel_size=3, stride=1, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.scale2_conv = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(channels, 32, kernel_size=7, stride=1, padding=3)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
        )
        
        # Further processing
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            
            nn.utils.spectral_norm(nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.utils.spectral_norm(nn.Linear(256, 1)),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Multi-scale feature extraction
        scale1_features = self.scale1_conv(x)
        scale2_features = self.scale2_conv(x)
        
        # Concatenate multi-scale features
        multi_scale_features = torch.cat([scale1_features, scale2_features], dim=1)
        
        # Fusion and classification
        fused_features = self.fusion(multi_scale_features)
        output = self.model(fused_features)
        
        return output

class FewShot1DDataset(Dataset):
    def __init__(self, data, labels=None, transform=None):
        """
        Args:
            data: numpy array of shape [n_samples, 4500, 14]
            labels: numpy array of shape [n_samples] (optional)
        """
        # Transpose to (n_samples, 14, 4500) for Conv1d
        self.data = torch.tensor(data.transpose(0, 2, 1), dtype=torch.float32)
        
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]  # Shape: (14, 4500)
        
        if self.transform:
            sample = self.transform(sample)
            
        if self.labels is not None:
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return sample, label
        return sample, 0

# Enhanced training function with few-shot learning capabilities
def train_enhanced_few_shot_gan(normal_data, device, epochs=100, batch_size=32, lr_g=1e-4, lr_d=2e-4):
    """
    Enhanced Few-Shot GAN training with improved stability and attention mechanisms
    """
    print(f"Training Enhanced Few-Shot GAN on data shape: {normal_data.shape}")
    
    # Data loading
    dataset = FewShot1DDataset(normal_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Initialize enhanced models
    latent_dim = 100
    num_channels = normal_data.shape[-1]
    seq_length = normal_data.shape[1]
    
    generator = EnhancedFewShotGenerator(latent_dim=latent_dim, channels=num_channels, seq_len=seq_length).to(device)
    discriminator = EnhancedFewShotDiscriminator(channels=num_channels, seq_len=seq_length).to(device)

    # Xavier/He initialization for better stability
    def init_weights(m):
        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
            nn.init.xavier_normal_(m.weight, gain=0.02)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)

    generator.apply(init_weights)
    discriminator.apply(init_weights)

    # Optimizers with different learning rates for stability
    optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))

    # Learning rate schedulers
    scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, patience=20, factor=0.8)
    scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, patience=20, factor=0.8)

    # Enhanced loss function with label smoothing
    def adversarial_loss_smooth(pred, target_is_real, smoothing=0.1):
        if target_is_real:
            target = torch.ones_like(pred) * (1.0 - smoothing) + smoothing * torch.rand_like(pred)
        else:
            target = torch.zeros_like(pred) + smoothing * torch.rand_like(pred)
        return nn.BCELoss()(pred, target)

    # Training history
    g_losses, d_losses = [], []
    
    print("Starting Enhanced Few-Shot GAN training...")
    print(f"Generator LR: {lr_g}, Discriminator LR: {lr_d}")

    for epoch in range(epochs):
        epoch_g_loss, epoch_d_loss = 0, 0
        
        for i, (real_samples, _) in enumerate(dataloader):
            real_samples = real_samples.to(device)  # Shape: (batch, 14, 4500)
            batch_size_actual = real_samples.size(0)
            
            # Add noise to real samples for robustness
            noisy_real = real_samples + 0.05 * torch.randn_like(real_samples)
            
            # ========================
            # Train Discriminator
            # ========================
            optimizer_D.zero_grad()
            
            # Real samples
            real_pred = discriminator(noisy_real)
            d_real_loss = adversarial_loss_smooth(real_pred, True)
            
            # Fake samples
            z = torch.randn(batch_size_actual, latent_dim, device=device)
            fake_samples = generator(z).detach()
            fake_pred = discriminator(fake_samples)
            d_fake_loss = adversarial_loss_smooth(fake_pred, False)
            
            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 0.5)
            optimizer_D.step()
            
            # ========================
            # Train Generator
            # ========================
            optimizer_G.zero_grad()
            
            # Generate fake samples
            z = torch.randn(batch_size_actual, latent_dim, device=device)
            fake_samples = generator(z)
            fake_pred = discriminator(fake_samples)
            
            # Generator loss (want discriminator to classify fake as real)
            g_loss = adversarial_loss_smooth(fake_pred, True)
            g_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 0.5)
            optimizer_G.step()
            
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            
            if i % 50 == 0:
                print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

        # Store average losses per epoch
        avg_g_loss = epoch_g_loss / len(dataloader)
        avg_d_loss = epoch_d_loss / len(dataloader)
        g_losses.append(avg_g_loss)
        d_losses.append(avg_d_loss)

        # Update learning rates
        scheduler_G.step(avg_g_loss)
        scheduler_D.step(avg_d_loss)

        # Enhanced stability monitoring
        if epoch % 10 == 0:
            monitor_gan_stability(g_losses, d_losses, window=10)
    
    print("Enhanced Few-Shot GAN training completed!")
    return generator, discriminator, g_losses, d_losses

def generate_samples(generator, num_samples, latent_dim=100):
    """
    Generate samples using the trained generator
    Returns data in shape (num_samples, 4500, 14)
    """
    device = next(generator.parameters()).device
    generator.eval()
    
    batch_size = 16
    all_samples = []
    
    with torch.no_grad():
        for start in range(0, num_samples, batch_size):
            end = min(start + batch_size, num_samples)
            current_batch_size = end - start
            
            z = torch.randn(current_batch_size, latent_dim, device=device)
            batch_samples = generator(z)  # Shape: (batch, 14, 4500)
            all_samples.append(batch_samples.cpu())
    
    generated_data = torch.cat(all_samples, dim=0).numpy()
    
    # Transpose back to (n_samples, 4500, 14)
    generated_data = generated_data.transpose(0, 2, 1)
    
    return generated_data

# Enhanced monitoring with more detailed analysis
def monitor_gan_stability(g_losses, d_losses, window=10):
    """
    Enhanced GAN training stability monitoring
    """
    if len(g_losses) < window:
        return
    
    # Recent losses
    recent_g = np.mean(g_losses[-window:])
    recent_d = np.mean(d_losses[-window:])
    
    # Loss ratio (should be roughly balanced)
    ratio = recent_g / (recent_d + 1e-8)
    
    # Loss variance (should be stable, not oscillating wildly)
    g_var = np.var(g_losses[-window:])
    d_var = np.var(d_losses[-window:])
    
    # Loss trend analysis
    if len(g_losses) >= window * 2:
        g_trend = np.mean(g_losses[-window:]) - np.mean(g_losses[-window*2:-window])
        d_trend = np.mean(d_losses[-window:]) - np.mean(d_losses[-window*2:-window])
    else:
        g_trend = d_trend = 0
    
    print(f"G/D Ratio: {ratio:.3f} | G_var: {g_var:.4f} | D_var: {d_var:.4f}")
    print(f"G_trend: {g_trend:+.4f} | D_trend: {d_trend:+.4f}")
    
    # Enhanced stability warnings
    if ratio > 5:
        print("⚠️  Generator significantly overpowering Discriminator")
        print("   💡 Consider: Lower G learning rate or train D more frequently")
    elif ratio < 0.2:
        print("⚠️  Discriminator significantly overpowering Generator")
        print("   💡 Consider: Lower D learning rate or add noise to real data")
    elif g_var > 1.0 or d_var > 1.0:
        print("⚠️  High variance - unstable training detected")
        print("   💡 Consider: Lower learning rates or gradient clipping")
    elif abs(g_trend) > 0.5 or abs(d_trend) > 0.5:
        print("⚠️  Significant loss trends detected")
        print("   💡 Consider: Learning rate scheduling or early stopping")
    else:
        print("✅ Training appears stable and balanced")

# Enhanced visualization
def plot_enhanced_training_curves(g_losses, d_losses):
    """
    Plot enhanced training curves with additional analysis
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    axes[0, 0].plot(g_losses, label='Generator Loss', alpha=0.7)
    axes[0, 0].plot(d_losses, label='Discriminator Loss', alpha=0.7)
    axes[0, 0].set_title('Training Losses')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss ratio
    ratios = [g/(d+1e-8) for g, d in zip(g_losses, d_losses)]
    axes[0, 1].plot(ratios, color='green', alpha=0.7)
    axes[0, 1].axhline(y=1, color='red', linestyle='--', alpha=0.5, label='Ideal Ratio')
    axes[0, 1].set_title('Generator/Discriminator Loss Ratio')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('G_loss / D_loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Moving averages
    window = 10
    if len(g_losses) >= window:
        g_ma = pd.Series(g_losses).rolling(window=window).mean()
        d_ma = pd.Series(d_losses).rolling(window=window).mean()
        
        axes[1, 0].plot(g_ma, label=f'Generator MA({window})', alpha=0.7)
        axes[1, 0].plot(d_ma, label=f'Discriminator MA({window})', alpha=0.7)
        axes[1, 0].set_title('Moving Average Losses')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Loss variance
    if len(g_losses) >= window:
        g_var = pd.Series(g_losses).rolling(window=window).var()
        d_var = pd.Series(d_losses).rolling(window=window).var()
        
        axes[1, 1].plot(g_var, label=f'Generator Var({window})', alpha=0.7)
        axes[1, 1].plot(d_var, label=f'Discriminator Var({window})', alpha=0.7)
        axes[1, 1].set_title('Loss Variance (Stability Indicator)')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Variance')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Few-Shot GAN Training

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention1D(nn.Module):
    """Self-attention module for capturing long-range dependencies"""
    def __init__(self, in_channels):
        super(SelfAttention1D, self).__init__()
        self.in_channels = in_channels
        self.query = nn.Conv1d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv1d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv1d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        batch_size, channels, length = x.size()
        
        # Generate query, key, value
        q = self.query(x).view(batch_size, -1, length).permute(0, 2, 1)
        k = self.key(x).view(batch_size, -1, length)
        v = self.value(x).view(batch_size, -1, length)
        
        # Attention calculation
        attention = torch.bmm(q, k)
        attention = self.softmax(attention)
        
        # Apply attention to values
        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, length)
        
        # Residual connection with learnable weight
        return self.gamma * out + x

class ResidualBlock1D(nn.Module):
    """Residual block for better gradient flow"""
    def __init__(self, channels):
        super(ResidualBlock1D, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm1d(channels)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)

class ImprovedFewShotGenerator(nn.Module):
    """Enhanced generator with residual connections, self-attention, and progressive upsampling"""
    def __init__(self, latent_dim=100, output_channels=14, target_length=4500):
        super(ImprovedFewShotGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.output_channels = output_channels
        self.target_length = target_length
        
        # Initial projection - start with very small length
        self.initial_length = 72  # Will be upsampled to 4500
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256 * self.initial_length),
            nn.BatchNorm1d(256 * self.initial_length),
            nn.ReLU(True)
        )
        
        # Progressive upsampling with residual blocks
        self.upsample_blocks = nn.ModuleList([
            # 72 -> 144
            nn.Sequential(
                nn.ConvTranspose1d(256, 256, 4, stride=2, padding=1),
                nn.BatchNorm1d(256),
                nn.ReLU(True),
                ResidualBlock1D(256)
            ),
            # 144 -> 288
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 4, stride=2, padding=1),
                nn.BatchNorm1d(128),
                nn.ReLU(True),
                ResidualBlock1D(128)
            ),
            # 288 -> 576
            nn.Sequential(
                nn.ConvTranspose1d(128, 128, 4, stride=2, padding=1),
                nn.BatchNorm1d(128),
                nn.ReLU(True),
                ResidualBlock1D(128)
            ),
            # 576 -> 1152
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 4, stride=2, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU(True),
                ResidualBlock1D(64)
            ),
            # 1152 -> 2304
            nn.Sequential(
                nn.ConvTranspose1d(64, 64, 4, stride=2, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU(True),
                ResidualBlock1D(64)
            ),
        ])
        
        # Self-attention for long-range dependencies
        self.attention = SelfAttention1D(64)
        
        # Final upsampling to exact target length
        self.final_upsample = nn.Sequential(
            nn.ConvTranspose1d(64, 32, 4, stride=2, padding=1),  # 2304 -> 4608
            nn.BatchNorm1d(32),
            nn.ReLU(True),
            nn.Conv1d(32, output_channels, 3, padding=1),  # Channel adjustment
            nn.Tanh()  # Output activation
        )
        
        # Learnable cropping/padding to get exact length
        self.length_adjuster = nn.Conv1d(output_channels, output_channels, 1)
        
    def forward(self, z):
        # Project latent vector
        x = self.fc(z)
        x = x.view(z.size(0), 256, self.initial_length)
        
        # Progressive upsampling with residual connections
        for upsample_block in self.upsample_blocks:
            x = upsample_block(x)
        
        # Apply self-attention
        x = self.attention(x)
        
        # Final upsampling
        x = self.final_upsample(x)
        
        # Adjust to exact target length
        current_length = x.size(-1)
        if current_length > self.target_length:
            # Crop to target length
            start_idx = (current_length - self.target_length) // 2
            x = x[:, :, start_idx:start_idx + self.target_length]
        elif current_length < self.target_length:
            # Pad to target length
            pad_size = self.target_length - current_length
            x = F.pad(x, (pad_size // 2, pad_size - pad_size // 2))
        
        # Final channel adjustment
        x = self.length_adjuster(x)
        
        return x

class ImprovedFewShotDiscriminator(nn.Module):
    """Enhanced discriminator with spectral normalization and progressive downsampling"""
    def __init__(self, input_channels=14):
        super(ImprovedFewShotDiscriminator, self).__init__()
        
        # Progressive downsampling with spectral normalization
        self.conv_blocks = nn.ModuleList([
            # 4500 -> 2250
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(input_channels, 32, 4, stride=2, padding=1)),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.2)
            ),
            # 2250 -> 1125
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(32, 64, 4, stride=2, padding=1)),
                nn.BatchNorm1d(64),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.2)
            ),
            # 1125 -> 562
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(64, 128, 4, stride=2, padding=1)),
                nn.BatchNorm1d(128),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.3)
            ),
            # 562 -> 281
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(128, 256, 4, stride=2, padding=1)),
                nn.BatchNorm1d(256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.3)
            ),
            # 281 -> 140
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(256, 512, 4, stride=2, padding=1)),
                nn.BatchNorm1d(512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.4)
            )
        ])
        
        # Self-attention for better feature extraction
        self.attention = SelfAttention1D(512)
        
        # Final classification layers
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.utils.spectral_norm(nn.Linear(256, 1)),
            nn.Sigmoid()
        )
    
    def forward(self, x, return_features=False):
        features = []
        
        # Progressive downsampling
        for conv_block in self.conv_blocks:
            x = conv_block(x)
            if return_features:
                features.append(x)
        
        # Apply self-attention
        x = self.attention(x)
        if return_features:
            features.append(x)
        
        # Final classification
        output = self.classifier(x)
        
        if return_features:
            return output, features
        return output

def feature_matching_loss(real_features, fake_features):
    """Feature matching loss to improve sample quality"""
    loss = 0
    for real_feat, fake_feat in zip(real_features, fake_features):
        loss += nn.MSELoss()(fake_feat.mean(0), real_feat.mean(0))
    return loss

def gradient_penalty(discriminator, real_samples, fake_samples, device):
    """Gradient penalty for improved training stability"""
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1).expand_as(real_samples).to(device)
    
    interpolated = alpha * real_samples + (1 - alpha) * fake_samples
    interpolated.requires_grad_(True)
    
    d_interpolated = discriminator(interpolated)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def train_improved_few_shot_gan(X_train, epochs=200, batch_size=32, latent_dim=100, 
                               lr_g=0.0001, lr_d=0.0002, device='cuda'):
    """
    Train Few-Shot GAN with all improvements for better FID score
    """
    print("🚀 Starting Enhanced Few-Shot GAN Training for Better FID Score")
    print("=" * 60)
    
    # Data loading
    dataset = FewShot1DDataset(X_train)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # Initialize improved models
    generator = ImprovedFewShotGenerator(latent_dim=latent_dim).to(device)
    discriminator = ImprovedFewShotDiscriminator().to(device)
    
    # Optimizers with improved settings
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(0.0, 0.9))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.0, 0.9))
    
    # Loss function
    adversarial_loss = nn.BCELoss()
    
    # Training tracking
    g_losses, d_losses = [], []
    feature_losses, gp_losses = [], []
    
    print(f"📊 Training Configuration:")
    print(f"   • Generator Parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"   • Discriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    print(f"   • Dataset Size: {len(dataset):,} samples")
    print(f"   • Batch Size: {batch_size}")
    print(f"   • Learning Rates: G={lr_g}, D={lr_d}")
    print("=" * 60)
    
    for epoch in range(epochs):
        epoch_g_loss, epoch_d_loss = 0, 0
        epoch_feature_loss, epoch_gp_loss = 0, 0
        
        for i, (real_samples, _) in enumerate(dataloader):
            real_samples = real_samples.to(device)
            batch_size_current = real_samples.size(0)
            
            # ========================
            # Train Discriminator (2 times per generator update)
            # ========================
            for _ in range(2):
                optimizer_D.zero_grad()
                
                # Real samples with label smoothing
                real_labels = torch.ones(batch_size_current, 1, device=device) * 0.9
                real_pred = discriminator(real_samples)
                real_loss = adversarial_loss(real_pred, real_labels)
                
                # Fake samples
                z = torch.randn(batch_size_current, latent_dim, device=device)
                fake_samples = generator(z).detach()
                fake_labels = torch.zeros(batch_size_current, 1, device=device) * 0.1
                fake_pred = discriminator(fake_samples)
                fake_loss = adversarial_loss(fake_pred, fake_labels)
                
                # Gradient penalty
                gp = gradient_penalty(discriminator, real_samples, fake_samples, device)
                
                # Total discriminator loss
                d_loss = real_loss + fake_loss + 10.0 * gp
                d_loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 0.5)
                optimizer_D.step()
                
                epoch_gp_loss += gp.item()
            
            # ========================
            # Train Generator with Feature Matching
            # ========================
            optimizer_G.zero_grad()
            
            # Generate fake samples
            z = torch.randn(batch_size_current, latent_dim, device=device)
            fake_samples = generator(z)
            
            # Get discriminator predictions and features
            fake_pred, fake_features = discriminator(fake_samples, return_features=True)
            _, real_features = discriminator(real_samples, return_features=True)
            
            # Adversarial loss
            valid_labels = torch.ones(batch_size_current, 1, device=device)
            adv_loss = adversarial_loss(fake_pred, valid_labels)
            
            # Feature matching loss
            fm_loss = feature_matching_loss(real_features, fake_features)
            
            # Combined generator loss
            g_loss = adv_loss + 10.0 * fm_loss
            g_loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 0.5)
            optimizer_G.step()
            
            # Track losses
            epoch_g_loss += adv_loss.item()
            epoch_d_loss += (real_loss.item() + fake_loss.item()) / 2
            epoch_feature_loss += fm_loss.item()
            
            if i % 50 == 0:
                print(f"[Epoch {epoch+1:3d}/{epochs}] [Batch {i:3d}/{len(dataloader)}] "
                      f"[D: {d_loss.item():.4f}] [G: {adv_loss.item():.4f}] "
                      f"[FM: {fm_loss.item():.4f}] [GP: {gp.item():.4f}]")
        
        # Store average losses per epoch
        g_losses.append(epoch_g_loss / len(dataloader))
        d_losses.append(epoch_d_loss / len(dataloader))
        feature_losses.append(epoch_feature_loss / len(dataloader))
        gp_losses.append(epoch_gp_loss / (len(dataloader) * 2))
        
        # Enhanced monitoring every 20 epochs
        if epoch % 20 == 0:
            print(f"\n{'='*15} Epoch {epoch+1} Summary {'='*15}")
            print(f"📈 Avg Generator Loss: {g_losses[-1]:.4f}")
            print(f"📉 Avg Discriminator Loss: {d_losses[-1]:.4f}")
            print(f"🎯 Avg Feature Matching Loss: {feature_losses[-1]:.4f}")
            print(f"⚖️  Avg Gradient Penalty: {gp_losses[-1]:.4f}")
            
            # Generate samples for quality check
            with torch.no_grad():
                z_test = torch.randn(32, latent_dim, device=device)
                test_samples = generator(z_test)
                
                # Basic quality metrics
                sample_mean = test_samples.mean().item()
                sample_std = test_samples.std().item()
                sample_range = (test_samples.max() - test_samples.min()).item()
                
                print(f"📊 Generated Sample Stats:")
                print(f"   • Mean: {sample_mean:.4f}")
                print(f"   • Std: {sample_std:.4f}")
                print(f"   • Range: {sample_range:.4f}")
            
            print("=" * 50)
    
    return generator, discriminator, g_losses, d_losses, feature_losses, gp_losses

print("✅ Enhanced training function ready!")

# Train the improved Few-Shot GAN
print("🔄 Training Enhanced Few-Shot GAN with Improved Architecture...")
print("=" * 70)

# Train with shorter epochs first to test
enhanced_generator, enhanced_discriminator, enh_g_losses, enh_d_losses, enh_fm_losses, enh_gp_losses = train_improved_few_shot_gan(
    X_train_normal, 
    epochs=200,  # Start with 50 epochs for comparison
    batch_size=32,
    latent_dim=100,
    lr_g=0.005,
    lr_d=0.0002,
    device=device
)

# Generate and Combine

In [None]:
# Generate samples
num_samples = len(X_train_normal)  # Number of samples to generate
generated_data = generate_samples(enhanced_generator, num_samples)


In [None]:
# ===============================
# FID SCORE EVALUATION
# ===============================

# Test the simplified FID calculation
print("Testing simplified FID calculation...")

# Use smaller subsets for testing
test_real = X_train_normal[:100]  # Use 100 samples for testing
test_generated = generated_data[:100]

print(f"Test real data shape: {test_real.shape}")
print(f"Test generated data shape: {test_generated.shape}")

# Calculate FID score
fid_score = calculate_fid_score(
    real_data=test_real,
    fake_data=test_generated,
    device=device,
    sample_rate=1000,
)

if fid_score is not None:
    print(f"\n🎉 SUCCESS! FID Score: {fid_score:.4f}")
    
    # Interpret the score
    if fid_score < 10:
        quality = "Excellent"
    elif fid_score < 25:
        quality = "Good"
    elif fid_score < 50:
        quality = "Fair"
    elif fid_score < 100:
        quality = "Poor"
    else:
        quality = "Very Poor"
    
    print(f"Quality Assessment: {quality}")
else:
    print("❌ FID calculation failed. Please check the error messages above.")

In [None]:
run_comprehensive_cross_validation_experiment(X_test_normal, X_test_faulty, device, generated_data, epochs=200, batch_size=32)