In [None]:
import torch
import torch.optim as optim
from dataloaders import get_mnist_dataloaders, get_lsun_dataloader
from models import Generator, Discriminator
from training import Trainer

data_loader, _ = get_mnist_dataloaders(batch_size=64)
img_size = [32, 32, 1]

generator = Generator(img_size=img_size, latent_dim=100, dim=16)
discriminator = Discriminator(img_size=img_size, dim=16)

print(generator)
print(discriminator)

# Initialize optimizers
lr = 1e-4
betas = (.9, .99)
G_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=betas)
D_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)

# Train model
epochs = 200
trainer = Trainer(generator, discriminator, G_optimizer, D_optimizer,
                  use_cuda=torch.cuda.is_available())
trainer.train(data_loader, epochs, save_training_gif=True)

# Save models
name = 'mnist_model'
torch.save(trainer.G.state_dict(), './gen_' + name + '.pt')
torch.save(trainer.D.state_dict(), './dis_' + name + '.pt')


Generator(
  (latent_to_features): Sequential(
    (0): Linear(in_features=100, out_features=512, bias=True)
    (1): ReLU()
  )
  (features_to_image): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ConvTranspose2d(16, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): Sigmoid()
  )
)
Discriminator(
  (image_to_features): Sequential(
    (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0

Iteration 51
D: -0.46091228723526
GP: 0.22784389555454254
Gradient norm: 0.9631915092468262
G: -0.24656520783901215
Iteration 101
D: -0.5449245572090149
GP: 0.13112932443618774
Gradient norm: 0.9659456014633179
G: -0.28101977705955505
Iteration 151
D: -0.4861961007118225
GP: 0.0906035453081131
Gradient norm: 0.9776062369346619
G: -0.3764142394065857
Iteration 201
D: -0.4936192035675049
GP: 0.11959351599216461
Gradient norm: 0.9540016651153564
G: -0.3213968873023987
Iteration 251
D: -0.4904814064502716
GP: 0.16620740294456482
Gradient norm: 0.9820190668106079
G: -0.27358895540237427
Iteration 301
D: -0.5459770560264587
GP: 0.10605807602405548
Gradient norm: 0.9853087663650513
G: -0.27185991406440735
Iteration 351
D: -0.5149109959602356
GP: 0.12189965695142746
Gradient norm: 0.9965879917144775
G: -0.2972521185874939
Iteration 401
D: -0.4174409806728363
GP: 0.1861962378025055
Gradient norm: 0.9527084231376648
G: -0.3398931920528412
Iteration 451
D: -0.5035995841026306
GP: 0.16803312301635

Iteration 701
D: -0.578984260559082
GP: 0.15073874592781067
Gradient norm: 0.9782195091247559
G: -0.25260084867477417
Iteration 751
D: -0.6231685280799866
GP: 0.07808607071638107
Gradient norm: 0.9929487705230713
G: -0.2366362363100052
Iteration 801
D: -0.5936532020568848
GP: 0.10331276804208755
Gradient norm: 0.9803120493888855
G: -0.2515721917152405
Iteration 851
D: -0.642730176448822
GP: 0.06384813040494919
Gradient norm: 0.9836370944976807
G: -0.23518916964530945
Iteration 901
D: -0.6157644987106323
GP: 0.07101911306381226
Gradient norm: 0.9829713106155396
G: -0.24559181928634644

Epoch 8
Iteration 1
D: -0.6185596585273743
GP: 0.08245342969894409
Gradient norm: 0.987671434879303
G: -0.234270378947258
Iteration 51
D: -0.6342903971672058
GP: 0.0867370218038559
Gradient norm: 0.9761833548545837
G: -0.18264903128147125
Iteration 101
D: -0.6388489603996277
GP: 0.08525640517473221
Gradient norm: 0.9787094593048096
G: -0.2290719598531723
Iteration 151
D: -0.6197888851165771
GP: 0.09858179

Iteration 401
D: -0.5182872414588928
GP: 0.15178775787353516
Gradient norm: 0.9907987117767334
G: -0.20868128538131714
Iteration 451
D: -0.5255070924758911
GP: 0.1607261747121811
Gradient norm: 0.9916197061538696
G: -0.19312241673469543
Iteration 501
D: -0.5482699275016785
GP: 0.10178733617067337
Gradient norm: 0.967477560043335
G: -0.21647609770298004
Iteration 551
D: -0.6119486093521118
GP: 0.06988351047039032
Gradient norm: 0.9830538630485535
G: -0.21354267001152039
Iteration 601
D: -0.5392169952392578
GP: 0.1606328785419464
Gradient norm: 0.9798941612243652
G: -0.19601164758205414
Iteration 651
D: -0.5547767877578735
GP: 0.09619452059268951
Gradient norm: 0.9814688563346863
G: -0.2262158840894699
Iteration 701
D: -0.49543529748916626
GP: 0.12983514368534088
Gradient norm: 0.9999894499778748
G: -0.23693573474884033
Iteration 751
D: -0.4677678644657135
GP: 0.160396009683609
Gradient norm: 1.022005319595337
G: -0.21485485136508942
Iteration 801
D: -0.5610417723655701
GP: 0.09453587979

Iteration 51
D: -0.49245619773864746
GP: 0.09011518210172653
Gradient norm: 0.9874296188354492
G: -0.2677153944969177
Iteration 101
D: -0.46909600496292114
GP: 0.11035440862178802
Gradient norm: 0.9934139251708984
G: -0.25860458612442017
Iteration 151
D: -0.4255180358886719
GP: 0.09997950494289398
Gradient norm: 0.9934889078140259
G: -0.2847839295864105
Iteration 201
D: -0.35319340229034424
GP: 0.19968809187412262
Gradient norm: 0.9735634326934814
G: -0.27792614698410034
Iteration 251
D: -0.43004775047302246
GP: 0.12465334683656693
Gradient norm: 0.9872424006462097
G: -0.24570894241333008
Iteration 301
D: -0.4647567868232727
GP: 0.10763149708509445
Gradient norm: 0.9909951686859131
G: -0.23846633732318878
Iteration 351
D: -0.48731017112731934
GP: 0.06818687915802002
Gradient norm: 0.9881497025489807
G: -0.24409231543540955
Iteration 401
D: -0.3771359920501709
GP: 0.1326991468667984
Gradient norm: 0.9795471429824829
G: -0.2721167206764221
Iteration 451
D: -0.5477291941642761
GP: 0.05822

Iteration 651
D: -0.39770326018333435
GP: 0.1267571747303009
Gradient norm: 1.0292088985443115
G: -0.29377835988998413
Iteration 701
D: -0.4108220934867859
GP: 0.1346387267112732
Gradient norm: 0.9927284717559814
G: -0.2683728337287903
Iteration 751
D: -0.445003867149353
GP: 0.08836120367050171
Gradient norm: 0.9839912056922913
G: -0.2580106854438782
Iteration 801
D: -0.39857715368270874
GP: 0.10895084589719772
Gradient norm: 0.9909239411354065
G: -0.289470374584198
Iteration 851
D: -0.32268983125686646
GP: 0.10990212857723236
Gradient norm: 0.9995492100715637
G: -0.30696433782577515
Iteration 901
D: -0.24709895253181458
GP: 0.13778600096702576
Gradient norm: 0.9992586970329285
G: -0.3617578148841858

Epoch 19
Iteration 1
D: -0.3490709662437439
GP: 0.09901802986860275
Gradient norm: 0.9923567771911621
G: -0.3762279748916626
Iteration 51
D: -0.35976552963256836
GP: 0.08665986359119415
Gradient norm: 0.9814763069152832
G: -0.35038167238235474
Iteration 101
D: -0.29936569929122925
GP: 0.1

Iteration 301
D: -0.6271286606788635
GP: 0.04709985479712486
Gradient norm: 1.006144404411316
G: -0.18522030115127563
Iteration 351
D: -0.5349909067153931
GP: 0.11777064204216003
Gradient norm: 1.0023088455200195
G: -0.1632196009159088
Iteration 401
D: -0.5976855754852295
GP: 0.06543231010437012
Gradient norm: 0.9765343070030212
G: -0.17959919571876526
Iteration 451
D: -0.470124751329422
GP: 0.16995736956596375
Gradient norm: 0.9744197130203247
G: -0.17076340317726135
Iteration 501
D: -0.4370311498641968
GP: 0.2297501564025879
Gradient norm: 0.9663738012313843
G: -0.1840614676475525
Iteration 551
D: -0.5397310853004456
GP: 0.10084227472543716
Gradient norm: 0.9966753721237183
G: -0.19115284085273743
Iteration 601
D: -0.5837901830673218
GP: 0.08365274965763092
Gradient norm: 0.9783172607421875
G: -0.19442319869995117
Iteration 651
D: -0.5878666043281555
GP: 0.09453821927309036
Gradient norm: 1.0010547637939453
G: -0.12544818222522736
Iteration 701
D: -0.5029377341270447
GP: 0.1986621618


Epoch 26
Iteration 1
D: -0.5590047836303711
GP: 0.07414397597312927
Gradient norm: 1.0162395238876343
G: -0.21745219826698303
Iteration 51
D: -0.4585934281349182
GP: 0.06469233334064484
Gradient norm: 1.004767894744873
G: -0.19606931507587433
Iteration 101
D: -0.5922865867614746
GP: 0.06546764820814133
Gradient norm: 0.9910632967948914
G: -0.13275645673274994
Iteration 151
D: -0.49547579884529114
GP: 0.05852697044610977
Gradient norm: 0.9949573278427124
G: -0.13722041249275208
Iteration 201
D: -0.607144832611084
GP: 0.04135992377996445
Gradient norm: 0.9870659112930298
G: -0.16026221215724945
Iteration 251
D: -0.4462604820728302
GP: 0.18496069312095642
Gradient norm: 0.9761145114898682
G: -0.16900509595870972
Iteration 301
D: -0.5408263206481934
GP: 0.08465752005577087
Gradient norm: 0.994654655456543
G: -0.1725156009197235
Iteration 351
D: -0.5249221324920654
GP: 0.13166628777980804
Gradient norm: 0.9765559434890747
G: -0.1834334433078766
Iteration 401
D: -0.5509406328201294
GP: 0.09

Iteration 601
D: -0.6753531098365784
GP: 0.05289150029420853
Gradient norm: 0.9685598611831665
G: -0.10179758816957474
Iteration 651
D: -0.6569645404815674
GP: 0.05059681832790375
Gradient norm: 1.0043922662734985
G: -0.11966053396463394
Iteration 701
D: -0.6889615058898926
GP: 0.04546763375401497
Gradient norm: 0.9705058336257935
G: -0.1138979122042656
Iteration 751
D: -0.6210055351257324
GP: 0.05198679864406586
Gradient norm: 0.9906966686248779
G: -0.1352582573890686
Iteration 801
D: -0.6356440186500549
GP: 0.14311403036117554
Gradient norm: 0.9506105184555054
G: -0.11474782228469849
Iteration 851
D: -0.648846447467804
GP: 0.08781552314758301
Gradient norm: 1.0149503946304321
G: -0.13228581845760345
Iteration 901
D: -0.6299367547035217
GP: 0.07665179669857025
Gradient norm: 1.0103572607040405
G: -0.10970078408718109

Epoch 30
Iteration 1
D: -0.6619446873664856
GP: 0.08130960166454315
Gradient norm: 0.9799031019210815
G: -0.11602427065372467
Iteration 51
D: -0.5625734925270081
GP: 0.1

Iteration 251
D: -0.5616677403450012
GP: 0.0702548399567604
Gradient norm: 1.0346964597702026
G: -0.14066985249519348
Iteration 301
D: -0.5966842770576477
GP: 0.041062891483306885
Gradient norm: 1.0152549743652344
G: -0.18586933612823486
Iteration 351
D: -0.6696077585220337
GP: 0.03851722180843353
Gradient norm: 1.008086085319519
G: -0.155299574136734
Iteration 401
D: -0.6729373931884766
GP: 0.06714960932731628
Gradient norm: 1.0047831535339355
G: -0.11351944506168365
Iteration 451
D: -0.741543173789978
GP: 0.02838870882987976
Gradient norm: 1.0051066875457764
G: -0.09657926112413406
Iteration 501
D: -0.7751218676567078
GP: 0.03227967396378517
Gradient norm: 1.013367772102356
G: -0.10460555553436279
Iteration 551
D: -0.6849365234375
GP: 0.14825360476970673
Gradient norm: 0.9732481241226196
G: -0.09165562689304352
Iteration 601
D: -0.7257167100906372
GP: 0.04216282069683075
Gradient norm: 1.0108778476715088
G: -0.09862074255943298
Iteration 651
D: -0.59677654504776
GP: 0.148747026920318

In [6]:
print('hi')

hi
