In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from matplotlib import pyplot as plt
import numpy as np
from copy import deepcopy
from torchvision.utils import save_image
from tqdm import tqdm
os.makedirs("images", exist_ok=True)

In [None]:
from mnist_gan import get_root_path

In [None]:
PROJECT_PATH = get_root_path()
DATA_PATH = os.path.join(get_root_path(), "data")
try: 
    os.mkdir(DATA_PATH)
except Exception as e:
    print(e)

BATCH_SIZE = 64
EPOCH = 10
LATENT_FEATURES = 100

_OUT_FEATURES = 1
_MNIST_SHAPE = 784

GEN_LR = 2e-4
DIS_LR = 1e-4

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
mean, std = (0.1307,), (0.3081,)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x*2 - 1.)
#     transforms.Normalize(mean=mean, std=std),
#     transforms.Lambda(lambda x: x.flatten())
])

train_dataset = datasets.MNIST(
    root=os.path.join(DATA_PATH, "mnist_data"), train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(
    root=os.path.join(DATA_PATH, "mnist_data"), train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

num_batches = len(train_loader)

In [None]:
imgs, labels = next(iter(train_loader))
i = np.random.choice(len(imgs))

fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(imgs[i].squeeze(axis=0)*std[0]+mean[0], cmap="gray")
ax.set_title(labels[i])
plt.show()

In [None]:
class Generator(nn.Module):
    def __init__(
        self, 
        in_features: int,
        out_features: int,
        negative_slope: float=.2):
        super().__init__()       
        self.fc1 = nn.Linear(in_features=in_features, out_features=256)
        self.fc2 = nn.Linear(in_features=self.fc1.out_features, out_features=self.fc1.out_features*2)
        self.fc3 = nn.Linear(in_features=self.fc2.out_features, out_features=self.fc2.out_features*2)
        self.fc4 = nn.Linear(in_features=self.fc3.out_features, out_features=out_features)
        self.negative_slope = negative_slope
    
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), negative_slope=self.negative_slope)
        x = F.leaky_relu(self.fc2(x), negative_slope=self.negative_slope)
        x = F.leaky_relu(self.fc3(x), negative_slope=self.negative_slope)
        return torch.tanh(self.fc4(x))
#         return self.fc4(x)
            
    
class Discriminator(nn.Module):
    def __init__(
        self, 
        in_features: int,
        out_features: int, 
        negative_slope: float=.2,
        dropout: float=.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=1024)
        self.fc2 = nn.Linear(in_features=self.fc1.out_features, out_features=self.fc1.out_features//2)
        self.fc3 = nn.Linear(in_features=self.fc2.out_features, out_features=self.fc2.out_features//2)
        self.fc4 = nn.Linear(in_features=self.fc3.out_features, out_features=1)
        self.negative_slope = negative_slope
        self.dropout = dropout
    
    def forward(self, x):
        x = F.dropout(
            F.leaky_relu(self.fc1(x), negative_slope=self.negative_slope),
            p=self.dropout)
        x = F.dropout(
            F.leaky_relu(self.fc2(x), negative_slope=self.negative_slope),
            p=self.dropout)
        x = F.dropout(
            F.leaky_relu(self.fc3(x), negative_slope=self.negative_slope),
            p=self.dropout)
        return self.fc4(x)

In [None]:
G = Generator(in_features=LATENT_FEATURES, out_features=_MNIST_SHAPE).to(device)
D = Discriminator(in_features=_MNIST_SHAPE, out_features=_OUT_FEATURES).to(device)

In [None]:
G

In [None]:
D

In [None]:
G.eval()
assert G(torch.randn(size=(BATCH_SIZE, LATENT_FEATURES), device=device)).shape == imgs.flatten(start_dim=1).shape

In [None]:
G.eval()
D.eval()
assert D(G(torch.randn(size=(BATCH_SIZE, LATENT_FEATURES), device=device))).shape == (BATCH_SIZE, _OUT_FEATURES)

In [None]:
G_optim = Adam(G.parameters(), lr=GEN_LR)
D_optim = Adam(D.parameters(), lr=DIS_LR)

In [None]:
def D_loss(
    y_hat_real: torch.Tensor,
    y_hat_fake: torch.Tensor,
    epsilon: float=1e-9) -> float:
    
    """
    y_hat_real: torch.Tensor of shape (n, 1)
        float values in unconstrained space
    
    y_hat_fake: torch.Tensor of shape (n, 1)
        float values in unconstrained space
    """
    y_hat_real = torch.sigmoid(y_hat_real)
    y_hat_fake = torch.sigmoid(y_hat_fake)
    
    return - torch.mean(
        torch.log(y_hat_real + epsilon) 
        + torch.log(1 - y_hat_fake + epsilon)
    )

In [None]:
def G_loss(
    y_hat_fake: torch.Tensor, 
    epsilon: float=1e-9) -> float:
    
    y_hat_fake = torch.sigmoid(y_hat_fake)
    return -torch.mean(torch.log(y_hat_fake + epsilon))

In [None]:
G.eval()
Z = torch.randn(size=(1, LATENT_FEATURES), device=device)
X_fake = G(Z)
# Y_hat_fake = D(torch.sigmoid(X_fake))
plt.imshow(X_fake.view(28, 28).cpu().detach(), cmap="gray")

In [None]:
losses = {"D":[], "G":[]}

In [None]:
with torch.set_grad_enabled(True):
    D.train()
    G.train()    
    t = tqdm(range(100), leave=False)
    for i in t:
        for j, (X_real, _) in enumerate(train_loader):

            # ================= Train Discriminator =======================
            X_real = X_real.flatten(start_dim=1).to(device)
            Y_hat_real = D(X_real)

            Z = torch.randn(size=(len(X_real), LATENT_FEATURES), device=device)
            X_fake = G(Z)
            Y_hat_fake = D(X_fake)

            D_optim.zero_grad()
            d_loss = D_loss(y_hat_real=Y_hat_real, y_hat_fake=Y_hat_fake)
            d_loss.backward()
            D_optim.step()
            losses["D"].append(d_loss.item())

            # ==================== Train Generator =========================
            Z = torch.randn(size=(len(X_real), LATENT_FEATURES), device=device)
            X_fake = G(Z)
            Y_hat_fake = D(X_fake)

            G_optim.zero_grad()
            g_loss = G_loss(y_hat_fake=Y_hat_fake)
            g_loss.backward()
            G_optim.step()
            losses["G"].append(g_loss.item())

            t.set_description(f"E: {i}, B: {j}, D: {losses['D'][-1]:.3f}, G: {losses['G'][-1]:.3f}")
            
            batches_done = i * num_batches + j
            if batches_done % 200 == 0:
                save_image(X_fake.data[:25].view(25,1, 28,28), "images/%d.png" % batches_done, nrow=5, normalize=True)