In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torchvision.transforms.functional import resize
from torch.nn.functional import interpolate
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import os
from torchvision.utils import save_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [35]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),   # Accepts 3-channel RGB input
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


In [36]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

In [37]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 128 * 8 * 8)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # (B, 64, 16, 16)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),    # (B, 3, 32, 32)
            nn.Sigmoid()  # Ensures output in [0,1]
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 128, 8, 8)
        return self.deconv(x)


In [38]:
# Cell 3: Visualize Reconstructions
def visualize(epoch, x, x_hat, n=6):
    x = x[:n].cpu().detach().numpy()
    x_hat = x_hat[:n].cpu().detach().numpy()
    fig, axes = plt.subplots(2, n, figsize=(n*1.2, 3))
    for i in range(n):
        axes[0, i].imshow(x[i][0], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(x_hat[i][0], cmap='gray')
        axes[1, i].axis('off')
    plt.suptitle(f"Epoch {epoch+1}: Top - Original | Bottom - Reconstructed")
    plt.show()


In [39]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),     # (B, 3, 32, 32) -> (B, 64, 16, 16)
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),   # (B, 64, 16, 16) -> (B, 128, 8, 8)
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [40]:
def train(model_components, dataloader, optimizers, epochs=10, save_dir='./eval_images'):
    encoder, decoder, discriminator = model_components
    opt_vae, opt_disc = optimizers
    vae_losses, d_losses = [], []

    real_dir = os.path.join(save_dir, 'real')
    fake_dir = os.path.join(save_dir, 'fake')
    os.makedirs(real_dir, exist_ok=True)
    os.makedirs(fake_dir, exist_ok=True)

    for epoch in range(epochs):
        for i, (x, _) in enumerate(dataloader):
            x = x.to(device)
            batch_size = x.size(0)

            mu, logvar = encoder(x)
            z = reparameterize(mu, logvar)
            x_hat = decoder(z)

            recon_loss = F.binary_cross_entropy(x_hat, x, reduction='sum') / batch_size
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
            d_fake = discriminator(x_hat)
            gan_loss = torch.mean(torch.log(d_fake + 1e-8))
            vae_loss = recon_loss + kl_div - 0.001 * gan_loss

            opt_vae.zero_grad()
            vae_loss.backward()
            opt_vae.step()

            d_real = discriminator(x)
            d_fake = discriminator(x_hat.detach())
            d_loss = -torch.mean(torch.log(d_real + 1e-8) + torch.log(1 - d_fake + 1e-8))

            opt_disc.zero_grad()
            d_loss.backward()
            opt_disc.step()

            vae_losses.append(vae_loss.item())
            d_losses.append(d_loss.item())

            # Save images from first batch each epoch
            # Save real inputs
            for idx in range(batch_size):
                save_image(x[idx].cpu(), os.path.join(real_dir, f'epoch{epoch+1}_img{idx+1}.png'), normalize=True)
            # Save generated outputs
            for idx in range(batch_size):
                save_image(x_hat[idx].cpu(), os.path.join(fake_dir, f'epoch{epoch+1}_img{idx+1}.png'), normalize=True)

        print(f"Epoch [{epoch+1}/{epochs}] | VAE Loss: {vae_loss.item():.4f} | D Loss: {d_loss.item():.4f}")
        visualize(epoch, x, x_hat)

    return vae_losses, d_losses

In [None]:
# Cell 5: Initialize and Run Training
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

latent_dim = 20
encoder = Encoder(latent_dim).to(device)
decoder = Decoder(latent_dim).to(device)
discriminator = Discriminator().to(device)

opt_vae = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
opt_disc = optim.Adam(discriminator.parameters(), lr=1e-4)

vae_losses, d_losses = train((encoder, decoder, discriminator), dataloader, (opt_vae, opt_disc), epochs=100)


In [None]:
# Assuming 10 epochs
epochs = [i for i in range(1, len(vae_losses) + 1)]

# Ensure losses are on CPU and converted to numpy (if they are tensors)
vae_losses_np = [loss.detach().cpu().item() if torch.is_tensor(loss) else loss for loss in vae_losses]
d_losses_np = [loss.detach().cpu().item() if torch.is_tensor(loss) else loss for loss in d_losses]

# Plotting
plt.figure(figsize=(10, 5))
plt.plot(epochs, vae_losses_np, label='VAE Loss')
plt.plot(epochs, d_losses_np, label='Discriminator Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curves")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Cell 1: Imports
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score,
    recall_score, f1_score, roc_curve, auc
)
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


In [None]:
# Cell 2: Evaluation Function
def evaluate_discriminator(discriminator, dataloader, encoder, decoder, phase='Train'):
    discriminator.eval()
    encoder.eval()
    decoder.eval()

    y_true = []
    y_pred = []

    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            batch_size = x.size(0)

            # Generate fake images
            mu, logvar = encoder(x)
            z = reparameterize(mu, logvar)
            x_fake = decoder(z)

            # Discriminator outputs
            real_scores = discriminator(x).view(-1)
            fake_scores = discriminator(x_fake).view(-1)

            y_true.extend([1] * batch_size + [0] * batch_size)  # 1 for real, 0 for fake
            y_pred.extend(real_scores.cpu().numpy().tolist() + fake_scores.cpu().numpy().tolist())

    # Threshold at 0.5 to get binary predictions
    y_pred_binary = [1 if p >= 0.5 else 0 for p in y_pred]

    # Metrics
    cm = confusion_matrix(y_true, y_pred_binary)
    acc = accuracy_score(y_true, y_pred_binary)
    prec = precision_score(y_true, y_pred_binary)
    rec = recall_score(y_true, y_pred_binary)
    f1 = f1_score(y_true, y_pred_binary)
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)

    print(f"\n=== {phase} Set Evaluation ===")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall:    {rec:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    print(f"AUC:       {roc_auc:.4f}")
    print("Confusion Matrix:")
    print(cm)

    # Plot Confusion Matrix
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"{phase} - Confusion Matrix")
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

    # Plot ROC Curve
    plt.figure(figsize=(6, 4))
    plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{phase} - ROC Curve")
    plt.legend(loc='lower right')
    plt.grid()
    plt.show()

    discriminator.train()
    encoder.train()
    decoder.train()


