## 0. Imports

In [1]:
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np
import torch
import time
import json
from tqdm.notebook import tqdm

from dataset import LoadDataset
from MSGModel import Generator, Discriminator, DownSampleImage
from trainMSG import train_generator, train_discriminator, evaluate
from settings import DEVICE, MEAN, STD, BATCH_SIZE, LATENT_SIZE, EPOCHS, TRAINNING_DS_SIZE, RELOAD_DS_PER_EPOCH, LOSS_FN, OPTIMIZER, BETAS, D_LR, D_LR_DECAY, G_LR, G_LR_DECAY, CORES, LOAD_MODELS

# torch.set_num_threads(4)

%load_ext autoreload
%autoreload 2

## 1. Dataset

In [2]:
ds = LoadDataset('../data/train', MEAN, STD, length=TRAINNING_DS_SIZE)
loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)

val_ds = LoadDataset('../data/val', MEAN, STD)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)

print(f"Dataset size: {len(ds)}, loader size: {len(loader)}")
print(f"Validation dataset size: {len(val_ds)}, validation loader size: {len(val_loader)}")

Dataset size: 92032, loader size: 11504
Validation dataset size: 192, validation loader size: 24


## 2. Models

In [3]:
G, D, downsampler = Generator(), Discriminator(), DownSampleImage()
# print(G)
# print(D)
print(f"Num parameters in Discriminator: {sum(p.numel() for p in D.parameters())}")
print(f"Num parameters in Generator: {sum(p.numel() for p in G.parameters())}")

Num parameters in Discriminator: 23137153
Num parameters in Generator: 23053781


## 3. Training

In [4]:
d_losses = []
g_losses = []
r_scores = []
f_scores = []
d_optimizer = OPTIMIZER(D.parameters(), D_LR, betas=BETAS)
g_optimizer = OPTIMIZER(G.parameters(), G_LR, betas=BETAS)

try:
    if LOAD_MODELS:
        D.load_state_dict(torch.load('models/discriminator.pth'))
        G.load_state_dict(torch.load('models/generator.pth'))
        print("Model loaded")
    else:
        print("Trainning new model")
except:
    LOAD_MODEL = False
    print("Trainning new model")
    pass

D.to(DEVICE)
G.to(DEVICE)

if not LOAD_MODELS:
    with open("log.json", "w") as fh:
        json.dump({"epochs": 0}, fh)
    start_epochs = 0
else:
    with open("log.json", "r") as fh:
        log = json.load(fh)
    start_epochs = log['epochs']
    for e in range(1, start_epochs+1):
        d_losses.append(log[str(e)]['d_loss'])
        g_losses.append(log[str(e)]['g_loss'])
        r_scores.append(log[str(e)]['r_score'])
        f_scores.append(log[str(e)]['f_score'])

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY * start_epochs
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY * start_epochs
    
    latent = torch.randn(BATCH_SIZE, *LATENT_SIZE).to(DEVICE)
    fake_images = G(latent)[0] * 0.5 + 0.5
    save_image(fake_images, 'current.png')

print(f"Start at epochs {start_epochs}")

Model loaded
Start at epochs 39


