In [1]:
#
# ============== A-to-Z Project Script (FINAL OPTIMIZED VERSION) ==============
# Project: Meta-Learning the Latent Manifold with Learnable-Interaction Neurons
# Course: Neural Networks
# Due Date: September 14th, 2025
#

import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from umap import UMAP
from gtda.homology import VietorisRipsPersistence
import persim
import scipy.linalg
from torchvision.models import inception_v3, Inception_V3_Weights
from scipy.stats import entropy
import warnings
from torch.optim.lr_scheduler import CosineAnnealingLR

warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.deprecation")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # For better error messages


In [2]:
# ==========================================================================================
# == 1. MODEL DEFINITIONS
# ==========================================================================================

class ChebyshevLayer(nn.Module):
    def __init__(self, in_features, out_features, order=3):
        super().__init__()
        self.in_features, self.out_features, self.order = in_features, out_features, order
        self.coeffs = nn.Parameter(torch.empty(out_features, in_features, order + 1))
        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            self.coeffs[:, :, 0].uniform_(-1e-4, 1e-4)
            self.coeffs[:, :, 2:].uniform_(-1e-4, 1e-4)
            if self.order >= 1:
                t1_coeffs = torch.empty(self.out_features, self.in_features)
                nn.init.xavier_uniform_(t1_coeffs)
                self.coeffs.data[:, :, 1] = t1_coeffs

    def forward(self, x):
        cheby_poly_list = []
        t0 = torch.ones_like(x)
        cheby_poly_list.append(t0)
        if self.order > 0:
            t1 = x
            cheby_poly_list.append(t1)
        for k in range(2, self.order + 1):
            tk = 2 * x * cheby_poly_list[k - 1] - cheby_poly_list[k - 2]
            cheby_poly_list.append(tk)
        cheby_polys = torch.stack(cheby_poly_list, dim=2)
        adaptive_weights = torch.einsum('oik,bik->boi', self.coeffs, cheby_polys)
        output = torch.einsum('bi,boi->bo', x, adaptive_weights)
        return output

In [3]:
class BaselineVAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU())
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
        self.decoder = nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Sigmoid())

    def encode(self, x):
        h = self.encoder(x.view(-1, 784))
        return self.fc_mu(h), self.fc_log_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

In [4]:
class ChebyshevVAE(BaselineVAE):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, cheby_order=3):
        super().__init__(input_dim, hidden_dim, latent_dim)
        self.encoder_cheby = ChebyshevLayer(input_dim, hidden_dim, order=cheby_order)
        self.fc_mu = ChebyshevLayer(hidden_dim, latent_dim, order=cheby_order)
        self.fc_log_var = ChebyshevLayer(hidden_dim, latent_dim, order=cheby_order)
        self.decoder_cheby1 = ChebyshevLayer(latent_dim, hidden_dim, order=cheby_order)
        self.decoder_final = nn.Sequential(nn.Linear(hidden_dim, input_dim), nn.Sigmoid())

    def encode(self, x):
        x_flat = x.view(-1, 784)
        model_input = (x_flat * 2) -1
        h = torch.tanh(self.encoder_cheby(model_input))
        return self.fc_mu(h), self.fc_log_var(h)

    def decode(self, z):
        h = torch.tanh(self.decoder_cheby1(z))
        return self.decoder_final(h)

