In [130]:
import torch
print(f"Is CUDA available? {torch.cuda.is_available()}")
# If available, this will show the version PyTorch was built with
print(f"PyTorch CUDA version: {torch.version.cuda}")

Is CUDA available? True
PyTorch CUDA version: 12.1


In [131]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.deprecation")

import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from umap import UMAP
from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PairwiseDistance


# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [132]:
class ChebyshevLayer(nn.Module):
    """
    An implementation of the Adaptive Neuron using Chebyshev Polynomials.
    The neuron's weight is a function of the input, modeled by a polynomial expansion.
    The output is y = sum(x_i * w_i(x_i)), where w_i(x_i) is the adaptive weight.
    """
    def __init__(self, in_features, out_features, order=3):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.order = order
        
        # Learnable coefficients for the Chebyshev polynomials
        # Shape: (out_features, in_features, order + 1)
        self.coeffs = nn.Parameter(torch.empty(out_features, in_features, order + 1))
        nn.init.xavier_uniform_(self.coeffs)

    # In your project.py file, inside the ChebyshevLayer class

    def forward(self, x):
        # IMPORTANT: Input x must be scaled to the [-1, 1] range.
        
        # --- FIX STARTS HERE ---
        # Avoid in-place modification by building a list of polynomial tensors
        cheby_poly_list = []
        
        # T_0(x) = 1
        t0 = torch.ones_like(x)
        cheby_poly_list.append(t0)
        
        if self.order > 0:
            # T_1(x) = x
            t1 = x
            cheby_poly_list.append(t1)
        
        # Recursively compute higher-order polynomials
        for k in range(2, self.order + 1):
            tk = 2 * x * cheby_poly_list[k - 1] - cheby_poly_list[k - 2]
            cheby_poly_list.append(tk)

        # Stack the list of tensors into a single tensor along a new dimension
        # Shape becomes: (batch_size, in_features, order + 1)
        cheby_polys = torch.stack(cheby_poly_list, dim=2)
        # --- FIX ENDS HERE ---

        # The rest of the function remains the same
        adaptive_weights = torch.einsum('oik,bik->boi', self.coeffs, cheby_polys)
        output = torch.einsum('bi,boi->bo', x, adaptive_weights)
        
        return output

In [133]:
class BaselineVAE(nn.Module):
    """A standard VAE with a simple MLP encoder and decoder."""
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.latent_dim = latent_dim
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid() # To output pixel values between 0 and 1
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_log_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

In [134]:
class ChebyshevVAE(BaselineVAE):
    """
    A VAE using ChebyshevLayers.
    CORRECTED VERSION: Uses tanh activation to ensure inputs to subsequent
    ChebyshevLayers are in the stable [-1, 1] range.
    """
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, cheby_order=3):
        super().__init__(input_dim, hidden_dim, latent_dim)
        # Override encoder and decoder with ChebyshevLayers
        self.encoder_cheby = ChebyshevLayer(input_dim, hidden_dim, order=cheby_order)
        self.fc_mu = ChebyshevLayer(hidden_dim, latent_dim, order=cheby_order)
        self.fc_log_var = ChebyshevLayer(hidden_dim, latent_dim, order=cheby_order)
        
        self.decoder_cheby1 = ChebyshevLayer(latent_dim, hidden_dim, order=cheby_order)
        # The final layer remains linear with sigmoid for reconstruction.
        self.decoder_final = nn.Sequential(nn.Linear(hidden_dim, input_dim), nn.Sigmoid())

    def encode(self, x):
        # Apply tanh to constrain the output of the first layer to [-1, 1]
        # before passing it to the next Chebyshev layers.
        h = torch.tanh(self.encoder_cheby(x))
        return self.fc_mu(h), self.fc_log_var(h)

    def decode(self, z):
        # Also apply tanh here for consistency
        h = torch.tanh(self.decoder_cheby1(z))
        return self.decoder_final(h)

In [135]:
def vae_loss_function(recon_x, x, mu, log_var):
    """Standard VAE loss with shape and device checks."""
    # Ensure same device
    assert recon_x.device == x.device, f"Device mismatch: recon_x on {recon_x.device}, x on {x.device}"
    
    # Reverse normalization for x if needed (for paired transforms)
    x = x * 0.5 + 0.5  # Transform from [-1, 1] to [0, 1]
    x = x.view(-1, 784)
    recon_x = recon_x.view(-1, 784)
    
    # Check shapes
    assert recon_x.shape == x.shape, f"Shape mismatch: recon_x {recon_x.shape}, x {x.shape}"
    
    # Check value ranges
    assert (recon_x >= 0).all() and (recon_x <= 1).all(), "recon_x values out of [0, 1] range"
    assert (x >= 0).all() and (x <= 1).all(), "x values out of [0, 1] range"
    
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [136]:
import persim
from gtda.homology import VietorisRipsPersistence
import numpy as np
import torch

