In [None]:
import torch, torchaudio, torchvision.transforms as transforms, matplotlib.pyplot as plt, torch.nn as nn, torch.optim as optim, numpy as np, os
from torchvision.models import vgg16, VGG16_Weights
from torch.utils.data import DataLoader, TensorDataset
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import sys
sys.path.append("../")
from ad_utils import *

print(torch.cuda.device_count())
cuda0 = torch.device("cuda:0")
cuda1 = torch.device("cuda:1")
device = cuda1
print(torch.cuda.get_device_name(device) if torch.cuda.is_available() else "No GPU available")
data = np.load("../../hvcm/RFQ.npy", allow_pickle=True)
label = np.load("../../hvcm/RFQ_labels.npy", allow_pickle=True)
label = label[:, 1]  # Assuming the second column is the label
label = (label == "Fault").astype(int)  # Convert to binary labels
print(data.shape, label.shape)

scaler = StandardScaler()
data = scaler.fit_transform(data.reshape(-1, data.shape[-1])).reshape(data.shape)

normal_data = data[label == 0]
faulty_data = data[label == 1]

normal_label = label[label == 0]
faulty_label = label[label == 1]

X_train_normal, X_test_normal, y_train_normal, y_test_normal = train_test_split(normal_data, normal_label, test_size=0.2, random_state=42, shuffle=True)
X_train_faulty, X_test_faulty, y_train_faulty, y_test_faulty = train_test_split(faulty_data, faulty_label, test_size=0.2, random_state=42, shuffle=True)

In [None]:
# ===============================
# MADGAN ARCHITECTURE FOR MULTIVARIATE TIME SERIES ANOMALY DETECTION
# ===============================

class MemoryEfficientMADGAN(nn.Module):
    def __init__(self, input_dim, latent_dim=64, sequence_length=None):
        super(MemoryEfficientMADGAN, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.sequence_length = sequence_length or input_dim
        
        # Memory-efficient Generator with residual connections
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, input_dim),
            nn.Tanh()
        )
        
        # Lightweight Discriminator
        self.discriminator = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # Encoder for anomaly detection
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, latent_dim)
        )
    
    def forward(self, x):
        return self.discriminator(x)
    
    def generate(self, z):
        return self.generator(z)
    
    def encode(self, x):
        return self.encoder(x)

# Memory-efficient training function
def train_madgan_memory_efficient(model, normal_data, epochs=100, batch_size=32, lr=0.0002):
    model.to(device)
    
    # Optimizers with gradient clipping for stability
    optimizer_G = optim.Adam(model.generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(model.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_E = optim.Adam(model.encoder.parameters(), lr=lr, betas=(0.5, 0.999))
    
    criterion = nn.BCELoss()
    mse_loss = nn.MSELoss()
    
    # Create DataLoader for memory efficiency
    normal_tensor = torch.FloatTensor(normal_data)
    dataset = TensorDataset(normal_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    
    model.train()
    
    for epoch in range(epochs):
        epoch_d_loss = 0
        epoch_g_loss = 0
        num_batches = 0
        
        for batch_data, in dataloader:
            batch_data = batch_data.to(device, non_blocking=True)
            batch_size_actual = batch_data.size(0)
            
            # Clear cache periodically
            if num_batches % 10 == 0:
                torch.cuda.empty_cache()
            
            # Train Discriminator
            optimizer_D.zero_grad()
            
            # Real data
            real_labels = torch.ones(batch_size_actual, 1).to(device)
            real_output = model.discriminator(batch_data)
            d_loss_real = criterion(real_output, real_labels)
            
            # Fake data
            z = torch.randn(batch_size_actual, model.latent_dim).to(device)
            fake_data = model.generator(z).detach()
            fake_labels = torch.zeros(batch_size_actual, 1).to(device)
            fake_output = model.discriminator(fake_data)
            d_loss_fake = criterion(fake_output, fake_labels)
            
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.discriminator.parameters(), 1.0)
            optimizer_D.step()
            
            # Train Generator and Encoder
            optimizer_G.zero_grad()
            optimizer_E.zero_grad()
            
            # Generator loss
            z = torch.randn(batch_size_actual, model.latent_dim).to(device)
            fake_data = model.generator(z)
            fake_output = model.discriminator(fake_data)
            g_loss_adv = criterion(fake_output, real_labels)
            
            # Reconstruction loss
            encoded = model.encoder(batch_data)
            reconstructed = model.generator(encoded)
            reconstruction_loss = mse_loss(reconstructed, batch_data)
            
            # Feature matching loss
            with torch.no_grad():
                real_features = model.discriminator(batch_data)
            fake_features = model.discriminator(fake_data)
            feature_loss = mse_loss(fake_features, real_features)
            
            # Combined loss
            g_loss = g_loss_adv + 10 * reconstruction_loss + 5 * feature_loss
            g_loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.generator.parameters(), 1.0)
            torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), 1.0)
            
            optimizer_G.step()
            optimizer_E.step()
            
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            num_batches += 1
        
        if epoch % 20 == 0:
            print(f"Epoch [{epoch}/{epochs}] D_Loss: {epoch_d_loss/num_batches:.4f}, G_Loss: {epoch_g_loss/num_batches:.4f}")
    
    return model