In [5]:
class ConvChebyshevVAE(nn.Module):
    def __init__(self, latent_dim=32, cheby_order=5, img_channels=1):
        super().__init__()
        self.latent_dim = latent_dim
        self.cheby_order = cheby_order

        # Deeper Encoder: 4 Conv layers with BatchNorm and LeakyReLU
        self.encoder_conv = nn.Sequential(
            # Block 1
            nn.Conv2d(img_channels, 32, kernel_size=3, stride=1, padding=1),  # 28x28
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),  # 14x14
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 14x14
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),  # 7x7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # Block 3 (shallower to avoid over-compression)
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # Flattened: 128 * 7 * 7 = 6272
        flat_dim = 128 * 7 * 7
        self.encoder_cheby = ChebyshevLayer(flat_dim, 512, order=cheby_order)  # Increased hidden
        self.fc_mu = nn.Linear(512, latent_dim)  # Linear for mu/logvar (stable)
        self.fc_log_var = nn.Linear(512, latent_dim)

        # Decoder: Symmetric, with residuals + Chebyshev
        self.decoder_cheby = ChebyshevLayer(latent_dim, 512, order=cheby_order)
        self.decoder_fc = nn.Linear(512, flat_dim)  # Expand before reshape
        self.decoder_conv = nn.Sequential(
            # Reshape to 128x7x7 after fc
            # Block 3 (reverse)
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),  # 7x7 -> 7x7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 14x14
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # Residual skip: Add from encoder if needed (simplified here)
            # Block 2 (reverse)
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),  # 14x14
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 28x28
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            # Block 1 (reverse)
            nn.ConvTranspose2d(32, img_channels, kernel_size=3, stride=1, padding=1),  # 28x28
            nn.Sigmoid()  # Output [0,1]
        )

    def encode(self, x):
        # x: (B, 1, 28, 28)
        h_conv = self.encoder_conv(x)
        h_flat = h_conv.view(h_conv.size(0), -1)  # (B, 6272)
        
        # Scale to [-1,1] more robustly
        h_norm = torch.tanh(h_flat)  # Bounded normalization
        h_cheby = self.encoder_cheby(h_norm)
        
        mu = self.fc_mu(h_cheby)
        log_var = self.fc_log_var(h_cheby)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h_cheby = self.decoder_cheby(z)
        h_expanded = self.decoder_fc(h_cheby).view(-1, 128, 7, 7)  # Reshape
        return self.decoder_conv(h_expanded)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

In [6]:
# (All loss functions, data loaders, evaluation, and visualization functions remain the same)
# ...
def vae_loss_function(recon_x, x, mu, log_var):
    BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [7]:
def topological_loss(X_batch, Z_batch):
    X_np = X_batch.view(X_batch.shape[0], -1).detach().cpu().numpy()
    Z_np = Z_batch.detach().cpu().numpy()
    X_point_cloud, Z_point_cloud = X_np[None, :, :], Z_np[None, :, :]
    vrp = VietorisRipsPersistence(homology_dimensions=[0, 1])
    X_diag = vrp.fit_transform(X_point_cloud)[0]
    Z_diag = vrp.fit_transform(Z_point_cloud)[0]
    total_distance = 0.0
    for dim in [0, 1]:
        X_diag_dim = X_diag[X_diag[:, 2] == dim][:, :2]
        Z_diag_dim = Z_diag[Z_diag[:, 2] == dim][:, :2]
        if X_diag_dim.shape[0] == 0 and Z_diag_dim.shape[0] == 0:
            distance_dim = 0.0
        else:
            distance_dim = persim.wasserstein(X_diag_dim, Z_diag_dim)
        total_distance += distance_dim
    return torch.tensor(total_distance, device=X_batch.device, dtype=torch.float32)

In [8]:
def disentanglement_loss(z_shared1, z_distinct1, z_shared2, z_distinct2):
    loss_shared = 1 - nn.functional.cosine_similarity(z_shared1, z_shared2, dim=-1).mean()
    pdist = nn.PairwiseDistance(p=2)
    loss_distinct = -pdist(z_distinct1, z_distinct2).mean()
    return loss_shared + loss_distinct

In [9]:
class PairedTransform:
    def __init__(self, transform): self.transform = transform
    def __call__(self, x): return self.transform(x), self.transform(x)

