## 0. Imports

In [8]:
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np
import torch
import time
import json
from tqdm.notebook import tqdm

from dataset import LoadDataset
from MSGModel import Generator, Discriminator, DownSampleImage
from trainMSG import train_generator, train_discriminator, evaluate
from settings import DEVICE, MEAN, STD, BATCH_SIZE, LATENT_SIZE, EPOCHS, TRAINNING_DS_SIZE, RELOAD_DS_PER_EPOCH, LOSS_FN, OPTIMIZER, BETAS, D_LR, D_LR_DECAY, G_LR, G_LR_DECAY, CORES, LOAD_MODELS

# torch.set_num_threads(4)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 1. Dataset

In [9]:
ds = LoadDataset('../data/train', MEAN, STD, length=TRAINNING_DS_SIZE)
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)

val_ds = LoadDataset('../data/val', MEAN, STD)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)

print(f"Dataset size: {len(ds)}, loader size: {len(loader)}")
print(f"Validation dataset size: {len(val_ds)}, validation loader size: {len(val_loader)}")

Dataset size: 92032, loader size: 5752
Validation dataset size: 192, validation loader size: 12


## 2. Models

In [10]:
G, D, downsampler = Generator(), Discriminator(), DownSampleImage()
# print(G)
# print(D)
print(f"Num parameters in Discriminator: {sum(p.numel() for p in D.parameters())}")
print(f"Num parameters in Generator: {sum(p.numel() for p in G.parameters())}")

Num parameters in Discriminator: 23137153
Num parameters in Generator: 23053781


## 3. Training

In [11]:
d_losses = []
g_losses = []
r_scores = []
f_scores = []
d_optimizer = OPTIMIZER(D.parameters(), D_LR, betas=BETAS)
g_optimizer = OPTIMIZER(G.parameters(), G_LR, betas=BETAS)

try:
    if LOAD_MODELS:
        D.load_state_dict(torch.load('discriminator.pth'))
        G.load_state_dict(torch.load('generator.pth'))
        print("Model loaded")
    else:
        print("Trainning new model")
except:
    print("Trainning new model")
    pass

D.to(DEVICE)
G.to(DEVICE)

if not LOAD_MODELS:
    with open("log.json", "w") as fh:
        json.dump({"epochs": 0}, fh)
    start_epochs = 0
else:
    with open("log.json", "r") as fh:
        log = json.load(fh)
    start_epochs = log['epochs']
    for e in range(1, start_epochs+1):
        d_losses.append(log[str(e)]['d_loss'])
        g_losses.append(log[str(e)]['g_loss'])
        r_scores.append(log[str(e)]['r_score'])
        f_scores.append(log[str(e)]['f_score'])

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY * start_epochs
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY * start_epochs
    
    latent = torch.randn(BATCH_SIZE, *LATENT_SIZE).to(DEVICE)
    fake_images = G(latent)[0] * 0.5 + 0.5
    save_image(fake_images, 'current.png')

print(f"Start at epochs {start_epochs}")

Trainning new model
Start at epochs 0


