In [1]:
import torch 
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter

In [2]:
class Generator(nn.Module):
    # N x 100
    # tanspose output = (in - 1) * stride (-2 * padding) x (kernel_size - 1)
    def __init__(self, z_dim=100, out_chan=1, gen_features=128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._block(z_dim, gen_features*8, 4, 1, 0), 
            self._block(gen_features*8, gen_features*4, kernel_size=4, stride=2, padding=1),
            self._block(gen_features*4, gen_features*2, kernel_size=4, stride=2),
            self._block(gen_features*2, gen_features, stride=2),
            nn.ConvTranspose2d(gen_features, out_chan, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    def _block(self, in_channels, out_channels, kernel_size=4, stride=1, padding=1):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, z_dim):
        return self.gen(z_dim)    

In [3]:
class Discriminator(nn.Module):
    def __init__(self, im_channels=1, disc_features=128):
        super(Discriminator, self).__init__()
        self.gen = nn.Sequential(
            nn.Conv2d(im_channels, disc_features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            self._block(disc_features, disc_features*2, 4, 2, 1),
            self._block(disc_features*2, disc_features*4,kernel_size=4, stride=2, padding=1),
            self._block(disc_features*4, disc_features *8, kernel_size=4, stride=2, padding=1),
            nn.Conv2d(disc_features*8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size=4, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, img):
        return self.gen(img)

In [13]:
batch_size = 128
z_dim = 100
img_channels = 3
lr = 3e-4
disc_features = 128
gen_features = 128
beta1, beta2 = 0.5, 0.99
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epoch = 10
img_shape = 64

gen = Generator(z_dim=z_dim, out_chan=img_channels, gen_features=gen_features).to(device)
disc = Discriminator(im_channels=img_channels, disc_features=disc_features).to(device)

gen_optim = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta1, beta2))
disc_optim = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta1, beta2))

criterion = nn.BCELoss()

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((img_shape, img_shape)),
        transforms.Normalize([0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)])
    ]
)

data = datasets.CIFAR10(root=".", transform=transform, download=False)
loader = DataLoader(data, batch_size=batch_size, shuffle=True)

In [14]:
step = 0
fake_writer = SummaryWriter(f"runs/fake")
real_writer = SummaryWriter(f"runs/real")

for epoch in range(n_epoch):
    for idx, (real, label) in enumerate(loader):
        real = real.to(device)
        z = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake = gen(z)
        # train discriminator
        fake_pred = disc(fake.detach())
        fake_loss = criterion(fake_pred, torch.zeros_like(fake_pred))
        real_pred = disc(real)
        real_loss = criterion(real_pred, torch.ones_like(real_pred))
        disc_loss  = (fake_loss + real_loss) / 2

        disc.zero_grad()
        disc_loss.backward()
        disc_optim.step()

        # train generator
        z2 = torch.randn((batch_size, z_dim, 1, 1)).to(device)
        fake2 = gen(z2)
        fake2_pred = disc(fake2)
        gen_loss = criterion(fake2_pred, torch.ones_like(fake2_pred))

        gen.zero_grad()
        gen_loss.backward()
        gen_optim.step()

        if idx == 0:
            print(
                f"Epoch {epoch} / {n_epoch}\n"
                f"Loss Discriminator {disc_loss:.4f}, Loss Generator {gen_loss:.4f}"
            )
        
            with torch.no_grad():
                z_test = torch.randn((batch_size, z_dim, 1, 1)).to(device)
                fake_test = gen(z_test).reshape(-1, img_channels, img_shape, img_shape)
                real = real.reshape(-1, img_channels, img_shape, img_shape)

                fake_grid = make_grid(fake[:16], nrow=4, normalize=True)
                real_grid = make_grid(real[:16], nrow=4, normalize=True)

                fake_writer.add_image(
                    "Fake Images", fake_grid, global_step=step
                )
                real_writer.add_image(
                    "Real Images", real_grid, global_step=step
                )
            
                step += 1
        

Epoch 0 / 10
Loss Discriminator 0.7957, Loss Generator 8.9928
Epoch 1 / 10
Loss Discriminator 0.5142, Loss Generator 2.2372
Epoch 2 / 10
Loss Discriminator 0.1474, Loss Generator 3.4059
Epoch 3 / 10
Loss Discriminator 0.1815, Loss Generator 5.3358
Epoch 4 / 10
Loss Discriminator 0.9499, Loss Generator 3.1359
Epoch 5 / 10
Loss Discriminator 0.1925, Loss Generator 3.7925
Epoch 6 / 10
Loss Discriminator 0.2505, Loss Generator 4.2355
Epoch 7 / 10
Loss Discriminator 0.5210, Loss Generator 2.1843
Epoch 8 / 10
Loss Discriminator 0.0535, Loss Generator 4.3568
Epoch 9 / 10
Loss Discriminator 0.0017, Loss Generator 5.5643
