# imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import argparse
import sys
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score


In [2]:
# Adjust sys.argv to remove unwanted Jupyter arguments
sys.argv = sys.argv[:1]  # Keep only the script name, remove Jupyter's arguments

# Now proceed with argparse as usual
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=200, help="size of the batches")
parser.add_argument("--n_cpu", type=int, default=8, help="number of CPU threads for data loading")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=2000, help="interval between image sampling")

# Parse the arguments
opt = parser.parse_args()

img_shape = (opt.channels, opt.img_size, opt.img_size)
print("this is the image shape ", img_shape)
cuda = True if torch.cuda.is_available() else False

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if cuda:
    torch.cuda.manual_seed(42)


this is the image shape  (1, 28, 28)


In [3]:

# Data Loading Classes and Functions
class NumpyDataset(Dataset):
    def __init__(self, dataX, dataY=None):
        self.dataX = np.load(dataX)
        self.dataY = np.load(dataY) if dataY is not None else None

    def __len__(self):
        return len(self.dataX)

    def __getitem__(self, idx):
        data = torch.tensor(self.dataX[idx], dtype=torch.float32)
        label = (
            torch.tensor(self.dataY[idx], dtype=torch.long)
            if self.dataY is not None
            else None
        )
        return data, label


In [4]:
# Define file paths
dataX = "../../../data/mnist/trainX.npy"
dataY = "../../../data/mnist/trainY.npy"
devX = "../../../data/mnist/validX.npy"
devY = "../../../data/mnist/validY.npy"
testX = "../../../data/mnist/testX.npy"
testY = "../../../data/mnist/testY.npy"

# Create dataloaders
train_dataset = NumpyDataset(dataX, dataY)
train_loader = DataLoader(dataset=train_dataset, batch_size=200, shuffle=True)

dev_dataset = NumpyDataset(devX, devY)
dev_loader = DataLoader(dataset=dev_dataset, batch_size=200, shuffle=False)

test_dataset = NumpyDataset(testX, testY)
test_loader = DataLoader(dataset=test_dataset, batch_size=200, shuffle=False)



In [5]:

# --- Encoder (Same as VAE) ---
class Encoder(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=360, latent_dim=20):
        super(Encoder, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, latent_dim)

        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.weight, mean=0.0, std=sigma)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        z = self.fc3(x)  # Latent space representation
        return z


# --- Decoder (Generator) ---
class Decoder(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=360, output_dim=784):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.weight, mean=0.0, std=sigma)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, z):
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = torch.sigmoid(self.fc3(z))  # Sigmoid activation in the output layer
        return z


# --- Discriminator (GAN) ---
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=360):
        super(Discriminator, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)  # Output a single value for real/fake classification

        self._init_weights()

    def _init_weights(self, sigma=0.05):
        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.weight, mean=0.0, std=sigma)
            nn.init.constant_(layer.bias, 0.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Sigmoid activation for binary output (real/fake)
        return x


# --- GAN-AE Model ---
class GANAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=360, latent_dim=20):
        super(GANAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
        self.discriminator = Discriminator(input_dim, hidden_dim)

    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        real_or_fake = self.discriminator(x_recon)
        return x_recon, real_or_fake




In [6]:
# --- Loss Functions ---
def bce_loss(pred, target):
    return F.binary_cross_entropy(pred, target,reduction="sum")

def gan_ae_loss(recon_x, x, real_or_fake, real_label=1, fake_label=0):
    # Reconstruction loss
    recon_loss =bce_loss(recon_x, x)

    # Discriminator loss (binary cross entropy)
    d_loss_real = bce_loss(real_or_fake, torch.full_like(real_or_fake, real_label))
    d_loss_fake = bce_loss(real_or_fake, torch.full_like(real_or_fake, fake_label))

    return recon_loss + d_loss_real + d_loss_fake



In [7]:
# --- Training Loop ---
ganae = GANAE(input_dim=784)
optimizer = torch.optim.Adam(ganae.parameters(), lr=0.0002)

for epoch in range(opt.n_epochs):
    total_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()

        recon_data, real_or_fake = ganae(data)
        loss = gan_ae_loss(recon_data, data, real_or_fake)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() # * data.size(0)

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader.dataset):.4f}")


