# GAN

In [11]:
# imports 
import os
import torch
import torchvision
import torch.nn as nn
from tqdm import tqdm
from torchvision import transforms
from torchvision.utils import save_image

print(f'torch version is : {torch.__version__}')
print(f'torchvision version is : {torchvision.__version__}')

torch version is : 1.11.0+cpu
torchvision version is : 0.12.0+cpu


In [2]:
# Setting Hyper-parameters
epochs = 5
latent_size = 64
batch_size = 100
hidden_size = 256
image_size = 28 * 28
sample_dir = 'samples'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
# creating save images folder
if not os.path.exists(sample_dir):
    os.makedirs(os.path.join(os.getcwd(), sample_dir))

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [13]:
# Get the FashionMNIST Dataset 
train_dataset = torchvision.datasets.MNIST(
    root='../basics/mnist',
    train=True,
    download=False,
    transform=transform, 
)

test_dataset = torchvision.datasets.MNIST(
    root='../basics/mnist',
    train=False,
    download=False,
    transform=transform
)

In [14]:
# Define dataloading pipelines
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

In [8]:
# Define the descriminator 
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1), # It means we have only 1 output class
    nn.Sigmoid()
)

# Define the generator
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

In [9]:
# Convert them to device type 
D = D.to(device=device)
G = G.to(device=device)

In [18]:
# Define the loss and optimizer
loss = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):
    out = (x+1)/2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [20]:
# Starting the training
from torch import imag


for epoch in tqdm(range(epochs)):
    for i, (images, _) in enumerate(train_loader):
        # Reshape the dataset 
        images = images.reshape(batch_size, -1).to(device)
        
        # Generate the labels 
        real_labels = torch.ones(batch_size, 1).to(device=device)
        fake_labels = torch.zeros(batch_size, 1).to(device=device)

        # First train the discriminator
        outputs = D(images)
        outputs = outputs.to(device)
        d_loss_real = loss(outputs, real_labels)
        real_score = outputs

        # Compute BCELoss using fake images
        z = torch.randn(batch_size, latent_size).to(device) # fake image
        fake_images = G(z) # Fake images outputs from generator
        outputs = D(fake_images)
        d_loss_fake = loss(outputs, fake_labels)
        fake_score = outputs

        # Backpropagation and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        # Let's generate the generator
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = loss(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step() 

        # Let's save the results and images 
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

100%|██████████| 5/5 [02:57<00:00, 35.41s/it]
