# imports

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import argparse
import sys
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score


In [7]:
# Adjust sys.argv to remove unwanted Jupyter arguments
sys.argv = sys.argv[:1]  # Keep only the script name, remove Jupyter's arguments

# Now proceed with argparse as usual
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=200, help="size of the batches")
parser.add_argument("--n_cpu", type=int, default=8, help="number of CPU threads for data loading")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=2000, help="interval between image sampling")

# Parse the arguments
opt = parser.parse_args()

img_shape = (opt.channels, opt.img_size, opt.img_size)
print("this is the image shape ", img_shape)
cuda = True if torch.cuda.is_available() else False

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if cuda:
    torch.cuda.manual_seed(42)


this is the image shape  (1, 28, 28)


In [8]:

# Data Loading Classes and Functions
class NumpyDataset(Dataset):
    def __init__(self, dataX, dataY=None):
        self.dataX = np.load(dataX)
        self.dataY = np.load(dataY) if dataY is not None else None

    def __len__(self):
        return len(self.dataX)

    def __getitem__(self, idx):
        data = torch.tensor(self.dataX[idx], dtype=torch.float32)
        label = (
            torch.tensor(self.dataY[idx], dtype=torch.long)
            if self.dataY is not None
            else None
        )
        return data, label


In [9]:
# Define file paths
dataX = "../../../data/mnist/trainX.npy"
dataY = "../../../data/mnist/trainY.npy"
devX = "../../../data/mnist/validX.npy"
devY = "../../../data/mnist/validY.npy"
testX = "../../../data/mnist/testX.npy"
testY = "../../../data/mnist/testY.npy"

# Create dataloaders
train_dataset = NumpyDataset(dataX, dataY)
train_loader = DataLoader(dataset=train_dataset, batch_size=200, shuffle=True)

dev_dataset = NumpyDataset(devX, devY)
dev_loader = DataLoader(dataset=dev_dataset, batch_size=200, shuffle=False)

test_dataset = NumpyDataset(testX, testY)
test_loader = DataLoader(dataset=test_dataset, batch_size=200, shuffle=False)



In [52]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
        )
        self.mu_layer = nn.Linear(hidden_dims[1], latent_dim)
        self.logvar_layer = nn.Linear(hidden_dims[1], latent_dim)
        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in [*self.model, self.mu_layer, self.logvar_layer]:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0.0, std=sigma)
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.model(x)
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, input_dim, hidden_dims):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], input_dim),
            nn.Sigmoid()
        )
        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in self.model:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0.0, std=sigma)
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, z):
        x_recon = self.model(z)
        return x_recon.view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self, latent_dim, hidden_dims):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims[0]),
            nn.ReLU(),
            nn.Linear(hidden_dims[0], hidden_dims[1]),
            nn.ReLU(),
            nn.Linear(hidden_dims[1], 1),
            nn.Sigmoid()
        )
        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in self.model:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0.0, std=sigma)
                nn.init.constant_(layer.bias, 0.0)

    def forward(self, z):
        return self.model(z)

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class GANAE(nn.Module):
    def __init__(self, input_dim, hidden_dims, latent_dim, l2_lambda=1e-3):
        super(GANAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dims, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim, hidden_dims)
        self.discriminator = Discriminator(latent_dim, hidden_dims)
        self.l2_lambda = l2_lambda

    def compute_l2_penalty(self):
        l2_penalty = 0
        for param in self.decoder.parameters():
            if param.requires_grad:
                l2_penalty += torch.sum(param**2)
        return self.l2_lambda * l2_penalty

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        real_or_fake = torch.sigmoid(self.discriminator(z))
        return x_recon, real_or_fake, mu, logvar


In [53]:
# --- Loss Functions ---
def bce_loss(pred, target):
    return F.binary_cross_entropy(pred, target,reduction="sum")