def topological_loss(X_batch, Z_batch):
    """
    Calculates the topological loss using persistence diagrams.
    NEW APPROACH: Uses the `persim` library for a more robust distance calculation
    that correctly handles diagrams with different numbers of points.
    """
    # 1. Prepare data and reshape for giotto-tda
    X_np = X_batch.view(X_batch.shape[0], -1).detach().cpu().numpy()
    Z_np = Z_batch.detach().cpu().numpy()
    
    X_point_cloud = X_np[None, :, :]
    Z_point_cloud = Z_np[None, :, :]
    
    # 2. Compute persistence diagrams using giotto-tda (this part works fine)
    vrp = VietorisRipsPersistence(homology_dimensions=[0, 1])
    X_diag = vrp.fit_transform(X_point_cloud)[0]
    Z_diag = vrp.fit_transform(Z_point_cloud)[0]
    
    # 3. Compute distance using persim, separately for each homology dimension
    total_distance = 0.0
    homology_dims = [0, 1] # H0 (components) and H1 (holes)

    for dim in homology_dims:
        # Filter the diagrams to get points for the current dimension
        X_diag_dim = X_diag[X_diag[:, 2] == dim]
        Z_diag_dim = Z_diag[Z_diag[:, 2] == dim]
        
        # persim expects (n_points, 2) arrays (birth, death), so we slice
        X_birth_death = X_diag_dim[:, :2]
        Z_birth_death = Z_diag_dim[:, :2]
        
        # Handle the case where one diagram has no features for a dimension
        if X_birth_death.shape[0] == 0 and Z_birth_death.shape[0] == 0:
            distance_dim = 0.0
        else:
            distance_dim = persim.wasserstein(X_birth_death, Z_birth_death)
        
        total_distance += distance_dim
        
    return torch.tensor(total_distance, device=X_batch.device, dtype=torch.float32)

In [137]:
def disentanglement_loss(z_shared1, z_distinct1, z_shared2, z_distinct2):
    """
    Self-supervised loss for disentanglement.
    Your Idea 3: Self-Supervised Learning for Disentangled Dimensionality Reduction.
    """
    # 1. Force shared representations to be similar (cosine similarity)
    loss_shared = 1 - nn.functional.cosine_similarity(z_shared1, z_shared2, dim=-1).mean()
    
    # 2. Force distinct representations to be dissimilar (simple contrastive loss)
    # This is a simplified version. A more robust implementation would use InfoNCE loss.
    pdist = nn.PairwiseDistance(p=2)
    loss_distinct = -pdist(z_distinct1, z_distinct2).mean() # Push them apart
    
    return loss_shared + loss_distinct

In [138]:
class PairedTransform:
    """A transform that returns two different augmented views of the same image."""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return self.transform(x), self.transform(x)

def get_dataloaders(batch_size, use_paired_transforms=False):
    # Only use ToTensor() to scale data to the [0, 1] range.
    base_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    if use_paired_transforms:
        # Augmentations for disentanglement loss
        paired_transform = PairedTransform(
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        )
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=paired_transform)
    else:
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=base_transform)

    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=base_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [139]:
# ==========================================================================================
# == 4. VISUALIZATION UTILITIES
# ==========================================================================================
def save_reconstructions(model, test_loader, device, save_path):
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(test_loader))
        data = data.to(device)
        recon, _, _ = model(data)
        
        comparison = torch.cat([data.view(-1, 1, 28, 28)[:8], 
                                recon.view(-1, 1, 28, 28)[:8]])
        
        # De-normalize from [-1, 1] to [0, 1]
        comparison = comparison * 0.5 + 0.5
        
        from torchvision.utils import save_image
        save_image(comparison.cpu(), save_path, nrow=8)

In [140]:
def save_generated_samples(model, device, save_path, num_samples=64):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, model.latent_dim).to(device)
        samples = model.decode(z).cpu()

        # De-normalize
        samples = samples * 0.5 + 0.5
        
        from torchvision.utils import save_image
        save_image(samples.view(num_samples, 1, 28, 28), save_path)

In [141]:
def save_latent_space(model, test_loader, device, save_path):
    model.eval()
    all_z = []
    all_labels = []
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device)
            mu, _ = model.encode(data.view(-1, 784))
            all_z.append(mu.cpu().numpy())
            all_labels.append(labels.numpy())

    all_z = np.concatenate(all_z, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Use UMAP for dimensionality reduction
    reducer = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    embedding = reducer.fit_transform(all_z)

    plt.figure(figsize=(12, 10))
    plt.scatter(embedding[:, 0], embedding[:, 1], c=all_labels, cmap='Spectral', s=5)
    plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))
    plt.title('UMAP Projection of the Latent Space')
    plt.savefig(save_path)
    plt.close()