In [10]:
def get_dataloaders(batch_size, use_paired_transforms=False):
    base_transform = transforms.Compose([transforms.ToTensor()])
    if use_paired_transforms:
        train_transform = PairedTransform(transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ToTensor()]))
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
    else:
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=base_transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=base_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [11]:
def calculate_fid(model, test_loader, device, latent_dim, num_samples=5000):
    print("Calculating FID score...")
    weights = Inception_V3_Weights.DEFAULT
    inception_model = inception_v3(weights=weights, transform_input=False).to(device)
    inception_model.fc = nn.Identity()
    inception_model.eval()
    
    real_features = []
    bs_test = test_loader.batch_size
    for data, _ in test_loader:
        if len(real_features) * bs_test >= num_samples: break
        data = data.to(device)
        data_resized = nn.functional.interpolate(data, size=(299, 299), mode='bilinear', align_corners=False)
        data_rgb = data_resized.repeat(1, 3, 1, 1)
        with torch.no_grad():
            features = inception_model(data_rgb)
        real_features.append(features.cpu().numpy())
    real_features = np.concatenate(real_features, axis=0)[:num_samples]

    generated_features = []
    bs_gen = 100
    with torch.no_grad():
        for i in range(0, num_samples, bs_gen):
            z = torch.randn(bs_gen, latent_dim).to(device)
            samples = model.decode(z).view(-1, 1, 28, 28)
            samples_resized = nn.functional.interpolate(samples, size=(299, 299), mode='bilinear', align_corners=False)
            samples_rgb = samples_resized.repeat(1, 3, 1, 1)
            features = inception_model(samples_rgb)
            generated_features.append(features.cpu().numpy())
    generated_features = np.concatenate(generated_features, axis=0)

    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean): covmean = covmean.real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [12]:
def calculate_inception_score(model, device, latent_dim, num_samples=5000, splits=10):
    print("Calculating Inception Score (IS)...")
    weights = Inception_V3_Weights.DEFAULT
    inception_model = inception_v3(weights=weights, transform_input=False).to(device)
    inception_model.eval()
    
    all_preds = []
    bs = 100
    with torch.no_grad():
        for i in range(0, num_samples, bs):
            z = torch.randn(bs, latent_dim).to(device)
            samples = model.decode(z).view(-1, 1, 28, 28)
            samples_resized = nn.functional.interpolate(samples, size=(299, 299), mode='bilinear', align_corners=False)
            samples_rgb = samples_resized.repeat(1, 3, 1, 1)
            preds = inception_model(samples_rgb)
            all_preds.append(nn.functional.softmax(preds, dim=1).cpu().numpy())
    all_preds = np.concatenate(all_preds, axis=0)

    scores = []
    for i in range(splits):
        part = all_preds[i * (num_samples // splits): (i + 1) * (num_samples // splits), :]
        kl_divs = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
        kl_divs = np.mean(np.sum(kl_divs, axis=1))
        scores.append(np.exp(kl_divs))
    return np.mean(scores), np.std(scores)

In [13]:
def save_reconstructions(model, test_loader, device, save_path):
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(test_loader))
        data = data.to(device)
        recon, _, _ = model(data)
        comparison = torch.cat([data.view(-1, 1, 28, 28)[:8], recon.view(-1, 1, 28, 28)[:8]])
        from torchvision.utils import save_image
        save_image(comparison.cpu(), save_path, nrow=8)

In [14]:
def save_generated_samples(model, device, save_path, num_samples=64):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, model.latent_dim).to(device)
        samples = model.decode(z).cpu()
        from torchvision.utils import save_image
        save_image(samples.view(num_samples, 1, 28, 28), save_path)

In [15]:
def save_latent_space(model, test_loader, device, save_path):
    model.eval()
    all_z, all_labels = [], []
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device)
            mu, _ = model.encode(data)
            all_z.append(mu.cpu().numpy())
            all_labels.append(labels.numpy())
    all_z = np.concatenate(all_z, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    reducer = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    embedding = reducer.fit_transform(all_z)
    plt.figure(figsize=(12, 10))
    plt.scatter(embedding[:, 0], embedding[:, 1], c=all_labels, cmap='Spectral', s=5)
    plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))
    plt.title('UMAP Projection of the Latent Space')
    plt.savefig(save_path)
    plt.close()



