In [1]:
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import sys
sys.path.append("../")
from ad_utils import *
import warnings
warnings.filterwarnings('ignore')

cuda0 = torch.device("cuda:0")
cuda1 = torch.device("cuda:1")
device = cuda1
print(torch.cuda.get_device_name(device) if torch.cuda.is_available() else "No GPU available")

data = np.load("../../hvcm/RFQ.npy", allow_pickle=True)
label = np.load("../../hvcm/RFQ_labels.npy", allow_pickle=True)
label = label[:, 1]  # Assuming the second column is the label
label = (label == "Fault").astype(int)  # Convert to binary labels
print(data.shape, label.shape)

scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)

normal_data = data[label == 0]
faulty_data = data[label == 1]

normal_label = label[label == 0]
faulty_label = label[label == 1]

X_train_normal, X_test_normal, y_train_normal, y_test_normal = train_test_split(normal_data, normal_label, test_size=0.2, random_state=42, shuffle=True)
X_train_faulty, X_test_faulty, y_train_faulty, y_test_faulty = train_test_split(faulty_data, faulty_label, test_size=0.2, random_state=42, shuffle=True)

NVIDIA A30
(872, 4500, 14) (872,)


# Few-Shot GAN

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Attention mechanism for Few-Shot learning
class AttentionBlock(nn.Module):
    def __init__(self, in_channels, attention_dim=64):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv1d(in_channels, attention_dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(attention_dim, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights

# Enhanced Few-Shot Generator with Attention and Residual Connections
class EnhancedFewShotGenerator(nn.Module):
    def __init__(self, latent_dim=100, channels=14, seq_len=4500):
        super(EnhancedFewShotGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.seq_len = seq_len
        
        # Start with a smaller sequence length and upsample
        self.init_seq_len = seq_len // 32  # More aggressive downsampling
        
        # Enhanced initial projection with batch normalization
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 256 * self.init_seq_len),
            nn.BatchNorm1d(256 * self.init_seq_len),
            nn.ReLU(inplace=True)
        )
        
        # Progressive upsampling with attention and residual connections
        self.conv_blocks = nn.ModuleList([
            # Block 1: 256 -> 128 channels
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(128),
                nn.ReLU(inplace=True),
                AttentionBlock(128)
            ),
            # Block 2: 128 -> 64 channels  
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU(inplace=True),
                AttentionBlock(64)
            ),
            # Block 3: 64 -> 32 channels
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(32),
                nn.ReLU(inplace=True),
            ),
            # Block 4: 32 -> 16 channels
            nn.Sequential(
                nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm1d(16),
                nn.ReLU(inplace=True),
            ),
            # Final block: 16 -> channels
            nn.Sequential(
                nn.ConvTranspose1d(16, channels, kernel_size=4, stride=2, padding=1),
                nn.Tanh()
            )
        ])

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 256, self.init_seq_len)
        
        # Progressive upsampling with skip connections
        for i, block in enumerate(self.conv_blocks):
            out = block(out)
        
        # Ensure exact sequence length through interpolation if needed
        if out.size(2) != self.seq_len:
            out = nn.functional.interpolate(out, size=self.seq_len, mode='linear', align_corners=False)
        
        return out  # Shape: (batch, 14, 4500)

# Enhanced Few-Shot Discriminator with Multi-Scale Features
class EnhancedFewShotDiscriminator(nn.Module):
    def __init__(self, channels=14, seq_len=4500):
        super(EnhancedFewShotDiscriminator, self).__init__()
        
        # Multi-scale feature extraction
        self.scale1_conv = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(channels, 32, kernel_size=3, stride=1, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.scale2_conv = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(channels, 32, kernel_size=7, stride=1, padding=3)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
        )
        
        # Further processing
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            
            nn.utils.spectral_norm(nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.utils.spectral_norm(nn.Linear(256, 1)),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        # Multi-scale feature extraction
        scale1_features = self.scale1_conv(x)
        scale2_features = self.scale2_conv(x)
        
        # Concatenate multi-scale features
        multi_scale_features = torch.cat([scale1_features, scale2_features], dim=1)
        
        # Fusion and classification
        fused_features = self.fusion(multi_scale_features)
        output = self.model(fused_features)
        
        return output

class FewShot1DDataset(Dataset):
    def __init__(self, data, labels=None, transform=None):
        """
        Args:
            data: numpy array of shape [n_samples, 4500, 14]
            labels: numpy array of shape [n_samples] (optional)
        """
        # Transpose to (n_samples, 14, 4500) for Conv1d
        self.data = torch.tensor(data.transpose(0, 2, 1), dtype=torch.float32)
        
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]  # Shape: (14, 4500)
        
        if self.transform:
            sample = self.transform(sample)
            
        if self.labels is not None:
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return sample, label
        return sample, 0

