In [None]:
import torch.nn as nn
import torch

print(torch.__version__)

import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as dataloader
from torch.utils.tensorboard import SummaryWriter

print(torch.cuda.is_available())

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),

            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),

            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.disc(x)



class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self._block(channels_noise, features_g * 16, 4, 1, 0), # 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1), # 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1), # 16 x 16
            self._block(features_g * 4, features_g * 2, 4, 2, 1), # 32 x 32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal(m.weight.data, 0.0, 0.02)

from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMG_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_CLASSES = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
NUM_EPOCHS = 15
LAMBDA_GP = 10

transforms = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

def test_classes():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W)).to(device)
    disc = Discriminator(in_channels, 8).to(device)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    print(summary(disc, input_size=(in_channels, H, H)))
    gen = Generator(channels_noise=noise_dim, channels_img=in_channels, features_g=8).to(device)
    z = torch.randn((N, noise_dim, 1, 1)).to(device)
    assert gen(z).shape == (N, in_channels, H, W), " Generator test failed"
    print(summary(gen, input_size=(noise_dim, 1, 1)))

test_classes()

def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * epsilon + fake * (1 - epsilon)


    # calcuated mixed scores
    mixed_scores = critic(interpolated_images)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]

    gradient = gradient.view(gradient.shape[0], -1) # flattening
    gradient_norm = gradient.norm(2, dim=1) # taking norm of flattened dim
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

def save_checkpoint(state, filename='celeba_wgan_gp'):
    print("-> saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(filename, gen, disc):
    print("-> loading checkpoint")
    gen.load_state_dict(torch.load(filename))
    disc.load_state_dict(torch.load(filename))

from torch.utils.data import DataLoader

celeba_dataset = datasets.CelebA(root='data',
                                 split='train',
                                 transform=transforms,
                                 download=True)
celeba_loader = DataLoader(dataset=celeba_dataset, batch_size=BATCH_SIZE, shuffle=True)

gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)

initialize_weights(gen)
initialize_weights(disc)



opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/WPGAN_CELEBA/2/real")
writer_fake = SummaryWriter(f"logs/WPGAN_CELEBA/2/fake")

step = 0
img_idx = 0
gen.train()
disc.train()

from tqdm import tqdm

from torchvision.transforms import ToPILImage

import numpy as np
# main training loop
loss_g = []
loss_d = []
real_scores = []
fake_scores = []


for epoch in range(NUM_EPOCHS):
    loss_g_per_epoch = []
    loss_d_per_epoch = []
    real_scores_per_epoch= []
    fake_scores_per_epoch = []

    for batch_idx, (real, _) in enumerate(tqdm(celeba_loader)):
        # print(real.shape)

        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Discriminator
        # for _ in range(CRITIC_ITERATIONS):
        #     noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
        #     fake = gen(noise)
        #     # print("fake shape", fake.shape)
        #     # print("real shape", real.shape)
        #     critical_real = disc(real).reshape(-1)
        #     critical_fake = disc(fake).reshape(-1)
        #     gp = gradient_penalty(disc, real, fake, device)
        #     loss_disc = (
        #         -(torch.mean(critical_real) - torch.mean(critical_fake)) + LAMBDA_GP * gp
        #     )
        #     disc.zero_grad()
        #     loss_disc.backward(retain_graph=True)
        #     opt_disc.step()

        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        # print("fake shape", fake.shape)
        # print("real shape", real.shape)
        critical_real = disc(real).reshape(-1)
        critical_fake = disc(fake).reshape(-1)
        gp = gradient_penalty(disc, real, fake, device)
        loss_disc = (
            -(torch.mean(critical_real) - torch.mean(critical_fake)) + LAMBDA_GP * gp
        )
        real_preds = disc(real)
        real_score = torch.mean(real_preds).item()
        real_scores_per_epoch.append(real_score)

        fake_preds = disc(fake)
        fake_score = torch.mean(fake_preds).item()
        fake_scores_per_epoch.append(fake_score)

        loss_d_per_epoch.append(loss_disc.item())

        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()


        # Train Generator
        if batch_idx % 5 == 0:
          gen_fake = disc(fake).reshape(-1)
          loss_gen = -torch.mean(gen_fake)
          loss_g_per_epoch.append(loss_gen.item())
          loss_g.append(np.mean(loss_g_per_epoch))

          gen.zero_grad()
          loss_gen.backward()
          opt_gen.step()

        if batch_idx % 100 == 0 and batch_idx != 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(celeba_loader)} \
                loss D: {loss_disc:.4f}, loss G {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)

                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

                # to_pil = ToPILImage()

                # img_fake = to_pil(img_grid_fake)
                # img_fake.save(f" ../images/base/fake/fake_images_grid_{img_idx}.png"
                #              )
            step += 1
        loss_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_scores_per_epoch))
        fake_scores.append(np.mean(fake_scores_per_epoch))



import matplotlib.pyplot as plt

def draw_image(gen_image):
  gen_image = torch.transpose(gen_image, 3, 1)
  gen_image = torch.transpose(gen_image, 1, 2)
  gen_image = gen_image[0].detach().numpy()
  plt.imshow(gen_image)

noise = torch.randn(64, Z_DIM, 1, 1).to(device)
gen_image = gen(noise)

draw_image(gen_image.to(torch.device('cpu')))

In [None]:
%load_ext tensorboard

%tensorboard --logdir ./logs

plt.plot(loss_d)
plt.show()

plt.plot(loss_g)