In [1]:
import argparse
import yaml

import os
import random
import logging

import numpy as np
import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from omegaconf import OmegaConf

from src.utils.export_model import ModelExport
from src.utils.tensorboard import TensorBoardLoggerWithMetrics
from src.utils.model_factory import ModelFactory
from src.utils.options import BaseOptions
from src.utils.versioning import get_git_diff

from hydra.experimental import compose, initialize
from sklearn.model_selection import ParameterGrid
from src.utils.checkpointing import set_latest_checkpoint
from src.data.sequence_module import AlternateSequenceDataModule

# register models
import src.models

In [2]:
filepath = 'src/configs/transformer.yaml'

with open(filepath) as f:
    experiment_cfg = yaml.load(f, Loader=yaml.SafeLoader)
    
    
config_path = "src/configs"
initialize(config_path=config_path)

base_config = experiment_cfg["base_config"]
experiment_params = experiment_cfg["parameters"]
for k in experiment_params:
    if not isinstance(experiment_params[k], list):
        experiment_params[k] = [experiment_params[k]]

param_grid = ParameterGrid(experiment_params)

See https://github.com/omry/omegaconf/issues/426 for migration instructions.

  OmegaConf.register_resolver(name, f)


In [3]:
param_set = param_grid[0]
param_overrides = []

for k in param_set:
    param_overrides.append(k + "=" + str(param_set[k]))

cfg = compose(base_config + ".yaml", overrides=param_overrides)
set_latest_checkpoint(cfg)

cfg = OmegaConf.to_container(cfg.model, resolve=True)
cfg = OmegaConf.create(cfg)
OmegaConf.set_struct(cfg, True)

In [4]:
dm = instantiate(cfg.dataset, batch_size=8)
dm.setup()

  if OmegaConf.is_none(config):
  if OmegaConf.is_none(config):


In [5]:
model = ModelFactory.instantiate(cfg, skeleton=dm.get_skeleton())

In [6]:
from src.data.datasets import SplitFileDatabaseLoader
from src.data.typed_table import TypedColumnSequenceDataset
from src.geometry.skeleton import Skeleton
from src.data.batched_sequence_dataset import LafanSequenceDataset
from src.evaluation.l2q_error import L2Q
from src.evaluation.l2p_error import L2P
from src.evaluation.npss_error import NPSS
from src.data.frame_sampler import MiddleFramesRemover
from src.data.augmentation import BatchRemoveQuatDiscontinuities, BatchYRotateOnFrame, \
    BatchCenterXZ
from tqdm.auto import tqdm
import torch

split = SplitFileDatabaseLoader('../datasets').pull('deeppose_lafan_v1_fps30')
lafan_val_raw = TypedColumnSequenceDataset(split, subset="Validation")

lafan_val_raw.remove_short_sequences(65)
lafan_val_raw.format_as_sliding_windows(65, 40)

skeleton_data = lafan_val_raw.config["skeleton"]
skeleton = Skeleton(skeleton_data)
skeleton.remove_joints(['LeftToeEnd', 'RightToeEnd', 'LeftHandEnd', 'RightHandEnd', 'HeadEnd'])

validation_dataset = LafanSequenceDataset(source=lafan_val_raw, skeleton=skeleton,
                                           batch_size=1,  shuffle=False,  drop_last=False,
                                           seed=0, min_length=65, max_length=65)

frame_sampler = MiddleFramesRemover(past_context=10, future_context=1, middle_frames=30)

In [7]:
torch.set_num_threads(12)

for i in tqdm(range(len(validation_dataset))):
    b = validation_dataset[i]
    
    past_frames, future_frames, target_frames = model.get_data_from_batch(b, frame_sampler=frame_sampler)
    target_data, predicted = model.forward_wrapped(past_frames, future_frames, target_frames)
    

HBox(children=(FloatProgress(value=0.0, max=2232.0), HTML(value='')))