In [142]:
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create results directory
    results_dir = os.path.join("results", args.run_name)
    os.makedirs(results_dir, exist_ok=True)

    # --- Data Loading ---
    use_paired = args.model_type in ['B', 'D']
    train_loader, test_loader = get_dataloaders(args.batch_size, use_paired_transforms=use_paired)

    # --- Model Selection ---
    if args.model_type in ['A', 'B']:
        model = BaselineVAE(latent_dim=args.latent_dim).to(device)
        print("Initialized Baseline VAE (Model A/B)")
    elif args.model_type in ['C', 'D']:
        model = ChebyshevVAE(latent_dim=args.latent_dim, cheby_order=args.cheby_order).to(device)
        print(f"Initialized Chebyshev-VAE (Model C/D) with order {args.cheby_order}")
    else:
        raise ValueError("Invalid model type specified.")

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    # --- Training Loop ---
    train_losses, test_losses = [], []
    for epoch in range(1, args.epochs + 1):
        model.train()
        total_train_loss = 0
        for batch_idx, data in enumerate(train_loader):
            optimizer.zero_grad()
            
            if use_paired:
                # Paired data for disentanglement
                (data1, data2), _ = data
                data1, data2 = data1.to(device), data2.to(device)
                
                # Verify input shapes
                assert data1.shape == data2.shape, f"Shape mismatch: data1 {data1.shape}, data2 {data2.shape}"
                
                # Forward pass on first view
                recon, mu, log_var = model(data1)
                
                # --- Calculate Hybrid Loss (Model B or D) ---
                loss_v = vae_loss_function(recon, data1, mu, log_var)
                
                # Calculate topological loss only periodically to save time
                if batch_idx % 20 == 0: # Calculate every 20 batches
                    z = model.reparameterize(mu, log_var)
                    loss_t = topological_loss(data1, z)
                else:
                    # On other batches, set it to zero so it doesn't affect the gradient
                    loss_t = torch.tensor(0.0, device=data1.device) 
                
                # Disentanglement Loss (this one is fast, can run every time)
                mu1, _ = model.encode(data1.view(-1, 784))
                mu2, _ = model.encode(data2.view(-1, 784))
                z_shared1, z_distinct1 = torch.chunk(mu1, 2, dim=1)
                z_shared2, z_distinct2 = torch.chunk(mu2, 2, dim=1)
                loss_d = disentanglement_loss(z_shared1, z_distinct1, z_shared2, z_distinct2)
                
                loss = loss_v + args.gamma * loss_t + args.delta * loss_d
            else:
                # Standard training (Model A or C)
                data, _ = data
                data = data.to(device)
                recon, mu, log_var = model(data)
                loss = vae_loss_function(recon, data, mu, log_var)

            loss.backward()
            total_train_loss += loss.item()
            optimizer.step()
        
        avg_train_loss = total_train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        # Testing
        model.eval()
        total_test_loss = 0
        with torch.no_grad():
            for data, _ in test_loader:
                data = data.to(device)
                recon, mu, log_var = model(data)
                total_test_loss += vae_loss_function(recon, data, mu, log_var).item()
        
        avg_test_loss = total_test_loss / len(test_loader.dataset)
        test_losses.append(avg_test_loss)

        print(f'====> Epoch: {epoch} Average train loss: {avg_train_loss:.4f} | Average test loss: {avg_test_loss:.4f}')

    print("Training finished.")

    # --- Save Artifacts ---
    print("Saving model and generating visualizations...")
    # Save model checkpoint
    torch.save(model.state_dict(), os.path.join(results_dir, "model.pth"))

    # Save loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.title('Model Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(results_dir, 'loss_curve.png'))
    plt.close()

    # Save visualizations
    save_reconstructions(model, test_loader, device, os.path.join(results_dir, 'reconstructions.png'))
    save_generated_samples(model, device, os.path.join(results_dir, 'generated_samples.png'))
    save_latent_space(model, test_loader, device, os.path.join(results_dir, 'latent_space.png'))

    print(f"Results saved to {results_dir}")

In [144]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

if __name__ == '__main__':
    class Args:
        model_type = 'D'
        batch_size = 32  # Reduced for performance
        epochs = 20
        lr = 1e-3
        latent_dim = 20
        cheby_order = 3
        gamma = 0.1
        delta = 1.0
        run_name = 'ModelD_Final_Run'
    
    args = Args()
    main(args)

Using device: cuda
Initialized Chebyshev-VAE (Model C/D) with order 3
====> Epoch: 1 Average train loss: 320.1696 | Average test loss: 835.1429
====> Epoch: 2 Average train loss: 373.1273 | Average test loss: 863.2622
====> Epoch: 3 Average train loss: 390.0425 | Average test loss: 807.6828
====> Epoch: 4 Average train loss: 399.4748 | Average test loss: 1142.4739
====> Epoch: 5 Average train loss: 448.0072 | Average test loss: 826.0711
====> Epoch: 6 Average train loss: 469.6598 | Average test loss: 851.1818
====> Epoch: 7 Average train loss: 441.8610 | Average test loss: 1119.5208
====> Epoch: 8 Average train loss: 413.6027 | Average test loss: 876.8319
====> Epoch: 9 Average train loss: 474.1185 | Average test loss: 1117.7382
====> Epoch: 10 Average train loss: 423.2540 | Average test loss: 978.8857
====> Epoch: 11 Average train loss: 473.8053 | Average test loss: 1048.2882
====> Epoch: 12 Average train loss: 414.3477 | Average test loss: 1034.1915
====> Epoch: 13 Average train loss

  warn(


Results saved to results\ModelD_Final_Run