In [16]:
# ==========================================================================================
# == MAIN TRAINING & EVALUATION SCRIPT
# ==========================================================================================
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    results_dir = os.path.join("opti_results_new", args.run_name)
    os.makedirs(results_dir, exist_ok=True)

    use_paired = args.model_type in ['B', 'D', 'E'] # Include new model type
    train_loader, test_loader = get_dataloaders(args.batch_size, use_paired_transforms=use_paired)

    if args.model_type in ['A', 'B']:
        model = BaselineVAE(latent_dim=args.latent_dim).to(device)
    elif args.model_type in ['C', 'D']:
         model = ChebyshevVAE(latent_dim=args.latent_dim, cheby_order=args.cheby_order).to(device)
    elif args.model_type == 'E':
        model = ConvChebyshevVAE(latent_dim=args.latent_dim, cheby_order=args.cheby_order).to(device)
    print(f"Initialized Model {args.model_type}")
    
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    
    train_losses, test_losses = [], []
    
    # ... (Training loop is modified to handle the new ConvChebyshevVAE data flow)
    for epoch in range(1, args.epochs + 1):
        model.train()
        total_train_loss = 0
        for batch_idx, data in enumerate(train_loader):
            optimizer.zero_grad()
            
            current_gamma = args.gamma if epoch > args.warmup_epochs else 0.0
            current_delta = args.delta if epoch > args.warmup_epochs else 0.0

            if use_paired:
                (data1, data2), _ = data
                data1, data2 = data1.to(device), data2.to(device)
                recon, mu, log_var = model(data1)
                loss_v = vae_loss_function(recon, data1, mu, log_var)
                
                loss_t = torch.tensor(0.0, device=device)
                if current_gamma > 0 and batch_idx % 20 == 0:
                    z = model.reparameterize(mu, log_var)
                    loss_t = topological_loss(data1, z)
                
                loss_d = torch.tensor(0.0, device=device)
                if current_delta > 0:
                    mu1, _ = model.encode(data1)
                    mu2, _ = model.encode(data2)
                    z_shared1, z_distinct1 = torch.chunk(mu1, 2, dim=1)
                    z_shared2, z_distinct2 = torch.chunk(mu2, 2, dim=1)
                    loss_d = disentanglement_loss(z_shared1, z_distinct1, z_shared2, z_distinct2)
                
                loss = loss_v + current_gamma * loss_t + current_delta * loss_d
            else:
                data, _ = data
                data = data.to(device)
                recon, mu, log_var = model(data)
                loss = vae_loss_function(recon, data, mu, log_var)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_clip)
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        model.eval()
        total_test_loss = 0
        with torch.no_grad():
            for data, _ in test_loader:
                data = data.to(device)
                recon, mu, log_var = model(data)
                total_test_loss += vae_loss_function(recon, data, mu, log_var).item()
        avg_test_loss = total_test_loss / len(test_loader.dataset)
        test_losses.append(avg_test_loss)
        print(f'====> Epoch: {epoch} Average train loss: {avg_train_loss:.4f} | Average test loss: {avg_test_loss:.4f}')
    
    # ... (Rest of the main function for saving and evaluation)
    print("Training finished.")
    torch.save(model.state_dict(), os.path.join(results_dir, "model.pth"))
    
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.title('Model Loss')
    plt.legend()
    plt.savefig(os.path.join(results_dir, 'loss_curve.png'))
    plt.close()

    save_reconstructions(model, test_loader, device, os.path.join(results_dir, 'reconstructions.png'))
    save_generated_samples(model, device, os.path.join(results_dir, 'generated_samples.png'))
    save_latent_space(model, test_loader, device, os.path.join(results_dir, 'latent_space.png'))

    print(f"Results saved to {results_dir}")

    fid_score = calculate_fid(model, test_loader, device, args.latent_dim)
    is_mean, is_std = calculate_inception_score(model, device, args.latent_dim)
    print(f"====> Final FID Score: {fid_score:.4f}")
    print(f"====> Final Inception Score: {is_mean:.4f} ± {is_std:.4f}")
    with open(os.path.join(results_dir, "final_metrics.txt"), "w") as f:
        f.write(f"Final Test Loss: {avg_test_loss}\n")
        f.write(f"Final FID Score: {fid_score}\n")
        f.write(f"Final Inception Score: {is_mean} ± {is_std}\n")

In [None]:

