In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard



In [5]:
# returns value bw 0 to 1, 0 means fake img, 1 means real img
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.disc(x)
        
# returns image generated
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh() # because we will normalize input, so it will make sense only,
#                     if we normalize output
        )
        
    def forward(self, img):
        return self.gen(img)
        


In [6]:
# hper-parameters
lr_rate = 3e-2
noise_dim = 256
img_dim = 28*28*1 # 784
batch_size = 64
num_epochs = 5

disc = Discriminator(img_dim)
gen = Generator(noise_dim, img_dim)
fixed_noise = torch.randn((batch_size,noise_dim))
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))]
)
dataset = datasets.MNIST(root="pytorch_tutorials/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# optimiser for discriminator
opt_disc = optim.Adam(disc.parameters(), lr=lr_rate)
# optimiser for generator
opt_gen = optim.Adam(gen.parameters(), lr=lr_rate)
criterion = nn.BCELoss()



In [7]:
for epoch in range(num_epochs):
    for batch_idx, (real_img,_) in enumerate(loader):
#         real_img = real_img.view(-1,784)
#         print(real_img.shape)
#         real_img = torch.cat(real_img, dim=0)  # Concatenate the list of tensors
        
#         # Convert RGB images to grayscale if necessary
#         if real_img.shape[1] > 1:
#             real_img = torch.mean(real_img, dim=1, keepdim=True)
        
        real_img = real_img.view(-1, 784)
        batch_size = real_img.shape[0]
        
        # Train Discriminator: loss = max log(D(real_img)) + log( 1 - D(G(noise)))
        noise = torch.randn(batch_size,noise_dim)
        G_noise = gen(noise) # = G(noise)
        d_real = disc(real_img).view(-1) # = D(real_img)
        loss1 = criterion(d_real, torch.ones_like(d_real)) # = log(D(real_img))
        d_g_noise = disc(G_noise).view(-1)
        loss2 = criterion(d_g_noise, torch.zeros_like(d_g_noise))
        lossD = (loss1+loss2)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        # Train Generator: loss = min log(1 - D(G(noise))) or max log(D(G(noise))
        d_g_noise = disc(G_noise).view(-1)
        lossG = criterion(d_g_noise, torch.ones_like(d_g_noise))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

Epoch [0/5] Batch 0/938                       Loss D: 0.7543, loss G: 1.1597
Epoch [1/5] Batch 0/938                       Loss D: 50.0000, loss G: 0.0000
Epoch [2/5] Batch 0/938                       Loss D: 50.0000, loss G: 0.0000
Epoch [3/5] Batch 0/938                       Loss D: 50.0000, loss G: 0.0000
Epoch [4/5] Batch 0/938                       Loss D: 50.0000, loss G: 0.0000
