# imports

In [78]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import argparse
import sys
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import logging


In [79]:
# Adjust sys.argv to remove unwanted Jupyter arguments
sys.argv = sys.argv[:1]  # Keep only the script name, remove Jupyter's arguments

# Now proceed with argparse as usual
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=200, help="size of the batches")
parser.add_argument("--n_cpu", type=int, default=8, help="number of CPU threads for data loading")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=2000, help="interval between image sampling")

# Parse the arguments
opt = parser.parse_args()

img_shape = (opt.channels, opt.img_size, opt.img_size)
print("this is the image shape ", img_shape)
cuda = True if torch.cuda.is_available() else False

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]  
)

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if cuda:
    torch.cuda.manual_seed(42)


this is the image shape  (1, 28, 28)


In [80]:

# Data Loading Classes and Functions
class NumpyDataset(Dataset):
    def __init__(self, dataX, dataY=None):
        self.dataX = np.load(dataX)
        self.dataY = np.load(dataY) if dataY is not None else None

    def __len__(self):
        return len(self.dataX)

    def __getitem__(self, idx):
        data = torch.tensor(self.dataX[idx], dtype=torch.float32)
        label = (
            torch.tensor(self.dataY[idx], dtype=torch.long)
            if self.dataY is not None
            else None
        )
        return data, label


In [81]:
# Define file paths
dataX = "../../../data/mnist/trainX.npy"
dataY = "../../../data/mnist/trainY.npy"
devX = "../../../data/mnist/validX.npy"
devY = "../../../data/mnist/validY.npy"
testX = "../../../data/mnist/testX.npy"
testY = "../../../data/mnist/testY.npy"

# Create dataloaders
train_dataset = NumpyDataset(dataX, dataY)
train_loader = DataLoader(dataset=train_dataset, batch_size=200, shuffle=True)

dev_dataset = NumpyDataset(devX, devY)
dev_loader = DataLoader(dataset=dev_dataset, batch_size=200, shuffle=False)

test_dataset = NumpyDataset(testX, testY)
test_loader = DataLoader(dataset=test_dataset, batch_size=200, shuffle=False)



In [90]:

# --- Encoder (Same as VAE) ---
class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=360, latent_dim=20):
        super(Encoder, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Hidden layer 2 to mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Hidden layer 2 to log-variance
        self._init_weights()

    def _init_weights(self, sigma=0.02):
        for layer in [self.fc1, self.fc2, self.fc_mu, self.fc_logvar]:
            nn.init.normal_(layer.weight, mean=0.0, std=sigma)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


# --- Decoder (Generator) ---
class Decoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=360, output_dim=784):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

        self._init_weights()

    def _init_weights(self, sigma=0.02):
        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.weight, mean=0.0, std=sigma)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, z):
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = torch.sigmoid(self.fc3(z))  # Sigmoid activation in the output layer
        return z


# --- Discriminator (GAN) ---
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=360):
        super(Discriminator, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)  # Output a single value for real/fake classification

        self._init_weights()

    def _init_weights(self, sigma=0.02):
        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.weight, mean=0.0, std=sigma)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Sigmoid activation for binary output (real/fake)
        return x