Epoch 1, Loss: 234.2925
Epoch 2, Loss: 137.9755
Epoch 3, Loss: 113.5640
Epoch 4, Loss: 101.9898
Epoch 5, Loss: 96.4930
Epoch 6, Loss: 92.3891
Epoch 7, Loss: 88.9045
Epoch 8, Loss: 86.3177
Epoch 9, Loss: 84.3004
Epoch 10, Loss: 82.5445
Epoch 11, Loss: 81.1324
Epoch 12, Loss: 80.0676
Epoch 13, Loss: 79.2167
Epoch 14, Loss: 78.4626
Epoch 15, Loss: 77.7974
Epoch 16, Loss: 77.2049
Epoch 17, Loss: 76.6436
Epoch 18, Loss: 76.1616
Epoch 19, Loss: 75.7200
Epoch 20, Loss: 75.2922
Epoch 21, Loss: 74.9039
Epoch 22, Loss: 74.5392
Epoch 23, Loss: 74.2301
Epoch 24, Loss: 73.9024
Epoch 25, Loss: 73.6023
Epoch 26, Loss: 73.3210
Epoch 27, Loss: 73.0784
Epoch 28, Loss: 72.8321
Epoch 29, Loss: 72.6034
Epoch 30, Loss: 72.3800
Epoch 31, Loss: 72.1748
Epoch 32, Loss: 71.9646
Epoch 33, Loss: 71.7872
Epoch 34, Loss: 71.6073
Epoch 35, Loss: 71.4411
Epoch 36, Loss: 71.2781
Epoch 37, Loss: 71.1312
Epoch 38, Loss: 70.9605
Epoch 39, Loss: 70.8261
Epoch 40, Loss: 70.6874
Epoch 41, Loss: 70.5599
Epoch 42, Loss: 70.42

In [8]:
# Calculate error
def calculate_error(model, loader):
    model.eval()
    total_bce = 0.0
    total_mse = 0.0
    total_samples = 0
    with torch.no_grad():
        for data, _ in loader:
            data = data.view(data.size(0), -1)
            recon_data, _ = model(data)
            recon_data = recon_data.view(data.size(0), -1)
            bce = F.binary_cross_entropy(recon_data, data, reduction='sum')
            total_bce += bce.item()
            mse = F.mse_loss(recon_data, data, reduction='sum')
            total_mse += mse.item()
            total_samples += data.size(0)

    avg_bce = total_bce / total_samples
    avg_mse = total_mse / total_samples
    print(f"BCE: {avg_bce:.4f}, MSE: {avg_mse:.4f}")
    return avg_bce, avg_mse

calculate_error(ganae, test_loader)

BCE: 68.3799, MSE: 6.3320


(68.37991318359374, 6.331951708984375)

In [9]:
def evaluate_classification_gan_ae(model, loader):
    """
    Evaluate classification error using the trained GAN-AE model and a logistic regression classifier.

    Parameters:
    - model: Trained GAN-AE model with an encoder.
    - loader: DataLoader providing batches of (data, labels).

    Returns:
    - err: Classification error percentage.
    """
    model.eval()
    latents, labels = [], []
    
    with torch.no_grad():
        for data, label in loader:
            # Flatten data and move to the appropriate device
            data = data.view(data.size(0), -1).to(next(model.parameters()).device)
            label = label.to(next(model.parameters()).device)  # Ensure labels are also on the same device

            if label is not None:  # Collect latent representations and labels
                z = model.encoder(data)
                latents.append(z.cpu().numpy())
                labels.append(label.cpu().numpy())
    
    if not labels:  # Handle case where no labels are provided
        print("No labels provided for classification.")
        return None

    # Stack latents and labels into arrays
    latents = np.vstack(latents)
    labels = np.hstack(labels).reshape(-1)
    print(f"labels shape: {labels.shape}")
    print(f"latents shape: {latents.shape}")
    # Ensure consistent lengths of latents and labels
    min_len = min(len(latents), len(labels))
    latents = latents[:min_len]
    labels = labels[:min_len]

    # Split latent data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(latents, labels, test_size=0.3, random_state=42)

    # Train logistic regression on training split
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train, y_train)

    # Evaluate on test split
    predictions = clf.predict(X_test)
    err = 100 * (1 - accuracy_score(y_test, predictions))
    print(f"Classification Error: {err:.2f}%")
    
    return err


