<a href="https://colab.research.google.com/github/imraunav/MNIST_GAN/blob/main/Training_WGANGP_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import os
from time import time

In [2]:
from tqdm import tqdm

In [3]:
seed = 32
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
print("Seed set!")

Seed set!


In [4]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [5]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.Resize((32, 32), max_size=None, antialias=True)
    ]
)
train_ds = MNIST("./MNIST", download=True, train=True, transform=transform)
test_ds = MNIST("./MNIST", download=True, train=False, transform=transform)

mnist_ds = ConcatDataset([train_ds, test_ds])

In [6]:
batch_size = 128
dataloader = DataLoader(mnist_ds, batch_size=batch_size, shuffle=True)

In [7]:
def gradient_penalty(real, fake):
	m = real.shape[0]
	epsilon = torch.rand(m, 1, 1, 1)
	is_cuda = torch.cuda.is_available()
	if is_cuda:
		epsilon = epsilon.cuda()
	# print(epsilon.shape, real.shape, fake.shape)
	interpolated_img = epsilon * real + (1-epsilon) * fake
	interpolated_out = discriminator(interpolated_img)

	grads = torch.autograd.grad(outputs=interpolated_out, inputs=interpolated_img,
							   grad_outputs=torch.ones(interpolated_out.shape).cuda() if is_cuda else torch.ones(interpolated_out.shape),
							   create_graph=True, retain_graph=True)[0]
	grads = grads.reshape([m, -1])
	grad_penalty = ((grads.norm(2, dim=1) - 1) ** 2).mean()
	return grad_penalty

In [8]:
latent_dim = 100

class Generator(torch.nn.Module):
    def __init__(self, channels=1):
        super().__init__()
        # Filters [1024, 512, 256]
        # Input_dim = 100
        # Output_dim = C (number of channels)
        self.main_module = nn.Sequential(
            # Z latent vector 100
            nn.ConvTranspose2d(in_channels=latent_dim, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(num_features=1024),
            nn.ReLU(True),

            # State (1024x4x4)
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(True),

            # State (512x8x8)
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(True),

            # State (256x16x16)
            nn.ConvTranspose2d(in_channels=256, out_channels=channels, kernel_size=4, stride=2, padding=1))
            # output of main module --> Image (Cx32x32)

        self.output = nn.Tanh()

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), 1, 1)
        x = self.main_module(x)
        return self.output(x)


class Discriminator(torch.nn.Module):
    def __init__(self, channels=1):
        super().__init__()
        # Filters [256, 512, 1024]
        # Input_dim = channels (Cx64x64)
        # Output_dim = 1
        self.main_module = nn.Sequential(
            # Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid
            # in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch.
            # There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d()
            # Image (Cx32x32)
            nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (256x16x16)
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (512x8x8)
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(1024, affine=True),
            nn.LeakyReLU(0.2, inplace=True))
            # output of main module --> State (1024x4x4)

        self.output = nn.Sequential(
            # The output of D is no longer a probability, we do not apply sigmoid at the output of D.
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))


    def forward(self, x):
        x = self.main_module(x)
        return self.output(x)

    def feature_extraction(self, x):
        # Use discriminator for feature extraction then flatten to vector of 16384
        x = self.main_module(x)
        return x.view(-1, 1024*4*4)


# class Generator(nn.Module):
#     def __init__(self, hidden_dim=512):
#         super().__init__()
#         self.net = nn.Sequential(
#                 nn.Linear(latent_dim, hidden_dim),
#                 nn.BatchNorm1d(hidden_dim),
#                 nn.LeakyReLU(),

#                 nn.Linear(hidden_dim, hidden_dim),
#                 nn.BatchNorm1d(hidden_dim),
#                 nn.LeakyReLU(),

#                 nn.Linear(hidden_dim, 784),
#         )
#     def forward(self, noise):
#         batch_size = noise.size(0)
#         return self.net(noise).view(batch_size, 1, 28, 28).sigmoid() # [B, C*H*W] -> [B, C, H, W]

# class Discriminator(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.net = nn.Sequential(
#                 nn.Linear(784, 256),
#                 nn.ReLU(),
#                 nn.Dropout(p=0.3),

