In [1]:
import numpy as np
from torchvision import transforms as T

from discriminators import Discriminator32
from generators import Generator32
from gan import WGAN

from cifar10 import CIFAR10, get_train_loader
from utils import set_seed
from train import train_gan

In [2]:
import gc, torch

def initialize(seed=42):
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    set_seed(seed)

In [3]:
SEED = 42
BATCH_SIZE = 128
NUM_SAMPLES = 100

LATENT_DIM = 100
IN_CHANNELS = 3
OUT_CHANNELS = 3
BASE = 64

NUM_EPOCHS = 5
TOTAL_EPOCHS = 10

In [4]:
initialize(seed=SEED)
LOSS_TYPE = "gp_one-sided"
FILENAME = f"cifar10_wgan-{LOSS_TYPE}"

dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1)) 

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = WGAN(discriminator, generator, use_gp=True, one_sided=True)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:-15.399, real_loss:-2.541, fake_loss:-12.858, g_loss:13.173, gp:0.273                                                  
[  2/5] d_loss:-7.456, real_loss:-2.606, fake_loss:-4.850, g_loss:5.091, gp:0.168                                                  
[  3/5] d_loss:-6.573, real_loss:-1.480, fake_loss:-5.092, g_loss:5.139, gp:0.094                                                  
[  4/5] d_loss:-43.401, real_loss:-36.333, fake_loss:-7.068, g_loss:7.150, gp:1.206                                                  
[  5/5] d_loss:-5.856, real_loss:2.679, fake_loss:-8.535, g_loss:8.607, gp:0.193                                                  
>> ./outputs/cifar10_wgan-gp_one-sided_epoch005.png is saved.

[  1/5] d_loss:-6.919, real_loss:-0.680, fake_loss:-6.239, g_loss:6.287, gp:0.267                                                  
[  2/5] d_loss:-7.372, real_loss:-2.642, fake_loss:-4.731, g_loss:4.759, gp:0.398                                                  
[  3/5] d

In [4]:
initialize(seed=SEED)
LOSS_TYPE = "default"
FILENAME = f"cifar10_wgan-{LOSS_TYPE}"

dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1)) 

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = WGAN(discriminator, generator, use_gp=False, one_sided=False)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:-0.501, real_loss:-0.255, fake_loss:-0.245, g_loss:0.248                                                  
[  2/5] d_loss:-0.464, real_loss:-0.237, fake_loss:-0.227, g_loss:0.230                                                  
[  3/5] d_loss:-0.347, real_loss:-0.179, fake_loss:-0.168, g_loss:0.171                                                  
[  4/5] d_loss:-0.308, real_loss:-0.158, fake_loss:-0.150, g_loss:0.152                                                  
[  5/5] d_loss:-0.281, real_loss:-0.144, fake_loss:-0.137, g_loss:0.136                                                  
>> ./outputs/cifar10_wgan-default_epoch005.png is saved.

[  1/5] d_loss:-0.268, real_loss:-0.138, fake_loss:-0.130, g_loss:0.130                                                  
[  2/5] d_loss:-0.263, real_loss:-0.134, fake_loss:-0.129, g_loss:0.126                                                  
[  3/5] d_loss:-0.244, real_loss:-0.124, fake_loss:-0.121, g_loss:0.117                 

In [5]:
initialize(seed=SEED)
LOSS_TYPE = "gp"
FILENAME = f"cifar10_wgan-{LOSS_TYPE}"

dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1)) 

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = WGAN(discriminator, generator, use_gp=True, one_sided=False)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:-15.457, real_loss:-2.429, fake_loss:-13.028, g_loss:13.346, gp:0.270                                                  
[  2/5] d_loss:-7.344, real_loss:-2.331, fake_loss:-5.012, g_loss:5.149, gp:0.178                                                  
[  3/5] d_loss:-5.385, real_loss:-0.324, fake_loss:-5.061, g_loss:5.085, gp:0.083                                                  
[  4/5] d_loss:-50.365, real_loss:-45.089, fake_loss:-5.276, g_loss:5.374, gp:1.952                                                  
[  5/5] d_loss:-5.712, real_loss:0.597, fake_loss:-6.309, g_loss:6.322, gp:0.217                                                  
>> ./outputs/cifar10_wgan-gp_epoch005.png is saved.

[  1/5] d_loss:-6.306, real_loss:-3.400, fake_loss:-2.906, g_loss:2.856, gp:0.265                                                  
[  2/5] d_loss:-6.822, real_loss:-4.049, fake_loss:-2.772, g_loss:2.787, gp:0.372                                                  
[  3/5] d_loss:-6.4