# imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import argparse
import sys
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score


In [2]:
# Adjust sys.argv to remove unwanted Jupyter arguments
sys.argv = sys.argv[:1]  # Keep only the script name, remove Jupyter's arguments

# Now proceed with argparse as usual
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=200, help="size of the batches")
parser.add_argument("--n_cpu", type=int, default=8, help="number of CPU threads for data loading")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=2000, help="interval between image sampling")

# Parse the arguments
opt = parser.parse_args()

img_shape = (opt.channels, opt.img_size, opt.img_size)
print("this is the image shape ", img_shape)
cuda = True if torch.cuda.is_available() else False

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if cuda:
    torch.cuda.manual_seed(42)


this is the image shape  (1, 28, 28)


In [3]:

# Data Loading Classes and Functions
class NumpyDataset(Dataset):
    def __init__(self, dataX, dataY=None):
        self.dataX = np.load(dataX)
        self.dataY = np.load(dataY) if dataY is not None else None

    def __len__(self):
        return len(self.dataX)

    def __getitem__(self, idx):
        data = torch.tensor(self.dataX[idx], dtype=torch.float32)
        label = (
            torch.tensor(self.dataY[idx], dtype=torch.long)
            if self.dataY is not None
            else None
        )
        return data, label


In [4]:




# # Define file paths
# dataX = "../../../data/mnist/trainX.npy"
# dataY = "../../../data/mnist/trainY.npy"
# devX = "../../../data/mnist/validX.npy"
# devY = "../../../data/mnist/validY.npy"
# testX = "../../../data/mnist/testX.npy"
# testY = "../../../data/mnist/testY.npy"

# # Create dataloaders
# train_dataset = NumpyDataset(dataX, dataY)
# train_loader = DataLoader(dataset=train_dataset, batch_size=200, shuffle=True)

# dev_dataset = NumpyDataset(devX, devY)
# dev_loader = DataLoader(dataset=dev_dataset, batch_size=200, shuffle=False)

# test_dataset = NumpyDataset(testX, testY)
# test_loader = DataLoader(dataset=test_dataset, batch_size=200, shuffle=False)


import os
import numpy as np
import torchvision
from torch.utils.data import DataLoader, random_split
# Assuming NumpyDataset is a custom class you have defined
# Define file paths
base_dir = "../../../data/mnist/"
os.makedirs(base_dir, exist_ok=True)

dataX = os.path.join(base_dir, "trainX.npy")
dataY = os.path.join(base_dir, "trainY.npy")
devX = os.path.join(base_dir, "validX.npy")
devY = os.path.join(base_dir, "validY.npy")
testX = os.path.join(base_dir, "testX.npy")
testY = os.path.join(base_dir, "testY.npy")

# Check if any MNIST files are missing
required_files = [dataX, dataY, devX, devY, testX, testY]
if not all(os.path.exists(f) for f in required_files):
    print("Downloading and processing MNIST dataset...")

    # Download raw MNIST data
    train_set = torchvision.datasets.MNIST(
        root='./data',
        train=True,
        download=True
    )
    test_set = torchvision.datasets.MNIST(
        root='./data',
        train=False,
        download=True
    )

    # Convert to numpy arrays
    train_images = train_set.data.numpy()
    train_labels = train_set.targets.numpy()
    test_images = test_set.data.numpy()
    test_labels = test_set.targets.numpy()

    # Shuffle training data
    shuffle_idx = np.random.permutation(len(train_images))
    train_images = train_images[shuffle_idx]
    train_labels = train_labels[shuffle_idx]

    # Split train into train/validation (50k/10k)
    train_images = train_images.astype(np.float32) / 255.0
    train_labels = train_labels.astype(np.int64)

    # Save processed data
    np.save(dataX, train_images[:50000])
    np.save(dataY, train_labels[:50000])
    np.save(devX, train_images[50000:])
    np.save(devY, train_labels[50000:])

    # Process and save test data
    test_images = test_images.astype(np.float32) / 255.0
    test_labels = test_labels.astype(np.int64)
    np.save(testX, test_images)
    np.save(testY, test_labels)

# Create dataloaders
train_dataset = NumpyDataset(dataX, dataY)
train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)

dev_dataset = NumpyDataset(devX, devY)
dev_loader = DataLoader(dev_dataset, batch_size=200, shuffle=False)

test_dataset = NumpyDataset(testX, testY)
test_loader = DataLoader(test_dataset, batch_size=200, shuffle=False)


