## 0. Imports

In [91]:
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np
import torch
import time
import json
from tqdm.notebook import tqdm

from dataset import LoadDataset
from MSGModel import Generator, Discriminator, DownSampleImage
from trainMSG import train_generator, train_discriminator, evaluate
from settings import DEVICE, MEAN, STD, BATCH_SIZE, LATENT_SIZE, EPOCHS, LOSS_FN, OPTIMIZER, BETAS, D_LR, D_LR_DECAY, G_LR, G_LR_DECAY, CORES, LOAD_MODELS

torch.set_num_threads(4)

import autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## 1. Dataset

In [92]:
ds = LoadDataset('../data/train', MEAN, STD, length=8192)
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)
num_batches = len(ds)/BATCH_SIZE

val_ds = LoadDataset('../data/val', MEAN, STD)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)

len(ds), len(val_loader)

(8192, 12)

## 2. Models

In [93]:
G, D, downsampler = Generator(), Discriminator(), DownSampleImage()
# print(G)
# print(D)
print(f"Num parameters in Discriminator: {sum(p.numel() for p in D.parameters())}")
print(f"Num parameters in Generator: {sum(p.numel() for p in G.parameters())}")

Num parameters in Discriminator: 23137153
Num parameters in Generator: 23053781


## 3. Training

In [94]:
d_losses = []
g_losses = []
r_scores = []
f_scores = []
d_optimizer = OPTIMIZER(D.parameters(), D_LR, betas=BETAS)
g_optimizer = OPTIMIZER(G.parameters(), G_LR, betas=BETAS)

try:
    if LOAD_MODELS:
        D.load_state_dict(torch.load('discriminator.pth'))
        G.load_state_dict(torch.load('generator.pth'))
        print("Model loaded")
    else:
        print("Trainning new model")
except:
    print("Trainning new model")
    pass

D.to(DEVICE)
G.to(DEVICE)

if not LOAD_MODELS:
    with open("log.json", "w") as fh:
        json.dump({"epochs": 0}, fh)
    start_epochs = 0
else:
    with open("log.json", "r") as fh:
        log = json.load(fh)
    start_epochs = log['epochs']
    for e in range(1, start_epochs+1):
        d_losses.append(log[str(e)]['d_loss'])
        g_losses.append(log[str(e)]['g_loss'])
        r_scores.append(log[str(e)]['r_score'])
        f_scores.append(log[str(e)]['f_score'])

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY * start_epochs
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY * start_epochs
    
    latent = torch.randn(BATCH_SIZE, *LATENT_SIZE).to(DEVICE)
    fake_images = G(latent)[0] * 0.5 + 0.5
    save_image(fake_images, 'current.png')

print(f"Start at epochs {start_epochs}")

Trainning new model
Start at epochs 0


In [None]:
for epochs in range(start_epochs, EPOCHS):
    start = time.time()

    for batch_images in tqdm(loader):
        batch_images = batch_images.to(DEVICE)
        d_loss = train_discriminator(G, D, downsampler, batch_images, LOSS_FN, d_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        g_loss = train_generator(G, D, LOSS_FN, g_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        
    d_loss, g_loss, r_score, f_score, fake_images = evaluate(G, D, downsampler, LOSS_FN, val_loader, LATENT_SIZE, BATCH_SIZE, DEVICE)
    d_losses.append(d_loss)
    g_losses.append(g_loss)
    r_scores.append(r_score)
    f_scores.append(f_score)
    
    torch.save(G.state_dict(), 'models/generator.pth')
    torch.save(D.state_dict(), 'models/discriminator.pth')
    
    d_loss = round(d_loss,5)
    g_loss = round(g_loss,5)
    r_score = round(r_score,5)
    f_score = round(f_score,5)

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY

    elapsed = round(time.time()-start, 2)
    print(f"Epoch #{epochs+1}: d_loss={d_loss}, g_loss={g_loss}, r_score={r_score}, f_score={f_score} elapsed={elapsed} s")
    
    with open("log.json", "r") as fh:
        log = json.load(fh)
    log["epochs"] += 1
    log[epochs+1] = {"d_loss": d_loss, "g_loss": g_loss, "r_score": r_score, "f_score": f_score, "elapsed": elapsed}
    
    with open("log.json", "w") as fh:
        json.dump(log, fh)
    
    save_image(fake_images, f'output/{epochs+1}.png')

  0%|          | 0/512 [00:00<?, ?it/s]