In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import json
from copy import deepcopy

from torch import optim
from torch.utils.data import DataLoader

from data.datasets.random_dataset import RandomDataset
from data.datasets.golden_panels import GoldenPanelsDataset

from networks.plain_ssupervae import PlainSSuperVAE
from networks.ssupervae_contextual_attentional import SSuperVAEContextualAttentional
from training.ssupervae_contextual_attn_trainer import SSuperVAEContextualAttentionalTrainer
from training.vae_trainer import VAETrainer
from utils.config_utils import read_config, Config
from utils.plot_utils import *
from utils.logging_utils import *
from utils import pytorch_util as ptu
from utils.image_utils import *
from configs.base_config import *
from functional.losses.elbo import elbo

In [3]:
initiate_logger()

In [4]:
def save_best_loss_model(model_name, model, best_loss):
    print('[INFO] Current best loss: ' + str(best_loss))
    torch.save(model, base_dir + 'playground/ssupervae/weights/' + model_name + ".pth")


def train(data_loader,
          config,
          panel_dim,
          model_name='plain_ssupervae',
          cont_epoch=-1,
          cont_model=None):
    # loading config
    print("[INFO] Initiate training...")

    # creating model and training details
    net = SSuperVAEContextualAttentional(config.backbone,
                                         panel_img_size=panel_dim,
                                         latent_dim=config.latent_dim,
                                         embed_dim=config.embed_dim,
                                         seq_size=config.seq_size,
                                         decoder_channels=config.decoder_channels,
                                         gen_img_size=config.image_dim).to(ptu.device)

    criterion = elbo

    optimizer = optim.Adam(net.parameters(),
                           lr=config.lr,
                           betas=(config.beta_1, config.beta_2),
                           weight_decay=config.weight_decay)

    d_params = list(net.local_disc.parameters()) + list(net.global_disc.parameters())
    optimizer_disc = optim.Adam(d_params,
                                lr=config.lr,
                                betas=(config.beta_1, config.beta_2),
                                weight_decay=config.weight_decay)

    scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                            lambda epoch: (config.train_epochs - epoch) / config.train_epochs,
                                            last_epoch=-1)

    scheduler_disc = optim.lr_scheduler.LambdaLR(optimizer_disc,
                                            lambda epoch: (config.train_epochs - epoch) / config.train_epochs,
                                            last_epoch=-1)
    # init trainer
    trainer = SSuperVAEContextualAttentionalTrainer(model=net,
                                                    config_disc=config,
                                                    model_name=model_name,
                                                    criterion=criterion,
                                                    train_loader=data_loader,
                                                    test_loader=None,
                                                    epochs=config.train_epochs,
                                                    optimizer=optimizer,
                                                    optimizer_disc=optimizer_disc,
                                                    scheduler=scheduler,
                                                    scheduler_disc=scheduler_disc,
                                                    grad_clip=config.g_clip,
                                                    best_loss_action=lambda m, l: save_best_loss_model(model_name, m,
                                                                                                       l),
                                                    save_dir=base_dir + 'playground/ssupervae/',
                                                    checkpoint_every_epoch=True
                                                    )

    if cont_epoch > -1:
        epoch, losses = trainer.load_checkpoint(epoch=cont_epoch)
    elif cont_model is not None:
        epoch, losses = trainer.load_checkpoint(alternative_chkpt_path=cont_model)
        print("[INFO] Continues from loaded model in epoch:", epoch)
        scheduler.step()
    else:
        epoch, losses = None, {}

    train_losses, test_losses = trainer.train_epochs(starting_epoch=epoch, losses=losses)

    print("[INFO] Completed training!")

    save_training_plot(train_losses['loss'],
                       test_losses['loss'],
                       "Plain_SSuperVAE Losses",
                       base_dir + 'playground/supervae/' + f'results/ssupervae_plot.png'
                       )
    return net

In [5]:
ptu.set_gpu_mode(True)
config = read_config(Config.VAE_CONTEXT_ATTN)
golden_age_config = read_config(Config.GOLDEN_AGE)

panel_dim = golden_age_config.panel_dim[0]

cont_epoch = -1
cont_model = None  # "playground/ssupervae/weights/model-18.pth"
# TODO: move this to config
limit_size = -1

# data = RandomDataset((3, 3, 360, 360), (3, config.image_dim, config.image_dim))
data = GoldenPanelsDataset(golden_age_config.panel_path,
                           golden_age_config.sequence_path,
                           golden_age_config.panel_dim,
                           config.image_dim,
                           augment=False,
                           mask_val=golden_age_config.mask_val,
                           mask_all=golden_age_config.mask_all,
                           return_mask=golden_age_config.return_mask,
                           train_test_ratio=golden_age_config.train_test_ratio,
                           train_mode=True,
                           limit_size=limit_size)
data_loader = DataLoader(data, batch_size=config.batch_size, shuffle=True, num_workers=4)

In [None]:
model = train(data_loader,
              config,
              model_name=get_dt_string() + "_model",
              cont_epoch=cont_epoch,
              cont_model=cont_model,
              panel_dim=panel_dim)
torch.save(model, base_dir + 'playground/ssupervae/results/' + "ssuper_vae_context_model.pth")

[INFO] Initiate training...
Loaded pretrained weights for efficientnet-b5


Epoch 0, loss 17420.7769, reconstruction_loss 18029.7846, kl_loss 39.2964, l1_fine 1.0234, wgan_g -649.3275, wgan_d -481.1501, wgan_gp 18.5179, d -295.9715: 100%|██████████| 18016/18016 [42:39<00:00,  7.04it/s]


[INFO] Current best loss: 0


Epoch 1, loss 17648.5780, reconstruction_loss 17877.0423, kl_loss 50.4639, l1_fine 1.0324, wgan_g -279.9607, wgan_d -531.4763, wgan_gp 20.4071, d -327.4054: 100%|██████████| 18016/18016 [42:39<00:00,  7.04it/s]
Epoch 2, loss 17444.7711, reconstruction_loss 17589.7201, kl_loss 66.2533, l1_fine 0.9996, wgan_g -212.2020, wgan_d -535.4358, wgan_gp 20.4809, d -330.6265: 100%|██████████| 18016/18016 [43:04<00:00,  6.97it/s]
Epoch 3, loss 17090.4572, reconstruction_loss 17252.7519, kl_loss 69.4212, l1_fine 0.9778, wgan_g -232.6935, wgan_d -512.8566, wgan_gp 19.6407, d -316.4497: 100%|██████████| 18016/18016 [42:51<00:00,  7.01it/s]
Epoch 4, loss 16996.2053, reconstruction_loss 17197.4961, kl_loss 74.1443, l1_fine 0.9797, wgan_g -276.4148, wgan_d -493.3303, wgan_gp 18.8149, d -305.1816: 100%|██████████| 18016/18016 [43:21<00:00,  6.92it/s]
Epoch 5, loss 17092.5795, reconstruction_loss 17322.9721, kl_loss 82.8913, l1_fine 0.9929, wgan_g -314.2768, wgan_d -542.6121, wgan_gp 20.5370, d -337.2419:

In [None]:
# experiment start time: may 14 - 18.04