In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [17]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        #Input: N x img_channels x 64 x 64
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            # 32 x 32 
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16x16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8x8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4x4
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0), # 1x1 
            nn.Sigmoid()
        )


    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.disc(x)


In [18]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        # Input: N x z_dim x 1 x 1
        self.gen = nn.Sequential(
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self._block(features_g*16, features_g * 8, 4, 2, 1), # 8 x 8
            self._block(features_g*8, features_g * 4, 4, 2, 1), # 16 x 16
            self._block(features_g*4, features_g * 2, 4, 2, 1), # 32 x 32
            nn.ConvTranspose2d(
                features_g*2,
                channels_img,
                4,2,1
            ), # 64 x 64
            nn.Tanh()
        )        
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.gen(x)


In [19]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [32]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
LEARNING_RATE_DISC=2e-5
LEARNING_RATE_GEN=2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

In [33]:
transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
           [0.5 for _ in range(CHANNELS_IMG)],
           [0.5 for _ in range(CHANNELS_IMG)]
        )
    ]
)

dataset=datasets.MNIST(root="./dataset/",train=True,transform=transform,download=True)
dataset = datasets.ImageFolder(root="dataset/celeb_dataset/", transform=transform)
    
loader =  DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [34]:
gen  = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5,0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE_GEN, betas=(0.5,0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"./logs/real")
writer_fake = SummaryWriter(f"./logs/fake")

step = 0

In [35]:
gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [36]:
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
        fake = gen(noise)

        disc_real = disc(real).reshape(-1)
        disc_fake = disc(fake.detach()).reshape(-1)

        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        loss_disc = (loss_disc_fake + loss_disc_real)/2
        
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()



        if batch_idx % 10 == 0:
            print(
                    f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                        Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)

                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
            
            step += 1

Epoch [0/5] Batch 0/1583                         Loss D: 0.6975, loss G: 0.8050
Epoch [0/5] Batch 10/1583                         Loss D: 0.3541, loss G: 1.3884
Epoch [0/5] Batch 20/1583                         Loss D: 0.1975, loss G: 1.8628
Epoch [0/5] Batch 30/1583                         Loss D: 0.1214, loss G: 2.2718
Epoch [0/5] Batch 40/1583                         Loss D: 0.0795, loss G: 2.6491
Epoch [0/5] Batch 50/1583                         Loss D: 0.0536, loss G: 2.9630
Epoch [0/5] Batch 60/1583                         Loss D: 0.0402, loss G: 3.2496
Epoch [0/5] Batch 70/1583                         Loss D: 0.0307, loss G: 3.4999
Epoch [0/5] Batch 80/1583                         Loss D: 0.0237, loss G: 3.7374
Epoch [0/5] Batch 90/1583                         Loss D: 0.0193, loss G: 3.9466
Epoch [0/5] Batch 100/1583                         Loss D: 0.0147, loss G: 4.1452
Epoch [0/5] Batch 110/1583                         Loss D: 0.0130, loss G: 4.3287
Epoch [0/5] Batch 120/1583 

KeyboardInterrupt: 