## 0. Imports

In [1]:
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np
import torch
import time
import json
from tqdm.notebook import tqdm

from dataset import LoadDataset
from model import Generator, Discriminator
from train import train_generator, train_discriminator, evaluate
from settings import DEVICE, MEAN, STD, BATCH_SIZE, LATENT_SIZE, EPOCHS, LOSS_FN, OPTIMIZER, BETAS, D_LR, D_LR_DECAY, G_LR, G_LR_DECAY, CORES, LOAD_MODELS

torch.set_num_threads(4)

import autoreload
%load_ext autoreload
%autoreload 2

## 1. Dataset

In [2]:
ds = LoadDataset('data/train', MEAN, STD)
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)
num_batches = len(ds)/BATCH_SIZE

val_ds = LoadDataset('data/val', MEAN, STD)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)

len(ds)

92032

## 2. Models

In [3]:
G, D = Generator(), Discriminator()
print(G)
print(D)
print(f"Num parameters in Discriminator: {sum(p.numel() for p in D.parameters())}")
print(f"Num parameters in Generator: {sum(p.numel() for p in G.parameters())}")

Generator(
  (upsampling): Sequential(
    (0): ConvTranspose2d(128, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): 

## 3. Training

In [4]:
d_losses = []
g_losses = []
r_scores = []
f_scores = []
d_optimizer = OPTIMIZER(D.parameters(), D_LR, betas=BETAS)
g_optimizer = OPTIMIZER(G.parameters(), G_LR, betas=BETAS)


try:
    if LOAD_MODELS:
        D.load_state_dict(torch.load('discriminator.pth'))
        G.load_state_dict(torch.load('generator.pth'))
        print("Model loaded")
    else:
        print("Trainning new model")
except:
    print("Trainning new model")
    pass

D.to(DEVICE)
G.to(DEVICE)

if not LOAD_MODELS:
    with open("log.json", "w") as fh:
        json.dump({"epochs": 0}, fh)
    start_epochs = 0
else:
    with open("log.json", "r") as fh:
        log = json.load(fh)
    start_epochs = log['epochs']
    for e in range(1, start_epochs+1):
        d_losses.append(log[str(e)]['d_loss'])
        g_losses.append(log[str(e)]['g_loss'])
        r_scores.append(log[str(e)]['r_score'])
        f_scores.append(log[str(e)]['f_score'])

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY * start_epochs
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY * start_epochs
    latent = torch.randn(BATCH_SIZE, *LATENT_SIZE).to(DEVICE)
    fake_images = G(latent)*0.5 + 0.5
    save_image(fake_images, 'current.png')
print(f"Start at epochs {start_epochs}")

Model loaded
Start at epochs 1


In [None]:
for epochs in range(start_epochs, EPOCHS):
    start = time.time()

    for batch_images in tqdm(loader):
        batch_images = batch_images.to(DEVICE)
        d_loss = train_discriminator(G, D, batch_images, LOSS_FN, d_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        g_loss = train_generator(G, D, LOSS_FN, g_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        
    d_loss, g_loss, r_score, f_score, fake_images = evaluate(G, D, LOSS_FN, val_loader, LATENT_SIZE, BATCH_SIZE, DEVICE)
    d_losses.append(d_loss)
    g_losses.append(g_loss)
    r_scores.append(r_score)
    f_scores.append(f_score)
    torch.save(G.state_dict(), 'generator.pth')
    torch.save(D.state_dict(), 'discriminator.pth')
    
    d_loss = round(d_loss,5)
    g_loss = round(g_loss,5)
    r_score = round(r_score,5)
    f_score = round(f_score,5)

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY

    elapsed = round(time.time()-start, 2)
    print(f"Epoch #{epochs+1}: d_loss={d_loss}, g_loss={g_loss}, r_score={r_score}, f_score={f_score} elapsed={elapsed} s")
    
    with open("log.json", "r") as fh:
        log = json.load(fh)
    log["epochs"] += 1
    log[epochs+1] = {"d_loss": d_loss, "g_loss": g_loss, "r_score": r_score, "f_score": f_score, "elapsed": elapsed}
    
    with open("log.json", "w") as fh:
        json.dump(log, fh)
    
    save_image(fake_images, f'output/{epochs+1}.png')

  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #2: d_loss=0.14574, g_loss=5.39642, r_score=0.882, f_score=0.00624 elapsed=16830.96 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #3: d_loss=0.31925, g_loss=5.39696, r_score=0.74758, f_score=0.00563 elapsed=16730.28 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #4: d_loss=0.11097, g_loss=4.53547, r_score=0.92214, f_score=0.02386 elapsed=16637.74 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #5: d_loss=0.05343, g_loss=5.15192, r_score=0.95684, f_score=0.00772 elapsed=16648.47 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #6: d_loss=1.25529, g_loss=0.86844, r_score=0.99187, f_score=0.5579 elapsed=16633.59 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #7: d_loss=0.17687, g_loss=4.26376, r_score=0.85782, f_score=0.01514 elapsed=20220.14 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #8: d_loss=0.11768, g_loss=2.9622, r_score=0.94877, f_score=0.06111 elapsed=36104.61 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #9: d_loss=0.22829, g_loss=3.92042, r_score=0.83027, f_score=0.02208 elapsed=16373.28 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #10: d_loss=0.18367, g_loss=3.74004, r_score=0.86768, f_score=0.03178 elapsed=16395.93 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #11: d_loss=0.06465, g_loss=7.42611, r_score=0.94013, f_score=0.00174 elapsed=22015.74 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #12: d_loss=0.04757, g_loss=4.9611, r_score=0.96313, f_score=0.00961 elapsed=17094.22 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #13: d_loss=0.1957, g_loss=11.52215, r_score=0.83408, f_score=0.00013 elapsed=16666.67 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #14: d_loss=0.20518, g_loss=3.02202, r_score=0.8677, f_score=0.0547 elapsed=16590.43 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #15: d_loss=0.07689, g_loss=4.07848, r_score=0.94592, f_score=0.02045 elapsed=16601.97 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #16: d_loss=0.13673, g_loss=7.0118, r_score=0.87781, f_score=0.00117 elapsed=61700.25 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #17: d_loss=0.11142, g_loss=5.87, r_score=0.89909, f_score=0.00327 elapsed=16296.42 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #18: d_loss=0.13313, g_loss=6.15621, r_score=0.88101, f_score=0.00278 elapsed=16675.59 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #19: d_loss=0.11412, g_loss=7.42271, r_score=0.89396, f_score=0.00092 elapsed=17830.54 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #20: d_loss=0.18493, g_loss=4.3946, r_score=0.86958, f_score=0.01635 elapsed=17921.5 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #21: d_loss=0.1092, g_loss=5.15798, r_score=0.90436, f_score=0.0066 elapsed=17497.86 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #22: d_loss=0.09605, g_loss=5.57523, r_score=0.91334, f_score=0.00432 elapsed=16454.51 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #23: d_loss=0.09718, g_loss=4.89593, r_score=0.91647, f_score=0.00865 elapsed=16928.7 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #24: d_loss=0.11783, g_loss=5.31263, r_score=0.89589, f_score=0.00561 elapsed=19584.2 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #25: d_loss=0.09328, g_loss=12.53481, r_score=0.91258, f_score=0.0 elapsed=19099.09 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #26: d_loss=0.12785, g_loss=6.82708, r_score=0.88402, f_score=0.00129 elapsed=18944.99 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #27: d_loss=0.13145, g_loss=7.42654, r_score=0.97923, f_score=0.09823 elapsed=18313.39 s


  0%|          | 0/1438 [00:00<?, ?it/s]

Epoch #28: d_loss=0.07824, g_loss=3.25569, r_score=0.9644, f_score=0.04086 elapsed=19181.97 s


  0%|          | 0/1438 [00:00<?, ?it/s]