#                 nn.Linear(256, 256),
#                 nn.ReLU(),
#                 nn.Dropout(p=0.3),

#                 nn.Linear(256, 1),
#         )
#     def forward(self, img):
#         batch_size = img.size(0)
#         img = img.view(batch_size, -1)  # [B, C, H, W] -> [B, C*H*W]
#         return self.net(img)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: ", device)

Device:  cuda


In [10]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [11]:
# noise = torch.rand((32, latent_dim), device=device)

# img = generator(noise)
# print(img.shape)

# score = discriminator(img)
# print(score.shape)

In [12]:
# dummy_latent = torch.rand((64, latent_dim), device=device)

# dummy_img = generator(dummy_latent)

In [13]:
# dummy_img.shape

In [14]:
# from IPython.display import Image
# progress_path = "drive/MyDrive/GAN_MNIST/progress"
# os.makedirs(progress_path, exist_ok=True)
# save_image(dummy_img, progress_path+"/dummy.png", nrow=8)

# Image(os.path.join(progress_path, 'dummy.png'))

In [15]:
print(generator)
print(discriminator)

Generator(
  (main_module): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (output): Tanh()
)
Discriminator(
  (main_module): Sequential(
    (0): Conv2d(1, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2, 

In [16]:
gen_opt = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.0, 0.9), weight_decay=2e-5)
disc_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.0, 0.9), weight_decay=2e-5)

# adv_crit = nn.BCEWithLogitsLoss().to(device)

In [17]:
# def train_gen(real_batch, real_labels):

#     generator.train()
#     discriminator.eval()

#     batch_size = real_batch.size(0)
#     z = torch.normal(0, 1, (batch_size, latent_dim), device=device)
#     fake_batch = generator(z, real_labels)
#     # with torch.no_grad():
#     #     score = discriminator(fake_batch)
#     score = discriminator(fake_batch, real_labels)
#     targets = torch.ones_like(score, device=device) # generator target is 1 for fake
#     loss = adv_crit(score, targets)

#     gen_opt.zero_grad()
#     loss.backward()
#     gen_opt.step()

#     return loss.item()

# def train_disc(real_batch, real_labels):

#     generator.eval()
#     discriminator.train()

#     batch_size = real_batch.size(0)
#     z = torch.normal(0, 1, (batch_size, latent_dim), device=device)
#     with torch.no_grad():
#         fake_batch = generator(z, real_labels)

#     all_batch = torch.cat([real_batch, fake_batch], dim=0)
#     all_labels = torch.cat([real_labels, real_labels], dim=0)
#     score = discriminator(all_batch, all_labels)
#     targets = torch.cat(
#         [
#             torch.ones((batch_size, 1), device=device),
#             torch.zeros((batch_size, 1), device=device),
#         ]
#     )

#     loss = adv_crit(score, targets)

#     disc_opt.zero_grad()
#     loss.backward()
#     disc_opt.step()

#     return loss.item()

In [18]:
def checkpoint(epoch, losses, CKPT_PATH, CKPT="checkpoint.pt"):
    progress_path = os.path.join(CKPT_PATH, "progress")

    state = {
        "epoch" : epoch,
        "generator" : generator.state_dict(),
        "discriminator" : discriminator.state_dict(),
        "gen_opt" : gen_opt.state_dict(),
        "disc_opt" : disc_opt.state_dict(),
        "losses" : losses,
    }

    os.makedirs(CKPT_PATH, exist_ok=True)
    torch.save(state, os.path.join(CKPT_PATH, CKPT))

    generator.eval()
    gen_batch = generator(sample_latent)

    os.makedirs(progress_path, exist_ok=True)
    save_image(gen_batch, os.path.join(progress_path, f"epoch_{epoch+1}.png"))

    print(f"! Checkpoint saved at {epoch+1}.")

