In [68]:
import torch
import numpy as np
import random
import os
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.datasets import MNIST
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader 
from torch import nn, optim
from torch.autograd import Variable
from torchvision.utils import save_image


Latent_noise_dim = 50 # (z) 
batch_size = 64
learning_rate = 0.0001
epochs = 100
optimizer = optim.Adam

## Data

In [69]:
# Normalization: x -> (x - mean) / std
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # single channel, so one value
])

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    transform=transform,
    download=False
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)
image, label = train_dataset[10]
print(image.min(), image.max())

tensor(-1.) tensor(0.9922)


## Prior On Noise

In [70]:
def p_z(z):
    return torch.normal(0, 1, size=(z, Latent_noise_dim))

## Generator and Discriminator Network/function

In [71]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        
        self.img_shape = img_shape
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            # *block(256, 512),
            *block(256, 512),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img
    
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


In [72]:
def discriminator_loss(D, real_imgs, fake_imgs):

    real_preds = D(real_imgs)
    fake_preds = D(fake_imgs)

    eps = 1e-8
    real_loss = -torch.mean(torch.log(real_preds + eps))
    fake_loss = -torch.mean(torch.log(1 - fake_preds + eps))

    return real_loss + fake_loss


def generator_loss(D, fake_imgs, non_saturating=True):
    fake_preds = D(fake_imgs)
    eps = 1e-8

    if non_saturating:
        g_loss = -torch.mean(torch.log(fake_preds + eps))
    else:
        g_loss = torch.mean(torch.log(1 - fake_preds + eps))

    return g_loss

## Training Loop

