In [3]:
import atexit
import gc
import math
import os
import sys
from argparse import ArgumentParser

import numpy as np
import torch
import wandb

from src import train_utils
from src import viz
from src.dataloader import train_dataloader, val_dataloader
from src.models import PJPE, weight_init, Critic
from src.trainer import training_epoch, validation_epoch
from src.callbacks import CallbackList, ModelCheckpoint, Logging, BetaScheduler, Analyze, MaxNorm
from src.train import training_specific_args
import sys; sys.argv=['']; del sys
# Experiment Configuration, Config, is distributed to all the other modules
parser = training_specific_args()
config = parser.parse_args()
torch.manual_seed(config.seed)
np.random.seed(config.seed)

# GPU setup
use_cuda = config.cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
config.device = device  # Adding device to config, not already in argparse
config.num_workers = 4 if use_cuda else 4  # for dataloader

# ignore when debugging on cpu
if not use_cuda:
    # os.environ['WANDB_MODE'] = 'dryrun'  # Doesnt auto sync to project
    os.environ['WANDB_TAGS'] = 'CPU'
    wandb.init(anonymous='allow', project="hpe3d", config=config)  # to_delete
else:
    # os.environ['WANDB_MODE'] = 'dryrun'
    wandb.init(anonymous='allow', project="hpe3d", config=config)

config.logger = wandb
config.logger.run.save()
config.run_name = config.logger.run.name  # handle name change in wandb
# Data loading
config.train_subjects = [9, 11]
train_loader = train_dataloader(config)
config.val_subjects = [9, 11]
val_loader = val_dataloader(config)

variant = [['2d', '3d']]

models = train_utils.get_models(variant, config)  # model instances
if config.self_supervised:
    critic = Critic()
    models['Critic'] = critic
optimizers = train_utils.get_optims(variant, models, config)  # optimer for each pair
schedulers = train_utils.get_schedulers(optimizers)

# For multiple GPUs
if torch.cuda.device_count() > 1:
    print(f'[INFO]: Using {torch.cuda.device_count()} GPUs')
    for key in models.keys():
        models[key] = torch.nn.DataParallel(models[key])

# To CPU or GPU or TODO TPU
for key in models.keys():
    models[key] = models[key].to(device)
    # models[key].apply(weight_init)

config.mpjpe_min=float('inf')
config.mpjpe_at_min_val=float('inf')

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
wandb: Wandb version 0.9.6 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
[INFO]: Training data loader called
[INFO]: processing subjects: [9, 11]
samples - 109867
[INFO]: Validation data loader called
[INFO]: processing subjects: [9, 11]
samples - 109867


In [5]:
config.resume_run = "colorful-planet-2357"
# initiate all required callbacks, keep the order in mind!!!
cb = CallbackList([ModelCheckpoint(),
                    Logging(),                       
                    BetaScheduler(config, strategy="cycling"),
                    Analyze(500)])

cb.setup(config = config, models = models, optimizers = optimizers,
            train_loader = train_loader, val_loader = val_loader, variant = variant)

for epoch in range(1):
    for n_pair, pair in enumerate(variant):
        vae_type="_2_".join(pair)
        # model -- encoder, decoder / critic
        model=[models[f"Encoder{pair[0].upper()}"],
                    models[f"Decoder{pair[1].upper()}"]]
        optimizer = [optimizers[n_pair]]
        scheduler = [schedulers[n_pair]]

        if config.self_supervised:
            model.append(models['Critic'])
            optimizer.append(optimizers[-1])
            scheduler.append(schedulers[-1])

        val_loss = validation_epoch(
                config, cb, model, val_loader, epoch, vae_type)


        cb.on_epoch_end(config=config, val_loss=val_loss, model=model,
                        n_pair=n_pair, optimizers=optimizers, epoch=epoch)



[INFO] Loaded Checkpoint colorful-planet-2357: Encoder2D @ epoch 1


RuntimeError: Error(s) in loading state_dict for Encoder2D:
	Missing key(s) in state_dict: "LBAD_3.w1.weight", "LBAD_3.bn1.weight", "LBAD_3.bn1.bias", "LBAD_3.bn1.running_mean", "LBAD_3.bn1.running_var", "LBAD_4.w1.weight", "LBAD_4.bn1.weight", "LBAD_4.bn1.bias", "LBAD_4.bn1.running_mean", "LBAD_4.bn1.running_var". 
	size mismatch for enc_inp_block.0.weight: copying a param with shape torch.Size([1024, 32]) from checkpoint, the shape in current model is torch.Size([512, 32]).
	size mismatch for enc_inp_block.0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for enc_inp_block.1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for enc_inp_block.1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for enc_inp_block.1.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for enc_inp_block.1.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_1.w1.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for LBAD_1.bn1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_1.bn1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_1.bn1.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_1.bn1.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_2.w1.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for LBAD_2.bn1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_2.bn1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_2.bn1.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for LBAD_2.bn1.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for fc_mean.weight: copying a param with shape torch.Size([51, 1024]) from checkpoint, the shape in current model is torch.Size([100, 512]).
	size mismatch for fc_mean.bias: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([100]).
	size mismatch for fc_logvar.weight: copying a param with shape torch.Size([51, 1024]) from checkpoint, the shape in current model is torch.Size([100, 512]).
	size mismatch for fc_logvar.bias: copying a param with shape torch.Size([51]) from checkpoint, the shape in current model is torch.Size([100]).