# Enhanced training function with few-shot learning capabilities
def train_enhanced_few_shot_gan(normal_data, device, epochs=100, batch_size=32, lr_g=1e-4, lr_d=2e-4):
    """
    Enhanced Few-Shot GAN training with improved stability and attention mechanisms
    """
    print(f"Training Enhanced Few-Shot GAN on data shape: {normal_data.shape}")
    
    # Data loading
    dataset = FewShot1DDataset(normal_data)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Initialize enhanced models
    latent_dim = 100
    num_channels = normal_data.shape[-1]
    seq_length = normal_data.shape[1]
    
    generator = EnhancedFewShotGenerator(latent_dim=latent_dim, channels=num_channels, seq_len=seq_length).to(device)
    discriminator = EnhancedFewShotDiscriminator(channels=num_channels, seq_len=seq_length).to(device)

    # Xavier/He initialization for better stability
    def init_weights(m):
        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
            nn.init.xavier_normal_(m.weight, gain=0.02)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)

    generator.apply(init_weights)
    discriminator.apply(init_weights)

    # Optimizers with different learning rates for stability
    optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))

    # Learning rate schedulers
    scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, patience=20, factor=0.8)
    scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, patience=20, factor=0.8)

    # Enhanced loss function with label smoothing
    def adversarial_loss_smooth(pred, target_is_real, smoothing=0.1):
        if target_is_real:
            target = torch.ones_like(pred) * (1.0 - smoothing) + smoothing * torch.rand_like(pred)
        else:
            target = torch.zeros_like(pred) + smoothing * torch.rand_like(pred)
        return nn.BCELoss()(pred, target)

    # Training history
    g_losses, d_losses = [], []
    
    print("Starting Enhanced Few-Shot GAN training...")
    print(f"Generator LR: {lr_g}, Discriminator LR: {lr_d}")

    for epoch in range(epochs):
        epoch_g_loss, epoch_d_loss = 0, 0
        
        for i, (real_samples, _) in enumerate(dataloader):
            real_samples = real_samples.to(device)  # Shape: (batch, 14, 4500)
            batch_size_actual = real_samples.size(0)
            
            # Add noise to real samples for robustness
            noisy_real = real_samples + 0.05 * torch.randn_like(real_samples)
            
            # ========================
            # Train Discriminator
            # ========================
            optimizer_D.zero_grad()
            
            # Real samples
            real_pred = discriminator(noisy_real)
            d_real_loss = adversarial_loss_smooth(real_pred, True)
            
            # Fake samples
            z = torch.randn(batch_size_actual, latent_dim, device=device)
            fake_samples = generator(z).detach()
            fake_pred = discriminator(fake_samples)
            d_fake_loss = adversarial_loss_smooth(fake_pred, False)
            
            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 0.5)
            optimizer_D.step()
            
            # ========================
            # Train Generator
            # ========================
            optimizer_G.zero_grad()
            
            # Generate fake samples
            z = torch.randn(batch_size_actual, latent_dim, device=device)
            fake_samples = generator(z)
            fake_pred = discriminator(fake_samples)
            
            # Generator loss (want discriminator to classify fake as real)
            g_loss = adversarial_loss_smooth(fake_pred, True)
            g_loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 0.5)
            optimizer_G.step()
            
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            
            if i % 50 == 0:
                print(f"[Epoch {epoch+1}/{epochs}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

        # Store average losses per epoch
        avg_g_loss = epoch_g_loss / len(dataloader)
        avg_d_loss = epoch_d_loss / len(dataloader)
        g_losses.append(avg_g_loss)
        d_losses.append(avg_d_loss)

        # Update learning rates
        scheduler_G.step(avg_g_loss)
        scheduler_D.step(avg_d_loss)

        # Enhanced stability monitoring
        if epoch % 10 == 0:
            monitor_gan_stability(g_losses, d_losses, window=10)
    
    print("Enhanced Few-Shot GAN training completed!")
    return generator, discriminator, g_losses, d_losses

def generate_samples(generator, num_samples, latent_dim=100):
    """
    Generate samples using the trained generator
    Returns data in shape (num_samples, 4500, 14)
    """
    device = next(generator.parameters()).device
    generator.eval()
    
    batch_size = 16
    all_samples = []
    
    with torch.no_grad():
        for start in range(0, num_samples, batch_size):
            end = min(start + batch_size, num_samples)
            current_batch_size = end - start
            
            z = torch.randn(current_batch_size, latent_dim, device=device)
            batch_samples = generator(z)  # Shape: (batch, 14, 4500)
            all_samples.append(batch_samples.cpu())
    
    generated_data = torch.cat(all_samples, dim=0).numpy()
    
    # Transpose back to (n_samples, 4500, 14)
    generated_data = generated_data.transpose(0, 2, 1)
    
    return generated_data

# Enhanced monitoring with more detailed analysis
def monitor_gan_stability(g_losses, d_losses, window=10):
    """
    Enhanced GAN training stability monitoring
    """
    if len(g_losses) < window:
        return
    
    # Recent losses
    recent_g = np.mean(g_losses[-window:])
    recent_d = np.mean(d_losses[-window:])
    
    # Loss ratio (should be roughly balanced)
    ratio = recent_g / (recent_d + 1e-8)
    
    # Loss variance (should be stable, not oscillating wildly)
    g_var = np.var(g_losses[-window:])
    d_var = np.var(d_losses[-window:])
    
    # Loss trend analysis
    if len(g_losses) >= window * 2:
        g_trend = np.mean(g_losses[-window:]) - np.mean(g_losses[-window*2:-window])
        d_trend = np.mean(d_losses[-window:]) - np.mean(d_losses[-window*2:-window])
    else:
        g_trend = d_trend = 0
    
    print(f"G/D Ratio: {ratio:.3f} | G_var: {g_var:.4f} | D_var: {d_var:.4f}")
    print(f"G_trend: {g_trend:+.4f} | D_trend: {d_trend:+.4f}")
    
    # Enhanced stability warnings
    if ratio > 5:
        print("⚠️  Generator significantly overpowering Discriminator")
        print("   💡 Consider: Lower G learning rate or train D more frequently")
    elif ratio < 0.2:
        print("⚠️  Discriminator significantly overpowering Generator")
        print("   💡 Consider: Lower D learning rate or add noise to real data")
    elif g_var > 1.0 or d_var > 1.0:
        print("⚠️  High variance - unstable training detected")
        print("   💡 Consider: Lower learning rates or gradient clipping")
    elif abs(g_trend) > 0.5 or abs(d_trend) > 0.5:
        print("⚠️  Significant loss trends detected")
        print("   💡 Consider: Learning rate scheduling or early stopping")
    else:
        print("✅ Training appears stable and balanced")

# Enhanced visualization
def plot_enhanced_training_curves(g_losses, d_losses):
    """
    Plot enhanced training curves with additional analysis
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    axes[0, 0].plot(g_losses, label='Generator Loss', alpha=0.7)
    axes[0, 0].plot(d_losses, label='Discriminator Loss', alpha=0.7)
    axes[0, 0].set_title('Training Losses')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Loss ratio
    ratios = [g/(d+1e-8) for g, d in zip(g_losses, d_losses)]
    axes[0, 1].plot(ratios, color='green', alpha=0.7)
    axes[0, 1].axhline(y=1, color='red', linestyle='--', alpha=0.5, label='Ideal Ratio')
    axes[0, 1].set_title('Generator/Discriminator Loss Ratio')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('G_loss / D_loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Moving averages
    window = 10
    if len(g_losses) >= window:
        g_ma = pd.Series(g_losses).rolling(window=window).mean()
        d_ma = pd.Series(d_losses).rolling(window=window).mean()
        
        axes[1, 0].plot(g_ma, label=f'Generator MA({window})', alpha=0.7)
        axes[1, 0].plot(d_ma, label=f'Discriminator MA({window})', alpha=0.7)
        axes[1, 0].set_title('Moving Average Losses')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Loss variance
    if len(g_losses) >= window:
        g_var = pd.Series(g_losses).rolling(window=window).var()
        d_var = pd.Series(d_losses).rolling(window=window).var()
        
        axes[1, 1].plot(g_var, label=f'Generator Var({window})', alpha=0.7)
        axes[1, 1].plot(d_var, label=f'Discriminator Var({window})', alpha=0.7)
        axes[1, 1].set_title('Loss Variance (Stability Indicator)')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Variance')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Few-Shot GAN Training

In [3]:

class SelfAttention1D(nn.Module):
    """Self-attention module optimized for time series"""
    def __init__(self, in_channels):
        super(SelfAttention1D, self).__init__()
        self.in_channels = in_channels
        self.query = nn.Conv1d(in_channels, in_channels // 4, 1)  # Reduced for efficiency
        self.key = nn.Conv1d(in_channels, in_channels // 4, 1)
        self.value = nn.Conv1d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        batch_size, channels, length = x.size()
        
        # Subsample for long sequences to reduce computation
        if length > 1000:
            step = length // 500
            x_sub = x[:, :, ::step]
        else:
            x_sub = x
            
        sub_length = x_sub.size(2)
        
        # Generate query, key, value on subsampled data
        q = self.query(x_sub).view(batch_size, -1, sub_length).permute(0, 2, 1)
        k = self.key(x_sub).view(batch_size, -1, sub_length)
        v = self.value(x_sub).view(batch_size, -1, sub_length)
        
        # Attention calculation
        attention = torch.bmm(q, k)
        attention = self.softmax(attention)
        
        # Apply attention to values
        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, sub_length)
        
        # Interpolate back to original length if subsampled
        if length > 1000:
            out = nn.functional.interpolate(out, size=length, mode='linear', align_corners=False)
        
        return self.gamma * out + x

class ResidualBlock1D(nn.Module):
    """Optimized residual block for sensor data"""
    def __init__(self, channels):
        super(ResidualBlock1D, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, 5, padding=2)  # Larger kernel for temporal patterns
        self.bn1 = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, 5, padding=2)
        self.bn2 = nn.BatchNorm1d(channels)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)

class ImprovedFewShotGenerator(nn.Module):
    """Enhanced generator optimized for sensor time series"""
    def __init__(self, latent_dim=100, output_channels=14, target_length=4500):
        super(ImprovedFewShotGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.output_channels = output_channels
        self.target_length = target_length
        
        # Optimized initial projection for 4500 length
        self.initial_length = 141  # 141 * 32 = 4512 ≈ 4500
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 512 * self.initial_length),
            nn.BatchNorm1d(512 * self.initial_length),
            nn.ReLU(True)
        )
        
        # Progressive upsampling optimized for sensor data
        self.upsample_blocks = nn.ModuleList([
            # 141 -> 282
            nn.Sequential(
                nn.ConvTranspose1d(512, 256, 4, stride=2, padding=1),
                nn.BatchNorm1d(256),
                nn.ReLU(True),
                ResidualBlock1D(256)
            ),
            # 282 -> 564  
            nn.Sequential(
                nn.ConvTranspose1d(256, 128, 4, stride=2, padding=1),
                nn.BatchNorm1d(128),
                nn.ReLU(True),
                ResidualBlock1D(128)
            ),
            # 564 -> 1128
            nn.Sequential(
                nn.ConvTranspose1d(128, 64, 4, stride=2, padding=1),
                nn.BatchNorm1d(64),
                nn.ReLU(True),
                ResidualBlock1D(64)
            ),
            # 1128 -> 2256
            nn.Sequential(
                nn.ConvTranspose1d(64, 32, 4, stride=2, padding=1),
                nn.BatchNorm1d(32),
                nn.ReLU(True),
                ResidualBlock1D(32)
            ),
            # 2256 -> 4512
            nn.Sequential(
                nn.ConvTranspose1d(32, 16, 4, stride=2, padding=1),
                nn.BatchNorm1d(16),
                nn.ReLU(True),
            ),
        ])
        
        # Attention for long-range temporal dependencies
        self.attention = SelfAttention1D(16)
        
        # Final layers for sensor-specific patterns
        self.final_layers = nn.Sequential(
            nn.Conv1d(16, output_channels, 7, padding=3),  # Larger kernel for temporal smoothness
            nn.Tanh()
        )
        
    def forward(self, z):
        # Project latent vector
        x = self.fc(z)
        x = x.view(z.size(0), 512, self.initial_length)
        
        # Progressive upsampling
        for upsample_block in self.upsample_blocks:
            x = upsample_block(x)
        
        # Apply attention for temporal coherence
        x = self.attention(x)
        
        # Final transformation
        x = self.final_layers(x)
        
        # Precise length adjustment for 4500
        current_length = x.size(-1)
        if current_length > self.target_length:
            start_idx = (current_length - self.target_length) // 2
            x = x[:, :, start_idx:start_idx + self.target_length]
        elif current_length < self.target_length:
            pad_left = (self.target_length - current_length) // 2
            pad_right = self.target_length - current_length - pad_left
            x = F.pad(x, (pad_left, pad_right), mode='reflect')
        
        return x

class ImprovedFewShotDiscriminator(nn.Module):
    """Enhanced discriminator optimized for sensor time series"""
    def __init__(self, input_channels=14):
        super(ImprovedFewShotDiscriminator, self).__init__()
        
        # Multi-scale feature extraction for different temporal patterns
        self.conv_blocks = nn.ModuleList([
            # 4500 -> 2250
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(input_channels, 64, 8, stride=2, padding=3)),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.1)
            ),
            # 2250 -> 1125
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(64, 128, 8, stride=2, padding=3)),
                nn.BatchNorm1d(128),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.1)
            ),
            # 1125 -> 562
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(128, 256, 8, stride=2, padding=3)),
                nn.BatchNorm1d(256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.2)
            ),
            # 562 -> 281
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(256, 512, 8, stride=2, padding=3)),
                nn.BatchNorm1d(512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.2)
            ),
            # 281 -> 140
            nn.Sequential(
                nn.utils.spectral_norm(nn.Conv1d(512, 512, 8, stride=2, padding=3)),
                nn.BatchNorm1d(512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout(0.3)
            )
        ])
        
        # Optimized attention for sensor patterns
        self.attention = SelfAttention1D(512)
        
        # Enhanced classifier for better discrimination
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.utils.spectral_norm(nn.Linear(256, 64)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.utils.spectral_norm(nn.Linear(64, 1)),
            nn.Sigmoid()
        )
    
    def forward(self, x, return_features=False):
        features = []
        
        # Progressive feature extraction
        for conv_block in self.conv_blocks:
            x = conv_block(x)
            if return_features:
                features.append(x)
        
        # Apply attention for important temporal patterns
        x = self.attention(x)
        if return_features:
            features.append(x)
        
        # Final classification
        output = self.classifier(x)
        
        if return_features:
            return output, features
        return output

def feature_matching_loss(real_features, fake_features):
    """Enhanced feature matching loss with temporal weighting"""
    loss = 0
    weights = [1.0, 1.5, 2.0, 2.5, 3.0, 1.0]  # Higher weight for middle layers
    
    for i, (real_feat, fake_feat) in enumerate(zip(real_features, fake_features)):
        weight = weights[i] if i < len(weights) else 1.0
        # Use both mean and variance matching for better distribution alignment
        mean_loss = nn.MSELoss()(fake_feat.mean(0), real_feat.mean(0))
        var_loss = nn.MSELoss()(fake_feat.var(0), real_feat.var(0))
        loss += weight * (mean_loss + 0.5 * var_loss)
    return loss

def gradient_penalty(discriminator, real_samples, fake_samples, device, lambda_gp=10.0):
    """Optimized gradient penalty for sensor data"""
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1).expand_as(real_samples).to(device)
    
    interpolated = alpha * real_samples + (1 - alpha) * fake_samples
    interpolated.requires_grad_(True)
    
    d_interpolated = discriminator(interpolated)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return lambda_gp * gradient_penalty

def train_improved_few_shot_gan(X_train, epochs=200, batch_size=16, latent_dim=128, 
                               lr_g=0.0002, lr_d=0.0001, device='cuda'):  # Reduced learning rates
    """
    Optimized training for sensor time series with better FID score
    """
    print("🚀 Starting Optimized Few-Shot GAN Training for Sensor Data")
    print("=" * 60)
    
    # Data loading with optimized batch size
    dataset = FewShot1DDataset(X_train)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    # Initialize optimized models
    generator = ImprovedFewShotGenerator(latent_dim=latent_dim).to(device)
    discriminator = ImprovedFewShotDiscriminator().to(device)
    
    # Balanced optimizers for sensor data - FIXED learning rates
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
    
    # Loss function
    adversarial_loss = nn.BCELoss()
    
    # Training tracking
    g_losses, d_losses = [], []
    feature_losses, gp_losses = [], []
    
    print(f"📊 Optimized Configuration:")
    print(f"   • Latent Dimension: {latent_dim}")
    print(f"   • Batch Size: {batch_size}")
    print(f"   • Learning Rates: G={lr_g}, D={lr_d}")
    print("=" * 60)
    
    for epoch in range(epochs):
        epoch_g_loss, epoch_d_loss = 0, 0
        epoch_feature_loss, epoch_gp_loss = 0, 0
        
        for i, (real_samples, _) in enumerate(dataloader):
            real_samples = real_samples.to(device)
            batch_size_current = real_samples.size(0)
            
            # ========================
            # Balanced Discriminator Training - REDUCED FREQUENCY
            # ========================
            if i % 2 == 0:  # Train discriminator every 2 iterations
                optimizer_D.zero_grad()
                
                # Real samples with STRONGER label smoothing
                real_labels = torch.ones(batch_size_current, 1, device=device) * (0.8 + 0.15 * torch.rand(batch_size_current, 1, device=device))
                real_pred = discriminator(real_samples)
                real_loss = adversarial_loss(real_pred, real_labels)
                
                # Fake samples
                z = torch.randn(batch_size_current, latent_dim, device=device)
                fake_samples = generator(z).detach()
                fake_labels = torch.zeros(batch_size_current, 1, device=device) + 0.2 * torch.rand(batch_size_current, 1, device=device)
                fake_pred = discriminator(fake_samples)
                fake_loss = adversarial_loss(fake_pred, fake_labels)
                
                # REDUCED gradient penalty
                gp = gradient_penalty(discriminator, real_samples, fake_samples, device, lambda_gp=2.0)  # Reduced from 5.0
                
                # Total discriminator loss
                d_loss = real_loss + fake_loss + gp
                d_loss.backward()
                
                # Stronger gradient clipping
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 0.5)  # Reduced from 1.0
                optimizer_D.step()
                
                epoch_d_loss += (real_loss.item() + fake_loss.item()) / 2
                epoch_gp_loss += gp.item()
            
            # ========================
            # Enhanced Generator Training - EVERY ITERATION
            # ========================
            optimizer_G.zero_grad()
            
            # Generate fake samples
            z = torch.randn(batch_size_current, latent_dim, device=device)
            fake_samples = generator(z)
            
            # Get predictions and features
            fake_pred, fake_features = discriminator(fake_samples, return_features=True)
            _, real_features = discriminator(real_samples, return_features=True)
            
            # Adversarial loss
            valid_labels = torch.ones(batch_size_current, 1, device=device)
            adv_loss = adversarial_loss(fake_pred, valid_labels)
            
            # Enhanced feature matching loss
            fm_loss = feature_matching_loss(real_features, fake_features)
            
            # REDUCED feature matching weight
            g_loss = adv_loss + 5.0 * fm_loss  # Reduced from 15.0
            g_loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 0.5)  # Reduced from 1.0
            optimizer_G.step()
            
            epoch_g_loss += adv_loss.item()
            epoch_feature_loss += fm_loss.item()
            
            if i % 50 == 0:
                print(f"[Epoch {epoch+1:3d}/{epochs}] [Batch {i:3d}/{len(dataloader)}] "
                      f"[D: {d_loss.item() if i % 2 == 0 else 'skipped':.4f}] [G: {adv_loss.item():.4f}] "
                      f"[FM: {fm_loss.item():.4f}] [GP: {gp.item() if i % 2 == 0 else 'skipped':.4f}]")
        
        # Store average losses per epoch
        g_losses.append(epoch_g_loss / len(dataloader))
        d_losses.append(epoch_d_loss / (len(dataloader) // 2))  # Adjusted for reduced D training
        feature_losses.append(epoch_feature_loss / len(dataloader))
        gp_losses.append(epoch_gp_loss / (len(dataloader) // 2))  # Adjusted for reduced D training
        
        # Enhanced monitoring every 20 epochs
        if epoch % 20 == 0:
            print(f"\n{'='*15} Epoch {epoch+1} Summary {'='*15}")
            print(f"📈 Avg Generator Loss: {g_losses[-1]:.4f}")
            print(f"📉 Avg Discriminator Loss: {d_losses[-1]:.4f}")
            print(f"🎯 Avg Feature Matching Loss: {feature_losses[-1]:.4f}")
            print(f"⚖️  Avg Gradient Penalty: {gp_losses[-1]:.4f}")
            
            # Generate samples for quality check
            with torch.no_grad():
                z_test = torch.randn(32, latent_dim, device=device)
                test_samples = generator(z_test)
                
                # Basic quality metrics
                sample_mean = test_samples.mean().item()
                sample_std = test_samples.std().item()
                sample_range = (test_samples.max() - test_samples.min()).item()
                
                print(f"📊 Generated Sample Stats:")
                print(f"   • Mean: {sample_mean:.4f}")
                print(f"   • Std: {sample_std:.4f}")
                print(f"   • Range: {sample_range:.4f}")
            
            print("=" * 50)
    
    return generator, discriminator, g_losses, d_losses, feature_losses, gp_losses


# Generate and Combine

In [4]:
# Train with optimized parameters for better FID score
generator, discriminator, g_loss, d_loss, feature_loss, gp_loss = train_improved_few_shot_gan(
    X_train_normal, 
    epochs=200, 
    batch_size=16,      # Reduced from 32 for better stability
    latent_dim=128,
    lr_g=0.0002,        # Reduced from 0.0003
    lr_d=0.0001,        # Keep same
    device=device
)

# Generate samples with correct function
generated_data = generate_samples(generator, len(X_train_normal), latent_dim=128)

print(f"✅ Generated data shape: {generated_data.shape}")
print(f"📊 Original data shape: {X_train_normal.shape}")

normal_combine = np.concatenate((X_train_normal, generated_data), axis=0)

🚀 Starting Optimized Few-Shot GAN Training for Sensor Data


📊 Optimized Configuration:
   • Latent Dimension: 128
   • Batch Size: 16
   • Learning Rates: G=0.0002, D=0.0001


[Epoch   1/200] [Batch   0/35] [D: 3.3825] [G: 0.6329] [FM: 3.9737] [GP: 1.9989]



📈 Avg Generator Loss: 0.7122
📉 Avg Discriminator Loss: 0.7190
🎯 Avg Feature Matching Loss: 5.3338
⚖️  Avg Gradient Penalty: 2.1168
📊 Generated Sample Stats:
   • Mean: 0.0761
   • Std: 0.3420
   • Range: 1.9663


[Epoch   2/200] [Batch   0/35] [D: 3.2986] [G: 0.7805] [FM: 7.9759] [GP: 1.9994]


[Epoch   3/200] [Batch   0/35] [D: 3.1545] [G: 0.9054] [FM: 8.5570] [GP: 1.9994]


[Epoch   4/200] [Batch   0/35] [D: 3.0658] [G: 1.0528] [FM: 4.9982] [GP: 1.9993]


[Epoch   5/200] [Batch   0/35] [D: 2.9538] [G: 1.1926] [FM: 4.5491] [GP: 1.9989]


[Epoch   6/200] [Batch   0/35] [D: 2.8018] [G: 1.5233] [FM: 4.8653] [GP: 1.9985]


[Epoch   7/200] [Batch   0/35] [D: 2.8969] [G: 2.0165] [FM: 6.5069] [GP: 1.9965]


[Epoch   8/200] [Batch   0/35] [D: 2.7374] [G: 2.1600] [FM: 8.3356] [GP: 1.9990]


[Epoch   9/200] [Batch   0/35] [D: 2.7511] [G: 2.1751] [FM: 5.6983] [GP: 1.9963]


[Epoch  10/200] [Batch   0/35] [D: 2.8138] [G: 1.4767] [FM: 3.7622] [GP: 1.9978]


[Epoch  11/200] [Batch   0/35] [D: 3.1799] [G: 2.3301] [FM: 4.5370] [GP: 1.9982]


[Epoch  12/200] [Batch   0/35] [D: 2.6692] [G: 2.6820] [FM: 4.0254] [GP: 1.9971]


[Epoch  13/200] [Batch   0/35] [D: 2.7259] [G: 2.4412] [FM: 3.8727] [GP: 1.9941]


[Epoch  14/200] [Batch   0/35] [D: 2.7350] [G: 2.3367] [FM: 3.7981] [GP: 1.9941]


[Epoch  15/200] [Batch   0/35] [D: 2.7530] [G: 1.9884] [FM: 2.9949] [GP: 1.9934]


[Epoch  16/200] [Batch   0/35] [D: 2.7451] [G: 2.0411] [FM: 2.5865] [GP: 1.9926]


[Epoch  17/200] [Batch   0/35] [D: 2.9100] [G: 2.4745] [FM: 2.8392] [GP: 1.9989]


[Epoch  18/200] [Batch   0/35] [D: 2.7829] [G: 2.2714] [FM: 2.5161] [GP: 1.9908]


[Epoch  19/200] [Batch   0/35] [D: 2.7546] [G: 2.2839] [FM: 3.1424] [GP: 1.9949]


[Epoch  20/200] [Batch   0/35] [D: 2.7158] [G: 2.4093] [FM: 2.9032] [GP: 1.9892]


[Epoch  21/200] [Batch   0/35] [D: 2.8895] [G: 2.4305] [FM: 2.5675] [GP: 1.9980]



📈 Avg Generator Loss: 2.2263
📉 Avg Discriminator Loss: 0.4264
🎯 Avg Feature Matching Loss: 2.8749
⚖️  Avg Gradient Penalty: 2.1094
📊 Generated Sample Stats:
   • Mean: -0.0319
   • Std: 0.2687
   • Range: 1.9998


[Epoch  22/200] [Batch   0/35] [D: 2.8493] [G: 2.1545] [FM: 2.5864] [GP: 1.9980]


[Epoch  23/200] [Batch   0/35] [D: 2.7004] [G: 2.5201] [FM: 2.2569] [GP: 1.9916]


[Epoch  24/200] [Batch   0/35] [D: 2.7451] [G: 2.4573] [FM: 2.4742] [GP: 1.9984]


[Epoch  25/200] [Batch   0/35] [D: 2.8378] [G: 2.1584] [FM: 2.5580] [GP: 1.9944]


[Epoch  26/200] [Batch   0/35] [D: 2.7080] [G: 2.3763] [FM: 3.0824] [GP: 1.9945]


[Epoch  27/200] [Batch   0/35] [D: 2.7163] [G: 2.2912] [FM: 3.2712] [GP: 1.9896]


[Epoch  28/200] [Batch   0/35] [D: 2.6900] [G: 2.2921] [FM: 2.9568] [GP: 1.9910]


[Epoch  29/200] [Batch   0/35] [D: 2.6098] [G: 2.3373] [FM: 2.8272] [GP: 1.9789]


[Epoch  30/200] [Batch   0/35] [D: 3.0364] [G: 2.1372] [FM: 3.3234] [GP: 1.9981]


[Epoch  31/200] [Batch   0/35] [D: 2.6969] [G: 2.5263] [FM: 3.7672] [GP: 1.9992]


[Epoch  32/200] [Batch   0/35] [D: 2.6857] [G: 2.3851] [FM: 3.0983] [GP: 1.9963]


[Epoch  33/200] [Batch   0/35] [D: 2.7180] [G: 2.3212] [FM: 2.9408] [GP: 1.9924]


[Epoch  34/200] [Batch   0/35] [D: 2.7795] [G: 2.2383] [FM: 2.7322] [GP: 1.9986]


[Epoch  35/200] [Batch   0/35] [D: 3.5558] [G: 2.2054] [FM: 3.0212] [GP: 1.9985]


[Epoch  36/200] [Batch   0/35] [D: 2.7139] [G: 2.2159] [FM: 3.4460] [GP: 1.9946]


[Epoch  37/200] [Batch   0/35] [D: 2.6388] [G: 2.2536] [FM: 2.9584] [GP: 1.9929]


[Epoch  38/200] [Batch   0/35] [D: 3.0183] [G: 2.5885] [FM: 3.2406] [GP: 1.9902]


[Epoch  39/200] [Batch   0/35] [D: 2.7119] [G: 2.1084] [FM: 3.4980] [GP: 1.9916]


[Epoch  40/200] [Batch   0/35] [D: 2.6892] [G: 2.3770] [FM: 3.0883] [GP: 1.9899]


[Epoch  41/200] [Batch   0/35] [D: 2.7852] [G: 2.4382] [FM: 3.4470] [GP: 1.9926]



📈 Avg Generator Loss: 2.2304
📉 Avg Discriminator Loss: 0.3856
🎯 Avg Feature Matching Loss: 3.1753
⚖️  Avg Gradient Penalty: 2.1111
📊 Generated Sample Stats:
   • Mean: 0.0122
   • Std: 0.3298
   • Range: 1.9994


[Epoch  42/200] [Batch   0/35] [D: 2.6247] [G: 1.9883] [FM: 3.4049] [GP: 1.9963]


[Epoch  43/200] [Batch   0/35] [D: 2.6634] [G: 2.4997] [FM: 3.2700] [GP: 1.9825]


[Epoch  44/200] [Batch   0/35] [D: 2.7139] [G: 2.3652] [FM: 4.2476] [GP: 1.9993]


[Epoch  45/200] [Batch   0/35] [D: 2.7129] [G: 2.2704] [FM: 3.4882] [GP: 1.9948]


[Epoch  46/200] [Batch   0/35] [D: 2.6840] [G: 2.3254] [FM: 3.4738] [GP: 1.9916]


[Epoch  47/200] [Batch   0/35] [D: 2.6957] [G: 2.3707] [FM: 3.0750] [GP: 1.9908]


[Epoch  48/200] [Batch   0/35] [D: 2.9654] [G: 2.0074] [FM: 3.0905] [GP: 1.9979]


[Epoch  49/200] [Batch   0/35] [D: 2.8136] [G: 2.4060] [FM: 2.4107] [GP: 1.9961]


[Epoch  50/200] [Batch   0/35] [D: 2.7422] [G: 2.0861] [FM: 2.7275] [GP: 1.9896]


[Epoch  51/200] [Batch   0/35] [D: 2.7852] [G: 2.2865] [FM: 3.5614] [GP: 1.9990]


[Epoch  52/200] [Batch   0/35] [D: 2.6625] [G: 2.4203] [FM: 2.7609] [GP: 1.9918]


[Epoch  53/200] [Batch   0/35] [D: 2.7342] [G: 2.2110] [FM: 2.7889] [GP: 1.9919]


[Epoch  54/200] [Batch   0/35] [D: 2.7252] [G: 2.1682] [FM: 2.3870] [GP: 1.9975]


[Epoch  55/200] [Batch   0/35] [D: 2.6987] [G: 2.1857] [FM: 2.8774] [GP: 1.9975]


[Epoch  56/200] [Batch   0/35] [D: 2.6998] [G: 2.2683] [FM: 2.0927] [GP: 1.9973]


[Epoch  57/200] [Batch   0/35] [D: 2.9295] [G: 2.6946] [FM: 2.0798] [GP: 1.9975]


[Epoch  58/200] [Batch   0/35] [D: 2.6905] [G: 2.4014] [FM: 2.7537] [GP: 1.9938]


[Epoch  59/200] [Batch   0/35] [D: 2.7100] [G: 2.1043] [FM: 3.0827] [GP: 1.9961]


[Epoch  60/200] [Batch   0/35] [D: 2.7485] [G: 2.2140] [FM: 2.7886] [GP: 1.9947]


[Epoch  61/200] [Batch   0/35] [D: 2.7943] [G: 2.0643] [FM: 2.1287] [GP: 1.9985]



📈 Avg Generator Loss: 2.2670
📉 Avg Discriminator Loss: 0.3774
🎯 Avg Feature Matching Loss: 2.9869
⚖️  Avg Gradient Penalty: 2.1118
📊 Generated Sample Stats:
   • Mean: 0.0142
   • Std: 0.3653
   • Range: 1.9999


[Epoch  62/200] [Batch   0/35] [D: 2.6387] [G: 2.4038] [FM: 3.6838] [GP: 1.9962]


[Epoch  63/200] [Batch   0/35] [D: 2.7611] [G: 2.3676] [FM: 3.0348] [GP: 1.9982]


[Epoch  64/200] [Batch   0/35] [D: 2.6878] [G: 2.2958] [FM: 3.5725] [GP: 1.9908]


[Epoch  65/200] [Batch   0/35] [D: 2.6498] [G: 2.4375] [FM: 3.5736] [GP: 1.9977]


[Epoch  66/200] [Batch   0/35] [D: 2.7097] [G: 2.1760] [FM: 3.2291] [GP: 1.9928]


[Epoch  67/200] [Batch   0/35] [D: 2.7563] [G: 2.0851] [FM: 3.6008] [GP: 1.9909]


[Epoch  68/200] [Batch   0/35] [D: 2.8417] [G: 2.2182] [FM: 3.2947] [GP: 1.9917]


[Epoch  69/200] [Batch   0/35] [D: 2.7338] [G: 2.2791] [FM: 2.9606] [GP: 1.9917]


[Epoch  70/200] [Batch   0/35] [D: 2.7116] [G: 2.2901] [FM: 2.6315] [GP: 1.9923]


[Epoch  71/200] [Batch   0/35] [D: 2.8001] [G: 2.4410] [FM: 3.1814] [GP: 1.9901]


[Epoch  72/200] [Batch   0/35] [D: 2.6955] [G: 2.4043] [FM: 2.6522] [GP: 1.9973]


[Epoch  73/200] [Batch   0/35] [D: 2.7138] [G: 2.4166] [FM: 2.2350] [GP: 1.9889]


[Epoch  74/200] [Batch   0/35] [D: 2.7405] [G: 2.3580] [FM: 2.8986] [GP: 1.9947]


[Epoch  75/200] [Batch   0/35] [D: 2.7102] [G: 2.3008] [FM: 3.2666] [GP: 1.9944]


[Epoch  76/200] [Batch   0/35] [D: 2.7149] [G: 2.4562] [FM: 3.4864] [GP: 1.9873]


[Epoch  77/200] [Batch   0/35] [D: 2.7181] [G: 2.4777] [FM: 3.4231] [GP: 1.9952]


[Epoch  78/200] [Batch   0/35] [D: 2.6758] [G: 2.4087] [FM: 2.4607] [GP: 1.9921]


[Epoch  79/200] [Batch   0/35] [D: 2.7115] [G: 2.1484] [FM: 2.7364] [GP: 1.9924]


[Epoch  80/200] [Batch   0/35] [D: 2.7048] [G: 2.2926] [FM: 3.8026] [GP: 1.9965]


[Epoch  81/200] [Batch   0/35] [D: 2.7037] [G: 2.2665] [FM: 3.1665] [GP: 1.9900]



📈 Avg Generator Loss: 2.2221
📉 Avg Discriminator Loss: 0.4032
🎯 Avg Feature Matching Loss: 3.1223
⚖️  Avg Gradient Penalty: 2.1101
📊 Generated Sample Stats:
   • Mean: -0.0274
   • Std: 0.4009
   • Range: 1.9999


[Epoch  82/200] [Batch   0/35] [D: 2.7750] [G: 2.2954] [FM: 2.8284] [GP: 1.9910]


[Epoch  83/200] [Batch   0/35] [D: 2.6501] [G: 2.2487] [FM: 3.6929] [GP: 1.9975]


[Epoch  84/200] [Batch   0/35] [D: 2.7497] [G: 2.2037] [FM: 3.3416] [GP: 1.9989]


[Epoch  85/200] [Batch   0/35] [D: 2.8812] [G: 2.2442] [FM: 3.3152] [GP: 1.9983]


[Epoch  86/200] [Batch   0/35] [D: 2.6987] [G: 2.3668] [FM: 3.3302] [GP: 1.9925]


[Epoch  87/200] [Batch   0/35] [D: 3.1075] [G: 2.3133] [FM: 3.1692] [GP: 1.9978]


[Epoch  88/200] [Batch   0/35] [D: 2.6789] [G: 1.8960] [FM: 3.4609] [GP: 1.9910]


[Epoch  89/200] [Batch   0/35] [D: 2.7558] [G: 2.4183] [FM: 3.7058] [GP: 1.9986]


[Epoch  90/200] [Batch   0/35] [D: 2.6838] [G: 2.4895] [FM: 3.2768] [GP: 1.9881]


[Epoch  91/200] [Batch   0/35] [D: 2.7095] [G: 2.2192] [FM: 5.0208] [GP: 1.9970]


[Epoch  92/200] [Batch   0/35] [D: 2.7418] [G: 2.3924] [FM: 4.0850] [GP: 1.9974]


[Epoch  93/200] [Batch   0/35] [D: 2.7493] [G: 2.5443] [FM: 2.6801] [GP: 1.9942]


[Epoch  94/200] [Batch   0/35] [D: 2.6891] [G: 2.1939] [FM: 3.8593] [GP: 1.9937]


[Epoch  95/200] [Batch   0/35] [D: 2.6739] [G: 2.1819] [FM: 3.8537] [GP: 1.9974]


[Epoch  96/200] [Batch   0/35] [D: 2.6885] [G: 2.0400] [FM: 5.3249] [GP: 1.9990]


[Epoch  97/200] [Batch   0/35] [D: 2.7612] [G: 2.5997] [FM: 3.8486] [GP: 1.9969]


[Epoch  98/200] [Batch   0/35] [D: 2.7052] [G: 2.3560] [FM: 4.2299] [GP: 1.9976]


[Epoch  99/200] [Batch   0/35] [D: 2.7265] [G: 2.2994] [FM: 3.8375] [GP: 1.9946]


[Epoch 100/200] [Batch   0/35] [D: 3.0070] [G: 2.1187] [FM: 5.1350] [GP: 1.9915]


[Epoch 101/200] [Batch   0/35] [D: 2.6992] [G: 2.3188] [FM: 3.7636] [GP: 1.9938]



📈 Avg Generator Loss: 2.2601
📉 Avg Discriminator Loss: 0.3777
🎯 Avg Feature Matching Loss: 3.8635
⚖️  Avg Gradient Penalty: 2.1133
📊 Generated Sample Stats:
   • Mean: 0.0265
   • Std: 0.3849
   • Range: 1.9999


[Epoch 102/200] [Batch   0/35] [D: 2.6830] [G: 2.4186] [FM: 3.7713] [GP: 1.9981]


[Epoch 103/200] [Batch   0/35] [D: 3.0662] [G: 2.4537] [FM: 4.5586] [GP: 1.9985]


[Epoch 104/200] [Batch   0/35] [D: 2.7364] [G: 2.1884] [FM: 4.2334] [GP: 1.9964]


[Epoch 105/200] [Batch   0/35] [D: 2.8051] [G: 2.6534] [FM: 4.3444] [GP: 1.9983]


[Epoch 106/200] [Batch   0/35] [D: 2.6824] [G: 2.6930] [FM: 2.2041] [GP: 1.9917]


[Epoch 107/200] [Batch   0/35] [D: 2.7847] [G: 2.0495] [FM: 3.4991] [GP: 1.9920]


[Epoch 108/200] [Batch   0/35] [D: 2.7016] [G: 2.0732] [FM: 3.6561] [GP: 1.9898]


[Epoch 109/200] [Batch   0/35] [D: 2.8115] [G: 2.1751] [FM: 3.8513] [GP: 1.9945]


[Epoch 110/200] [Batch   0/35] [D: 2.7201] [G: 2.2862] [FM: 4.2062] [GP: 1.9986]


[Epoch 111/200] [Batch   0/35] [D: 2.7150] [G: 2.4149] [FM: 3.0039] [GP: 1.9960]


[Epoch 112/200] [Batch   0/35] [D: 2.7312] [G: 2.3560] [FM: 3.9893] [GP: 1.9947]


[Epoch 113/200] [Batch   0/35] [D: 2.6798] [G: 2.2858] [FM: 3.3356] [GP: 1.9985]


[Epoch 114/200] [Batch   0/35] [D: 2.7069] [G: 1.9548] [FM: 4.4964] [GP: 1.9920]


[Epoch 115/200] [Batch   0/35] [D: 2.8657] [G: 2.1801] [FM: 4.6582] [GP: 1.9924]


[Epoch 116/200] [Batch   0/35] [D: 2.7327] [G: 1.9576] [FM: 4.0840] [GP: 1.9981]


[Epoch 117/200] [Batch   0/35] [D: 2.8639] [G: 2.2405] [FM: 4.2763] [GP: 1.9989]


[Epoch 118/200] [Batch   0/35] [D: 2.7234] [G: 2.3776] [FM: 4.1355] [GP: 1.9983]


[Epoch 119/200] [Batch   0/35] [D: 2.7326] [G: 2.1908] [FM: 2.8370] [GP: 1.9977]


[Epoch 120/200] [Batch   0/35] [D: 2.7131] [G: 2.4973] [FM: 3.3694] [GP: 1.9946]


[Epoch 121/200] [Batch   0/35] [D: 2.7214] [G: 2.2792] [FM: 3.1587] [GP: 1.9879]



📈 Avg Generator Loss: 2.3248
📉 Avg Discriminator Loss: 0.4001
🎯 Avg Feature Matching Loss: 3.3161
⚖️  Avg Gradient Penalty: 2.1131
📊 Generated Sample Stats:
   • Mean: -0.0015
   • Std: 0.3888
   • Range: 2.0000


[Epoch 122/200] [Batch   0/35] [D: 2.7208] [G: 2.2761] [FM: 3.2115] [GP: 1.9974]


[Epoch 123/200] [Batch   0/35] [D: 2.7372] [G: 2.2926] [FM: 3.8065] [GP: 1.9898]


[Epoch 124/200] [Batch   0/35] [D: 2.7825] [G: 2.1719] [FM: 3.0598] [GP: 1.9989]


[Epoch 125/200] [Batch   0/35] [D: 2.6702] [G: 2.1271] [FM: 3.3433] [GP: 1.9906]


[Epoch 126/200] [Batch   0/35] [D: 2.7071] [G: 2.3735] [FM: 4.8888] [GP: 1.9990]


[Epoch 127/200] [Batch   0/35] [D: 2.6889] [G: 2.3878] [FM: 4.6703] [GP: 1.9911]


[Epoch 128/200] [Batch   0/35] [D: 2.6922] [G: 2.3717] [FM: 3.9571] [GP: 1.9926]


[Epoch 129/200] [Batch   0/35] [D: 2.7736] [G: 2.0853] [FM: 2.7631] [GP: 1.9831]


[Epoch 130/200] [Batch   0/35] [D: 2.6964] [G: 2.0750] [FM: 4.1935] [GP: 1.9848]


[Epoch 131/200] [Batch   0/35] [D: 2.7400] [G: 2.1829] [FM: 3.7818] [GP: 1.9928]


[Epoch 132/200] [Batch   0/35] [D: 2.6258] [G: 2.3028] [FM: 3.6746] [GP: 1.9911]


[Epoch 133/200] [Batch   0/35] [D: 2.7314] [G: 2.1717] [FM: 4.4121] [GP: 1.9986]


[Epoch 134/200] [Batch   0/35] [D: 2.7009] [G: 2.1920] [FM: 2.8886] [GP: 1.9964]


[Epoch 135/200] [Batch   0/35] [D: 2.7762] [G: 2.4133] [FM: 3.1368] [GP: 1.9984]


[Epoch 136/200] [Batch   0/35] [D: 2.6721] [G: 2.2947] [FM: 5.2851] [GP: 1.9982]


[Epoch 137/200] [Batch   0/35] [D: 2.7723] [G: 2.2329] [FM: 4.2340] [GP: 1.9985]


[Epoch 138/200] [Batch   0/35] [D: 2.6834] [G: 2.3135] [FM: 3.7136] [GP: 1.9919]


[Epoch 139/200] [Batch   0/35] [D: 2.7621] [G: 2.3538] [FM: 4.7616] [GP: 1.9988]


[Epoch 140/200] [Batch   0/35] [D: 2.9321] [G: 2.3197] [FM: 3.3324] [GP: 1.9940]


[Epoch 141/200] [Batch   0/35] [D: 2.7324] [G: 2.4023] [FM: 4.6630] [GP: 1.9995]



📈 Avg Generator Loss: 2.1904
📉 Avg Discriminator Loss: 0.3898
🎯 Avg Feature Matching Loss: 3.7627
⚖️  Avg Gradient Penalty: 2.1128
📊 Generated Sample Stats:
   • Mean: 0.0011
   • Std: 0.4144
   • Range: 2.0000


[Epoch 142/200] [Batch   0/35] [D: 2.7386] [G: 2.1745] [FM: 3.6050] [GP: 1.9957]


[Epoch 143/200] [Batch   0/35] [D: 2.6795] [G: 2.1941] [FM: 4.4074] [GP: 1.9932]


[Epoch 144/200] [Batch   0/35] [D: 2.7776] [G: 2.2714] [FM: 3.8751] [GP: 1.9978]


[Epoch 145/200] [Batch   0/35] [D: 2.7204] [G: 2.4014] [FM: 4.8228] [GP: 1.9851]


[Epoch 146/200] [Batch   0/35] [D: 2.7867] [G: 2.1695] [FM: 3.9131] [GP: 1.9922]


[Epoch 147/200] [Batch   0/35] [D: 2.7617] [G: 2.2981] [FM: 3.9737] [GP: 1.9961]


[Epoch 148/200] [Batch   0/35] [D: 3.0064] [G: 2.3855] [FM: 3.0736] [GP: 1.9988]


[Epoch 149/200] [Batch   0/35] [D: 2.6316] [G: 2.1570] [FM: 3.6829] [GP: 1.9925]


[Epoch 150/200] [Batch   0/35] [D: 2.7333] [G: 2.1827] [FM: 3.0430] [GP: 1.9898]


[Epoch 151/200] [Batch   0/35] [D: 2.7831] [G: 2.2289] [FM: 3.6311] [GP: 1.9940]


[Epoch 152/200] [Batch   0/35] [D: 2.6977] [G: 2.3358] [FM: 3.6176] [GP: 1.9885]


[Epoch 153/200] [Batch   0/35] [D: 2.7344] [G: 2.2497] [FM: 5.4837] [GP: 1.9989]


[Epoch 154/200] [Batch   0/35] [D: 2.7076] [G: 2.2989] [FM: 3.3865] [GP: 1.9870]


[Epoch 155/200] [Batch   0/35] [D: 2.7221] [G: 2.4640] [FM: 5.0106] [GP: 1.9959]


[Epoch 156/200] [Batch   0/35] [D: 2.7577] [G: 2.2039] [FM: 3.7936] [GP: 1.9936]


[Epoch 157/200] [Batch   0/35] [D: 2.7375] [G: 2.3522] [FM: 5.5713] [GP: 1.9891]


[Epoch 158/200] [Batch   0/35] [D: 2.7586] [G: 2.3968] [FM: 4.1489] [GP: 1.9869]


[Epoch 159/200] [Batch   0/35] [D: 2.7411] [G: 2.4650] [FM: 4.0833] [GP: 1.9972]


[Epoch 160/200] [Batch   0/35] [D: 2.7386] [G: 2.2423] [FM: 3.6460] [GP: 1.9944]


[Epoch 161/200] [Batch   0/35] [D: 2.6932] [G: 2.3217] [FM: 4.5475] [GP: 1.9963]



📈 Avg Generator Loss: 2.3176
📉 Avg Discriminator Loss: 0.3779
🎯 Avg Feature Matching Loss: 4.6619
⚖️  Avg Gradient Penalty: 2.1136
📊 Generated Sample Stats:
   • Mean: -0.0120
   • Std: 0.3898
   • Range: 1.9995


[Epoch 162/200] [Batch   0/35] [D: 2.6986] [G: 2.2939] [FM: 5.3471] [GP: 1.9990]


[Epoch 163/200] [Batch   0/35] [D: 2.7658] [G: 2.2095] [FM: 3.6444] [GP: 1.9991]


[Epoch 164/200] [Batch   0/35] [D: 2.8001] [G: 2.2237] [FM: 3.9249] [GP: 1.9984]


[Epoch 165/200] [Batch   0/35] [D: 2.7395] [G: 2.2309] [FM: 3.4608] [GP: 1.9954]


[Epoch 166/200] [Batch   0/35] [D: 2.7294] [G: 2.3213] [FM: 4.7094] [GP: 1.9990]


[Epoch 167/200] [Batch   0/35] [D: 2.6990] [G: 2.2714] [FM: 3.1309] [GP: 1.9907]


[Epoch 168/200] [Batch   0/35] [D: 2.7731] [G: 2.2047] [FM: 4.1028] [GP: 1.9988]


[Epoch 169/200] [Batch   0/35] [D: 2.7346] [G: 2.1461] [FM: 3.5836] [GP: 1.9965]


[Epoch 170/200] [Batch   0/35] [D: 2.7140] [G: 2.2549] [FM: 4.4626] [GP: 1.9981]


[Epoch 171/200] [Batch   0/35] [D: 2.7305] [G: 2.3472] [FM: 5.6257] [GP: 1.9959]


[Epoch 172/200] [Batch   0/35] [D: 2.7225] [G: 2.3296] [FM: 5.2217] [GP: 1.9945]


[Epoch 173/200] [Batch   0/35] [D: 2.6590] [G: 2.3998] [FM: 3.8723] [GP: 1.9952]


[Epoch 174/200] [Batch   0/35] [D: 2.7120] [G: 2.3268] [FM: 4.4978] [GP: 1.9898]


[Epoch 175/200] [Batch   0/35] [D: 2.7017] [G: 2.1598] [FM: 3.4721] [GP: 1.9974]


[Epoch 176/200] [Batch   0/35] [D: 2.7504] [G: 2.2175] [FM: 4.2368] [GP: 1.9883]


[Epoch 177/200] [Batch   0/35] [D: 2.6843] [G: 2.2812] [FM: 4.8363] [GP: 1.9956]


[Epoch 178/200] [Batch   0/35] [D: 2.7144] [G: 2.2902] [FM: 4.0480] [GP: 1.9975]


[Epoch 179/200] [Batch   0/35] [D: 2.6565] [G: 2.2193] [FM: 4.6675] [GP: 1.9859]


[Epoch 180/200] [Batch   0/35] [D: 2.6974] [G: 2.3910] [FM: 4.4958] [GP: 1.9904]


[Epoch 181/200] [Batch   0/35] [D: 2.7382] [G: 2.2580] [FM: 5.3930] [GP: 1.9973]



📈 Avg Generator Loss: 2.2863
📉 Avg Discriminator Loss: 0.3902
🎯 Avg Feature Matching Loss: 4.4508
⚖️  Avg Gradient Penalty: 2.1129
📊 Generated Sample Stats:
   • Mean: 0.0035
   • Std: 0.4118
   • Range: 1.9999


[Epoch 182/200] [Batch   0/35] [D: 2.7441] [G: 2.2520] [FM: 3.8900] [GP: 1.9985]


[Epoch 183/200] [Batch   0/35] [D: 2.7227] [G: 2.2235] [FM: 4.3666] [GP: 1.9959]


[Epoch 184/200] [Batch   0/35] [D: 2.7396] [G: 2.1903] [FM: 4.5416] [GP: 1.9977]


[Epoch 185/200] [Batch   0/35] [D: 2.6680] [G: 2.3767] [FM: 3.4445] [GP: 1.9900]


[Epoch 186/200] [Batch   0/35] [D: 2.7828] [G: 2.4320] [FM: 4.1014] [GP: 1.9969]


[Epoch 187/200] [Batch   0/35] [D: 2.6831] [G: 2.2673] [FM: 3.8106] [GP: 1.9898]


[Epoch 188/200] [Batch   0/35] [D: 2.6389] [G: 2.5620] [FM: 4.6901] [GP: 1.9969]


[Epoch 189/200] [Batch   0/35] [D: 2.7318] [G: 2.4190] [FM: 4.4285] [GP: 1.9976]


[Epoch 190/200] [Batch   0/35] [D: 2.6998] [G: 2.5630] [FM: 2.4715] [GP: 1.9985]


[Epoch 191/200] [Batch   0/35] [D: 2.7513] [G: 2.1595] [FM: 3.3282] [GP: 1.9965]


[Epoch 192/200] [Batch   0/35] [D: 2.7261] [G: 2.0831] [FM: 3.1040] [GP: 1.9951]


[Epoch 193/200] [Batch   0/35] [D: 2.7846] [G: 2.2151] [FM: 3.8590] [GP: 1.9992]


[Epoch 194/200] [Batch   0/35] [D: 3.0544] [G: 2.2634] [FM: 3.0034] [GP: 1.9990]


[Epoch 195/200] [Batch   0/35] [D: 2.6410] [G: 2.4290] [FM: 4.5837] [GP: 1.9994]


[Epoch 196/200] [Batch   0/35] [D: 2.7605] [G: 2.3257] [FM: 4.7828] [GP: 1.9957]


[Epoch 197/200] [Batch   0/35] [D: 2.6994] [G: 2.2838] [FM: 2.6765] [GP: 1.9872]


[Epoch 198/200] [Batch   0/35] [D: 2.7324] [G: 2.1999] [FM: 3.1639] [GP: 1.9899]


[Epoch 199/200] [Batch   0/35] [D: 2.7298] [G: 2.4099] [FM: 3.6510] [GP: 1.9914]


[Epoch 200/200] [Batch   0/35] [D: 2.7092] [G: 2.3408] [FM: 3.8412] [GP: 1.9955]


✅ Generated data shape: (552, 4500, 14)
📊 Original data shape: (552, 4500, 14)


In [5]:
# ===============================
# FID SCORE EVALUATION
# ===============================

# # Test the simplified FID calculation
# print("Testing simplified FID calculation...")

# # Use smaller subsets for testing
# test_real = X_train_normal[:100]  # Use 100 samples for testing
# test_generated = generated_data[:100]

# print(f"Test real data shape: {test_real.shape}")
# print(f"Test generated data shape: {test_generated.shape}")

# # Calculate FID score
# fid_score = calculate_fid_score(
#     real_data=test_real,
#     fake_data=test_generated,
#     device=device,
#     sample_rate=1000,
# )

# if fid_score is not None:
#     print(f"\n🎉 SUCCESS! FID Score: {fid_score:.4f}")
    
#     # Interpret the score
#     if fid_score < 10:
#         quality = "Excellent"
#     elif fid_score < 25:
#         quality = "Good"
#     elif fid_score < 50:
#         quality = "Fair"
#     elif fid_score < 100:
#         quality = "Poor"
#     else:
#         quality = "Very Poor"
    
#     print(f"Quality Assessment: {quality}")
# else:
#     print("❌ FID calculation failed. Please check the error messages above.")

In [6]:
run_pipeline_with_cv(normal_combine, X_test_normal, X_test_faulty, 
                     device=device, batch_size=64, num_epochs=20)

Extracting features from all samples...



Starting 5-fold cross-validation...

Fold 1/5
------------------------------
Train normal samples: 110
Test samples: 35 (28 normal, 7 anomaly)


  Epoch 1/20, Loss: 0.289003


  Epoch 6/20, Loss: 0.115891


  Epoch 11/20, Loss: 0.071878


  Epoch 16/20, Loss: 0.068603


  Epoch 20/20, Loss: 0.052292


Results: Acc=0.9143, Prec=1.0000, Rec=0.5714, F1=0.7273
Optimal threshold: 0.039887

Fold 2/5
------------------------------
Train normal samples: 110
Test samples: 35 (28 normal, 7 anomaly)


  Epoch 1/20, Loss: 0.289330


  Epoch 6/20, Loss: 0.111455


  Epoch 11/20, Loss: 0.070788


  Epoch 16/20, Loss: 0.067106


  Epoch 20/20, Loss: 0.047838


Results: Acc=0.8571, Prec=1.0000, Rec=0.2857, F1=0.4444
Optimal threshold: 0.036746

Fold 3/5
------------------------------
Train normal samples: 110
Test samples: 35 (28 normal, 7 anomaly)


  Epoch 1/20, Loss: 0.289345


  Epoch 6/20, Loss: 0.109353


  Epoch 11/20, Loss: 0.071064


  Epoch 16/20, Loss: 0.061023


  Epoch 20/20, Loss: 0.041490


Results: Acc=0.9429, Prec=0.8571, Rec=0.8571, F1=0.8571
Optimal threshold: 0.028564

Fold 4/5
------------------------------
Train normal samples: 111
Test samples: 35 (27 normal, 8 anomaly)


  Epoch 1/20, Loss: 0.289061


  Epoch 6/20, Loss: 0.114926


  Epoch 11/20, Loss: 0.071152


  Epoch 16/20, Loss: 0.069969


  Epoch 20/20, Loss: 0.068915


Results: Acc=0.9429, Prec=1.0000, Rec=0.7500, F1=0.8571
Optimal threshold: 0.053738

Fold 5/5
------------------------------
Train normal samples: 111
Test samples: 35 (27 normal, 8 anomaly)


  Epoch 1/20, Loss: 0.289209


  Epoch 6/20, Loss: 0.109296


  Epoch 11/20, Loss: 0.071024


  Epoch 16/20, Loss: 0.069820


  Epoch 20/20, Loss: 0.066777


Results: Acc=0.8857, Prec=1.0000, Rec=0.5000, F1=0.6667
Optimal threshold: 0.051545

CROSS-VALIDATION RESULTS SUMMARY

FOLD-BY-FOLD RESULTS:
--------------------------------------------------------------------------------
Fold   Accuracy   Precision   Recall   F1-Score  Threshold   
--------------------------------------------------------------------------------
1      0.9143     1.0000      0.5714   0.7273    0.039887    
2      0.8571     1.0000      0.2857   0.4444    0.036746    
3      0.9429     0.8571      0.8571   0.8571    0.028564    
4      0.9429     1.0000      0.7500   0.8571    0.053738    
5      0.8857     1.0000      0.5000   0.6667    0.051545    

STATISTICAL SUMMARY:
--------------------------------------------------------------------------------
Metric       Mean     Std      Min      Max      Median  
--------------------------------------------------------------------------------
Accuracy     0.9086   0.0333   0.8571   0.9429   0.9143  
Precision    0.9714   0.0

{'fold_results': [{'fold': 1,
   'train_samples': 110,
   'test_samples': 35,
   'final_train_loss': 0.05229172855615616,
   'optimal_threshold': 0.0398870822456148,
   'accuracy': 0.9142857142857143,
   'precision': 1.0,
   'recall': 0.5714285714285714,
   'f1': 0.7272727272727273},
  {'fold': 2,
   'train_samples': 110,
   'test_samples': 35,
   'final_train_loss': 0.0478384343907237,
   'optimal_threshold': 0.036745875057849015,
   'accuracy': 0.8571428571428571,
   'precision': 1.0,
   'recall': 0.2857142857142857,
   'f1': 0.4444444444444445},
  {'fold': 3,
   'train_samples': 110,
   'test_samples': 35,
   'final_train_loss': 0.04148967862129212,
   'optimal_threshold': 0.028564112505527456,
   'accuracy': 0.9428571428571428,
   'precision': 0.8571428571428571,
   'recall': 0.8571428571428571,
   'f1': 0.8571428571428571},
  {'fold': 4,
   'train_samples': 111,
   'test_samples': 35,
   'final_train_loss': 0.06891531608998776,
   'optimal_threshold': 0.053738499681154885,
   'acc