In [1]:
import numpy as np
from torchvision import transforms as T

from discriminators import Discriminator32
from generators import Generator32
from gan import RaGAN

from cifar10 import CIFAR10, get_train_loader
from utils import set_seed
from train import train_gan

In [2]:
SEED = 42
BATCH_SIZE = 128
NUM_SAMPLES = 100

LATENT_DIM = 100
IN_CHANNELS = 3
OUT_CHANNELS = 3
BASE = 64

NUM_EPOCHS = 5
TOTAL_EPOCHS = 10

In [3]:
LOSS_TYPE = "ra"
FILENAME = f"cifar10_ragan-{LOSS_TYPE}"

set_seed(SEED)
dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1))

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = RaGAN(discriminator, generator, loss_type=LOSS_TYPE)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:0.407, real_loss:0.414, fake_loss:0.399, g_loss:3.174                                                   
[  2/5] d_loss:0.367, real_loss:0.373, fake_loss:0.360, g_loss:2.983                                                  
[  3/5] d_loss:0.364, real_loss:0.368, fake_loss:0.360, g_loss:2.709                                                  
[  4/5] d_loss:0.358, real_loss:0.360, fake_loss:0.356, g_loss:2.618                                                   
[  5/5] d_loss:0.367, real_loss:0.368, fake_loss:0.366, g_loss:2.506                                                  
>> ./outputs/cifar10_ragan-ra_epoch005.png is saved.

[  1/5] d_loss:0.370, real_loss:0.371, fake_loss:0.369, g_loss:2.540                                                  
[  2/5] d_loss:0.376, real_loss:0.377, fake_loss:0.376, g_loss:2.590                                                   
[  3/5] d_loss:0.372, real_loss:0.373, fake_loss:0.372, g_loss:2.517                                          

In [4]:
LOSS_TYPE = "rals"
FILENAME = f"cifar10_ragan-{LOSS_TYPE}"

set_seed(SEED)
dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1))

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = RaGAN(discriminator, generator, loss_type=LOSS_TYPE)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:0.091, real_loss:0.084, fake_loss:0.097, g_loss:3.999                                                  
[  2/5] d_loss:0.033, real_loss:0.032, fake_loss:0.035, g_loss:3.855                                                  
[  3/5] d_loss:0.025, real_loss:0.022, fake_loss:0.027, g_loss:3.842                                                  
[  4/5] d_loss:0.024, real_loss:0.021, fake_loss:0.026, g_loss:3.882                                                  
[  5/5] d_loss:0.023, real_loss:0.021, fake_loss:0.025, g_loss:3.862                                                  
>> ./outputs/cifar10_ragan-rals_epoch005.png is saved.

[  1/5] d_loss:0.023, real_loss:0.021, fake_loss:0.024, g_loss:3.852                                                  
[  2/5] d_loss:0.021, real_loss:0.019, fake_loss:0.022, g_loss:3.866                                                  
[  3/5] d_loss:0.019, real_loss:0.017, fake_loss:0.020, g_loss:3.859                                           