### Task 2.1

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader

import os

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 100
batch_size = 100
learning_rate = 0.0002

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

mnist = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)

# Discriminator network
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()
).to(device)

# Generator network
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
).to(device)

# Loss function and optimizer
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)

# Function to reset gradients
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

# Training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # Discriminator training
        # Real images
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Generator training
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Create directory if it doesn't exist
    os.makedirs('./gan_images', exist_ok=True)

    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        utils.save_image(images, './gan_images/real_images.png')
        utils.save_image(fake_images, './gan_images/fake_images-{}.png'.format(epoch+1))

# Save the model checkpoints
torch.save(G.state_dict(), './gan_generator.pth')
torch.save(D.state_dict(), './gan_discriminator.pth')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|████████████████████████████| 9912422/9912422 [00:03<00:00, 2999933.42it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 28881/28881 [00:00<00:00, 226880.28it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 2187655.24it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 1711945.43it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [0/100], Step [200/600], d_loss: 0.0481, g_loss: 4.4430, D(x): 0.99, D(G(z)): 0.04
Epoch [0/100], Step [400/600], d_loss: 0.1597, g_loss: 5.6715, D(x): 0.93, D(G(z)): 0.07
Epoch [0/100], Step [600/600], d_loss: 0.0442, g_loss: 5.3766, D(x): 0.98, D(G(z)): 0.03
Epoch [1/100], Step [200/600], d_loss: 0.0715, g_loss: 3.9779, D(x): 0.98, D(G(z)): 0.05
Epoch [1/100], Step [400/600], d_loss: 0.0904, g_loss: 4.5770, D(x): 0.95, D(G(z)): 0.03
Epoch [1/100], Step [600/600], d_loss: 0.1722, g_loss: 4.1302, D(x): 0.92, D(G(z)): 0.05
Epoch [2/100], Step [200/600], d_loss: 0.1638, g_loss: 5.1826, D(x): 0.92, D(G(z)): 0.05
Epoch [2/100], Step [400/600], d_loss: 0.1052, g_loss: 3.8618, D(x): 0.97, D(G(z)): 0.06
Epoch [2/100], Step [600/600], d_loss: 0.4449, g_loss: 3.1442, D(x): 0.86, D(G(z)): 0.20
Epoch [3/100], Step [200/600], d_loss: 0.3947, g_loss: 2.8401, D(x): 0.86, D(G(z)): 0.13
Epoch [3/100], Step [400/600], d_lo

Epoch [30/100], Step [400/600], d_loss: 0.3112, g_loss: 4.6060, D(x): 0.94, D(G(z)): 0.14
Epoch [30/100], Step [600/600], d_loss: 0.4899, g_loss: 4.0179, D(x): 0.87, D(G(z)): 0.13
Epoch [31/100], Step [200/600], d_loss: 0.2800, g_loss: 3.2951, D(x): 0.89, D(G(z)): 0.05
Epoch [31/100], Step [400/600], d_loss: 0.4534, g_loss: 3.0576, D(x): 0.92, D(G(z)): 0.17
Epoch [31/100], Step [600/600], d_loss: 0.4844, g_loss: 4.8579, D(x): 0.90, D(G(z)): 0.16
Epoch [32/100], Step [200/600], d_loss: 0.3589, g_loss: 3.7555, D(x): 0.92, D(G(z)): 0.18
Epoch [32/100], Step [400/600], d_loss: 0.5010, g_loss: 3.3608, D(x): 0.89, D(G(z)): 0.21
Epoch [32/100], Step [600/600], d_loss: 0.4247, g_loss: 4.0754, D(x): 0.87, D(G(z)): 0.11
Epoch [33/100], Step [200/600], d_loss: 0.5202, g_loss: 4.1145, D(x): 0.80, D(G(z)): 0.08
Epoch [33/100], Step [400/600], d_loss: 0.2732, g_loss: 3.4368, D(x): 0.94, D(G(z)): 0.14
Epoch [33/100], Step [600/600], d_loss: 0.2552, g_loss: 4.0396, D(x): 0.93, D(G(z)): 0.10
Epoch [34/

Epoch [61/100], Step [200/600], d_loss: 0.7831, g_loss: 1.9867, D(x): 0.79, D(G(z)): 0.27
Epoch [61/100], Step [400/600], d_loss: 0.5152, g_loss: 2.1943, D(x): 0.87, D(G(z)): 0.21
Epoch [61/100], Step [600/600], d_loss: 0.7426, g_loss: 2.0463, D(x): 0.78, D(G(z)): 0.20
Epoch [62/100], Step [200/600], d_loss: 0.6889, g_loss: 2.0327, D(x): 0.75, D(G(z)): 0.18
Epoch [62/100], Step [400/600], d_loss: 0.6913, g_loss: 2.3141, D(x): 0.72, D(G(z)): 0.15
Epoch [62/100], Step [600/600], d_loss: 0.6942, g_loss: 2.2723, D(x): 0.82, D(G(z)): 0.28
Epoch [63/100], Step [200/600], d_loss: 0.2940, g_loss: 3.8293, D(x): 0.88, D(G(z)): 0.10
Epoch [63/100], Step [400/600], d_loss: 0.4717, g_loss: 2.8286, D(x): 0.84, D(G(z)): 0.17
Epoch [63/100], Step [600/600], d_loss: 0.5606, g_loss: 2.6182, D(x): 0.84, D(G(z)): 0.21
Epoch [64/100], Step [200/600], d_loss: 0.6040, g_loss: 2.3600, D(x): 0.86, D(G(z)): 0.27
Epoch [64/100], Step [400/600], d_loss: 0.6813, g_loss: 2.2488, D(x): 0.75, D(G(z)): 0.16
Epoch [64/

Epoch [91/100], Step [600/600], d_loss: 1.0600, g_loss: 2.2009, D(x): 0.69, D(G(z)): 0.27
Epoch [92/100], Step [200/600], d_loss: 0.8171, g_loss: 2.7028, D(x): 0.68, D(G(z)): 0.16
Epoch [92/100], Step [400/600], d_loss: 0.8279, g_loss: 2.1755, D(x): 0.74, D(G(z)): 0.24
Epoch [92/100], Step [600/600], d_loss: 0.8592, g_loss: 2.2217, D(x): 0.65, D(G(z)): 0.17
Epoch [93/100], Step [200/600], d_loss: 0.7624, g_loss: 1.9484, D(x): 0.79, D(G(z)): 0.30
Epoch [93/100], Step [400/600], d_loss: 0.8328, g_loss: 1.8992, D(x): 0.70, D(G(z)): 0.21
Epoch [93/100], Step [600/600], d_loss: 0.7521, g_loss: 1.8613, D(x): 0.78, D(G(z)): 0.29
Epoch [94/100], Step [200/600], d_loss: 0.7335, g_loss: 2.0091, D(x): 0.72, D(G(z)): 0.20
Epoch [94/100], Step [400/600], d_loss: 0.6924, g_loss: 1.4274, D(x): 0.85, D(G(z)): 0.28
Epoch [94/100], Step [600/600], d_loss: 0.6807, g_loss: 1.7823, D(x): 0.80, D(G(z)): 0.24
Epoch [95/100], Step [200/600], d_loss: 0.9312, g_loss: 1.8717, D(x): 0.71, D(G(z)): 0.30
Epoch [95/

d_loss: 0.8115 means the descriminator is able to distinguish from fake images with an accuracy of 81 %.

g_loss: 1.5439 is the generator loss. It is lower means the generator is producing better quality images.

D(x): 0.79 is the probability assigned to real images being real.

D(G(z)): 0.29 is the probability assigned to fake images being real.

### Task 2.2

In [3]:
num_epochs = 100 # Adjusted for faster comparison

# Loss function and optimizer
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)

# Function to reset gradients
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

# Training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # Discriminator training
        # Real images
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # Generator training
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save generated images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        utils.save_image(images, './gan_images/real_images_logistic.png')
        utils.save_image(fake_images, './gan_images/fake_images_logistic-{}.png'.format(epoch+1))

# Save the model checkpoints
torch.save(G.state_dict(), './gan_generator_logistic.pth')
torch.save(D.state_dict(), './gan_discriminator_logistic.pth')


Epoch [0/100], Step [200/600], d_loss: 0.7688, g_loss: 1.6820, D(x): 0.80, D(G(z)): 0.31
Epoch [0/100], Step [400/600], d_loss: 0.9130, g_loss: 1.8409, D(x): 0.80, D(G(z)): 0.35
Epoch [0/100], Step [600/600], d_loss: 0.7252, g_loss: 2.0741, D(x): 0.73, D(G(z)): 0.22
Epoch [1/100], Step [200/600], d_loss: 0.9458, g_loss: 2.0372, D(x): 0.79, D(G(z)): 0.37
Epoch [1/100], Step [400/600], d_loss: 0.8088, g_loss: 1.5339, D(x): 0.69, D(G(z)): 0.21
Epoch [1/100], Step [600/600], d_loss: 1.0222, g_loss: 1.5385, D(x): 0.68, D(G(z)): 0.32
Epoch [2/100], Step [200/600], d_loss: 0.9073, g_loss: 1.8366, D(x): 0.74, D(G(z)): 0.31
Epoch [2/100], Step [400/600], d_loss: 0.8170, g_loss: 2.1253, D(x): 0.74, D(G(z)): 0.25
Epoch [2/100], Step [600/600], d_loss: 0.8946, g_loss: 2.0536, D(x): 0.83, D(G(z)): 0.36
Epoch [3/100], Step [200/600], d_loss: 1.1057, g_loss: 1.7813, D(x): 0.68, D(G(z)): 0.31
Epoch [3/100], Step [400/600], d_loss: 0.8417, g_loss: 2.1136, D(x): 0.74, D(G(z)): 0.26
Epoch [3/100], Step [

Epoch [30/100], Step [600/600], d_loss: 0.8619, g_loss: 1.4409, D(x): 0.73, D(G(z)): 0.30
Epoch [31/100], Step [200/600], d_loss: 0.8543, g_loss: 1.4394, D(x): 0.74, D(G(z)): 0.31
Epoch [31/100], Step [400/600], d_loss: 0.8961, g_loss: 1.6423, D(x): 0.69, D(G(z)): 0.26
Epoch [31/100], Step [600/600], d_loss: 0.9563, g_loss: 1.6637, D(x): 0.77, D(G(z)): 0.38
Epoch [32/100], Step [200/600], d_loss: 0.9746, g_loss: 1.7165, D(x): 0.76, D(G(z)): 0.39
Epoch [32/100], Step [400/600], d_loss: 0.7978, g_loss: 1.5978, D(x): 0.75, D(G(z)): 0.28
Epoch [32/100], Step [600/600], d_loss: 0.9420, g_loss: 1.4732, D(x): 0.68, D(G(z)): 0.28
Epoch [33/100], Step [200/600], d_loss: 0.8002, g_loss: 1.5392, D(x): 0.75, D(G(z)): 0.29
Epoch [33/100], Step [400/600], d_loss: 0.9846, g_loss: 2.0429, D(x): 0.59, D(G(z)): 0.22
Epoch [33/100], Step [600/600], d_loss: 0.7704, g_loss: 1.5918, D(x): 0.72, D(G(z)): 0.27
Epoch [34/100], Step [200/600], d_loss: 1.1233, g_loss: 1.5342, D(x): 0.58, D(G(z)): 0.26
Epoch [34/

Epoch [61/100], Step [400/600], d_loss: 0.9861, g_loss: 1.3172, D(x): 0.65, D(G(z)): 0.28
Epoch [61/100], Step [600/600], d_loss: 1.0333, g_loss: 1.5434, D(x): 0.74, D(G(z)): 0.41
Epoch [62/100], Step [200/600], d_loss: 0.9590, g_loss: 1.6353, D(x): 0.68, D(G(z)): 0.30
Epoch [62/100], Step [400/600], d_loss: 1.0005, g_loss: 1.4060, D(x): 0.62, D(G(z)): 0.25
Epoch [62/100], Step [600/600], d_loss: 1.0026, g_loss: 1.3337, D(x): 0.70, D(G(z)): 0.35
Epoch [63/100], Step [200/600], d_loss: 1.0545, g_loss: 1.2920, D(x): 0.59, D(G(z)): 0.24
Epoch [63/100], Step [400/600], d_loss: 0.9754, g_loss: 1.4810, D(x): 0.68, D(G(z)): 0.33
Epoch [63/100], Step [600/600], d_loss: 0.9235, g_loss: 1.5499, D(x): 0.69, D(G(z)): 0.30
Epoch [64/100], Step [200/600], d_loss: 0.9178, g_loss: 1.4232, D(x): 0.77, D(G(z)): 0.37
Epoch [64/100], Step [400/600], d_loss: 1.0763, g_loss: 1.8163, D(x): 0.65, D(G(z)): 0.31
Epoch [64/100], Step [600/600], d_loss: 0.9100, g_loss: 1.7336, D(x): 0.66, D(G(z)): 0.23
Epoch [65/

Epoch [92/100], Step [200/600], d_loss: 0.8334, g_loss: 1.6184, D(x): 0.66, D(G(z)): 0.25
Epoch [92/100], Step [400/600], d_loss: 1.0136, g_loss: 1.4062, D(x): 0.62, D(G(z)): 0.27
Epoch [92/100], Step [600/600], d_loss: 1.0291, g_loss: 1.6854, D(x): 0.69, D(G(z)): 0.36
Epoch [93/100], Step [200/600], d_loss: 0.8547, g_loss: 1.6099, D(x): 0.74, D(G(z)): 0.30
Epoch [93/100], Step [400/600], d_loss: 1.0625, g_loss: 1.4299, D(x): 0.62, D(G(z)): 0.27
Epoch [93/100], Step [600/600], d_loss: 0.9319, g_loss: 1.5135, D(x): 0.70, D(G(z)): 0.34
Epoch [94/100], Step [200/600], d_loss: 1.0453, g_loss: 1.3829, D(x): 0.66, D(G(z)): 0.34
Epoch [94/100], Step [400/600], d_loss: 0.9918, g_loss: 1.6047, D(x): 0.70, D(G(z)): 0.35
Epoch [94/100], Step [600/600], d_loss: 1.0346, g_loss: 1.4798, D(x): 0.73, D(G(z)): 0.38
Epoch [95/100], Step [200/600], d_loss: 0.9322, g_loss: 1.4987, D(x): 0.67, D(G(z)): 0.30
Epoch [95/100], Step [400/600], d_loss: 0.9445, g_loss: 1.3476, D(x): 0.66, D(G(z)): 0.26
Epoch [95/

#### Original GAN Code result:
Epoch [99/100], Step [600/600], d_loss: 0.8115, g_loss: 1.5439, D(x): 0.79, D(G(z)): 0.29

d_loss: 0.8115 means the descriminator is able to distinguish from fake images with an accuracy of 81 %.

g_loss: 1.5439 is the generator loss. It is lower means the generator is producing better quality images.

D(x): 0.79 is the probability assigned to real images being real.

D(G(z)): 0.29 is the probability assigned to fake images being real.
                
#### With BCE Loss result:
Epoch [99/100], Step [600/600], d_loss: 0.7678, g_loss: 2.0430, D(x): 0.71, D(G(z)): 0.24