In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
import torch
import random
import numpy as np

seed = 1 # random.randint(1, 10000)
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fcaa03ff6f0>

## Model configurations

In [2]:
# Data configs
IMG_DIR = "../training_images/"
BATCH_SIZE = 8
DLOADER_WORKERS = 0

# Model configs
DEPTH = 6 # Final image size is 2**DEPTH
LATENT_SIZE = 512 # Size of the input latent space

# Optimiser configs
G_LR = 0.0001
D_LR = 0.0003
BETA1 = 0
BETA2 = 0.999

torch.cuda.empty_cache()

## Setting up data and model

In [3]:
from dataloader import create_dataloader
dataloader = create_dataloader(IMG_DIR, BATCH_SIZE, DLOADER_WORKERS, DEPTH)

In [4]:
from MSG_GAN import MSG_GAN
from loss import StandardGAN

model = MSG_GAN(DEPTH, LATENT_SIZE)
g_optim = torch.optim.Adam(model.gen.parameters(), G_LR, [BETA1, BETA2])
d_optim = torch.optim.Adam(model.dis.parameters(), D_LR, [BETA1, BETA2])

**Load saved model if possible**

In [5]:
START_EPOCH = 0
TOTAL_EPOCH = 200

In [19]:
checkpoint = torch.load(f"{START_EPOCH}_save.pt")
model.gen.load_state_dict(checkpoint["gen"])
model.dis.load_state_dict(checkpoint["dis"])
g_optim.load_state_dict(checkpoint["gen_optim"])
d_optim.load_state_dict(checkpoint["dis_optim"])
model.gen_shadow.load_state_dict(checkpoint["gen_shadow"])

<All keys matched successfully>

In [6]:
loss = StandardGAN(model.dis, real_label=0.9)
model.train(dataloader, g_optim, d_optim, loss, start=START_EPOCH+1, num_epochs=TOTAL_EPOCH, checkpoint_factor=10, feedback_factor=40)


Epoch: 1
Elapsed [0:00:19.595205] batch: 40  d_loss: 0.178612  g_loss: 3.818853
Elapsed [0:00:39.407569] batch: 80  d_loss: 0.298625  g_loss: 3.387247
Elapsed [0:00:57.112386] batch: 120  d_loss: 0.205006  g_loss: 3.442460
Elapsed [0:01:16.354863] batch: 160  d_loss: 0.191966  g_loss: 3.814625
Elapsed [0:01:35.355829] batch: 200  d_loss: 0.172132  g_loss: 4.739140
Elapsed [0:01:52.868155] batch: 240  d_loss: 0.192985  g_loss: 4.728215
Elapsed [0:02:10.616658] batch: 280  d_loss: 0.168022  g_loss: 4.488097
Elapsed [0:02:29.348290] batch: 320  d_loss: 0.204518  g_loss: 4.982589
Elapsed [0:02:48.044660] batch: 360  d_loss: 0.163147  g_loss: 7.086243
Time taken for epoch: 179.969 secs

