In [34]:
%run Config.ipynb
%run Generator.ipynb
%run Discriminator.ipynb

In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
from torch.utils.tensorboard import SummaryWriter

In [36]:
disc = Discriminator(Config.image_dim).to(Config.device)
gen = Generator(Config.z_dim, Config.image_dim).to(Config.device)

In [37]:
fixed_noise = torch.rand(Config.batch_size, Config.z_dim).to(Config.device)

In [38]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [39]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=True)

In [40]:
opt_disc = optim.Adam(disc.parameters(), lr=Config.lr)
opt_gen = optim.Adam(gen.parameters(), lr=Config.lr)

In [41]:
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")

In [29]:
for epoch in range(Config.num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(Config.device)
        batch_size = real.shape[0]
        
        noise = torch.randn(Config.batch_size, Config.z_dim).to(Config.device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{Config.num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )
        
        with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=Config.step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=Config.step
                )
                Config.step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.7037, loss G: 0.6848
Epoch [1/50] Batch 0/1875                       Loss D: 0.3704, loss G: 1.2918


KeyboardInterrupt: 

In [45]:
%load_ext tensorboard
%tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 2892), started 0:49:51 ago. (Use '!kill 2892' to kill it.)