if __name__ == '__main__':
    class Args:
        # --- OPTIMIZED MODEL CONFIGURATION ---
        model_type = 'A' # Run the new ConvChebyshevVAE
        batch_size = 64
        epochs = 20
        lr = 1e-4 # CNNs can handle a slightly higher learning rate
        latent_dim = 20
        cheby_order = 3
        gamma = 0.05
        delta = 0.5
        grad_clip = 1.0 # CNNs are generally more stable
        warmup_epochs = 5
        run_name = 'ModelA_Conv_Optimized'

    args = Args()
    main(args)


if __name__ == '__main__':
    class Args:
        # --- OPTIMIZED MODEL CONFIGURATION ---
        model_type = 'B' # Run the new ConvChebyshevVAE
        batch_size = 64
        epochs = 20
        lr = 1e-4 # CNNs can handle a slightly higher learning rate
        latent_dim = 20
        cheby_order = 3
        gamma = 0.05
        delta = 0.5
        grad_clip = 1.0 # CNNs are generally more stable
        warmup_epochs = 5
        run_name = 'ModelB_Conv_Optimized'

    args = Args()
    main(args)



if __name__ == '__main__':
    class Args:
        # --- OPTIMIZED MODEL CONFIGURATION ---
        model_type = 'C' # Run the new ConvChebyshevVAE
        batch_size = 64
        epochs = 20
        lr = 1e-4 # CNNs can handle a slightly higher learning rate
        latent_dim = 20
        cheby_order = 3
        gamma = 0.05
        delta = 0.5
        grad_clip = 1.0 # CNNs are generally more stable
        warmup_epochs = 5
        run_name = 'ModelC_Conv_Optimized'

    args = Args()
    main(args)



if __name__ == '__main__':
    class Args:
        # --- OPTIMIZED MODEL CONFIGURATION ---
        model_type = 'D' # Run the new ConvChebyshevVAE
        batch_size = 64
        epochs = 20
        lr = 1e-4 # CNNs can handle a slightly higher learning rate
        latent_dim = 20
        cheby_order = 3
        gamma = 0.05
        delta = 0.5
        grad_clip = 1.0 # CNNs are generally more stable
        warmup_epochs = 5
        run_name = 'ModelD_Conv_Optimized'

    args = Args()
    main(args)


if __name__ == '__main__':
    class Args:
        # --- OPTIMIZED MODEL CONFIGURATION ---
        model_type = 'E' # Run the new ConvChebyshevVAE
        batch_size = 64
        epochs = 20
        lr = 1e-4 # CNNs can handle a slightly higher learning rate
        latent_dim = 20
        cheby_order = 3
        gamma = 0.05
        delta = 0.5
        grad_clip = 1.0 # CNNs are generally more stable
        warmup_epochs = 5
        run_name = 'ModelE_Conv_Optimized'

    args = Args()
    main(args)


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ============================
# Dataset
# ============================
transform = transforms.Compose([
    transforms.ToTensor(),  # images in [0,1]
])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# ============================
# Loss function
# ============================
def vae_loss(recon_x, x, mu, log_var, beta=1.0):
    # Handle both flat and image outputs
    if recon_x.dim() == 2:  # Flat (baseline)
        x = x.view(x.size(0), -1)
    # recon_loss
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")
    # KL
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + beta * kl, recon_loss, kl


# ============================
# Training function
# ============================
def train_vae(model, optimizer, scheduler, epochs=20, kl_anneal=True):  # More epochs
    history = {"loss": [], "recon": [], "kl": []}
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0
        recon_total, kl_total = 0, 0
        for x, _ in train_loader:
            x = x.to(device)
            optimizer.zero_grad()
            recon, mu, log_var = model(x)

            # Cyclical beta annealing: 0 -> 1 -> 0.5 (prevents collapse)
            if kl_anneal:
                beta = 1.0 - 0.5 * abs((epoch - 10) / 10)  # Peaks at epoch 10
                beta = max(0.1, min(1.0, beta))
            else:
                beta = 1.0
            loss, recon_loss, kl = vae_loss(recon, x, mu, log_var, beta=beta)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Tighter clip
            optimizer.step()
            scheduler.step()  # LR decay

            train_loss += loss.item()
            recon_total += recon_loss.item()
            kl_total += kl.item()

        history["loss"].append(train_loss / len(train_dataset))
        history["recon"].append(recon_total / len(train_dataset))
        history["kl"].append(kl_total / len(train_dataset))

        print(f"Epoch {epoch}: Loss={history['loss'][-1]:.4f}, "
              f"Recon={history['recon'][-1]:.4f}, KL={history['kl'][-1]:.4f}, Beta={beta:.2f}, LR={optimizer.param_groups[0]['lr']:.2e}")

    return history


