<a href="https://colab.research.google.com/github/kristopherpaul/ML_InductionAssignment_IITGN/blob/main/4_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **4 Implementation**

5. **Generative Adversarial Network** (*Hard*)\
GANs consist of two parts, a generator and a discriminator, which are trained together. The generator creates\
fake data (trying to mimic the real data distribution), and the discriminator tries to tell the difference between\
real and fake data. The training process involves alternately optimizing the generator and discriminator: the\
discriminator is trained to maximize its ability to classify real vs. fake, and the generator is trained to maximize\
the discriminator’s error on the fake data.

In [1]:
# Importing required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# Generator network
class Generator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Generator, self).__init__()
        self.layer1 = nn.Linear(input_size, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        return x

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Linear(input_size, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.sigmoid(self.layer3(x))
        return x

In [3]:
# Loading dataset and setting hyperparameters
input_size = 100
output_size = 784

batch_size = 128
num_epochs = 50
learning_rate = 0.001

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

generator = Generator(input_size, output_size)
discriminator = Discriminator(output_size, 1)

criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 150147567.32it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 33233386.51it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 35533349.76it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8718777.47it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
# Training loop
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        batch_size = images.size(0)
        images = images.reshape(batch_size, -1)

        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        disc_optimizer.zero_grad()

        real_outputs = discriminator(images)
        real_loss = criterion(real_outputs, real_labels)

        noise = torch.randn(batch_size, input_size)
        fake_images = generator(noise)

        fake_outputs = discriminator(fake_images)
        fake_loss = criterion(fake_outputs, fake_labels)

        disc_loss = real_loss + fake_loss

        disc_loss.backward()
        disc_optimizer.step()

        gen_optimizer.zero_grad()

        noise = torch.randn(batch_size, input_size)
        fake_images = generator(noise)

        outputs = discriminator(fake_images)
        gen_loss = criterion(outputs, real_labels)

        gen_loss.backward()
        gen_optimizer.step()

        if (i + 1) % 400 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], "
                  f"Discriminator Loss: {disc_loss.item():.4f}, "
                  f"Generator Loss: {gen_loss.item():.4f}")

Epoch [1/50], Discriminator Loss: 1.7748, Generator Loss: 6.6779
Epoch [2/50], Discriminator Loss: 0.3485, Generator Loss: 3.9011
Epoch [3/50], Discriminator Loss: 0.6889, Generator Loss: 1.2477
Epoch [4/50], Discriminator Loss: 1.8770, Generator Loss: 0.4926
Epoch [5/50], Discriminator Loss: 0.8229, Generator Loss: 2.0163
Epoch [6/50], Discriminator Loss: 0.7264, Generator Loss: 3.1769
Epoch [7/50], Discriminator Loss: 1.5940, Generator Loss: 1.3521
Epoch [8/50], Discriminator Loss: 1.2722, Generator Loss: 0.7500
Epoch [9/50], Discriminator Loss: 1.9881, Generator Loss: 0.6117
Epoch [10/50], Discriminator Loss: 2.2096, Generator Loss: 2.2488
Epoch [11/50], Discriminator Loss: 1.6231, Generator Loss: 0.6037
Epoch [12/50], Discriminator Loss: 1.5974, Generator Loss: 1.5309
Epoch [13/50], Discriminator Loss: 0.8987, Generator Loss: 1.4746
Epoch [14/50], Discriminator Loss: 1.0130, Generator Loss: 1.0620
Epoch [15/50], Discriminator Loss: 1.5123, Generator Loss: 0.5193
Epoch [16/50], Disc