In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!

<torch._C.Generator at 0x7fcb1abd2a30>

**HelperFunction**

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim = 10,device = 'cpu'):
    return torch.randn(n_samples,z_dim,device = device)

**Generator**

In [None]:
def get_generator_block(input_dim,output_dim):
    return nn.Sequential(
        nn.Linear(input_dim,output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

class Generator(nn.Module):
    def __init__(self,z_dim = 10,im_dim = 28 * 28,hidden_dim = 128):
        super(Generator,self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim,hidden_dim),
            get_generator_block(hidden_dim,hidden_dim * 2),
            get_generator_block(hidden_dim * 2,hidden_dim * 4),
            get_generator_block(hidden_dim * 4,hidden_dim * 8),
            nn.Linear(hidden_dim * 8,im_dim),
            nn.Sigmoid()
        )
    
    def forward(self,noise):
        return self.gen(noise)
    
    def get_gen(self):
        return self.gen

In [None]:
#test architect
a = Generator().get_gen()
print(a)

Sequential(
  (0): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (2): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (3): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (4): Linear(in_features=1024, out_features=784, bias=True)
  (5): Sigmoid()
)


**Discriminator**

In [None]:
def get_discriminator_block(input_dim,output_dim):
    return nn.Sequential(
        nn.Linear(input_dim,output_dim),
        nn.LeakyReLU(0.2)
    )

class Discriminator(nn.Module):
    def __init__(self,im_dim = 28 * 28,hidden_dim = 128):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim,hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4,hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2,hidden_dim),
            nn.Linear(hidden_dim,1)
        )
    
    def forward(self,image):
        return self.disc(image)
    
    def get_disc(self):
        return self.disc

In [None]:
#test architect
b = Discriminator().get_disc()
print(b)

Sequential(
  (0): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (1): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (2): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (3): Linear(in_features=128, out_features=1, bias=True)
)


**Initialize Component**

In [None]:
#set hyperparameter
z_dim = 64
batch_size = 128
n_epochs = 200
display_step = 500
lr = 0.00001
criterion = nn.BCEWithLogitsLoss()

#convert data to tensor
dataloader = DataLoader(
    MNIST('.', download=True, transform=transforms.ToTensor()),
    batch_size = batch_size,
    shuffle = True
)

#set device
device = 'cuda'

In [None]:
#optimize

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

**Loss Method**

In [None]:
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    fake_noise = get_noise(num_images,z_dim,device)
    fake_img = gen(fake_noise)
    disc_fake_pred = disc(fake_img.detach()) #to unaffect parameter of generator
    disc_fake_loss = criterion(disc_fake_pred,torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real.detach())
    disc_real_loss = criterion(disc_real_pred,torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2

    return disc_loss

def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    fake_noise = get_noise(num_images,z_dim,device)
    fake_image = gen(fake_noise)
    disc_fake_pred = disc(fake_image)
    gen_loss = criterion(disc_fake_pred,torch.ones_like(disc_fake_pred))

    return gen_loss  

**Training**

In [None]:
#some setup parameter

cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False

In [None]:
for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)
        
        ### Update discriminator ###

        disc_opt.zero_grad() #zero_out the gradient before backpropagation

        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        disc_loss.backward(retain_graph = True)

        disc_opt.step()

        ### Update discriminator ###
        
        #no-understand
        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        ### Update generator ###

        gen_opt.zero_grad()

        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward(retain_graph = True)

        gen_opt.step()
         
        ### Update discriminator ###

        # For testing purposes, to check that your code changes the generator weights
        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")
        
        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1