In [9]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
torch.manual_seed(0)

<torch._C.Generator at 0x7f137abb24f0>

In [14]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    return image_grid
    # plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    # plt.show()

In [5]:
def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )
class Generator(nn.Module):
    # z_dim : dimension of noise, scalar
    # im_dim: dimension of images: image(28x28 = 784)
    # hidden_dim: output of hidden layer, s scalar
    def __init__(self, z_dim = 10, im_dim = 784, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim*2),
            get_generator_block(hidden_dim*2, hidden_dim*4),
            get_generator_block(hidden_dim*4, hidden_dim*8),
            nn.Linear(hidden_dim*8, im_dim),
            nn.Sigmoid()
        )
    
    def forward(self, noise):
        return self.gen(noise)
    
    def get_model_gennerator(self):
        return self.gen
    

## Noise

In [6]:
def get_noise(n, z_dim, device='cpu'):
    # Create noise
    # n - the number of samples to generate
    # z_dim - the dimension of noise vector, scalar
    return torch.randn(n, z_dim, device=device)

## Dicriminator

In [7]:
def get_dis_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2)
    )
class Discriminator(nn.Module):
    # dimension of images: 28x28 = 784
    def __init__(self, im_dim = 784, hidden_dim = 128):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            get_dis_block(im_dim, hidden_dim*4),
            get_dis_block(hidden_dim*4, hidden_dim*2),
            get_dis_block(hidden_dim*2,hidden_dim),
            nn.Linear(hidden_dim,1)
        )
    def forward(self, image):
        return self.model(image)
    
    def get_model_discriminator(self):
        return self.model

## Training

In [8]:
# loss function
criterion = nn.BCEWithLogitsLoss()
# epochs 
epochs = 20
z_dim = 64 # dimension of the noise vector
display_step = 500
batch_size = 128
lr = 0.00002
# Em dung MNIST dataset
dataloader = DataLoader(
    MNIST('.', download=False, transform=transforms.ToTensor()),
    batch_size= batch_size,
    shuffle= True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

In [11]:
generator = Generator(z_dim=z_dim).to(device)
gen_optimazation = torch.optim.Adam(generator.parameters(), lr = lr)
discriminator = Discriminator().to(device=device)
disc_optimazation = torch.optim.Adam(discriminator.parameters(), lr = lr)

In [12]:
## Function to caculate discriminator's loss and generator's loss
def get_gen_loss(generator, discriminator, criterion, num_images, z_dim, device):
    # Create noise
    fake_noise = get_noise(num_images, z_dim, device)
    fake = generator(fake_noise)
    disc_fake_pred = discriminator(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss

def get_dis_loss(generator, discriminator, criterion, real, num_images, z_dim, device):
    # fake
    fake_noise = get_noise(num_images, z_dim=z_dim, device=device)
    fake = generator(fake_noise)
    disc_fake_predict = discriminator(fake.detach())
    disc_fake_loss = criterion(disc_fake_predict, torch.zeros_like(disc_fake_predict))
    # real
    dis_real_pred = discriminator(real)
    disc_real_loss = criterion(dis_real_pred, torch.ones_like(dis_real_pred))
    # loss
    disc_loss = (disc_fake_loss + disc_real_loss) / 2

    return disc_loss

In [17]:
cur_step = 0 # so luong buoc huan luyen
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False

for epoch in range(epochs):
    for batch_idx, (real,_) in enumerate(dataloader):
        cur_batch_size = len(real)
        # Flatten
        real = real.view(cur_batch_size, -1).to(device)
        
        ## Update discriminator
        disc_optimazation.zero_grad()
        disc_loss = get_dis_loss(generator, discriminator, criterion, real,
                                  num_images=cur_batch_size, z_dim=z_dim, device=device)
        # update gradien
        disc_loss.backward() 
        # update optimizer
        disc_optimazation.step()
        if test_generator:
            old_generator_weights = generator.gen[0][0].weight.detach().clone()
        
        # update generator
        gen_optimazation.zero_grad()
        gen_loss = get_gen_loss(generator, discriminator, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward()
        gen_optimazation.step()

        ## Kiem tra co thay doi weight cua generator hay khong
        if test_generator:
            try:
                assert lr > 0.0000002 or (generator.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(generator.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")

        mean_discriminator_loss += disc_loss.item() / display_step
        mean_generator_loss += gen_loss.item() / display_step

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(dataloader)} \
                      Loss disc: {disc_loss:.4f}, loss gen: {gen_loss:.4f}"
            )

            with torch.no_grad():
                fake_noise = get_noise(cur_batch_size, z_dim, device)
                fake = generator(fake_noise)
                img_grid_fake =show_tensor_images(fake)
                img_grid_real = show_tensor_images(real)
                mean_generator_loss = 0
                mean_discriminator_loss = 0

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=cur_step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=cur_step
                )
                cur_step += 1

    # # visualizing result
    # if cur_step % display_step == 0 and cur_step > 0:
    #     print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
    #     fake_noise = get_noise(cur_batch_size, z_dim, device)
    #     fake = generator(fake_noise)
        show_tensor_images(fake)
    #     show_tensor_images(real)
    #     mean_generator_loss = 0
    #     mean_discriminator_loss = 0
    
    # cur_step += 1
        


Epoch [0/20] Batch 0/469                       Loss disc: 0.1400, loss gen: 2.4649
Epoch [1/20] Batch 0/469                       Loss disc: 0.1423, loss gen: 2.7555
Epoch [2/20] Batch 0/469                       Loss disc: 0.1740, loss gen: 2.5036
Epoch [3/20] Batch 0/469                       Loss disc: 0.1115, loss gen: 2.6395
Epoch [4/20] Batch 0/469                       Loss disc: 0.0769, loss gen: 4.1252
Epoch [5/20] Batch 0/469                       Loss disc: 0.0666, loss gen: 3.6659
Epoch [6/20] Batch 0/469                       Loss disc: 0.1209, loss gen: 3.2423
Epoch [7/20] Batch 0/469                       Loss disc: 0.1093, loss gen: 3.5754
Epoch [8/20] Batch 0/469                       Loss disc: 0.0796, loss gen: 4.1256
Epoch [9/20] Batch 0/469                       Loss disc: 0.0955, loss gen: 4.2219
Epoch [10/20] Batch 0/469                       Loss disc: 0.0647, loss gen: 3.6646
Epoch [11/20] Batch 0/469                       Loss disc: 0.0492, loss gen: 3.7429
Ep