## Generative Adversarial Networks


In [1]:
# %pip install -q pillow
# %pip install -q numpy
# %pip install -q scipy
# %pip install -q matplotlib
# %pip install -q torchinfo
# %pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

In [2]:
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lr_scheduler

from torchvision.transforms import v2
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch.utils.data import sampler

from torchinfo import summary
from tqdm import tqdm

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    if device.type == "cuda":
        # Allow TensorFloat32 on matmul and convolutions
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # torch.set_float32_matmul_precision("medium")

print(f"Available device: {device.type}")
torch.set_default_device(device)

In [4]:
preprocess_fn = v2.Compose(
    [
        v2.Resize((28, 28)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

postprocess_fn = v2.Compose(
    [
        v2.ToPILImage(),
    ]
)

In [None]:
DATA_PATH = "../datasets"
NUM_TRAIN = 60000
NUM_VAL = 5000
NUM_TEST = 5000
MINIBATCH_SIZE = 96
EPOCHS = 20

fashion_train = datasets.FashionMNIST(
    DATA_PATH, train=True, download=True, transform=preprocess_fn
)


train_dataloader = DataLoader(
    fashion_train,
    batch_size=MINIBATCH_SIZE,
    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),
)

len(fashion_train), len(train_dataloader)

### Using Linear Layers


In [6]:
class Discriminator(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

In [7]:
discriminator = Discriminator().to(device)

In [8]:
class Generator(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.model(x)

In [9]:
generator = Generator().to(device)

In [10]:
def discrimator_train_step(model, real_data, fake_data, loss_fn, optimizer):
    optimizer.zero_grad()
    real_pred = model(real_data)

    real_error = loss_fn(real_pred, torch.ones(len(real_data), 1).to(device))
    real_error.backward()

    fake_pred = model(fake_data)

    fake_error = loss_fn(fake_pred, torch.zeros(len(fake_data), 1).to(device))
    fake_error.backward()

    optimizer.step()
    return real_error + fake_error

In [11]:
def generator_train_step(model, fake_data, real_data, loss_fn, optimizer):
    optimizer.zero_grad()

    prediction = model(fake_data)
    error = loss_fn(prediction, torch.ones(len(real_data), 1).to(device))
    error.backward()

    optimizer.step()
    return error

In [12]:
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4)

loss = nn.BCELoss()
epochs = 35

In [13]:
def train(generator, discrimator, dataloader, epochs=200):
    d_scheduler = lr_scheduler.LinearLR(
        d_optimizer, start_factor=0.99, end_factor=0.66, total_iters=17
    )

    g_scheduler = lr_scheduler.LinearLR(
        g_optimizer, start_factor=0.99, end_factor=0.66, total_iters=17
    )

    for epoch in range(epochs):
        loss_a = torch.zeros(2, len(dataloader))
        for i, (images, _) in enumerate(tqdm(dataloader)):
            
            real_data = images.view(len(images), -1).to(device)
            
            noise = torch.randn(len(real_data), 100)
            
            fake_data = generator(noise).to(device)
            fake_data = fake_data.detach()

            d_loss = discrimator_train_step(
                discrimator, real_data, fake_data, loss, d_optimizer
            )

            noise = torch.randn(len(real_data), 100)
            fake_data = generator(noise).to(device)

            g_loss = generator_train_step(
                discrimator, fake_data, real_data, loss, g_optimizer
            )

            loss_a[0][i] = d_loss
            loss_a[1][i] = g_loss
            
        d_scheduler.step()
        g_scheduler.step()

        print(
            f"epoch={epoch + 1} d_loss={loss_a[0].mean().item():.4f} g_loss={loss_a[1].mean().item():.4f} d_lr={d_optimizer.param_groups[0]['lr']:.7f} g_lr={g_optimizer.param_groups[0]['lr']:.7f}",
            end="\n\n",
        )

In [14]:
train(generator, discriminator, train_dataloader, epochs=epochs)

In [15]:
def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)  # remove the fake batch dimension
    image = postprocess_fn(image)
    plt.imshow(image)
    plt.axis("off")
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
generator.eval()
z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)

grid = make_grid(sample_images, nrow=8, normalize=True)

imshow(grid)

### Using Convolutional Layers


In [None]:
preprocess_fn = v2.Compose(
    [
        v2.Resize((64, 64)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

postprocess_fn = v2.Compose(
    [
        v2.ToPILImage(),
    ]
)

In [17]:
class ConvDiscriminator(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
summary(ConvDiscriminator(), input_size=(96, 1, 64, 64))

In [19]:
class ConvGenerator(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(1, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        return self.model(input)

In [None]:
summary(ConvGenerator(), input_size=(96, 1, 1, 1))

In [21]:
conv_discriminator = ConvDiscriminator().to(device)
conv_generator = ConvGenerator().to(device)

In [22]:
loss = nn.BCELoss()

d_optimizer = torch.optim.Adam(
    conv_discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999)
)
g_optimizer = torch.optim.Adam(conv_generator.parameters(), lr=2e-4, betas=(0.5, 0.999))

epochs = 10

In [23]:
def discrimator_train_step(model, real_data, fake_data, loss_fn, optimizer):
    optimizer.zero_grad()
    real_pred = model(real_data)

    real_error = loss_fn(real_pred.squeeze(), torch.ones(len(real_data)).to(device))
    real_error.backward()

    fake_pred = model(fake_data)

    fake_error = loss_fn(fake_pred.squeeze(), torch.zeros(len(fake_data)).to(device))
    fake_error.backward()

    optimizer.step()
    return real_error + fake_error

In [24]:
def generator_train_step(model, fake_data, real_data, loss_fn, optimizer):
    optimizer.zero_grad()

    prediction = model(fake_data)
    error = loss_fn(prediction.squeeze(), torch.ones(len(real_data)).to(device))
    error.backward()

    optimizer.step()
    return error

In [25]:
def conv_train(generator, discrimator, dataloader, epochs=200):
    for epoch in range(epochs):
        loss_a = torch.zeros(2, len(dataloader))
        for i, (images, _) in enumerate(tqdm(dataloader)):

            real_data = images.to(device)

            noise = torch.randn(len(real_data), 1, 1, 1).to(device)

            fake_data = generator(noise).to(device)
            fake_data = fake_data.detach()

            d_loss = discrimator_train_step(
                discrimator, real_data, fake_data, loss, d_optimizer
            )

            noise = torch.randn(len(real_data), 1, 1, 1)
            fake_data = generator(noise).to(device)

            g_loss = generator_train_step(
                discrimator, fake_data, real_data, loss, g_optimizer
            )

            loss_a[0][i] = d_loss
            loss_a[1][i] = g_loss

        print(
            f"epoch={epoch + 1} d_loss={loss_a[0].mean().item():.4f} g_loss={loss_a[1].mean().item():.4f}",
            end="\n\n",
        )

In [None]:
conv_train(conv_generator, conv_discriminator, train_dataloader, epochs=epochs)

In [None]:
conv_generator.eval()
z = torch.randn(64, 1, 1, 1).to(device)
sample_images = conv_generator(z).detach().cpu()

grid = make_grid(sample_images, nrow=8, normalize=True)

imshow(grid)