Downloading and processing MNIST dataset...


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.57MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 133kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.08MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.58MB/s]


In [5]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
        )
        self.mu_layer = nn.Linear(hidden_dims[1], latent_dim)
        self.logvar_layer = nn.Linear(hidden_dims[1], latent_dim)
        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in [*self.model, self.mu_layer, self.logvar_layer]:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0.0, std=sigma)
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.model(x)
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, input_dim, hidden_dims):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], input_dim),
            nn.Sigmoid()
        )
        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in self.model:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0.0, std=sigma)
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, z):
        x_recon = self.model(z)
        return x_recon.view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self, latent_dim, hidden_dims):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], 1),
            nn.Sigmoid()
        )
        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in self.model:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0.0, std=sigma)
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, z):
        return self.model(z)

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class GANAE(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim, l2_lambda=1e-3):
        super(GANAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dims, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim, hidden_dims)
        self.discriminator = Discriminator(latent_dim, hidden_dims)
        self.l2_lambda = l2_lambda

    def compute_l2_penalty(self):
        l2_penalty = 0
        for param in self.decoder.parameters():
            if param.requires_grad:
                l2_penalty += torch.sum(param**2)
        return self.l2_lambda * l2_penalty

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        real_or_fake = torch.sigmoid(self.discriminator(z))
        return x_recon, real_or_fake, mu, logvar


In [6]:
# --- Loss Functions ---
def bce_loss(pred, target):
    return F.binary_cross_entropy(pred, target,reduction="sum")

def gan_ae_loss(recon_x, x, real_or_fake, real_label=1, fake_label=0):
    # Reconstruction loss
    recon_loss =bce_loss(recon_x, x)

    # Discriminator loss (binary cross entropy)
    d_loss_real = bce_loss(real_or_fake, torch.full_like(real_or_fake, real_label))
    d_loss_fake = bce_loss(real_or_fake, torch.full_like(real_or_fake, fake_label))

    return recon_loss + d_loss_real + d_loss_fake



In [None]:
# --- Training Loop ---

# Initialize model with proper parameters
input_dim = 784  # 28x28 for MNIST
hidden_dims = [360, 360]  # Example dimensions
latent_dim = 40  # Example latent dimension
ganae = GANAE(input_dim=input_dim, hidden_dims=hidden_dims, latent_dim=latent_dim)

# Separate optimizers for generator (encoder+decoder) and discriminator
gen_optimizer = torch.optim.Adam(list(ganae.encoder.parameters()) +
                               list(ganae.decoder.parameters()), lr=0.0002)
disc_optimizer = torch.optim.Adam(ganae.discriminator.parameters(), lr=0.00002)

# Loss functions
reconstruction_loss = nn.BCELoss(reduction='sum')
adversarial_loss = nn.BCELoss()


def rescale_gradients(model, max_norm=5.0):
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)



def compute_loss(x_recon, x, real_or_fake, mu, logvar):
    # Flatten the input and reconstruction for BCE loss
    x_recon_flat = x_recon.view(-1, input_dim)
    x_flat = x.view(-1, input_dim)

    # Reconstruction loss
    recon_loss = reconstruction_loss(x_recon_flat, x_flat)

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Adversarial loss
    real_labels = torch.ones(real_or_fake.size()).to(x.device)
    fake_labels = torch.zeros(real_or_fake.size()).to(x.device)

    # Generator tries to fool discriminator
    gen_loss = adversarial_loss(real_or_fake, real_labels)

    # Discriminator loss
    disc_loss = adversarial_loss(real_or_fake, fake_labels)


    return recon_loss, kl_loss, gen_loss, disc_loss

