# Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import matplotlib.pyplot as plt

from torchvision.models import inception_v3
from torchvision import transforms
from torch.utils.data import DataLoader
from gan_package.gan import GAN
from gan_package.vanillaGAN import VanillaGAN_Generator, VanillaGAN_Discriminator

# Create dataset

In [2]:
image_size = 256
batch_size = 256
root = 'lsun/bedroom'

dataset = dset.ImageFolder(root=root,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    

# Set parameters

In [None]:
batch_size = 256
learning_rate_gen = 0.0005
learning_rate_dis = 0.005

In [None]:
img_shape = dataloader.dataset[0][0].shape
n_out = torch.prod(torch.tensor(img_shape))
latent_dim = 100

inception_model = inception_v3(pretrained=True, transform_input=False, aux_logits=True)
vanilla_generator = VanillaGAN_Generator(latent_dim=latent_dim, img_shape=img_shape, n_out=n_out)
vanilla_discriminator = VanillaGAN_Discriminator(img_shape=img_shape)
vanilla_gan = GAN(generator=vanilla_generator, discriminator=vanilla_discriminator, inception_model=inception_model)


criterion = nn.BCELoss()

vanilla_generator_optimizer = optim.Adam(vanilla_generator.parameters(), lr=learning_rate_gen)
vanilla_discriminator_optimizer = optim.Adam(vanilla_discriminator.parameters(), lr=learning_rate_dis)

# Training loop

In [4]:
vanilla_gan.train(dataset=dataset,
                    dataloader=dataloader,
                    discriminator_optimizer=vanilla_discriminator_optimizer,
                    generator_optimizer=vanilla_generator_optimizer,
                    criterion=criterion,
                    num_epochs=20)

  3%|▎         | 41/1185 [10:19<4:15:19, 13.39s/it]

In [None]:
vanilla_gan.save_generator('models/vanilla_generator.pt')
vanilla_gan.save_discriminator('models/vanilla_discriminator.pt')