In [None]:
for epochs in range(start_epochs, EPOCHS):
    start = time.time()
    # Reload dataset
    if RELOAD_DS_PER_EPOCH:
        ds = LoadDataset('../data/train', MEAN, STD, length=TRAINNING_DS_SIZE)
        loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)
    

    for batch_images in tqdm(loader):
        batch_images = batch_images.to(DEVICE)
        d_loss = train_discriminator(G, D, downsampler, batch_images, LOSS_FN, d_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        g_loss = train_generator(G, D, LOSS_FN, g_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        
    d_loss, g_loss, r_score, f_score, fake_images = evaluate(G, D, downsampler, LOSS_FN, val_loader, LATENT_SIZE, BATCH_SIZE, DEVICE)
    d_losses.append(d_loss)
    g_losses.append(g_loss)
    r_scores.append(r_score)
    f_scores.append(f_score)
    
    torch.save(G.state_dict(), 'models/generator.pth')
    torch.save(D.state_dict(), 'models/discriminator.pth')
    
    d_loss = round(d_loss,5)
    g_loss = round(g_loss,5)
    r_score = round(r_score,5)
    f_score = round(f_score,5)

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY

    elapsed = round(time.time()-start, 2)
    print(f"Epoch #{epochs+1}: d_loss={d_loss}, g_loss={g_loss}, r_score={r_score}, f_score={f_score} elapsed={elapsed} s")
    
    with open("log.json", "r") as fh:
        log = json.load(fh)
    log["epochs"] += 1
    log[epochs+1] = {"d_loss": d_loss, "g_loss": g_loss, "r_score": r_score, "f_score": f_score, "elapsed": elapsed}
    
    with open("log.json", "w") as fh:
        json.dump(log, fh)
    
    save_image(fake_images, f'output/{epochs+1}.png')

  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #1: d_loss=0.90493, g_loss=1.23562, r_score=0.61404, f_score=0.30757 elapsed=4917.68 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #2: d_loss=0.52996, g_loss=2.26103, r_score=0.73851, f_score=0.16041 elapsed=4723.23 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #3: d_loss=0.51714, g_loss=2.61397, r_score=0.70404, f_score=0.10284 elapsed=4720.6 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #4: d_loss=0.81693, g_loss=1.22355, r_score=0.76473, f_score=0.3524 elapsed=4719.83 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #5: d_loss=1.12531, g_loss=0.88256, r_score=0.85052, f_score=0.55257 elapsed=4719.69 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #6: d_loss=0.4733, g_loss=4.38477, r_score=0.67755, f_score=0.02379 elapsed=4719.49 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #7: d_loss=0.41805, g_loss=5.38573, r_score=0.71475, f_score=0.00938 elapsed=4718.9 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #8: d_loss=0.43242, g_loss=1.82661, r_score=0.82836, f_score=0.20567 elapsed=4721.09 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #9: d_loss=0.64293, g_loss=2.65909, r_score=0.62707, f_score=0.12225 elapsed=4717.18 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #10: d_loss=0.57936, g_loss=4.22395, r_score=0.59272, f_score=0.03045 elapsed=4717.12 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #11: d_loss=0.44601, g_loss=2.36255, r_score=0.87635, f_score=0.20891 elapsed=4718.82 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #12: d_loss=0.37311, g_loss=2.66895, r_score=0.84434, f_score=0.13798 elapsed=4721.98 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #13: d_loss=0.15797, g_loss=3.89858, r_score=0.88927, f_score=0.0379 elapsed=4721.88 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #14: d_loss=0.48356, g_loss=4.23668, r_score=0.70253, f_score=0.03775 elapsed=4720.91 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #15: d_loss=0.2726, g_loss=4.56632, r_score=0.83431, f_score=0.07264 elapsed=4721.3 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #16: d_loss=0.10636, g_loss=7.8413, r_score=0.90049, f_score=0.00087 elapsed=4721.1 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #17: d_loss=0.15286, g_loss=7.15363, r_score=0.87086, f_score=0.00573 elapsed=4721.4 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #18: d_loss=0.22832, g_loss=3.70251, r_score=0.87577, f_score=0.08498 elapsed=4721.01 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #19: d_loss=0.18037, g_loss=4.31056, r_score=0.8862, f_score=0.05304 elapsed=4721.69 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #20: d_loss=0.45746, g_loss=5.78347, r_score=0.7395, f_score=0.02212 elapsed=4721.08 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #21: d_loss=0.31051, g_loss=7.83662, r_score=0.78708, f_score=0.00641 elapsed=4722.39 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #22: d_loss=0.19381, g_loss=3.86882, r_score=0.92045, f_score=0.09039 elapsed=4720.86 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #23: d_loss=0.13119, g_loss=6.27133, r_score=0.88281, f_score=0.00594 elapsed=4721.37 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #24: d_loss=0.53031, g_loss=2.44035, r_score=0.70831, f_score=0.1508 elapsed=4721.24 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #25: d_loss=0.25095, g_loss=6.34129, r_score=0.7983, f_score=0.01541 elapsed=4721.17 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #26: d_loss=0.25896, g_loss=9.00749, r_score=0.81259, f_score=0.00283 elapsed=4722.74 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #27: d_loss=0.12044, g_loss=8.19923, r_score=0.88832, f_score=0.00131 elapsed=4720.53 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #28: d_loss=0.14995, g_loss=5.66784, r_score=0.88807, f_score=0.02879 elapsed=4719.8 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #29: d_loss=0.10462, g_loss=6.95501, r_score=0.90525, f_score=0.00452 elapsed=4718.96 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #30: d_loss=0.10636, g_loss=8.76734, r_score=0.90026, f_score=0.00084 elapsed=4720.72 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #31: d_loss=0.18355, g_loss=12.15923, r_score=0.84293, f_score=2e-05 elapsed=4721.75 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #32: d_loss=0.11681, g_loss=7.24873, r_score=0.89437, f_score=0.00469 elapsed=4721.33 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #33: d_loss=0.16612, g_loss=13.71993, r_score=0.85082, f_score=0.0 elapsed=4721.03 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #34: d_loss=0.15211, g_loss=7.94121, r_score=0.8617, f_score=0.00119 elapsed=4720.96 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #35: d_loss=0.13, g_loss=6.06861, r_score=0.88469, f_score=0.00686 elapsed=4721.11 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #36: d_loss=0.11973, g_loss=8.07066, r_score=0.89048, f_score=0.00299 elapsed=4721.48 s


  0%|          | 0/5752 [00:00<?, ?it/s]

Epoch #37: d_loss=0.10175, g_loss=13.4245, r_score=0.90355, f_score=1e-05 elapsed=4721.18 s


  0%|          | 0/5752 [00:00<?, ?it/s]