def gan_ae_loss(recon_x, x, real_or_fake, real_label=1, fake_label=0):
    # Reconstruction loss
    recon_loss =bce_loss(recon_x, x)

    # Discriminator loss (binary cross entropy)
    d_loss_real = bce_loss(real_or_fake, torch.full_like(real_or_fake, real_label))
    d_loss_fake = bce_loss(real_or_fake, torch.full_like(real_or_fake, fake_label))

    return recon_loss + d_loss_real + d_loss_fake



In [87]:
# --- Training Loop ---

# Initialize model with proper parameters
input_dim = 784  # 28x28 for MNIST
hidden_dims = [360, 360]  # Example dimensions
latent_dim = 20  # Example latent dimension
ganae = GANAE(input_dim=input_dim, hidden_dims=hidden_dims, latent_dim=latent_dim)

# Separate optimizers for generator (encoder+decoder) and discriminator
gen_optimizer = torch.optim.Adam(list(ganae.encoder.parameters()) + 
                               list(ganae.decoder.parameters()), lr=0.006)
disc_optimizer = torch.optim.Adam(ganae.discriminator.parameters(), lr=0.05)

# Loss functions
reconstruction_loss = nn.BCELoss(reduction='sum')
adversarial_loss = nn.BCELoss()


def rescale_gradients(model, max_norm=5.0):
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)



def compute_loss(x_recon, x, real_or_fake, mu, logvar):
    # Flatten the input and reconstruction for BCE loss
    x_recon_flat = x_recon.view(-1, input_dim)
    x_flat = x.view(-1, input_dim)
    
    # Reconstruction loss
    recon_loss = reconstruction_loss(x_recon_flat, x_flat)
    
    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Adversarial loss
    real_labels = torch.ones(real_or_fake.size()).to(x.device)
    fake_labels = torch.zeros(real_or_fake.size()).to(x.device)
    
    # Generator tries to fool discriminator
    gen_loss = adversarial_loss(real_or_fake, real_labels)
    
    # Discriminator loss
    disc_loss = adversarial_loss(real_or_fake, fake_labels)
    
    return recon_loss, kl_loss, gen_loss, disc_loss

# Training loop
n_epochs = opt.n_epochs
for epoch in range(n_epochs):
    total_recon_loss = 0
    total_kl_loss = 0
    total_gen_loss = 0
    total_disc_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        batch_size = data.size(0)
        data = data.to('cpu')
        data = (data > 0.5).float()
        
        # Train generator (encoder + decoder)
        gen_optimizer.zero_grad()
        x_recon, real_or_fake, mu, logvar = ganae(data)
        recon_loss, kl_loss, gen_loss, _ = compute_loss(x_recon, data, real_or_fake, mu, logvar)
        
        # Add L2 regularization
        l2_penalty = ganae.compute_l2_penalty()
        
        # Total generator loss
        g_loss = recon_loss + kl_loss + gen_loss + l2_penalty
        g_loss.backward(retain_graph=True)
        rescale_gradients(ganae)
        gen_optimizer.step()
        
        # Train discriminator
        disc_optimizer.zero_grad()
        _, real_or_fake, _, _ = ganae(data)
        _, _, _, disc_loss = compute_loss(x_recon, data, real_or_fake, mu, logvar)
        disc_loss.backward()
        rescale_gradients(ganae)
        disc_optimizer.step()
        
        # Record losses
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_gen_loss += gen_loss.item()
        total_disc_loss += disc_loss.item()
        
    # Average losses
    avg_recon_loss = total_recon_loss / len(train_loader.dataset)
    avg_kl_loss = total_kl_loss / len(train_loader.dataset)
    avg_gen_loss = total_gen_loss / len(train_loader.dataset)
    avg_disc_loss = total_disc_loss / len(train_loader.dataset)
    
    print(f"""Epoch [{epoch+1}/{n_epochs}]
    Reconstruction Loss: {avg_recon_loss:.4f}
    KL Loss: {avg_kl_loss:.4f}""")
    # print(f"Generator Loss: {avg_gen_loss:.4f}")
    # print(f"Discriminator Loss: {avg_disc_loss:.4f}")