# Training loop
n_epochs = opt.n_epochs
for epoch in range(n_epochs):
    total_recon_loss = 0
    total_kl_loss = 0
    total_gen_loss = 0
    total_disc_loss = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        batch_size = data.size(0)
        data = data.to('cpu')
        data = (data > 0.5).float()

        # Train generator (encoder + decoder)
        gen_optimizer.zero_grad()
        x_recon, real_or_fake, mu, logvar = ganae(data)
        recon_loss, kl_loss, gen_loss, _ = compute_loss(x_recon, data, real_or_fake, mu, logvar)

        # Add L2 regularization
        l2_penalty = ganae.compute_l2_penalty()

        # Total generator loss
        g_loss = recon_loss + kl_loss + gen_loss + l2_penalty
        # Reduce KL impact (scale it down)
        # g_loss = recon_loss + 0.01 * kl_loss + gen_loss + l2_penalty

        g_loss.backward(retain_graph=True)
        rescale_gradients(ganae)
        gen_optimizer.step()

        # Train discriminator
        # disc_optimizer.zero_grad()
        # _, real_or_fake, _, _ = ganae(data)
        # _, _, _, disc_loss = compute_loss(x_recon, data, real_or_fake, mu, logvar)
        # disc_loss.backward()
        # rescale_gradients(ganae)
        # disc_optimizer.step()
        n_critic = 5  # Number of times to train discriminator per generator step

        for _ in range(n_critic):  # Train discriminator multiple times
            disc_optimizer.zero_grad()
            _, real_or_fake, _, _ = ganae(data)
            _, _, _, disc_loss = compute_loss(x_recon, data, real_or_fake, mu, logvar)
            disc_loss.backward()
            rescale_gradients(ganae)
            disc_optimizer.step()


        # Record losses
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_gen_loss += gen_loss.item()
        total_disc_loss += disc_loss.item()

    # Average losses
    avg_recon_loss = total_recon_loss / len(train_loader.dataset)
    avg_kl_loss = total_kl_loss / len(train_loader.dataset)
    avg_gen_loss = total_gen_loss / len(train_loader.dataset)
    avg_disc_loss = total_disc_loss / len(train_loader.dataset)

    print(f"""Epoch [{epoch+1}/{n_epochs}]
    Reconstruction Loss: {avg_recon_loss:.4f}
    KL Loss: {avg_kl_loss:.4f}""")
    # print(f"Generator Loss: {avg_gen_loss:.4f}")
    # print(f"Discriminator Loss: {avg_disc_loss:.4f}")

Epoch [1/70]
    Reconstruction Loss: 236.9476
    KL Loss: 7.2343
Epoch [2/70]
    Reconstruction Loss: 155.4494
    KL Loss: 11.5451
Epoch [3/70]
    Reconstruction Loss: 133.3719
    KL Loss: 15.8849
Epoch [4/70]
    Reconstruction Loss: 120.9176
    KL Loss: 18.5318
Epoch [5/70]
    Reconstruction Loss: 112.0258
    KL Loss: 20.3965
Epoch [6/70]
    Reconstruction Loss: 104.4165
    KL Loss: 21.8173
Epoch [7/70]
    Reconstruction Loss: 98.6005
    KL Loss: 22.7700
Epoch [8/70]
    Reconstruction Loss: 94.1075
    KL Loss: 23.4457
Epoch [9/70]
    Reconstruction Loss: 90.2136
    KL Loss: 24.0164
Epoch [10/70]
    Reconstruction Loss: 86.8194
    KL Loss: 24.4989
Epoch [11/70]
    Reconstruction Loss: 83.7665
    KL Loss: 24.9587
Epoch [12/70]
    Reconstruction Loss: 81.0415
    KL Loss: 25.4325
Epoch [13/70]
    Reconstruction Loss: 78.6642
    KL Loss: 25.8932
Epoch [14/70]
    Reconstruction Loss: 76.6772
    KL Loss: 26.2744
Epoch [15/70]
    Reconstruction Loss: 74.9225
    K

In [8]:
# Calculate error
def calculate_error(model, loader):
    model.eval()
    total_bce = 0.0
    total_mse = 0.0
    total_samples = 0
    with torch.no_grad():
        for data, _ in loader:
            data = data.view(data.size(0), -1)
            model_output = model(data)  # Capture all outputs

            recon_data = model_output[0] if isinstance(model_output, tuple) else model_output

            #recon_data, _ = model(data)
            recon_data = recon_data.view(data.size(0), -1)
            bce = F.binary_cross_entropy(recon_data, data, reduction='sum')
            total_bce += bce.item()
            mse = F.mse_loss(recon_data, data, reduction='sum')
            total_mse += mse.item()
            total_samples += data.size(0)

    avg_bce = total_bce / total_samples
    avg_mse = total_mse / total_samples
    print(f"BCE: {avg_bce:.4f}, MSE: {avg_mse:.4f}")
    return avg_bce, avg_mse

calculate_error(ganae, test_loader)

BCE: 80.3978, MSE: 9.6159


(80.397849609375, 9.615925988769531)