# Anomaly detection function
def detect_anomalies_madgan(model, test_data, threshold_percentile=95):
    model.eval()
    anomaly_scores = []
    
    with torch.no_grad():
        test_tensor = torch.FloatTensor(test_data).to(device)
        
        # Reconstruction-based anomaly score
        encoded = model.encoder(test_tensor)
        reconstructed = model.generator(encoded)
        reconstruction_errors = torch.mean((test_tensor - reconstructed) ** 2, dim=1)
        
        # Discriminator-based anomaly score
        discriminator_scores = model.discriminator(test_tensor).squeeze()
        
        # Combined anomaly score
        combined_scores = reconstruction_errors + (1 - discriminator_scores)
        anomaly_scores = combined_scores.cpu().numpy()
    
    # Dynamic threshold
    threshold = np.percentile(anomaly_scores, threshold_percentile)
    predictions = (anomaly_scores > threshold).astype(int)
    
    return predictions, anomaly_scores, threshold

# Initialize and train the model
print("Initializing Memory-Efficient MADGAN...")

# Flatten the data to 2D if it's 3D
if len(X_train_normal.shape) == 3:
    print(f"Original data shape: {X_train_normal.shape}")
    X_train_flattened = X_train_normal.reshape(X_train_normal.shape[0], -1)
    print(f"Flattened data shape: {X_train_flattened.shape}")
    input_dim = X_train_flattened.shape[1]
else:
    X_train_flattened = X_train_normal
    input_dim = X_train_normal.shape[1]

print(f"Input dimension: {input_dim}")
madgan_model = MemoryEfficientMADGAN(input_dim=input_dim, latent_dim=64)

print(f"Training MADGAN on {X_train_flattened.shape[0]} normal samples...")
trained_madgan = train_madgan_memory_efficient(
    madgan_model, 
    X_train_flattened, 
    epochs=200, 
    batch_size=32,
    lr=0.0002
)

# Generate synthetic data for downstream tasks
print("Generating synthetic normal data...")
trained_madgan.eval()
with torch.no_grad():
    num_samples = len(X_train_normal)  # Memory-efficient generation
    z = torch.randn(num_samples, trained_madgan.latent_dim).to(device)
    memory_generated_data = trained_madgan.generator(z).cpu().numpy()


# Reshape generated data to match original input shape (n, 4500, 14)
memory_generated_data = memory_generated_data.reshape(-1, X_train_normal.shape[1], X_train_normal.shape[2])

print(f"Generated data shape: {memory_generated_data.shape}")
torch.cuda.empty_cache()  # Clear GPU memory

In [None]:
# ===============================
# FID SCORE EVALUATION
# ===============================

# Test the simplified FID calculation
print("Testing simplified FID calculation...")

# Use smaller subsets for testing
test_real = X_train_normal[:100]  # Use 100 samples for testing
test_generated = memory_generated_data[:100]

print(f"Test real data shape: {test_real.shape}")
print(f"Test generated data shape: {test_generated.shape}")

# Calculate FID score
fid_score = calculate_fid_score(
    real_data=test_real,
    fake_data=test_generated,
    device=device,
    sample_rate=1000,
)

if fid_score is not None:
    print(f"\n🎉 SUCCESS! FID Score: {fid_score:.4f}")
    
    # Interpret the score
    if fid_score < 10:
        quality = "Excellent"
    elif fid_score < 25:
        quality = "Good"
    elif fid_score < 50:
        quality = "Fair"
    elif fid_score < 100:
        quality = "Poor"
    else:
        quality = "Very Poor"
    
    print(f"Quality Assessment: {quality}")
else:
    print("❌ FID calculation failed. Please check the error messages above.")

In [None]:
run_comprehensive_cross_validation_experiment(X_test_normal, X_test_faulty, device, memory_generated_data, epochs=200, batch_size=32)