# Imports

In [1]:
import torch
from torch import Tensor
from torchvision.utils import save_image

from models.type import Type as ModelType
from data.type import Type as DataType
from results.results import Results
from train.params import Params
from train.builder import build_params

import ssl

from train.train_wgan import TrainWGan

ssl._create_default_https_context = ssl._create_unverified_context



# Train WGan with Fashion-Mnist

In [2]:
args: {} = {}
args["epochs"] = 1 # number of epochs of training
args["batch_size"] = 64 # size of the batches
args["lr"] = 0.0002 # adam: learning rate
args["b1"] = 0.5 # adam: decay of first order momentum of gradient
args["b2"] = 0.999 # adam: decay of first order momentum of gradient
args["latent_dim"] = 100 # dimensionality of the latent space
args["critic"] = 5 # number of training steps for discriminator per iter
args["gradient_penalty_lambda"] = 10 # loss weight for gradient penalty
args["save_generated_image_every"] = 50 # interval batches between saving image

params: Params = build_params(args, ModelType.WGAN_GP, DataType.FASHION_MNIST)

In [3]:
results: Results = Results("wgan")

train: TrainWGan = TrainWGan(params, results.loss_updated_callback)
train.run()

[Epoch 1/1] [Batch 1/938] [Discriminator loss: 7.965496] [Generator loss: 0.020921]
[Epoch 1/1] [Batch 6/938] [Discriminator loss: 4.205560] [Generator loss: 0.005333]
[Epoch 1/1] [Batch 11/938] [Discriminator loss: -4.479178] [Generator loss: -0.056411]
[Epoch 1/1] [Batch 16/938] [Discriminator loss: -14.374399] [Generator loss: -0.213896]
[Epoch 1/1] [Batch 21/938] [Discriminator loss: -20.077557] [Generator loss: -0.463803]
[Epoch 1/1] [Batch 26/938] [Discriminator loss: -23.303551] [Generator loss: -0.658140]
[Epoch 1/1] [Batch 31/938] [Discriminator loss: -26.219278] [Generator loss: -0.872584]
[Epoch 1/1] [Batch 36/938] [Discriminator loss: -25.330536] [Generator loss: -1.083971]
[Epoch 1/1] [Batch 41/938] [Discriminator loss: -24.415670] [Generator loss: -1.297791]
[Epoch 1/1] [Batch 46/938] [Discriminator loss: -25.419647] [Generator loss: -1.563311]
[Epoch 1/1] [Batch 51/938] [Discriminator loss: -25.345894] [Generator loss: -1.838504]
[Epoch 1/1] [Batch 56/938] [Discriminator

[Epoch 1/1] [Batch 441/938] [Discriminator loss: -5.031691] [Generator loss: -2.401665]
[Epoch 1/1] [Batch 446/938] [Discriminator loss: -5.067956] [Generator loss: -2.723453]
[Epoch 1/1] [Batch 451/938] [Discriminator loss: -5.020389] [Generator loss: -1.595452]
[Epoch 1/1] [Batch 456/938] [Discriminator loss: -5.884916] [Generator loss: -1.715223]
[Epoch 1/1] [Batch 461/938] [Discriminator loss: -5.421385] [Generator loss: -1.258875]
[Epoch 1/1] [Batch 466/938] [Discriminator loss: -4.142063] [Generator loss: -4.470121]
[Epoch 1/1] [Batch 471/938] [Discriminator loss: -4.629319] [Generator loss: -4.465357]
[Epoch 1/1] [Batch 476/938] [Discriminator loss: -5.577743] [Generator loss: -3.048803]
[Epoch 1/1] [Batch 481/938] [Discriminator loss: -4.820692] [Generator loss: -5.577889]
[Epoch 1/1] [Batch 486/938] [Discriminator loss: -4.966998] [Generator loss: -4.819818]
[Epoch 1/1] [Batch 491/938] [Discriminator loss: -4.531331] [Generator loss: -4.382471]
[Epoch 1/1] [Batch 496/938] [Dis

[Epoch 1/1] [Batch 846/938] [Discriminator loss: -4.375093] [Generator loss: -1.649229]
[Epoch 1/1] [Batch 851/938] [Discriminator loss: -4.230699] [Generator loss: -0.990210]
[Epoch 1/1] [Batch 856/938] [Discriminator loss: -4.348237] [Generator loss: -0.632761]
[Epoch 1/1] [Batch 861/938] [Discriminator loss: -4.332053] [Generator loss: 0.184667]
[Epoch 1/1] [Batch 866/938] [Discriminator loss: -5.035052] [Generator loss: 2.041342]
[Epoch 1/1] [Batch 871/938] [Discriminator loss: -4.179540] [Generator loss: -0.856347]
[Epoch 1/1] [Batch 876/938] [Discriminator loss: -4.913606] [Generator loss: 1.628767]
[Epoch 1/1] [Batch 881/938] [Discriminator loss: -4.620762] [Generator loss: 0.040448]
[Epoch 1/1] [Batch 886/938] [Discriminator loss: -4.622988] [Generator loss: 0.255912]
[Epoch 1/1] [Batch 891/938] [Discriminator loss: -4.235046] [Generator loss: 0.629328]
[Epoch 1/1] [Batch 896/938] [Discriminator loss: -4.483991] [Generator loss: 1.916887]
[Epoch 1/1] [Batch 901/938] [Discrimina

In [5]:
generator_losses: {} = results.generator_losses
discriminator_losses: {} = results.discriminator_losses
last_generator: torch.nn.Module = results.last_generator



In [6]:
import matplotlib.pyplot as plt

plt.title("WGAN_FASHION_MNIST")
plt.xlabel("Step")
plt.ylabel("Loss")

x_values: [] = []
y_values: [] = []
    

    
plt.plot(np.array(train_values), 'r', label='Train')
plt.xticks(np.arange(len(train_values)), np.arange(1, len(train_values)+1))
    
plt.plot(np.array(test_values), 'b', label='Test')
plt.xticks(np.arange(len(test_values)), np.arange(1, len(test_values)+1))
    
plt.legend(loc='best')
plt.show()

{1: 0.02092120796442032, 6: 0.005333012901246548, 11: -0.05641145259141922, 16: -0.2138962298631668, 21: -0.46380338072776794, 26: -0.6581395864486694, 31: -0.8725836277008057, 36: -1.0839707851409912, 41: -1.29779052734375, 46: -1.5633106231689453, 51: -1.8385038375854492, 56: -2.116753101348877, 61: -2.3850696086883545, 66: -2.7478296756744385, 71: -3.15753173828125, 76: -3.591050386428833, 81: -3.997032642364502, 86: -4.581509113311768, 91: -5.021251678466797, 96: -5.476865768432617, 101: -6.199121952056885, 106: -6.495793342590332, 111: -7.617455959320068, 116: -7.995965957641602, 121: -9.020689010620117, 126: -9.48087215423584, 131: -9.792229652404785, 136: -10.776042938232422, 141: -11.275318145751953, 146: -11.763315200805664, 151: -12.502201080322266, 156: -11.45547103881836, 161: -12.25210952758789, 166: -11.561017990112305, 171: -11.795769691467285, 176: -11.001579284667969, 181: -11.628725051879883, 186: -12.783065795898438, 191: -13.074605941772461, 196: -12.407998085021973

In [None]:
fixed_noise: Tensor = torch.randn(2, 100)
if torch.cuda.is_available():
    fixed_noise = fixed_noise.cuda()

generated_image = last_generator(fixed_noise)
image_path_1 = "results/wgan_1.png"
image_path_2 = "results/wgan_2.png"
save_image(generated_image.data[0], image_path_1, nrow=2, normalize=True)
save_image(generated_image.data[1], image_path_2, nrow=2, normalize=True)