Epoch: 2
Elapsed [0:03:17.908164] batch: 40  d_loss: 0.181761  g_loss: 4.062128
Elapsed [0:03:36.516981] batch: 80  d_loss: 0.174177  g_loss: 4.045695
Elapsed [0:03:55.032966] batch: 120  d_loss: 0.165330  g_loss: 5.662098
Elapsed [0:04:14.302273] batch: 160  d_loss: 0.166445  g_loss: 6.174897
Elapsed [0:04

Time taken for epoch: 176.345 secs

Epoch: 13
Elapsed [0:35:53.715354] batch: 40  d_loss: 0.162797  g_loss: 7.676990
Elapsed [0:36:11.972507] batch: 80  d_loss: 0.165485  g_loss: 7.296337
Elapsed [0:36:30.308765] batch: 120  d_loss: 0.183822  g_loss: 4.741402
Elapsed [0:36:48.264463] batch: 160  d_loss: 0.163877  g_loss: 5.861598
Elapsed [0:37:07.019266] batch: 200  d_loss: 0.182817  g_loss: 4.316624
Elapsed [0:37:24.731441] batch: 240  d_loss: 0.167585  g_loss: 6.419403
Elapsed [0:37:42.361466] batch: 280  d_loss: 0.169408  g_loss: 6.461421
Elapsed [0:38:00.950055] batch: 320  d_loss: 0.171612  g_loss: 7.024603
Elapsed [0:38:18.742972] batch: 360  d_loss: 0.171033  g_loss: 8.216120
Time taken for epoch: 176.539 secs

Epoch: 14
Elapsed [0:38:49.615621] batch: 40  d_loss: 0.166928  g_loss: 5.463449
Elapsed [0:39:08.481566] batch: 80  d_loss: 0.163448  g_loss: 6.128669
Elapsed [0:39:26.173232] batch: 120  d_loss: 0.164568  g_loss: 6.133421
Elapsed [0:39:44.507865] batch: 160  d_loss: 0.1

Elapsed [1:10:37.378432] batch: 360  d_loss: 0.165129  g_loss: 6.688696
Time taken for epoch: 175.399 secs

Epoch: 25
Elapsed [1:11:07.048115] batch: 40  d_loss: 0.162712  g_loss: 7.631296
Elapsed [1:11:25.442029] batch: 80  d_loss: 0.175850  g_loss: 3.868346
Elapsed [1:11:42.798753] batch: 120  d_loss: 0.165451  g_loss: 5.828510
Elapsed [1:12:01.033247] batch: 160  d_loss: 0.177265  g_loss: 3.640032
Elapsed [1:12:19.298242] batch: 200  d_loss: 0.169161  g_loss: 5.257630
Elapsed [1:12:38.065598] batch: 240  d_loss: 0.165456  g_loss: 4.858663
Elapsed [1:12:56.222813] batch: 280  d_loss: 0.163527  g_loss: 8.455005
Elapsed [1:13:14.410621] batch: 320  d_loss: 0.179534  g_loss: 4.260782
Elapsed [1:13:32.855277] batch: 360  d_loss: 0.169455  g_loss: 5.770987
Time taken for epoch: 175.945 secs

Epoch: 26
Elapsed [1:14:03.370144] batch: 40  d_loss: 0.196267  g_loss: 5.055945
Elapsed [1:14:21.688664] batch: 80  d_loss: 0.168362  g_loss: 4.554281
Elapsed [1:14:40.631537] batch: 120  d_loss: 0.1

Elapsed [1:46:51.551844] batch: 320  d_loss: 0.169373  g_loss: 5.090394
Elapsed [1:47:10.487081] batch: 360  d_loss: 0.162780  g_loss: 6.912550
Time taken for epoch: 183.218 secs

Epoch: 37
Elapsed [1:47:43.524505] batch: 40  d_loss: 0.178771  g_loss: 4.127937
Elapsed [1:48:01.716327] batch: 80  d_loss: 0.172513  g_loss: 5.249418
Elapsed [1:48:21.025154] batch: 120  d_loss: 0.165748  g_loss: 4.832312
Elapsed [1:48:39.452142] batch: 160  d_loss: 0.208927  g_loss: 3.696122
Elapsed [1:48:57.542092] batch: 200  d_loss: 0.166981  g_loss: 4.601999
Elapsed [1:49:16.534099] batch: 240  d_loss: 0.164717  g_loss: 5.037337
Elapsed [1:49:34.925991] batch: 280  d_loss: 0.163720  g_loss: 6.656817
Elapsed [1:49:52.533510] batch: 320  d_loss: 0.236678  g_loss: 3.729093
Elapsed [1:50:10.236388] batch: 360  d_loss: 0.165465  g_loss: 4.876777
Time taken for epoch: 178.242 secs

Epoch: 38
Elapsed [1:50:40.497849] batch: 40  d_loss: 0.167457  g_loss: 4.745253
Elapsed [1:50:58.552006] batch: 80  d_loss: 0.1

Elapsed [2:21:54.200478] batch: 280  d_loss: 0.179562  g_loss: 3.724975
Elapsed [2:22:12.407834] batch: 320  d_loss: 0.167207  g_loss: 4.704003
Elapsed [2:22:30.868299] batch: 360  d_loss: 0.165103  g_loss: 6.068410
Time taken for epoch: 176.495 secs

Epoch: 49
Elapsed [2:23:01.205429] batch: 40  d_loss: 0.166463  g_loss: 5.356653
Elapsed [2:23:19.839576] batch: 80  d_loss: 0.163506  g_loss: 8.344935
Elapsed [2:23:38.044491] batch: 120  d_loss: 0.167701  g_loss: 4.907684
Elapsed [2:23:55.424609] batch: 160  d_loss: 0.164004  g_loss: 5.528056
Elapsed [2:24:13.257064] batch: 200  d_loss: 0.163340  g_loss: 6.137459
Elapsed [2:24:31.663581] batch: 240  d_loss: 0.183387  g_loss: 6.500065
Elapsed [2:24:50.156848] batch: 280  d_loss: 0.163546  g_loss: 5.653981
Elapsed [2:25:08.839988] batch: 320  d_loss: 0.163641  g_loss: 5.962052
Elapsed [2:25:26.264488] batch: 360  d_loss: 0.163171  g_loss: 6.231571
Time taken for epoch: 175.879 secs

Epoch: 50
Elapsed [2:25:57.730338] batch: 40  d_loss: 0.

Elapsed [2:56:46.747296] batch: 240  d_loss: 0.233542  g_loss: 3.166710
Elapsed [2:57:04.615453] batch: 280  d_loss: 0.164116  g_loss: 5.263212
Elapsed [2:57:22.489645] batch: 320  d_loss: 0.164101  g_loss: 5.731289
Elapsed [2:57:40.454905] batch: 360  d_loss: 0.165701  g_loss: 4.686019
Time taken for epoch: 175.121 secs

Epoch: 61
Elapsed [2:58:15.178889] batch: 40  d_loss: 0.169819  g_loss: 6.556722
Elapsed [2:58:33.341947] batch: 80  d_loss: 0.201070  g_loss: 5.598119
Elapsed [2:58:52.021929] batch: 120  d_loss: 0.192641  g_loss: 3.629164
Elapsed [2:59:10.102599] batch: 160  d_loss: 0.197505  g_loss: 3.989800
Elapsed [2:59:28.253284] batch: 200  d_loss: 0.169635  g_loss: 5.509024
Elapsed [2:59:46.333289] batch: 240  d_loss: 0.163211  g_loss: 7.211070
Elapsed [3:00:04.623969] batch: 280  d_loss: 0.163403  g_loss: 6.906624
Elapsed [3:00:22.542720] batch: 320  d_loss: 0.167490  g_loss: 7.091652
Elapsed [3:00:40.567391] batch: 360  d_loss: 0.165355  g_loss: 5.558391
Time taken for epoch

Elapsed [3:31:32.420704] batch: 200  d_loss: 0.163163  g_loss: 6.178625
Elapsed [3:31:50.475881] batch: 240  d_loss: 0.201301  g_loss: 5.836020
Elapsed [3:32:09.877189] batch: 280  d_loss: 0.186203  g_loss: 3.385997
Elapsed [3:32:28.230055] batch: 320  d_loss: 0.167802  g_loss: 4.813329
Elapsed [3:32:46.106406] batch: 360  d_loss: 0.197114  g_loss: 3.994420
Time taken for epoch: 175.059 secs

Epoch: 73
Elapsed [3:33:16.426260] batch: 40  d_loss: 0.178640  g_loss: 4.034286
Elapsed [3:33:34.239053] batch: 80  d_loss: 0.164269  g_loss: 5.730256
Elapsed [3:33:52.236186] batch: 120  d_loss: 0.164031  g_loss: 5.778656
Elapsed [3:34:10.414613] batch: 160  d_loss: 0.164073  g_loss: 5.604185
Elapsed [3:34:29.183065] batch: 200  d_loss: 0.168527  g_loss: 8.590828
Elapsed [3:34:47.802456] batch: 240  d_loss: 0.162733  g_loss: 7.995696
Elapsed [3:35:05.913580] batch: 280  d_loss: 0.171855  g_loss: 4.442997
Elapsed [3:35:23.462626] batch: 320  d_loss: 0.165977  g_loss: 4.633571
Elapsed [3:35:41.868

KeyboardInterrupt: 