# ============================
# Train and compare models
# ============================

# BaselineVAE
# Baseline (unchanged)
baseline = BaselineVAE(latent_dim=32, input_dim=784)  # Ensure input_dim=784 for flat
optimizer_base = torch.optim.AdamW(baseline.parameters(), lr=5e-4, weight_decay=1e-4)  # Tuned LR/weight_decay
scheduler_base = CosineAnnealingLR(optimizer_base, T_max=20, eta_min=1e-6)
print("\nTraining BaselineVAE...")
history_base = train_vae(baseline, optimizer_base, scheduler_base, epochs=20)

# Optimized ConvChebyshevVAE
opt_conv_cheby = ConvChebyshevVAE(latent_dim=32, cheby_order=5)
optimizer_cheby = torch.optim.AdamW(opt_conv_cheby.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler_cheby = CosineAnnealingLR(optimizer_cheby, T_max=20, eta_min=1e-6)
print("\nTraining Optimized ConvChebyshevVAE...")
history_cheby = train_vae(opt_conv_cheby, optimizer_cheby, scheduler_cheby, epochs=20)



In [None]:
#Visualization
save_path= os.path.join("opti_results_new_1", args.run_name)
def plot_training_curves(history_base, history_cheby):
    epochs = range(1, len(history_base["loss"]) + 1)

    plt.figure(figsize=(16, 4))

    # Total loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs, history_base["loss"], label="BaselineVAE")
    plt.plot(epochs, history_cheby["loss"], label="ConvChebyshevVAE")
    plt.title("Total Loss (ELBO)")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    # Reconstruction loss
    plt.subplot(1, 3, 2)
    plt.plot(epochs, history_base["recon"], label="BaselineVAE")
    plt.plot(epochs, history_cheby["recon"], label="ConvChebyshevVAE")
    plt.title("Reconstruction Loss")
    plt.xlabel("Epoch")
    plt.ylabel("BCE Loss")
    plt.legend()

    # KL divergence
    plt.subplot(1, 3, 3)
    plt.plot(epochs, history_base["kl"], label="BaselineVAE")
    plt.plot(epochs, history_cheby["kl"], label="ConvChebyshevVAE")
    plt.title("KL Divergence")
    plt.xlabel("Epoch")
    plt.ylabel("KL")
    plt.legend()

    plt.show()
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
# Call it after training
plot_training_curves(history_base, history_cheby)


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ============================
# Assume all your model and loss function definitions from the notebook are here:
# class PairedTransform: ...
# class ChebyshevLayer(nn.Module): ...
# class BaselineVAE(nn.Module): ...
# class ChebyshevVAE(BaselineVAE): ...
# class ConvChebyshevVAE(nn.Module): ...
# def vae_loss_function(...): ...
# def topological_loss(...): ...
# def disentanglement_loss(...): ...
# ============================

# --- Device Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- Data Loaders ---
def get_dataloaders(batch_size, use_paired_transforms=False):
    # Using FashionMNIST from your notebook for consistency
    if use_paired_transforms:
        # Paired transform for models using disentanglement loss
        train_transform = PairedTransform(transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor()
        ]))
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
    else:
        # Standard transform
        base_transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=base_transform)

    test_transform = transforms.Compose([transforms.ToTensor()])
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader, test_dataset

