<a href="https://colab.research.google.com/github/imraunav/MNIST_GAN/blob/main/Training_GAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import os
from time import time

In [7]:
seed = 32
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
print("Seed set!")

Seed set!


In [24]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [25]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)
train_ds = MNIST("./MNIST", download=True, train=True, transform=transform)
test_ds = MNIST("./MNIST", download=True, train=False, transform=transform)

mnist_ds = ConcatDataset([train_ds, test_ds])

In [26]:
batch_size = 100
dataloader = DataLoader(mnist_ds, batch_size=batch_size, shuffle=True)

In [27]:
latent_dim = 128
class Generator(nn.Module):
    def __init__(self, hidden_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
                nn.Linear(latent_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(),

                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(),

                nn.Linear(hidden_dim, 784),
        )
    def forward(self, noise):
        batch_size = noise.size(0)
        return self.net(noise).view(batch_size, 1, 28, 28).sigmoid() # [B, C*H*W] -> [B, C, H, W]

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
                nn.Linear(784, 256),
                nn.Sigmoid(),
                nn.Dropout(p=0.7),

                nn.Linear(256, 256),
                nn.Sigmoid(),
                nn.Dropout(p=0.7),

                nn.Linear(256, 1),
        )
    def forward(self, img):
        batch_size = img.size(0)
        img = img.view(batch_size, -1)  # [B, C, H, W] -> [B, C*H*W]
        return self.net(img)

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: ", device)

Device:  cuda


In [29]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [30]:
# noise = torch.rand((32, latent_dim), device=device)

# img = generator(noise)
# print(img.shape)

# score = discriminator(img)
# print(score.shape)

In [31]:
# dummy_latent = torch.rand((64, latent_dim), device=device)

# dummy_img = generator(dummy_latent)

In [32]:
# dummy_img.shape

In [33]:
# from IPython.display import Image
# progress_path = "drive/MyDrive/GAN_MNIST/progress"
# os.makedirs(progress_path, exist_ok=True)
# save_image(dummy_img, progress_path+"/dummy.png", nrow=8)

# Image(os.path.join(progress_path, 'dummy.png'))

In [34]:
print(generator)
print(discriminator)

Generator(
  (net): Sequential(
    (0): Linear(in_features=128, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=1024, out_features=784, bias=True)
  )
)
Discriminator(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): Sigmoid()
    (2): Dropout(p=0.7, inplace=False)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): Sigmoid()
    (5): Dropout(p=0.7, inplace=False)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
)


In [35]:
gen_opt = torch.optim.Adam(generator.parameters())
disc_opt = torch.optim.Adam(discriminator.parameters())

adv_crit = nn.BCEWithLogitsLoss().to(device)

In [36]:
def train_gen(real_batch):

    generator.train()
    discriminator.eval()

    batch_size = real_batch.size(0)
    z = torch.randn((batch_size, latent_dim), device=device)
    fake_batch = generator(z)
    # with torch.no_grad():
    #     score = discriminator(fake_batch)
    score = discriminator(fake_batch)
    targets = torch.ones_like(score, device=device) # generator target is 1 for fake
    loss = adv_crit(score, targets)

    gen_opt.zero_grad()
    loss.backward()
    gen_opt.step()

    return loss.item()

def train_disc(real_batch):

    generator.eval()
    discriminator.train()

    batch_size = real_batch.size(0)
    z = torch.randn((batch_size, latent_dim), device=device)
    with torch.no_grad():
        fake_batch = generator(z)

    all_batch = torch.cat([real_batch, fake_batch], dim=0)

    score = discriminator(all_batch)
    targets = torch.cat(
        [
            torch.ones((batch_size, 1), device=device),
            torch.zeros((batch_size, 1), device=device),
        ]
    )

    loss = adv_crit(score, targets)

    disc_opt.zero_grad()
    loss.backward()
    disc_opt.step()

    return loss.item()

In [37]:
def checkpoint(epoch, losses, CKPT_PATH, CKPT="checkpoint.pt"):
    progress_path = os.path.join(CKPT_PATH, "progress")

    state = {
        "epoch" : epoch,
        "generator" : generator.state_dict(),
        "discriminator" : discriminator.state_dict(),
        "gen_opt" : gen_opt.state_dict(),
        "disc_opt" : disc_opt.state_dict(),
        "losses" : losses,
    }

    os.makedirs(CKPT_PATH, exist_ok=True)
    torch.save(state, os.path.join(CKPT_PATH, CKPT))

    generator.eval()
    gen_batch = generator(sample_latent)

    os.makedirs(progress_path, exist_ok=True)
    save_image(gen_batch, os.path.join(progress_path, f"epoch_{epoch+1}.png"))

    print(f"! Checkpoint saved at {epoch+1}.")

In [38]:
def load_checkpoint(CKPT_PATH, CKPT="checkpoint.pt"):
    state = torch.load(os.path.join(CKPT_PATH, CKPT))
    generator.load_state_dict(state["generator"])
    discriminator.load_state_dict(state["discriminator"])

    gen_opt.load_state_dict(state["gen_opt"])
    disc_opt.load_state_dict(state["disc_opt"])

    return state["epoch"], state["losses"]

In [39]:
MAX_EPOCH = 500
CKPT_PATH = "drive/MyDrive/GAN_MNIST"
CKPT_EVERY = 10
sample_latent = torch.rand((64, latent_dim), device=device)

losses = {
    "gen_loss" : [],
    "disc_loss" : [],
}

start = 0
# Load checkpoint if found
if os.path.exists(CKPT_PATH+"/checkpoint.pt"):
    print("Checkpoint found")
    print('Loading checkpoint...')
    start, losses = load_checkpoint(CKPT_PATH)

tic = time()
for epoch in range(start, MAX_EPOCH):
    print(f"Epoch {epoch+1}", end=" ")
    running_losses = {
        "gen_loss" : [],
        "disc_loss" : [],
    }
    for real_batch, labels in dataloader:

        # some regularization for the discriminator for the initial stages
        real_batch = real_batch.to(device) + (1/(epoch+0.1))*torch.rand_like(real_batch, device=device)

        disc_loss = train_disc(real_batch)
        gen_loss = train_gen(real_batch)

        running_losses["disc_loss"].append(disc_loss)
        running_losses["gen_loss"].append(gen_loss)

    for loss_name in ['gen_loss', "disc_loss"]:
        epoch_loss = running_losses[loss_name]
        losses[loss_name].append(sum(epoch_loss)/len(epoch_loss))

    print(f"| G_loss : {losses['gen_loss'][-1]:.4f}, D_loss : {losses['disc_loss'][-1]:.4f}")

    if epoch % CKPT_EVERY == 0:
        checkpoint(epoch, losses, CKPT_PATH)
        # pass
toc = time()
checkpoint(epoch, losses, CKPT_PATH)
print("Training done!")
print(f"Time taken to train: {(toc-tic)/60:.3f} mins")

Epoch 1 

KeyboardInterrupt: 

In [None]:
from matplotlib import pyplot as plt

plt.plot(losses["gen_loss"])
plt.plot(losses["disc_loss"])

plt.legend(["Generator", "Discriminator"])
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.show()