## GAN: CIFAR10

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Data

In [3]:
import os
import numpy as np
import pickle

def unpickle(filename):
    # tar -zxvf cifar-10-python.tar.gz
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='bytes')

    x = np.array(data[b'data']).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    y = np.array(data[b'labels'])
    return x, y


def load_cifar10(data_dir):
    batch_files = [os.path.join(data_dir, f"data_batch_{i+1}") for i in range(5)]
    test_file = os.path.join(data_dir, "test_batch")

    images, labels = [], []
    for filename in batch_files:
        x, y = unpickle(filename)
        images.append(x)
        labels.append(y)

    x_train = np.concatenate(images, axis=0)
    y_train = np.concatenate(labels, axis=0)

    x_test, y_test = unpickle(test_file)
    return (x_train, y_train), (x_test, y_test)

# data_dir = r"D:\datasets\cifar10_178M\cifar-10-batches-py"    ## windows
data_dir = "/mnt/d/datasets/cifar10_178M/cifar-10-batches-py"   ## wsl
(x_train, y_train), (x_test, y_test) = load_cifar10(data_dir)

print(f">> Train images: {x_train.shape}, {x_train.dtype}")
print(f">> Train labels: {y_train.shape}, {y_train.dtype}")
print(f">> Test images:  {x_test.shape}, {x_test.dtype}")
print(f">> Test labels:  {y_test.shape}, {y_test.dtype}")

>> Train images: (50000, 32, 32, 3), uint8
>> Train labels: (50000,), int64
>> Test images:  (10000, 32, 32, 3), uint8
>> Test labels:  (10000,), int64


In [4]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class CIFAR10(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label).long()
        return image, label

transform_train = transforms.Compose([
    # transforms.ToPILImage(),
    # transforms.RandomHorizontalFlip(0.3),
    # transforms.RandomVerticalFlip(0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

train_dataset = CIFAR10(x_train, y_train, transform=transform_train)
test_dataset = CIFAR10(x_test, y_test, transform=transform_test)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

x, y = next(iter(train_loader))
print(f">> x: {x.shape}, {x.dtype}, min={x.min()}, max={x.max()}")
print(f">> y: {y.shape}, {y.dtype}, min={y.min()}, max={y.max()}")

>> x: torch.Size([32, 3, 32, 32]), torch.float32, min=-1.0, max=1.0
>> y: torch.Size([32]), torch.int64, min=0, max=9


### Modeling

In [11]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels=3, n_classes=10):
        super().__init__()
        self.n_classes = n_classes
        self.embedding_dim = 32
        self.embedding = nn.Embedding(self.n_classes, self.embedding_dim)
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(in_channels + self.embedding_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1, bias=False),
        )

    def forward(self, noises, labels):
        labels_embedding = self.embedding(labels).view(-1, self.embedding_dim, 1, 1)
        inputs = torch.cat([noises, labels_embedding], dim=1)
        images = self.generator(inputs)
        return torch.tanh(images)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, n_classes=10):
        super().__init__()
        self.n_classes = n_classes
        self.embedding_dim = 32
        self.embedding = nn.Sequential(
            nn.Embedding(self.n_classes, self.embedding_dim),
            nn.Linear(self.embedding_dim, 1*32*32),
        )
        self.discriminator = nn.Sequential(
            nn.Conv2d(in_channels + 1, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, out_channels, 4, 1, 0, bias=False),
            nn.Flatten(),
        )

    def forward(self, images, labels):
        labels_embedding = self.embedding(labels).view(-1, 1, 32, 32)
        inputs = torch.cat([images, labels_embedding], dim=1)
        outputs = self.discriminator(inputs)
        return torch.sigmoid(outputs)

### Training

In [14]:
import sys
from tqdm import tqdm
from torchvision.utils import save_image

## Hyperparameters
set_seed(42)
n_epochs = 5
learning_rate = 2e-4
noise_size = 100
step_size = 1

n_classes = 10
n_outputs = 100
output_name = "cifar10_cgan"

## Modeling
modelG = Generator(in_channels=noise_size, out_channels=3, n_classes=10).to(device)
modelD = Discriminator(in_channels=3, out_channels=1, n_classes=10).to(device)