Epoch [1/50]
    Reconstruction Loss: 156.2635
    KL Loss: 8.3078
Epoch [2/50]
    Reconstruction Loss: 109.5181
    KL Loss: 13.0183
Epoch [3/50]
    Reconstruction Loss: 100.5738
    KL Loss: 13.9653
Epoch [4/50]
    Reconstruction Loss: 96.4678
    KL Loss: 14.4452
Epoch [5/50]
    Reconstruction Loss: 92.9108
    KL Loss: 15.1979
Epoch [6/50]
    Reconstruction Loss: 89.4688
    KL Loss: 15.9049
Epoch [7/50]
    Reconstruction Loss: 86.9907
    KL Loss: 16.4014
Epoch [8/50]
    Reconstruction Loss: 84.7705
    KL Loss: 16.8140
Epoch [9/50]
    Reconstruction Loss: 83.2108
    KL Loss: 17.1745
Epoch [10/50]
    Reconstruction Loss: 81.9077
    KL Loss: 17.4760
Epoch [11/50]
    Reconstruction Loss: 80.9458
    KL Loss: 17.6600
Epoch [12/50]
    Reconstruction Loss: 79.9866
    KL Loss: 17.8113
Epoch [13/50]
    Reconstruction Loss: 79.6533
    KL Loss: 17.9222
Epoch [14/50]
    Reconstruction Loss: 79.0498
    KL Loss: 18.0131
Epoch [15/50]
    Reconstruction Loss: 78.7225
    KL L

In [89]:
# Calculate error
def calculate_error(model, loader):
    model.eval()
    total_bce = 0.0
    total_mse = 0.0
    total_samples = 0
    with torch.no_grad():
        for data, _ in loader:
            data = data.view(data.size(0), -1)
            recon_data, _ = model(data)
            recon_data = recon_data.view(data.size(0), -1)
            bce = F.binary_cross_entropy(recon_data, data, reduction='sum')
            total_bce += bce.item()
            mse = F.mse_loss(recon_data, data, reduction='sum')
            total_mse += mse.item()
            total_samples += data.size(0)

    avg_bce = total_bce / total_samples
    avg_mse = total_mse / total_samples
    print(f"BCE: {avg_bce:.4f}, MSE: {avg_mse:.4f}")
    return avg_bce, avg_mse

calculate_error(ganae, test_loader)

ValueError: too many values to unpack (expected 2)

In [103]:

def evaluate_classification_gan_ae(model, loader):
    """
    Evaluate classification error using the trained GAN-AE model and a logistic regression classifier.

    Parameters:
    - model: Trained GAN-AE model with an encoder.
    - loader: DataLoader providing batches of (data, labels).

    Returns:
    - err: Classification error percentage.
    """
    model.eval()
    latents, labels = [], []
    
    with torch.no_grad():
        for data, label in loader:
            # Flatten data and move to the appropriate device
            data = data.view(data.size(0), -1).to(next(model.parameters()).device)
            label = label.to(next(model.parameters()).device)
            
            if label is not None:
                mu, logvar = model.encoder(data)
                z = reparameterize(mu, logvar)
                latents.append(z.cpu().numpy())
                labels.append(label.cpu().numpy())
    
    if not labels:
        print("No labels provided for classification.")
        return None

    # Stack latents and labels into arrays
    latents = np.vstack(latents)
    labels = np.hstack(labels).reshape(-1)
    print(f"labels shape: {labels.shape}")
    print(f"latents shape: {latents.shape}")
    
    # Ensure consistent lengths of latents and labels
    min_len = min(len(latents), len(labels))
    latents = latents[:min_len]
    labels = labels[:min_len]

    # Split latent data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(latents, labels, test_size=0.3, random_state=42)

    # Train logistic regression on training split
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)

    # Evaluate on test split
    predictions = clf.predict(X_test)
    err = 100 * (1 - accuracy_score(y_test, predictions))
    print(f"Classification Error: {err:.2f}%")
    
    return err


In [104]:
evaluate_classification_gan_ae(ganae,test_loader)