In [10]:
evaluate_classification_gan_ae(ganae,test_loader)

labels shape: (100000,)
latents shape: (10000, 20)
Classification Error: 10.53%


10.533333333333328

In [11]:
def evaluate_masked_mse(model, loader):
    """
    Evaluate the Masked Mean Squared Error (M-MSE) for a given model on a dataset.

    Parameters:
    - model: The trained model to evaluate.
    - loader: DataLoader providing batches of test data.

    Returns:
    - avg_mse: The average masked MSE over the entire dataset.
    """
    model.eval()  # Set the model to evaluation mode
    total_mse = 0.0  # Accumulate total masked MSE
    total_samples = 0  # Total number of samples processed

    with torch.no_grad():  # Disable gradient calculations for evaluation
        for data, _ in loader:
            # Flatten images into vectors (batch_size, D)
            data = data.view(data.size(0), -1)  # Shape: (batch_size, D)

            # Create a binary mask: Half of the columns are masked
            mask = torch.ones_like(data, dtype=torch.bool)  # Shape: (batch_size, D)
            mask[:, : data.size(1) // 2] = 0  # Mask the first half of the columns

            # Apply the mask to the input data
            masked_data = data * mask.float()  # Zero out the masked part
            
            # Forward pass through the model to get the reconstructed images
            output = model(masked_data)  # Assumes model outputs reconstructed data
            if isinstance(output, tuple):  # Extract reconstructed images if output is a tuple
                reconstructed = output[0]
            else:
                reconstructed = output

            # Ensure proper shape of reconstructed data
            reconstructed = reconstructed.view(data.size(0), -1)  # Shape: (batch_size, D)

            # Calculate the M-MSE for the current batch
            mse_batch = masked_mse(data, reconstructed, mask)  # Masked MSE for this batch
            
            # Accumulate results
            total_mse += mse_batch.item() * data.size(0)  # Multiply by batch size
            total_samples += data.size(0)  # Update total sample count

    # Compute the final average M-MSE across the dataset
    avg_mse = total_mse / total_samples
    print(f"Average Masked MSE: {avg_mse:.4f}")

    return avg_mse


# Helper function for batch-level Masked MSE
def masked_mse(x, x_hat, mask):
    """
    Compute the Masked Mean Squared Error (M-MSE) for a single batch.
    
    Parameters:
    - x: Original images (batch_size, D), where D is the number of pixels in each image.
    - x_hat: Reconstructed images (batch_size, D).
    - mask: Binary mask (batch_size, D), where 1 indicates unmasked pixels and 0 indicates masked pixels.

    Returns:
    - m_mse: Masked mean squared error for the batch.
    """
    # Compute squared error
    error = (x - x_hat) ** 2
    # Apply mask to focus only on masked-out regions
    masked_error = error * (1 - mask.float())  

    # Compute normalized MSE for the masked regions
    batch_mse = masked_error.sum(dim=1) / (1 - mask.float()).sum(dim=1)

    # Return average masked MSE for the batch
    return batch_mse.sum()


In [12]:
avg_mse = evaluate_masked_mse(ganae, test_loader)
print(avg_mse)

Average Masked MSE: 19.4237
19.423651809692384