In [5]:
for epochs in range(start_epochs, EPOCHS):
    start = time.time()
    # Reload dataset
    if RELOAD_DS_PER_EPOCH:
        ds = LoadDataset('../data/train', MEAN, STD, length=TRAINNING_DS_SIZE)
        loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=CORES)
    

    for batch_images in tqdm(loader):
        batch_images = batch_images.to(DEVICE)
        d_loss = train_discriminator(G, D, downsampler, batch_images, LOSS_FN, d_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        g_loss = train_generator(G, D, LOSS_FN, g_optimizer, LATENT_SIZE, BATCH_SIZE, DEVICE)
        
    d_loss, g_loss, r_score, f_score, fake_images = evaluate(G, D, downsampler, LOSS_FN, val_loader, LATENT_SIZE, BATCH_SIZE, DEVICE)
    d_losses.append(d_loss)
    g_losses.append(g_loss)
    r_scores.append(r_score)
    f_scores.append(f_score)
    
    torch.save(G.state_dict(), 'models/generator.pth')
    torch.save(D.state_dict(), 'models/discriminator.pth')
    
    d_loss = round(d_loss,5)
    g_loss = round(g_loss,5)
    r_score = round(r_score,5)
    f_score = round(f_score,5)

    d_optimizer.param_groups[0]['lr'] -= D_LR_DECAY
    g_optimizer.param_groups[0]['lr'] -= G_LR_DECAY

    elapsed = round(time.time()-start, 2)
    print(f"Epoch #{epochs+1}: d_loss={d_loss}, g_loss={g_loss}, r_score={r_score}, f_score={f_score} elapsed={elapsed} s")
    
    with open("log.json", "r") as fh:
        log = json.load(fh)
    log["epochs"] += 1
    log[epochs+1] = {"d_loss": d_loss, "g_loss": g_loss, "r_score": r_score, "f_score": f_score, "elapsed": elapsed}
    
    with open("log.json", "w") as fh:
        json.dump(log, fh)
    
    save_image(fake_images, f'output/{epochs+1}.png')

  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #40: d_loss=0.71647, g_loss=2.69979, r_score=0.67862, f_score=0.12022 elapsed=4998.07 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #41: d_loss=0.13717, g_loss=8.22196, r_score=0.87877, f_score=0.00144 elapsed=4985.12 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #42: d_loss=0.20721, g_loss=6.27129, r_score=0.82446, f_score=0.01245 elapsed=4973.17 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #43: d_loss=0.11716, g_loss=5.03163, r_score=0.90849, f_score=0.02072 elapsed=4973.17 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #44: d_loss=0.45308, g_loss=3.58495, r_score=0.81205, f_score=0.16802 elapsed=4968.67 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #45: d_loss=0.13566, g_loss=7.96275, r_score=0.87735, f_score=0.00448 elapsed=4974.47 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #46: d_loss=0.10742, g_loss=15.19667, r_score=0.89825, f_score=0.0 elapsed=4972.24 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #47: d_loss=0.10788, g_loss=10.7719, r_score=0.89789, f_score=4e-05 elapsed=4976.21 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #48: d_loss=0.17486, g_loss=7.24793, r_score=0.84455, f_score=0.00466 elapsed=4966.53 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #49: d_loss=0.19596, g_loss=15.64569, r_score=0.82478, f_score=1e-05 elapsed=4963.1 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #50: d_loss=0.21178, g_loss=7.89322, r_score=0.82858, f_score=0.00465 elapsed=4969.67 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #51: d_loss=0.14009, g_loss=6.43774, r_score=0.88645, f_score=0.01865 elapsed=4973.59 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #52: d_loss=0.10502, g_loss=12.52121, r_score=0.90099, f_score=0.00039 elapsed=4989.46 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #53: d_loss=0.10091, g_loss=15.93653, r_score=0.90418, f_score=1e-05 elapsed=4991.46 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #54: d_loss=0.11572, g_loss=6.69797, r_score=0.89538, f_score=0.00466 elapsed=4994.38 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #55: d_loss=0.11542, g_loss=8.6412, r_score=0.89181, f_score=0.00084 elapsed=4994.66 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #56: d_loss=0.10749, g_loss=10.29727, r_score=0.89833, f_score=0.0001 elapsed=4993.5 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #57: d_loss=0.11004, g_loss=11.82135, r_score=0.89593, f_score=2e-05 elapsed=4991.3 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #58: d_loss=0.09879, g_loss=8.61689, r_score=0.90653, f_score=0.00052 elapsed=4991.13 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #59: d_loss=0.09512, g_loss=7.8976, r_score=0.91215, f_score=0.00293 elapsed=4983.7 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #60: d_loss=0.13617, g_loss=10.63563, r_score=0.87309, f_score=8e-05 elapsed=4975.28 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #61: d_loss=0.1044, g_loss=11.71237, r_score=0.90146, f_score=0.00029 elapsed=4965.45 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #62: d_loss=0.10496, g_loss=5.56778, r_score=0.91115, f_score=0.01149 elapsed=4975.96 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #63: d_loss=0.13672, g_loss=13.18072, r_score=0.87295, f_score=0.0001 elapsed=4971.32 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #64: d_loss=0.11238, g_loss=12.72839, r_score=0.89486, f_score=0.00083 elapsed=4961.87 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #65: d_loss=0.11833, g_loss=7.24915, r_score=0.89103, f_score=0.00271 elapsed=4965.69 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #66: d_loss=0.06948, g_loss=12.62508, r_score=0.93301, f_score=1e-05 elapsed=4961.96 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #67: d_loss=0.13792, g_loss=9.0433, r_score=0.87235, f_score=0.00047 elapsed=4967.16 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #68: d_loss=1.20351, g_loss=1.76654, r_score=0.66188, f_score=0.49299 elapsed=4963.79 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #69: d_loss=0.10182, g_loss=20.10543, r_score=0.90347, f_score=0.0 elapsed=4960.2 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #70: d_loss=0.11854, g_loss=7.24891, r_score=0.89, f_score=0.00153 elapsed=4959.62 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #71: d_loss=0.33986, g_loss=8.98604, r_score=0.73745, f_score=0.00032 elapsed=4964.71 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #72: d_loss=0.32175, g_loss=3.27144, r_score=0.83202, f_score=0.11682 elapsed=4969.8 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Epoch #73: d_loss=0.30318, g_loss=10.8112, r_score=0.81948, f_score=7e-05 elapsed=4968.3 s


  0%|          | 0/11504 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x000001DBADFA0E50>
Traceback (most recent call last):
  File "C:\Users\Jerry\miniconda3\envs\jupyter\lib\site-packages\torch\utils\data\dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "C:\Users\Jerry\miniconda3\envs\jupyter\lib\site-packages\torch\utils\data\dataloader.py", line 1474, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "C:\Users\Jerry\miniconda3\envs\jupyter\lib\multiprocessing\process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "C:\Users\Jerry\miniconda3\envs\jupyter\lib\multiprocessing\popen_spawn_win32.py", line 108, in wait
    res = _winapi.WaitForSingleObject(int(self._handle), msecs)
KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
latent = 