In [9]:
def evaluate_classification_gan_ae(model, train_loader, test_loader):
    model.eval()
    train_latents, train_labels = [], []
    test_latents, test_labels = [], []
    device = next(model.parameters()).device

    # Extract latent representations from training data
    with torch.no_grad():
        for data, label in train_loader:
            data = (data > 0.5).float().view(-1, 784).to(device)
            mu, logvar = model.encoder(data)
            z = reparameterize(mu, logvar)
            train_latents.append(z.cpu().numpy())
            train_labels.append(label.cpu().numpy())

    # Extract latent representations from test data
    with torch.no_grad():
        for data, label in test_loader:
            data = (data > 0.5).float().view(-1, 784).to(device)
            mu, logvar = model.encoder(data)
            z = reparameterize(mu, logvar)
            test_latents.append(z.cpu().numpy())
            test_labels.append(label.cpu().numpy())

    # Stack latent representations and labels
    X_train = np.vstack(train_latents)
    y_train = np.hstack(train_labels)
    X_test = np.vstack(test_latents)
    y_test = np.hstack(test_labels)

    # Train logistic regression on training latents
    classifier = LogisticRegression(max_iter=1000)
    classifier.fit(X_train, y_train)

    # Evaluate on test latents
    y_pred = classifier.predict(X_test)
    error_percentage = 100 * (1 - accuracy_score(y_test, y_pred))

    print(f"Classification Error: {error_percentage:.2f}%")
    return error_percentage



In [12]:
evaluate_classification_gan_ae(ganae,train_loader,test_loader)

Classification Error: 14.03%


14.029999999999998

In [10]:
def evaluate_masked_mse(model, loader):
    """
    Evaluate the Masked Mean Squared Error (M-MSE) for the GAN-AE model.

    Parameters:
    - model: Trained GAN-AE model.
    - loader: DataLoader with test data.

    Returns:
    - avg_mse: Average MSE over masked regions, normalized by masked elements.
    """
    model.eval()
    total_mse = 0.0
    total_masked_elements = 0
    device = next(model.parameters()).device

    with torch.no_grad():
        for data, _ in loader:
            data = (data > 0.5).float().view(-1, 784).to(device)  # Binarize and flatten

            # Create mask: mask first half of the image (392/784 pixels)
            mask = torch.ones_like(data, dtype=torch.bool)
            mask[:, :392] = 0  # Mask left half

            # Apply mask to input
            masked_data = data * mask.float()

            # Encode and sample from latent space
            mu, logvar = model.encoder(masked_data)
            z = reparameterize(mu, logvar)

            # Decode to reconstruct
            reconstructed = model.decoder(z).view(-1, 784)

            # Compute MSE on masked region only
            masked_error = F.mse_loss(reconstructed[~mask], data[~mask], reduction='sum')
            total_mse += masked_error.item()
            total_masked_elements += (~mask).sum().item()

    # Normalize by total number of masked elements
    avg_mse = total_mse / total_masked_elements
    print(f"Average Masked MSE: {avg_mse:.4f}")
    return avg_mse




In [13]:
avg_mse = evaluate_masked_mse(ganae, test_loader)
print(avg_mse)

Average Masked MSE: 0.1125
0.1125103900520169


In [11]:
import torch
import torch.nn.functional as F

def monte_carlo_log_px(model, data_loader, num_samples=1000, device='cpu'):
    model.eval()
    total_log_px = 0
    num_images = 0

    with torch.no_grad():
        for batch in data_loader:
            x_real = batch[0].to(device)
            batch_size = x_real.size(0)

            z_samples = torch.randn(num_samples, latent_dim).to(device)

            x_recon = model.decoder(z_samples)

            for i in range(batch_size):
                log_px = compute_log_px(x_real[i], x_recon)
                total_log_px += log_px.item()
                num_images += 1

            if num_images >= num_samples:
                break

    average_log_px = total_log_px / num_images
    return average_log_px

def compute_log_px(x_real, x_recon):
    x_real_expanded = x_real.unsqueeze(0).expand(x_recon.size(0), -1, -1, -1)

    log_px = -F.binary_cross_entropy(x_recon, x_real_expanded, reduction='none')
    log_px = torch.logsumexp(log_px.sum(dim=[1,2,3]), dim=0) - torch.log(torch.tensor(x_recon.size(0)))
    return log_px

log_px_estimate = monte_carlo_log_px(ganae, test_loader, num_samples=1000, device='cpu')
print(f"Estimated log p(x): {log_px_estimate}")

Estimated log p(x): -207.98837602615356
