## GAN: MNIST

In [64]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

In [65]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Data

In [66]:
import os
import gzip

def load_mnist_images(data_dir, filename):
    data_path = os.path.join(data_dir, filename)
    with gzip.open(data_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    return data.reshape(-1, 28, 28)

def load_mnist_labels(data_dir, filename):
    data_path = os.path.join(data_dir, filename)
    with gzip.open(data_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=8)
    return data

# data_dir = r"D:\datasets\mnist_11M"   ## windows
data_dir = "/mnt/d/datasets/mnist_11M"  ## wsl

x_train = load_mnist_images(data_dir, "train-images-idx3-ubyte.gz")
y_train = load_mnist_labels(data_dir, "train-labels-idx1-ubyte.gz")
x_test = load_mnist_images(data_dir, "t10k-images-idx3-ubyte.gz")
y_test = load_mnist_labels(data_dir, "t10k-labels-idx1-ubyte.gz")

print(f">> Train images: {x_train.shape}, {x_train.dtype}")
print(f">> Train labels: {y_train.shape}, {y_train.dtype}")
print(f">> Test images:  {x_test.shape}, {x_test.dtype}")
print(f">> Test labels:  {y_test.shape}, {y_test.dtype}")

>> Train images: (60000, 28, 28), uint8
>> Train labels: (60000,), uint8
>> Test images:  (10000, 28, 28), uint8
>> Test labels:  (10000,), uint8


In [67]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images.reshape(-1, 28, 28, 1)
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label).long()
        return image, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

train_loader = DataLoader(Dataset(x_train, y_train, transform=transform), 
                          batch_size=64, shuffle=True)
test_loader = DataLoader(Dataset(x_test, y_test, transform=transform), 
                         batch_size=64, shuffle=True)

x, y = next(iter(train_loader))
print(f">> x: {x.shape}, {x.dtype}, min={x.min()}, max={x.max()}")
print(f">> y: {y.shape}, {y.dtype}, min={y.min()}, max={y.max()}")

>> x: torch.Size([64, 1, 28, 28]), torch.float32, min=-1.0, max=1.0
>> y: torch.Size([64]), torch.int64, min=0, max=9


### Modeling

In [85]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels=1, bias=False):
        super().__init__()
        self.embedding = nn.Embedding(10, 10)
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(in_channels + 10, 512, 4, 1, 0, bias=bias),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=bias),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=bias),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=bias),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels, 1, 1, 2, bias=bias),)

    def forward(self, noises, labels):
        labels_embedding = self.embedding(labels).unsqueeze(2).unsqueeze(3)
        print(noises.shape, labels_embedding.shape)
        x = torch.cat([noises, labels_embedding], dim=1)
        x = self.generator(x)
        return torch.tanh(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, bias=False):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1, bias=bias),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=bias),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=bias),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, out_channels, 4, 2, 1, bias=bias),
            nn.Flatten(),)

    def forward(self, x):
        x = self.discriminator(x)
        return torch.sigmoid(x)

### Training

In [None]:
import sys
from tqdm import tqdm
from torchvision.utils import save_image

## Hyperparameters
set_seed(42)
n_epochs = 5
learning_rate = 2e-4
step_size = 1
noise_size = 64

n_outputs = 100
output_name = "cgan_cnn_4layer"

## Modeling
modelG = Generator(in_channels=noise_size, out_channels=1).to(device)
modelD = Discriminator(in_channels=1, out_channels=1).to(device)

loss_fn = nn.BCELoss()
optimizerD = optim.Adam(modelD.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerG = optim.Adam(modelG.parameters(), lr=learning_rate, betas=(0.5, 0.999))

## Training
fixed_noises = torch.randn(n_outputs, noise_size, 1, 1).to(device)
fixed_labels = torch.randint(0, 10, (n_outputs,)).to(device)
output_dir = '/mnt/d/github/lectures-1/test/mnist/output_cgan'
output_images = modelG(fixed_noises, fixed_labels)
save_image(output_images, os.path.join(output_dir, f"{output_name}_0.png"), nrow=10)

torch.Size([100, 64, 1, 1]) torch.Size([100, 10, 1, 1])


In [None]:
for epoch in range(1, n_epochs + 1):
    with tqdm(train_loader, leave=False, file=sys.stdout, dynamic_ncols=True, ascii=True) as pbar:
        train_loss_r, train_loss_f, train_loss_g = 0, 0, 0
        for i, (real_images, labels) in enumerate(pbar):
            batch_size = len(real_images)
            real_labels = torch.ones((batch_size, 1)).to(device)
            fake_labels = torch.zeros((batch_size, 1)).to(device)
            noise = torch.randn(batch_size, noise_size, 1, 1).to(device)
            
            labels = labels.to(device)
            real_images = real_images.to(device)
            fake_images = modelG(noise)

            ## Training Discriminator
            pred_r = modelD(real_images)
            loss_r = loss_fn(pred_r, real_labels)
            loss_r.backward()

            pred_f = modelD(fake_images.detach())
            loss_f = loss_fn(pred_f, fake_labels)
            loss_f.backward()

            optimizerD.step()
            optimizerD.zero_grad()

            # Training Generator
            pred_g = modelD(fake_images)
            loss_g = loss_fn(pred_g, real_labels)
            loss_g.backward()

            optimizerG.step()
            optimizerG.zero_grad()
            
            train_loss_r += loss_r.item()
            train_loss_f += loss_f.item()
            train_loss_g += loss_g.item()

            desc = f"[{epoch:3d}/{n_epochs}] loss_r: {train_loss_r/(i + 1):.2e} " \
                   f"loss_f: {train_loss_f/(i + 1):.2e} loss_g: {train_loss_g/(i + 1):.2e}"

            if i % 10 == 0:
                pbar.set_description(desc)

        if epoch % step_size == 0:
            print(desc)
            output_images = modelG(fixed_noises)
            save_image(output_images, os.path.join(output_dir, f"{output_name}_{epoch}.png"), nrow=10)

[  1/5] loss_r: 2.29e-01 loss_f: 9.90e-01 loss_g: 5.64e-01                                                    
                                                                                                              

KeyboardInterrupt: 