loss_fn = nn.BCELoss()
optimizerD = optim.Adam(modelD.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerG = optim.Adam(modelG.parameters(), lr=learning_rate, betas=(0.5, 0.999))

## Training
fixed_noises = torch.randn(n_outputs, noise_size, 1, 1).to(device)
fixed_labels = torch.arange(n_classes).repeat(n_outputs // n_classes, 1).view(-1).to(device)
output_dir = '/mnt/d/github/lectures-1/test/cifar10/output_cgan'
output_path = os.path.join(output_dir, f"{output_name}_0.png")
output_images = modelG(fixed_noises, fixed_labels)
save_image(output_images, output_path, nrow=10, normalize=True)

In [17]:
for epoch in range(1, n_epochs + 1):
    with tqdm(train_loader, leave=False, file=sys.stdout, dynamic_ncols=True, ascii=True) as pbar:
        train_loss_r, train_loss_f, train_loss_g = 0, 0, 0
        for i, (real_images, labels) in enumerate(pbar):
            batch_size = len(real_images)
            real_labels = torch.ones((batch_size, 1)).to(device)
            fake_labels = torch.zeros((batch_size, 1)).to(device)
            noises = torch.randn(batch_size, noise_size, 1, 1).to(device)
            labels = labels.to(device)
            real_images = real_images.to(device)
            fake_images = modelG(noises, labels)

            ## Training Discriminator
            pred_r = modelD(real_images, labels)
            loss_r = loss_fn(pred_r, real_labels)
            loss_r.backward()

            pred_f = modelD(fake_images.detach(), labels)
            loss_f = loss_fn(pred_f, fake_labels)
            loss_f.backward()

            optimizerD.step()
            optimizerD.zero_grad()

            # Training Generator
            pred_g = modelD(fake_images, labels)
            loss_g = loss_fn(pred_g, real_labels)
            loss_g.backward()

            optimizerG.step()
            optimizerG.zero_grad()
            
            train_loss_r += loss_r.item()
            train_loss_f += loss_f.item()
            train_loss_g += loss_g.item()

            desc = f"[{epoch:3d}/{n_epochs}] loss_r: {train_loss_r/(i + 1):.2e} " \
                   f"loss_f: {train_loss_f/(i + 1):.2e} loss_g: {train_loss_g/(i + 1):.2e}"

            if i % 10 == 0:
                pbar.set_description(desc)

        if epoch % step_size == 0:
            print(desc)
            output_images = modelG(fixed_noises, fixed_labels)
            output_path = os.path.join(output_dir, f"{output_name}_{epoch}.png")
            save_image(output_images, output_path, nrow=10, normalize=True)

[  1/5] loss_r: 2.76e-01 loss_f: 1.06e+00 loss_g: 5.32e-01                                                     
[  2/5] loss_r: 1.54e-01 loss_f: 9.50e-01 loss_g: 6.31e-01                                                     
[  3/5] loss_r: 1.31e-01 loss_f: 9.23e-01 loss_g: 6.48e-01                                                     
[  4/5] loss_r: 1.24e-01 loss_f: 9.15e-01 loss_g: 6.55e-01                                                     
[  5/5] loss_r: 1.24e-01 loss_f: 9.15e-01 loss_g: 6.57e-01                                                     


### Evaluation

In [35]:
# def denormalize(img, mean=0.5, std=0.5):
#     normalize = transforms.Normalize([-mean/std], [1/std])
#     res = normalize(img)
#     res = torch.clamp(res, 0, 1)
#     return res

# # set_seed(111)
# fixed_noises = torch.randn(n_outputs, latent_dim, 1, 1).to(device)
# with torch.no_grad():
#     output_images = modelG(fixed_noises)

# images = denormalize(output_images)
# images = images.cpu().detach().permute(0, 2, 3, 1).squeeze()
# rows = [np.concatenate(images[i*10:(i+1)*10], axis=1) for i in range(10)]
# grid = np.concatenate(rows, axis=0)

# fig, ax = plt.subplots(figsize=(5, 5))
# ax.imshow(grid, cmap="gray_r")
# ax.set_axis_off()
# fig.tight_layout()
# plt.show()

# images.shape