# --- GAN-AE Model ---
class GANAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=360, latent_dim=20,l2_lambda=1e-3):
        super(GANAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
        self.discriminator = Discriminator(input_dim, hidden_dim)
        self.l2_lambda = l2_lambda
    def compute_l2_penalty(self):
        l2_penalty = 0
        for param in self.decoder.parameters():
            if param.requires_grad:
                l2_penalty += torch.sum(param**2)
        return self.l2_lambda * l2_penalty
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar




In [106]:
# --- Loss Functions ---
def bce_loss(pred, target):
    return F.binary_cross_entropy(pred, target,reduction="sum")

def gan_ae_loss(recon_x, x, real_or_fake, real_label=1, fake_label=0):
    # Reconstruction loss
    recon_loss =bce_loss(recon_x, x)

    # Discriminator loss (binary cross entropy)
    d_loss_real = bce_loss(real_or_fake, torch.full_like(real_or_fake, real_label))
    d_loss_fake = bce_loss(real_or_fake, torch.full_like(real_or_fake, fake_label))

    return recon_loss + d_loss_real + d_loss_fake

def autoencoder_gradients(recon_x, x):
    """
    Equation (4) - Autoencoder gradients
    """
    return F.binary_cross_entropy(recon_x, x, reduction='sum')

def generator_gradients(discriminator_output):
    """
    Equation (5) - Generator gradients
    """
    return -torch.mean(torch.log(discriminator_output + 1e-8))

def discriminator_gradients(discriminator_output_real, discriminator_output_fake):
    """
    Equation (6) - Discriminator gradients
    """
    return -torch.mean(torch.log(discriminator_output_real + 1e-8) + 
                      torch.log(1 - discriminator_output_fake + 1e-8))

In [107]:

num_epochs=50# Training loop
ganae = GANAE()
optimizer_gen = torch.optim.Adam(ganae.parameters(), lr=0.002)
optimizer_disc = torch.optim.Adam(ganae.discriminator.parameters(), lr=0.002)
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        batch_size = data.size(0)
        data = data.view(batch_size, -1)
        
        # Step 1: Train Autoencoder
        optimizer_gen.zero_grad()
        
        # Get encoder outputs
        mu, logvar = ganae.encoder(data)
        z = ganae.reparameterize(mu, logvar)
        
        # Generate reconstruction
        recon_x = ganae.decoder(z)
        
        # Get discriminator predictions
        disc_output = ganae.discriminator(recon_x)
        
        # Calculate autoencoder gradients (eq. 4)
        ae_loss = autoencoder_gradients(recon_x, data)
        
        # Calculate generator gradients (eq. 5)
        gen_loss = generator_gradients(disc_output)
        
        # Combined loss
        total_gen_loss = ae_loss + gen_loss
        total_gen_loss.backward()
        optimizer_gen.step()
        
        # Step 2: Train Discriminator
        optimizer_disc.zero_grad()
        
        # Generate new samples for discriminator training
        z_fake = ganae.reparameterize(mu, logvar)
        fake_samples = ganae.decoder(z_fake)
        
        # Get discriminator predictions for real and fake samples
        d_real = ganae.discriminator(data)
        d_fake = ganae.discriminator(fake_samples.detach())
        
        # Calculate discriminator loss (eq. 6)
        d_loss = discriminator_gradients(d_real, d_fake)
        
        loss = gan_ae_loss(recon_x, data, disc_output)
        
        d_loss.backward()
        optimizer_disc.step()
        total_loss += loss.item()
        
        # Logging
        # if batch_idx % 100 == 0:
        #     print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(train_loader)}] "
        #           f"AE Loss: {ae_loss.item():.4f} "
        #           f"Gen Loss: {gen_loss.item():.4f} "
        #           f"Disc Loss: {d_loss.item():.4f}")
    print(f"Epoch {epoch} Loss {total_loss / len(train_loader)}")

Epoch 0 Loss 39297.0950625
Epoch 1 Loss 26154.9494765625
Epoch 2 Loss 21748.662234375
Epoch 3 Loss 19286.81178125
Epoch 4 Loss 17801.5094765625
Epoch 5 Loss 16676.43815234375
Epoch 6 Loss 15857.13468359375
Epoch 7 Loss 15368.74139453125
Epoch 8 Loss 15008.49491796875
Epoch 9 Loss 14723.666
Epoch 10 Loss 14492.8149296875
Epoch 11 Loss 14321.977890625
Epoch 12 Loss 14163.19466015625
Epoch 13 Loss 14015.21926953125
Epoch 14 Loss 13891.94984375
Epoch 15 Loss 13806.196984375
Epoch 16 Loss 13708.88112890625
Epoch 17 Loss 13650.2826796875
Epoch 18 Loss 13601.386671875
Epoch 19 Loss 13541.22569921875
Epoch 20 Loss 13486.61441015625
Epoch 21 Loss 13445.2815078125
Epoch 22 Loss 13424.01158984375
Epoch 23 Loss 13367.80201953125
Epoch 24 Loss 13348.44544921875
Epoch 25 Loss 13297.1557421875
Epoch 26 Loss 13296.330359375
Epoch 27 Loss 13251.82384375
Epoch 28 Loss 13231.8799609375
Epoch 29 Loss 13211.09778515625
Epoch 30 Loss 13188.4995390625
Epoch 31 Loss 13171.61262109375
Epoch 32 Loss 13153.49954

In [98]:
# Calculate error
def calculate_error(model, loader):
    model.eval()
    total_bce = 0.0
    total_mse = 0.0
    total_samples = 0
    with torch.no_grad():
        for data, _ in loader:
            data = data.view(data.size(0), -1)
            recon_data, _ = model(data)
            recon_data = recon_data.view(data.size(0), -1)
            bce = F.binary_cross_entropy(recon_data, data, reduction='sum')
            total_bce += bce.item()
            mse = F.mse_loss(recon_data, data, reduction='sum')
            total_mse += mse.item()
            total_samples += data.size(0)

    avg_bce = total_bce / total_samples
    avg_mse = total_mse / total_samples
    print(f"BCE: {avg_bce:.4f}, MSE: {avg_mse:.4f}")
    return avg_bce, avg_mse

calculate_error(ganae, test_loader)

ValueError: too many values to unpack (expected 2)

In [99]:
def evaluate_classification_gan_ae(model, loader):
    """
    Evaluate classification error using the trained GAN-AE model and a logistic regression classifier.

    Parameters:
    - model: Trained GAN-AE model with an encoder.
    - loader: DataLoader providing batches of (data, labels).

    Returns:
    - err: Classification error percentage.
    """
    model.eval()
    latents, labels = [], []
    
    with torch.no_grad():
        for data, label in loader:
            # Flatten data and move to the appropriate device
            data = data.view(data.size(0), -1).to(next(model.parameters()).device)
            label = label.to(next(model.parameters()).device)  # Ensure labels are also on the same device

            if label is not None:  # Collect latent representations and labels
                z = model.encoder(data)
                latents.append(z.cpu().numpy())
                labels.append(label.cpu().numpy())
    
    if not labels:  # Handle case where no labels are provided
        print("No labels provided for classification.")
        return None

    # Stack latents and labels into arrays
    latents = np.vstack(latents)
    labels = np.hstack(labels).reshape(-1)
    print(f"labels shape: {labels.shape}")
    print(f"latents shape: {latents.shape}")
    # Ensure consistent lengths of latents and labels
    min_len = min(len(latents), len(labels))
    latents = latents[:min_len]
    labels = labels[:min_len]

    # Split latent data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(latents, labels, test_size=0.3, random_state=42)

    # Train logistic regression on training split
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)

    # Evaluate on test split
    predictions = clf.predict(X_test)
    err = 100 * (1 - accuracy_score(y_test, predictions))
    print(f"Classification Error: {err:.2f}%")
    
    return err