In [73]:
D = Discriminator((1, 28, 28))
G = Generator(Latent_noise_dim, (1, 28, 28))
optimizer_D = optimizer(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_G = optimizer(G.parameters(), lr=learning_rate, betas=(0.5, 0.999))

total_params = sum(p.numel() for p in G.parameters())
total_params_D = sum(p.numel() for p in D.parameters())
print(f"Total parameters in G: {total_params}")
print(f"Total parameters in D: {total_params_D}")

Total parameters in G: 574864
Total parameters in D: 533505


In [74]:
# d_losses, g_losses = [], []

# plt.ion()  # interactive mode
# fig, ax = plt.subplots()

# for epoch in range(epochs):
#     for i, (imgs, _) in enumerate(train_loader):
        
#         batch_size = imgs.size(0)
#         real_imgs = imgs.type(torch.FloatTensor)
        
#         z = p_z(batch_size)
#         fake_imgs = G(z).detach()   
#         d_loss = discriminator_loss(D, real_imgs, fake_imgs)

#         optimizer_D.zero_grad()
#         d_loss.backward()
#         optimizer_D.step()

#         # --- Train G ---
#         z = p_z(batch_size)
#         fake_imgs = G(z)
#         g_loss = generator_loss(D, fake_imgs, non_saturating=True)

#         optimizer_G.zero_grad()
#         g_loss.backward()
#         optimizer_G.step()
    
#     # Save losses
#     d_losses.append(d_loss.item())
#     g_losses.append(g_loss.item())

#     # Print losses
#     print(f"Epoch [{epoch+1}/{epochs}]  D_loss: {d_loss.item():.4f}  G_loss: {g_loss.item():.4f}")

#     # Update live plot
#     # ax.clear()
#     # ax.plot(d_losses, label="D loss")
#     # ax.plot(g_losses, label="G loss")
#     # ax.legend()
#     # ax.set_xlabel("Epochs")
#     # ax.set_ylabel("Loss")
#     # plt.pause(0.01)

#     # Save samples
#     # if epoch % 100 == 0:
#     save_image(fake_imgs.data[:25], f"./images/{epoch}.png", nrow=5, normalize=True)

# # plt.ioff()
# # plt.show()


# plt.plot(d_losses, label="D loss")
# plt.plot(g_losses, label="G loss")
# plt.xlabel("Epochs")
# plt.ylabel("Loss")
# plt.legend()
# plt.show()



In [75]:
import matplotlib
matplotlib.use('TkAgg')  # Use an interactive backend; change to 'Qt5Agg' or others if needed
import matplotlib.pyplot as plt
from IPython.display import clear_output  # Optional: For Jupyter environments; comment out if not in Jupyter


d_losses, g_losses = [], []  # Lists to store losses (last batch per epoch; for averages, accumulate inside inner loop)

plt.ion()  # Interactive mode on
fig, ax = plt.subplots()


for epoch in range(epochs):
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0
    num_batches = 0
    
    for i, (imgs, _) in enumerate(train_loader):
        batch_size = imgs.size(0)
        real_imgs = imgs.view(batch_size, -1).type(torch.FloatTensor)  # Add .to(device) if using GPU
        
        # Train D
        z = p_z(batch_size)  # e.g., torch.randn(batch_size, latent_dim)
        fake_imgs = G(z).detach()
        d_loss = discriminator_loss(D, real_imgs, fake_imgs)
        
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train G
        z = p_z(batch_size)
        fake_imgs = G(z)
        g_loss = generator_loss(D, fake_imgs, non_saturating=True)
        
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()
        
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()
        num_batches += 1
    
    avg_d_loss = epoch_d_loss / num_batches
    avg_g_loss = epoch_g_loss / num_batches
    d_losses.append(avg_d_loss)
    g_losses.append(avg_g_loss)

    print(f"Epoch [{epoch+1}/{epochs}]  D_loss: {avg_d_loss:.4f}  G_loss: {avg_g_loss:.4f}")

    ax.clear()
    ax.plot(d_losses, label="D loss", color="blue")
    ax.plot(g_losses, label="G loss", color="orange")
    ax.legend()
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    fig.canvas.draw()  # Explicit redraw
    fig.canvas.flush_events()  # Flush events
    plt.pause(0.1)  # Pause for visibility (adjust as needed)

    # Alternative for Jupyter: Use clear_output for live updates
    # clear_output(wait=True)
    # plt.figure()
    # plt.plot(d_losses, label="D loss", color="blue")
    # plt.plot(g_losses, label="G loss", color="orange")
    # plt.xlabel("Epochs")
    # plt.ylabel("Loss")
    # plt.legend()
    # plt.show()

    # Save samples (uncomment and adjust frequency)
    if (epoch + 1) % 10 == 0 or epoch == 0:
        with torch.no_grad():
            z = p_z(25)  
            fake_imgs = G(z).view(25, 1, 28, 28)  # Reshape for save_image (assuming MNIST)
            save_image(fake_imgs, f"./images/epoch_{epoch+1}.png", nrow=5, normalize=True)

plt.ioff()
plt.figure()  # New figure to avoid overriding the interactive one
plt.plot(d_losses, label="D loss", color="blue")
plt.plot(g_losses, label="G loss", color="orange")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

Epoch [1/100]  D_loss: 0.7938  G_loss: 1.3375
Epoch [2/100]  D_loss: 0.7299  G_loss: 1.8334
Epoch [3/100]  D_loss: 0.6699  G_loss: 2.0184
Epoch [4/100]  D_loss: 0.6427  G_loss: 2.1498
Epoch [5/100]  D_loss: 0.6353  G_loss: 2.2568
Epoch [6/100]  D_loss: 0.6127  G_loss: 2.3324
Epoch [7/100]  D_loss: 0.6140  G_loss: 2.3782
Epoch [8/100]  D_loss: 0.5638  G_loss: 2.4750
Epoch [9/100]  D_loss: 0.5923  G_loss: 2.4516
Epoch [10/100]  D_loss: 0.5277  G_loss: 2.6514
Epoch [11/100]  D_loss: 0.5216  G_loss: 2.7197
Epoch [12/100]  D_loss: 0.5600  G_loss: 2.6822
Epoch [13/100]  D_loss: 0.5850  G_loss: 2.5306
Epoch [14/100]  D_loss: 0.6041  G_loss: 2.3700
Epoch [15/100]  D_loss: 0.6465  G_loss: 2.3291
Epoch [16/100]  D_loss: 0.6342  G_loss: 2.2908
Epoch [17/100]  D_loss: 0.6898  G_loss: 2.1911
Epoch [18/100]  D_loss: 0.7554  G_loss: 1.9582
Epoch [19/100]  D_loss: 0.7571  G_loss: 1.9175
Epoch [20/100]  D_loss: 0.7765  G_loss: 1.8549
Epoch [21/100]  D_loss: 0.8003  G_loss: 1.8154
Epoch [22/100]  D_loss