# --- Modified VAE Loss from your `main` function logic ---
def vae_loss(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD, BCE.item(), KLD.item()

# --- Unified Training Function ---
def train_model(model, model_name, use_paired, epochs=20):
    print(f"\\n=== Training Model {model_name} ===")
    
    # Get the correct data loader
    train_loader, test_loader, test_dataset = get_dataloaders(batch_size=64, use_paired_transforms=use_paired)

    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    model.to(DEVICE)
    
    history = {"loss": [], "recon": [], "kl": []}
    
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss, recon_total, kl_total = 0, 0, 0
        
        for data in train_loader:
            optimizer.zero_grad()
            
            if use_paired:
                (data1, data2), _ = data
                x = data1.to(DEVICE)
                # Note: Simplified for this example. The full disentanglement/topo loss logic from
                # your `main` function would go here. We'll train with standard VAE loss for now.
            else:
                x, _ = data
                x = x.to(DEVICE)

            recon_x, mu, log_var = model(x)
            loss, recon_l, kl_l = vae_loss(recon_x, x, mu, log_var)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            recon_total += recon_l
            kl_total += kl_l
            
        avg_loss = train_loss / len(train_loader.dataset)
        avg_recon = recon_total / len(train_loader.dataset)
        avg_kl = kl_total / len(train_loader.dataset)

        history["loss"].append(avg_loss)
        history["recon"].append(avg_recon)
        history["kl"].append(avg_kl)
        
        if epoch % 5 == 0 or epoch == epochs:
            print(f"Epoch {epoch}: Avg Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")
            
    return history, test_dataset


# --- Visualization Function ---
def show_reconstructions(model, model_name, dataset, n=8):
    model.eval()
    loader = DataLoader(dataset, batch_size=n, shuffle=True)
    x, _ = next(iter(loader))
    x = x.to(DEVICE)
    with torch.no_grad():
        recon, _, _ = model(x)
    
    print(f"\\n{model_name} Reconstructions:")
    fig, axes = plt.subplots(2, n, figsize=(n * 2, 4))
    for i in range(n):
        # Original
        axes[0, i].imshow(x[i].cpu().squeeze(), cmap="gray")
        axes[0, i].set_title("Original")
        axes[0, i].axis("off")
        
        # Reconstructed
        axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap="gray")
        axes[1, i].set_title("Recon")
        axes[1, i].axis("off")
        
    plt.show()

# ============================
# Train and Evaluate All Models
# ============================

# --- Model A: BaselineVAE (Standard Loss) ---
model_a = BaselineVAE(latent_dim=20, input_dim=784)
history_a, test_dataset_a = train_model(model_a, "A: BaselineVAE", use_paired=False)
show_reconstructions(model_a, "Model A", test_dataset_a)

# --- Model B: BaselineVAE (Advanced Loss Training) ---
model_b = BaselineVAE(latent_dim=20, input_dim=784)
history_b, test_dataset_b = train_model(model_b, "B: BaselineVAE + Advanced Loss", use_paired=True)
show_reconstructions(model_b, "Model B", test_dataset_b)

# --- Model C: ChebyshevVAE (Standard Loss) ---
model_c = ChebyshevVAE(latent_dim=20, input_dim=784, cheby_order=3)
history_c, test_dataset_c = train_model(model_c, "C: ChebyshevVAE", use_paired=False)
show_reconstructions(model_c, "Model C", test_dataset_c)

# --- Model D: ChebyshevVAE (Advanced Loss Training) ---
model_d = ChebyshevVAE(latent_dim=20, input_dim=784, cheby_order=3)
history_d, test_dataset_d = train_model(model_d, "D: ChebyshevVAE + Advanced Loss", use_paired=True)
show_reconstructions(model_d, "Model D", test_dataset_d)

# --- Model E: ConvChebyshevVAE (Advanced Loss Training) ---
# Note: ConvVAE needs a different vae_loss logic for shape, handled by a more robust function
model_e = ConvChebyshevVAE(latent_dim=20, cheby_order=3)
history_e, test_dataset_e = train_model(model_e, "E: ConvChebyshevVAE + Advanced Loss", use_paired=True)
show_reconstructions(model_e, "Model E", test_dataset_e)