In [100]:
evaluate_classification_gan_ae(ganae,test_loader)

AttributeError: 'tuple' object has no attribute 'cpu'

In [None]:
def evaluate_masked_mse(model, loader):
    """
    Evaluate the Masked Mean Squared Error (M-MSE) for a given model on a dataset.

    Parameters:
    - model: The trained model to evaluate.
    - loader: DataLoader providing batches of test data.

    Returns:
    - avg_mse: The average masked MSE over the entire dataset.
    """
    model.eval()  # Set the model to evaluation mode
    total_mse = 0.0  # Accumulate total masked MSE
    total_samples = 0  # Total number of samples processed

    with torch.no_grad():  # Disable gradient calculations for evaluation
        for data, _ in loader:
            # Flatten images into vectors (batch_size, D)
            data = data.view(data.size(0), -1)  # Shape: (batch_size, D)

            # Create a binary mask: Half of the columns are masked
            mask = torch.ones_like(data, dtype=torch.bool)  # Shape: (batch_size, D)
            mask[:, : data.size(1) // 2] = 0  # Mask the first half of the columns

            # Apply the mask to the input data
            masked_data = data * mask.float()  # Zero out the masked part
            
            # Forward pass through the model to get the reconstructed images
            output = model(masked_data)  # Assumes model outputs reconstructed data
            if isinstance(output, tuple):  # Extract reconstructed images if output is a tuple
                reconstructed = output[0]
            else:
                reconstructed = output

            # Ensure proper shape of reconstructed data
            reconstructed = reconstructed.view(data.size(0), -1)  # Shape: (batch_size, D)

            # Calculate the M-MSE for the current batch
            mse_batch = masked_mse(data, reconstructed, mask)  # Masked MSE for this batch
            
            # Accumulate results
            total_mse += mse_batch.item() * data.size(0)  # Multiply by batch size
            total_samples += data.size(0)  # Update total sample count

    # Compute the final average M-MSE across the dataset
    avg_mse = total_mse / total_samples
    print(f"Average Masked MSE: {avg_mse:.4f}")

    return avg_mse


# Helper function for batch-level Masked MSE
def masked_mse(x, x_hat, mask):
    """
    Compute the Masked Mean Squared Error (M-MSE) for a single batch.
    
    Parameters:
    - x: Original images (batch_size, D), where D is the number of pixels in each image.
    - x_hat: Reconstructed images (batch_size, D).
    - mask: Binary mask (batch_size, D), where 1 indicates unmasked pixels and 0 indicates masked pixels.

    Returns:
    - m_mse: Masked mean squared error for the batch.
    """
    # Compute squared error
    error = (x - x_hat) ** 2
    # Apply mask to focus only on masked-out regions
    masked_error = error * (1 - mask.float())  

    # Compute normalized MSE for the masked regions
    batch_mse = masked_error.sum(dim=1) / (1 - mask.float()).sum(dim=1)

    # Return average masked MSE for the batch
    return batch_mse.sum()


In [None]:
avg_mse = evaluate_masked_mse(ganae, test_loader)
print(avg_mse)

Average Masked MSE: 19.4237
19.423651809692384


In [None]:
import os
import sys
# Add the parent directory of GAN_Models to sys.path
sys.path.append(os.path.abspath("../../"))  # Adjust the path as necessary

from density.fit_gmm import fit_gmm
from density.eval_logpx import evaluate_logpx
def evaluate_density_model(model,train_loader, test_loader):
    gmm = fit_gmm(model=model,data_loader=train_loader,latent_dim=20)

    log_px = evaluate_logpx(model=model,gmm=gmm,data_loader=test_loader,latent_dim=20)

    print(f"Log-likelihood: {log_px:.4f}")
    return log_px

In [None]:
evaluate_density_model(ganae,test_loader)

TypeError: evaluate_density_model() missing 1 required positional argument: 'test_loader'

In [None]:

import time
from utils.calc_perc_error import evaluate_perc_err
def evaluate_model(model, train_loader, test_loader, latent_dim, n_components=75, num_samples=5000):
    """
    Evaluate the model on various metrics including BCE loss, classification error, GMM fitting, 
    Monte Carlo log-likelihood, and inference time.

    Args:
        model: The trained model to evaluate.
        train_loader: DataLoader providing the training dataset.
        test_loader: DataLoader providing the test dataset.
        latent_dim: Dimension of the latent space.
        n_components: Number of GMM components for fitting.
        num_samples: Number of Monte Carlo samples for log-likelihood estimation.

    Returns:
        results: A dictionary containing evaluation metrics and total inference time.
    """
    logging.info("Starting model evaluation...")
    inference_start_time = time.time()

    results = {}
    model.eval()

    logging.info("Calculating Binary Cross-Entropy (BCE) loss...")
    bce_losses = []

   
    # logging.info("Evaluating classification error (%Err)...")
    # err = evaluate_perc_err(model, train_loader, test_loader)
    # results['Classification_Error'] = err
    # logging.info(f"Classification error: {err:.4f}%")

    logging.info("Fitting GMM on latent space...")
    gmm = fit_gmm(train_loader, model, latent_dim=latent_dim, n_components=n_components)
    logging.info("Finished fitting GMM.")

    logging.info("Evaluating Monte Carlo log-likelihood...")
    test_logpx = evaluate_logpx(test_loader, model, gmm, latent_dim=latent_dim, num_samples=num_samples)
    results['Monte_Carlo_Log_Likelihood'] = test_logpx
    logging.info(f"Monte Carlo log-likelihood: {test_logpx:.4f}")

    total_inference_time = time.time() - inference_start_time
    results['Total_Inference_Time'] = total_inference_time
    logging.info(f"Total inference time: {total_inference_time:.2f} seconds")

    return results

In [None]:
evaluate_model(train_loader=train_loader,test_loader=test_loader,latent_dim=20,model=ganae,n_components=75,num_samples=5000)

2025-01-20 12:24:19,386 - INFO - Starting model evaluation...
2025-01-20 12:24:19,388 - INFO - Calculating Binary Cross-Entropy (BCE) loss...
2025-01-20 12:24:19,389 - INFO - Fitting GMM on latent space...


Extracting latent vectors from dataLoader using the provided model...
Collected latent data shape: torch.Size([50000, 20])


2025-01-20 12:27:29,410 - INFO - Finished fitting GMM.
2025-01-20 12:27:29,411 - INFO - Evaluating Monte Carlo log-likelihood...


Saving GMM to file: gmm.pkl


2025-01-20 12:42:27,152 - INFO - Monte Carlo log-likelihood: -276.2071
2025-01-20 12:42:27,153 - INFO - Total inference time: 1087.77 seconds


{'Monte_Carlo_Log_Likelihood': -276.20710555114744,
 'Total_Inference_Time': 1087.7655725479126}