In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torch import optim
from torch.utils.data import DataLoader
from data.datasets.golden_panels import GoldenPanelsDataset
from networks.ssupervae_contextual_attentional import SSuperVAEContextualAttentional
from training.ssupervae_contextual_attn_trainer import SSuperVAEContextualAttentionalTrainer
from utils.config_utils import read_config, Config
from utils.plot_utils import *
from utils.logging_utils import *
from utils.image_utils import *
from configs.base_config import *
from functional.losses.elbo import elbo

In [3]:
initiate_logger()

In [4]:
def save_best_loss_model(model_name, model, best_loss):
    print('[INFO] Current best loss: ' + str(best_loss))
    torch.save(model, base_dir + 'playground/ssupervae_contextual_attention/weights/' + model_name + ".pth")


def train(data_loader,
          config,
          panel_dim,
          model_name='ssupervae_contextual_attention',
          cont_epoch=-1,
          cont_model=None):
    # loading config
    print("[INFO] Initiate training...")

    # creating model and training details
    net = SSuperVAEContextualAttentional(config.backbone,
                                         panel_img_size=panel_dim,
                                         latent_dim=config.latent_dim,
                                         embed_dim=config.embed_dim,
                                         seq_size=config.seq_size,
                                         decoder_channels=config.decoder_channels,
                                         gen_img_size=config.image_dim).to(ptu.device)

    criterion = elbo

    optimizer = optim.Adam(net.parameters(),
                           lr=config.lr,
                           betas=(config.beta_1, config.beta_2),
                           weight_decay=config.weight_decay)

    d_params = list(net.local_disc.parameters()) + list(net.global_disc.parameters())
    optimizer_disc = optim.Adam(d_params,
                                lr=config.lr,
                                betas=(config.beta_1, config.beta_2),
                                weight_decay=config.weight_decay)

    scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                            lambda epoch: (config.train_epochs - epoch) / config.train_epochs,
                                            last_epoch=-1)

    scheduler_disc = optim.lr_scheduler.LambdaLR(optimizer_disc,
                                                 lambda epoch: (config.train_epochs - epoch) / config.train_epochs,
                                                 last_epoch=-1)
    # init trainer
    trainer = SSuperVAEContextualAttentionalTrainer(model=net,
                                                    config_disc=config,
                                                    model_name=model_name,
                                                    criterion=criterion,
                                                    train_loader=data_loader,
                                                    test_loader=None,
                                                    epochs=config.train_epochs,
                                                    optimizer=optimizer,
                                                    optimizer_disc=optimizer_disc,
                                                    scheduler=scheduler,
                                                    scheduler_disc=scheduler_disc,
                                                    grad_clip=config.g_clip,
                                                    best_loss_action=lambda m, l: save_best_loss_model(model_name, m,
                                                                                                       l),
                                                    save_dir=base_dir + 'playground/ssupervae_contextual_attention/',
                                                    checkpoint_every_epoch=True
                                                    )

    if cont_epoch > -1:
        epoch, losses = trainer.load_checkpoint(epoch=cont_epoch)
    elif cont_model is not None:
        epoch, losses = trainer.load_checkpoint(alternative_chkpt_path=cont_model)
        print("[INFO] Continues from loaded model in epoch:", epoch)
        scheduler.step()
    else:
        epoch, losses = None, {}

    train_losses, test_losses = trainer.train_epochs(starting_epoch=epoch, losses=losses)

    print("[INFO] Completed training!")

    save_training_plot(train_losses['loss'],
                       test_losses['loss'],
                       "Plain_SSuperVAE Losses",
                       base_dir + 'playground/ssupervae_contextual_attention/' + f'results/ssupervae_plot.png'
                       )
    return net

In [5]:
ptu.set_gpu_mode(True)
config = read_config(Config.VAE_CONTEXT_ATTN)
golden_age_config = read_config(Config.GOLDEN_AGE)

panel_dim = golden_age_config.panel_dim[0]

cont_epoch = -1
cont_model = None  # "playground/ssupervae/weights/model-18.pth"
# TODO: move this to config
limit_size = -1

# data = RandomDataset((3, 3, 360, 360), (3, config.image_dim, config.image_dim))
data = GoldenPanelsDataset(golden_age_config.panel_path,
                           golden_age_config.sequence_path,
                           golden_age_config.panel_dim,
                           config.image_dim,
                           augment=False,
                           mask_val=golden_age_config.mask_val,
                           mask_all=golden_age_config.mask_all,
                           return_mask=golden_age_config.return_mask,
                           return_mask_coordinates=golden_age_config.return_mask_coordinates,
                           train_test_ratio=golden_age_config.train_test_ratio,
                           train_mode=golden_age_config.train_mode,
                           limit_size=limit_size)
data_loader = DataLoader(data, batch_size=config.batch_size, shuffle=True, num_workers=4)

In [None]:
model = train(data_loader,
              config,
              model_name=get_dt_string() + "_model",
              cont_epoch=cont_epoch,
              cont_model=cont_model,
              panel_dim=panel_dim)
torch.save(model, base_dir + 'playground/ssupervae_contextual_attention/results/' + "ssuper_vae_context_model.pth")

[INFO] Initiate training...
Loaded pretrained weights for efficientnet-b5


Epoch 0, loss 12571.9304, reconstruction_loss 12714.2855, kl_loss 11.3321, l1_fine 0.4600, wgan_g -154.1472, wgan_d -128.3612, wgan_gp 3.9481, d -88.8804: 100%|██████████| 18016/18016 [43:27<00:00,  6.91it/s] 


[INFO] Current best loss: 0


Epoch 1, loss 12522.3450, reconstruction_loss 12666.9658, kl_loss 20.8161, l1_fine 0.4586, wgan_g -165.8956, wgan_d -135.1522, wgan_gp 4.1447, d -93.7050: 100%|██████████| 18016/18016 [43:26<00:00,  6.91it/s]
Epoch 2, loss 12442.1849, reconstruction_loss 12553.7072, kl_loss 22.6358, l1_fine 0.4392, wgan_g -134.5972, wgan_d -125.9457, wgan_gp 3.7635, d -88.3103: 100%|██████████| 18016/18016 [43:27<00:00,  6.91it/s]
Epoch 3, loss 12424.9355, reconstruction_loss 12581.6878, kl_loss 28.0906, l1_fine 0.4426, wgan_g -185.2855, wgan_d -126.8436, wgan_gp 3.8968, d -87.8755: 100%|██████████| 18016/18016 [43:30<00:00,  6.90it/s]
Epoch 4, loss 12455.3959, reconstruction_loss 12583.1918, kl_loss 27.0573, l1_fine 0.4406, wgan_g -155.2936, wgan_d -129.8154, wgan_gp 3.9526, d -90.2890: 100%|██████████| 18016/18016 [43:34<00:00,  6.89it/s]
Epoch 5, loss 12389.4014, reconstruction_loss 12476.3221, kl_loss 27.9984, l1_fine 0.4239, wgan_g -115.3431, wgan_d -120.2096, wgan_gp 3.6209, d -84.0003:  29%|██▉ 

In [None]:
# experiment start time: may 14 - 18.04