In [3]:
import numpy as np
import os
import torch
from torch.utils.data import DataLoader
import wandb

from camera import world_to_camera, normalize_screen_coordinates
from humaneva_dataset import HumanEvaDataset
from loss import mpjpe
from model import FrameModel
from run import run
from preprocessed_dataset import PreprocessedDataset
from main import Args, fetch

In [5]:
args = Args()
he_dataset = HumanEvaDataset(args.dataset_path)

# convert 3D pose world coordinates to camera coordinates
for subject in he_dataset.subjects():
    for action in he_dataset[subject].keys():
        anim = he_dataset[subject][action]
        if 'positions' in anim:
            positions_3d = []
            for cam in anim['cameras']:
                pos_3d = world_to_camera(anim['positions'], R=cam['orientation'], t=cam['translation'])
                pos_3d[:, 1:] -= pos_3d[:, :1] # Remove global offset, but keep trajectory in first position
                positions_3d.append(pos_3d)
            anim['positions_3d'] = positions_3d

# get 2D keypoints
keypoints = np.load(args.dataset_2d_path, allow_pickle=True)
keypoints_metadata = keypoints['metadata'].item()
keypoints_symmetry = keypoints_metadata['keypoints_symmetry']
kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
joints_left, joints_right = list(he_dataset.skeleton().joints_left()), list(he_dataset.skeleton().joints_right())
keypoints = keypoints['positions_2d'].item() 

# convert 2D pose world coordinates to screen coordinates
for subject in keypoints.keys():
    for action in keypoints[subject]:
        for cam_idx, kps in enumerate(keypoints[subject][action]):
            # Normalize camera frame
            cam = he_dataset.cameras()[subject][cam_idx]
            kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h'])
            keypoints[subject][action][cam_idx] = kps

poses_train_3d, poses_train_2d, cameras_train = fetch(args.subjects_train, keypoints, he_dataset, args.actions_train)
poses_val_3d, poses_val_2d, cameras_val = fetch(args.subjects_val, keypoints, he_dataset, args.actions_val)

train_dataset = PreprocessedDataset(poses_train_2d, poses_train_3d, cameras_train, 
                                    keypoints_metadata, he_dataset.skeleton(), he_dataset.fps())
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0, shuffle=True)
val_dataset = PreprocessedDataset(poses_val_2d, poses_val_3d, cameras_val,
                                  keypoints_metadata, he_dataset.skeleton(), he_dataset.fps())
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=0, shuffle=False)