In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
from torch import nn

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
device

In [None]:
D = nn.Sequential(
    nn.Linear(28 ** 2, 1024),
    nn.ReLU(),
    nn.Dropout(.3),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(.3),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(.3),
    nn.Linear(256, 1),
    nn.Sigmoid()
).to(device)

In [None]:
G = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 512),
    nn.ReLU(),
    nn.Linear(512, 1024),
    nn.ReLU(),
    nn.Linear(1024, 28 ** 2),
    nn.Tanh()
).to(device)

In [None]:
loss_fn = nn.BCELoss()
lr = .0001
optimD = torch.optim.Adam(D.parameters(), lr=lr)
optimG = torch.optim.Adam(G.parameters(), lr=lr)

In [None]:
def see_output():
    noise = torch.randn(32, 100).to(device)
    fake_samples = G(noise).cpu().detach()
    plt.figure(dpi=100, figsize=(20, 10))
    for i in range(32):
        ax = plt.subplot(4, 8, i+1)
        img = (fake_samples[i] / 2 + .5).reshape(28, 28)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
    plt.show()

In [None]:
see_output()

In [None]:
import torchvision 
import torchvision.transforms as T
transform = T.Compose(
    [
        T.ToTensor(),
        T.Normalize([.5], [.5])
    ]
)
train_set = torchvision.datasets.FashionMNIST(
    root='.',
    train=True,
    download=True,
    transform=transform
)
test_set = torchvision.datasets.FashionMNIST(
    root='.',
    train=False,
    download=True,
    transform=transform
)

In [None]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [None]:
def train_D_on_real(real_samples):
    optimD.zero_grad()
    output = D(real_samples.reshape(-1, 28 ** 2))
    loss = loss_fn(output, torch.ones_like(output, dtype=torch.float32))
    loss.backward()
    optimD.step()
    return loss

In [None]:
def train_D_on_fake():
    optimD.zero_grad()
    fake_samples = G(torch.randn(batch_size, 100).to(device))
    output = D(fake_samples)
    loss = loss_fn(output, torch.zeros_like(output, dtype=torch.float32))
    loss.backward()
    optimD.step()
    return loss

In [None]:
def train_G():
    optimG.zero_grad()
    generated_result = G(torch.randn(batch_size, 100).to(device))
    output = D(generated_result)
    loss = loss_fn(output, torch.ones_like(output, dtype=torch.float32))
    loss.backward()
    optimG.step()
    return loss

In [None]:
for epoch in range(50):
    dloss = 0
    gloss = 0
    for n, (real_samples, _) in enumerate(train_loader):
        dloss += train_D_on_real(real_samples.to(device))
        dloss += train_D_on_fake()
        gloss =  train_G()
    gloss /= n
    dloss /= n
    if epoch % 10 == 9:
        print(f"at epoch {epoch+1}, dloss: {dloss}, gloss {gloss}")
        see_output()
