In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transformations
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.datasets as datasets
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
import matplotlib.pyplot as plt
from utils import gradient_penalty

In [10]:
# hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # change according to data
NOISE_DIM = 100
NUM_EPOCHS = 10
FEATURES_DISC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10 # wgan-gradient penalty replaces clipping
WORKERS = 6

In [11]:
transforms = transformations.Compose(
    [
        transformations.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transformations.RandomHorizontalFlip(p=0.5),
        transformations.ToTensor(),
        transformations.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range (CHANNELS_IMG)])
    ]
)

dataset = datasets.ImageFolder('./dataset', transform=transforms)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))


fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [12]:
'''gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN)
gen.load_state_dict(torch.load("./models/Generator"))

critic = Discriminator(CHANNELS_IMG, FEATURES_DISC)
critic.load_state_dict(torch.load("./models/Critic"))'''


'gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN)\ngen.load_state_dict(torch.load("./models/Generator"))\n\ncritic = Discriminator(CHANNELS_IMG, FEATURES_DISC)\ncritic.load_state_dict(torch.load("./models/Critic"))'

In [13]:
gen.train()
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [14]:
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, NOISE_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()


        # Train generator
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # print losses occasionally and print to tensorboard

        if batch_idx % 80 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G; {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take up to 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/10] Batch 80/329                   Loss D: 6.4264, loss G; -0.3543
Epoch [0/10] Batch 160/329                   Loss D: 6.8600, loss G; -0.4636
Epoch [0/10] Batch 240/329                   Loss D: 5.2296, loss G; -0.3732
Epoch [0/10] Batch 320/329                   Loss D: 4.8994, loss G; -0.4283
Epoch [1/10] Batch 80/329                   Loss D: 3.4745, loss G; -0.5447
Epoch [1/10] Batch 160/329                   Loss D: 2.4740, loss G; -0.3435
Epoch [1/10] Batch 240/329                   Loss D: 1.9236, loss G; -0.2051
Epoch [1/10] Batch 320/329                   Loss D: 2.2354, loss G; -0.5325
Epoch [2/10] Batch 80/329                   Loss D: 1.5609, loss G; -0.3605
Epoch [2/10] Batch 160/329                   Loss D: 2.8049, loss G; -0.1636
Epoch [2/10] Batch 240/329                   Loss D: 2.2533, loss G; -0.4540
Epoch [2/10] Batch 320/329                   Loss D: 10.0522, loss G; -1.0000
Epoch [3/10] Batch 80/329                   Loss D: 9.8545, loss G; -1.0000
Ep

In [None]:
torch.save(gen.state_dict(), "./models/Generator", _use_new_zipfile_serialization=False)
torch.save(critic.state_dict(), "./models/Critic", _use_new_zipfile_serialization=False)