labels shape: (100000,)
latents shape: (10000, 20)
Classification Error: 10.53%


10.533333333333328

In [105]:
def evaluate_masked_mse(model, loader):
    """
    Evaluate the Masked Mean Squared Error (M-MSE) for a given model on a dataset.

    Parameters:
    - model: The trained model to evaluate.
    - loader: DataLoader providing batches of test data.

    Returns:
    - avg_mse: The average masked MSE over the entire dataset.
    """
    model.eval()  # Set the model to evaluation mode
    total_mse = 0.0  # Accumulate total masked MSE
    total_samples = 0  # Total number of samples processed

    with torch.no_grad():  # Disable gradient calculations for evaluation
        for data, _ in loader:
            # Flatten images into vectors (batch_size, D)
            data = data.view(data.size(0), -1)  # Shape: (batch_size, D)

            # Create a binary mask: Half of the columns are masked
            mask = torch.ones_like(data, dtype=torch.bool)  # Shape: (batch_size, D)
            mask[:, : data.size(1) // 2] = 0  # Mask the first half of the columns


            masked_data = data * mask.float()  # Zero out the masked part
            
            output = model(masked_data)  # Assumes model outputs reconstructed data
            if isinstance(output, tuple):  # Extract reconstructed images if output is a tuple
                reconstructed = output[0]
            else:
                reconstructed = output
            reconstructed = reconstructed.view(data.size(0), -1)  # Shape: (batch_size, D)

            mse_batch = masked_mse(data, reconstructed, mask)  # Masked MSE for this batch
            
            # Accumulate results
            total_mse += mse_batch.item() * data.size(0)  # Multiply by batch size
            total_samples += data.size(0)  # Update total sample count

    # Compute the final average M-MSE across the dataset
    avg_mse = total_mse / total_samples
    print(f"Average Masked MSE: {avg_mse:.4f}")

    return avg_mse


# Helper function for batch-level Masked MSE
def masked_mse(x, x_hat, mask):
    """
    Compute the Masked Mean Squared Error (M-MSE) for a single batch.
    
    Parameters:
    - x: Original images (batch_size, D), where D is the number of pixels in each image.
    - x_hat: Reconstructed images (batch_size, D).
    - mask: Binary mask (batch_size, D), where 1 indicates unmasked pixels and 0 indicates masked pixels.

    Returns:
    - m_mse: Masked mean squared error for the batch.
    """
    # Compute squared error
    error = (x - x_hat) ** 2
    # Apply mask to focus only on masked-out regions
    masked_error = error * (1 - mask.float())  

    # Compute normalized MSE for the masked regions
    batch_mse = masked_error.sum(dim=1) / (1 - mask.float()).sum(dim=1)

    # Return average masked MSE for the batch
    return batch_mse.sum()


In [106]:
avg_mse = evaluate_masked_mse(ganae, test_loader)
print(avg_mse)

Average Masked MSE: 13.4537
13.453733501434327


In [114]:
import os
import sys
# Add the parent directory of GAN_Models to sys.path
sys.path.append(os.path.abspath("../../"))  # Adjust the path as necessary

from density.fit_gmm import fit_gmm
from density.eval_logpx import evaluate_logpx
def evaluate_density_model(model,train_loader, test_loader):
    gmm = fit_gmm(model=model,data_loader=train_loader,latent_dim=20)

    log_px = evaluate_logpx(model=model,gmm=gmm,data_loader=test_loader,latent_dim=20)

    print(f"Log-likelihood: {log_px:.4f}")
    return log_px


evaluate_density_model(model=ganae,train_loader=train_loader,test_loader=test_loader)

Extracting latent vectors from dataLoader using the provided model...


TypeError: expected Tensor as element 0 in argument 0, but got tuple

In [111]:
evaluate_density_model(model=ganae,train_loader=train_loader,test_loader=test_loader)

Extracting latent vectors from dataLoader using the provided model...


TypeError: expected Tensor as element 0 in argument 0, but got tuple