In [None]:
# Cell 3: Evaluate on Train and Test Sets
test_dataset = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

evaluate_discriminator(discriminator, dataloader, encoder, decoder, phase='Train')
evaluate_discriminator(discriminator, test_loader, encoder, decoder, phase='Test')


In [None]:
# Reparameterization trick
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

In [41]:
import os
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from PIL import Image

# ----------------- CONFIG -----------------
latent_dim = 100
batch_size = 64
epochs = 10
save_dir = "eval_images"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----------------- DATA -------------------
transform = transforms.Compose([
    transforms.ToTensor()  # range [0, 1]
])

dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# ----------------- MODELS -----------------
encoder = Encoder(latent_dim).to(device)
decoder = Decoder(latent_dim).to(device)
discriminator = Discriminator().to(device)

# ----------------- OPTIMIZERS -----------------
opt_vae = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
opt_disc = optim.Adam(discriminator.parameters(), lr=1e-4)

# ----------------- TRAINING -----------------
os.makedirs(f"{save_dir}/real", exist_ok=True)
os.makedirs(f"{save_dir}/fake", exist_ok=True)
start_time = time.time()

image_id = 0  # for saving FID/IS image sets

for epoch in range(epochs):
    for i, (x, _) in enumerate(train_loader):
        x = x.to(device)
        batch_size = x.size(0)

        # === VAE Forward ===
        mu, logvar = encoder(x)
        z = reparameterize(mu, logvar)
        x_hat = decoder(z)

        recon_loss = F.binary_cross_entropy(x_hat, x, reduction='sum') / batch_size
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
        d_fake = discriminator(x_hat)
        gan_loss = torch.mean(torch.log(d_fake + 1e-8))

        vae_loss = recon_loss + kl_div - 0.001 * gan_loss

        opt_vae.zero_grad()
        vae_loss.backward()
        opt_vae.step()

        # === Discriminator ===
        d_real = discriminator(x)
        d_fake = discriminator(x_hat.detach())
        d_loss = -torch.mean(torch.log(d_real + 1e-8) + torch.log(1 - d_fake + 1e-8))

        opt_disc.zero_grad()
        d_loss.backward()
        opt_disc.step()

        # Save real and fake images (first N only)
        if epoch == epochs - 1 and image_id < 1000:
            for idx in range(min(batch_size, 1000 - image_id)):
                save_image(x[idx], f"{save_dir}/real/{image_id:04d}.png")
                save_image(x_hat[idx], f"{save_dir}/fake/{image_id:04d}.png")
                image_id += 1

    print(f"Epoch [{epoch+1}/{epochs}] | VAE Loss: {vae_loss.item():.4f} | D Loss: {d_loss.item():.4f}")

# ----------------- TIMER -----------------
end_time = time.time()
print(f"\nTotal Training Time: {(end_time - start_time)/60:.2f} minutes")

# ----------------- FID & IS Evaluation -----------------
def preprocess_fid_image(img_path):
    img = Image.open(img_path).convert("RGB").resize((299, 299))
    img_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))
    img_tensor = img_tensor.view(img.size[1], img.size[0], 3).permute(2, 0, 1)
    return img_tensor.unsqueeze(0)

def compute_fid_and_is(real_dir, fake_dir):
    fid = FrechetInceptionDistance(feature=2048).to(device)
    is_model = InceptionScore().to(device)

    for path in sorted(os.listdir(real_dir)):
        tensor = preprocess_fid_image(os.path.join(real_dir, path)).to(device)
        fid.update(tensor, real=True)

    for path in sorted(os.listdir(fake_dir)):
        tensor = preprocess_fid_image(os.path.join(fake_dir, path)).to(device)
        fid.update(tensor, real=False)
        is_model.update(tensor)

    fid_score = fid.compute().item()
    is_score, is_std = is_model.compute()
    return fid_score, is_score.item(), is_std.item()

fid, is_score, is_std = compute_fid_and_is(f"{save_dir}/real", f"{save_dir}/fake")
print(f"\nFID Score: {fid:.4f}")
print(f"Inception Score: {is_score:.4f} ± {is_std:.4f}")


Files already downloaded and verified
Epoch [1/10] | VAE Loss: 1806.4388 | D Loss: 0.4879
Epoch [2/10] | VAE Loss: 1864.8618 | D Loss: 0.4830
Epoch [3/10] | VAE Loss: 1881.4722 | D Loss: 0.1904
Epoch [4/10] | VAE Loss: 1791.9443 | D Loss: 0.0390
Epoch [5/10] | VAE Loss: 1767.8518 | D Loss: 0.0206
Epoch [6/10] | VAE Loss: 1883.4524 | D Loss: 0.0478
Epoch [7/10] | VAE Loss: 1736.4465 | D Loss: 0.0317
Epoch [8/10] | VAE Loss: 1767.6279 | D Loss: 0.0793
Epoch [9/10] | VAE Loss: 1878.8243 | D Loss: 0.0931
Epoch [10/10] | VAE Loss: 1806.6317 | D Loss: 0.0265

Total Training Time: 1.70 minutes


  img_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))



FID Score: 71.2428
Inception Score: 3.0169 ± 0.1106