In [19]:
def load_checkpoint(CKPT_PATH, CKPT="checkpoint.pt"):
    state = torch.load(os.path.join(CKPT_PATH, CKPT))
    generator.load_state_dict(state["generator"])
    discriminator.load_state_dict(state["discriminator"])

    gen_opt.load_state_dict(state["gen_opt"])
    disc_opt.load_state_dict(state["disc_opt"])

    return state["epoch"], state["losses"]

In [None]:
MAX_EPOCH = 500
CKPT_PATH = "drive/MyDrive/WGANGP_MNIST"
CKPT_EVERY = 10

Z_DIM = latent_dim
CHANNELS = 1
N_CRITIC = 5
GRADIENT_PENALTY = 10
LOAD_MODEL = False

sample_latent = torch.normal(0, 1, (64, latent_dim), device=device)
# sample_labels = torch.randint(0, 10, (64,), device=device)
# print(f"Checkpoint labels to sample: {sample_labels}")

losses = {
    "gen_loss" : [],
    "disc_loss" : [],
}

start = 0
# Load checkpoint if found
if os.path.exists(CKPT_PATH+"/checkpoint.pt"):
    print("Checkpoint found")
    print('Loading checkpoint...')
    start, losses = load_checkpoint(CKPT_PATH)
    checkpoint(start, losses, CKPT_PATH)

tic = time()
total_iter = 0
max_iter = len(dataloader)

for epoch in range(start, MAX_EPOCH):
    generator.train()
    discriminator.train()

    print(f"Epoch {epoch+1}", end=" ")
    running_losses = {
        "gen_loss" : [],
        "disc_loss" : [],
    }

    for real_batch, labels in tqdm(dataloader):
        total_iter += 1

        x_real = real_batch.to(device)

        z_fake = torch.randn(x_real.size(0), latent_dim, device=device)
        x_fake = generator(z_fake)

        fake_out = discriminator(x_fake.detach())
        real_out = discriminator(x_real.detach())
        # print(x_fake.shape, fake_out.shape, real_out.shape)
        x_out = torch.cat([real_out, fake_out])
        d_loss = -(real_out.mean() - fake_out.mean()) + gradient_penalty(x_real, x_fake) * GRADIENT_PENALTY + (x_out ** 2).mean() ** 0.0001

        disc_opt.zero_grad()
        d_loss.backward()
        disc_opt.step()

        running_losses["disc_loss"].append(d_loss.item())
        if total_iter % N_CRITIC == 0:
            z_fake = torch.randn(batch_size, latent_dim, device=device)
            x_fake = generator(z_fake)

            fake_out = discriminator(x_fake)
            g_loss = - fake_out.mean()

            gen_opt.zero_grad()
            g_loss.backward()
            gen_opt.step()
            running_losses["gen_loss"].append(g_loss.item())

    for loss_name in ['gen_loss', "disc_loss"]:
        epoch_loss = running_losses[loss_name]
        losses[loss_name].append(sum(epoch_loss)/len(epoch_loss))

    print(f"| G_loss : {losses['gen_loss'][-1]:.4f}, D_loss : {losses['disc_loss'][-1]:.4f}")

    if (epoch+1) % CKPT_EVERY == 0:
        checkpoint(epoch, losses, CKPT_PATH)
        # pass
toc = time()
checkpoint(epoch, losses, CKPT_PATH)
print("Training done!")
print(f"Time taken to train: {(toc-tic)/60:.3f} mins")

Checkpoint labels to sample: tensor([4, 0, 9, 9, 4, 3, 9, 5, 8, 2, 2, 1, 4, 8, 6, 9, 9, 7, 4, 4, 6, 9, 6, 9,
        8, 3, 6, 1, 9, 3, 6, 6, 6, 6, 3, 1, 5, 7, 4, 2, 7, 9, 1, 7, 0, 5, 7, 5,
        5, 6, 5, 4, 6, 0, 3, 1, 6, 7, 1, 5, 6, 9, 2, 0], device='cuda:0')
Epoch 1 

 26%|██▌       | 142/547 [01:31<04:39,  1.45it/s]

In [None]:
from matplotlib import pyplot as plt

plt.plot(losses["gen_loss"])
plt.plot(losses["disc_loss"])

plt.legend(["Generator", "Discriminator"])
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.show()