In [6]:
import numpy as np
from torchvision import transforms as T

from discriminators import Discriminator32
from generators import Generator32
from gan import GAN

from cifar10 import CIFAR10, get_train_loader
from utils import set_seed
from train import train_gan

In [7]:
SEED = 42
BATCH_SIZE = 128
NUM_SAMPLES = 100

LATENT_DIM = 100
IN_CHANNELS = 3
OUT_CHANNELS = 3
BASE = 64

NUM_EPOCHS = 5
TOTAL_EPOCHS = 10

In [9]:
LOSS_TYPE = "bce"
FILENAME = f"cifar10_gan-{LOSS_TYPE}"

set_seed(SEED)
dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1))

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = GAN(discriminator, generator, loss_type=LOSS_TYPE)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:0.426, real_loss:0.228, fake_loss:0.198, g_loss:4.694                                                  
[  2/5] d_loss:0.443, real_loss:0.236, fake_loss:0.207, g_loss:3.981                                                   
[  3/5] d_loss:0.506, real_loss:0.261, fake_loss:0.245, g_loss:3.110                                                  
[  4/5] d_loss:0.450, real_loss:0.223, fake_loss:0.227, g_loss:3.104                                                  
[  5/5] d_loss:0.609, real_loss:0.310, fake_loss:0.299, g_loss:2.849                                                   
>> ./outputs/cifar10_gan-bce_epoch005.png is saved.

[  1/5] d_loss:0.688, real_loss:0.345, fake_loss:0.343, g_loss:2.219                                                  
[  2/5] d_loss:0.752, real_loss:0.380, fake_loss:0.372, g_loss:2.266                                                  
[  3/5] d_loss:0.724, real_loss:0.360, fake_loss:0.363, g_loss:2.371                                            

In [10]:
LOSS_TYPE = "mse"
FILENAME = f"cifar10_gan-{LOSS_TYPE}"

set_seed(SEED)
dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1))

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = GAN(discriminator, generator, loss_type=LOSS_TYPE)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:0.721, real_loss:0.169, fake_loss:0.552, g_loss:2.791                                                  
[  2/5] d_loss:0.231, real_loss:0.135, fake_loss:0.096, g_loss:0.794                                                  
[  3/5] d_loss:0.305, real_loss:0.168, fake_loss:0.137, g_loss:0.734                                                  
[  4/5] d_loss:0.288, real_loss:0.156, fake_loss:0.132, g_loss:0.812                                                  
[  5/5] d_loss:0.340, real_loss:0.177, fake_loss:0.164, g_loss:0.768                                                  
>> ./outputs/cifar10_gan-mse_epoch005.png is saved.

[  1/5] d_loss:0.324, real_loss:0.169, fake_loss:0.155, g_loss:0.775                                                  
[  2/5] d_loss:0.342, real_loss:0.174, fake_loss:0.168, g_loss:0.728                                                  
[  3/5] d_loss:0.285, real_loss:0.146, fake_loss:0.138, g_loss:0.726                                              

In [11]:
LOSS_TYPE = "hinge"
FILENAME = f"cifar10_gan-{LOSS_TYPE}"

set_seed(SEED)
dataset = CIFAR10(root_dir="/mnt/d/datasets/cifar10", split="train",
                  transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5]*3, std=[0.5]*3)]))
train_loader = get_train_loader(dataset, batch_size=BATCH_SIZE)
noises = np.random.normal(size=(NUM_SAMPLES, LATENT_DIM, 1, 1))

discriminator = Discriminator32(in_channels=IN_CHANNELS, base=BASE)
generator = Generator32(latent_dim=LATENT_DIM, out_channels=OUT_CHANNELS, base=BASE)
gan = GAN(discriminator, generator, loss_type=LOSS_TYPE)
history = train_gan(gan, train_loader, num_epochs=NUM_EPOCHS, total_epochs=TOTAL_EPOCHS, 
                    noises=noises, filename=FILENAME)

[  1/5] d_loss:0.556, real_loss:0.308, fake_loss:0.247, g_loss:3.355                                                  
[  2/5] d_loss:0.464, real_loss:0.251, fake_loss:0.213, g_loss:2.418                                                  
[  3/5] d_loss:0.565, real_loss:0.294, fake_loss:0.271, g_loss:1.833                                                  
[  4/5] d_loss:0.444, real_loss:0.227, fake_loss:0.218, g_loss:1.848                                                  
[  5/5] d_loss:0.742, real_loss:0.377, fake_loss:0.365, g_loss:1.624                                                  
>> ./outputs/cifar10_gan-hinge_epoch005.png is saved.

[  1/5] d_loss:0.801, real_loss:0.398, fake_loss:0.403, g_loss:1.420                                                  
[  2/5] d_loss:0.812, real_loss:0.403, fake_loss:0.408, g_loss:1.423                                                  
[  3/5] d_loss:0.824, real_loss:0.415, fake_loss:0.409, g_loss:1.390                                            