In [1]:
import os
os.chdir('../../..')

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from ai import cs
import torch

from databases.joint_sets import MuPoTSJoints
from databases.datasets import PersonStackedMuPoTsDataset, Mpi3dTestDataset
from util.misc import load
from util.viz import *
from util.pose import remove_root
from training.torch_tools import *
from training.preprocess import get_postprocessor, SaveableCompose, MeanNormalize3D
from training.loaders import UnchunkedGenerator
from scripts.eval import load_model

In [3]:
def joint2bone(nd):
    cj = get_cjs()
    return nd[:, cj[:, 0], :] - nd[:, cj[:, 1], :]

def bone2joint(pred_bx, pred_by, pred_bz, root):
    cj = get_cjs()
    cj_index = [2, 1, 0, 5, 4, 3, 9, 8, 12, 11, 10, 15, 14, 13, 7, 6]
    ordered_cj = cj[cj_index, :]
    pred_bxyz = np.stack((pred_bx, pred_by, pred_bz), axis=-1)
    res = np.zeros((root.shape[0], 17, 3))
    res[:, 14, :] = root
    for (a, b), i in zip(ordered_cj, cj_index):
        res[:, a, :] = res[:, b, :] + pred_bxyz[:, i, :]
    return res

def get_cjs():
    connected_joints = MuPoTSJoints().LIMBGRAPH
    return np.array(connected_joints)

def get_rtp(nd):
    diff = joint2bone(nd)
    r, t, p = cs.cart2sp(diff[:, :, 0], diff[:, :, 1], diff[:, :, 2])
    return r, t, p

def get_lengths(nd):
    r, _, _ = get_rtp(nd)
    return r

def get_xyz(r, t, p, root):    
    pred_bx, pred_by, pred_bz = cs.sp2cart(r, t, p)
    return bone2joint(pred_bx, pred_by, pred_bz, root)

In [4]:
model_dir = '../models/4b1006aa968a47139217c9e7ac31e52f/'
config, model = load_model(model_dir)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
def get_dataset(config):
    data = PersonStackedMuPoTsDataset(
        config["pose2d_type"],
        config.get("pose3d_scaling", "normal"),
        pose_validity="all",
    )
#     data = Mpi3dTestDataset(
#         config["pose2d_type"],
#         config.get("pose3d_scaling", "normal"),
#         eval_frames_only=True,
#     )
    return data

In [None]:
dataset = get_dataset(config)

In [None]:
params_path = f"{model_dir}/preprocess_params.pkl"
transform = SaveableCompose.from_file(params_path, dataset, globals())
dataset.transform = transform

assert isinstance(transform.transforms[1].normalizer, MeanNormalize3D)
normalizer3d = transform.transforms[1].normalizer

post_process_func = get_postprocessor(config, dataset, normalizer3d)

In [None]:
augment = True
pad = (model.receptive_field() - 1) // 2
generator = UnchunkedGenerator(dataset, pad, augment)
seqs = sorted(np.unique(dataset.index.seq))

data_3d_mm = {}
preprocessed3d = {}
for seq in seqs:
    inds = np.where(dataset.index.seq == seq)[0]
    batch = dataset.get_samples(inds, False)
    preprocessed3d[seq] = batch["pose3d"][batch["valid_pose"]]
    data_3d_mm[seq] = dataset.poses3d[inds][batch["valid_pose"]]
#     break

bl = {}
root = {}
org_pose3d = {}
for seq in seqs:
    inds = np.where(dataset.index.seq == seq)[0]
    batch = dataset.get_samples(inds, False)
    bl[seq] = batch["length"][batch["valid_pose"]]
    root[seq] = batch["root"][batch["valid_pose"]]
    org_pose3d[seq] = batch["org_pose3d"][batch["valid_pose"]]
#     break

In [None]:
_dgt = {}
seqs = sorted(np.unique(dataset.index.seq))
for seq in seqs:
    inds = np.where(dataset.index.seq == seq)[0]
    batch = dataset.get_samples(inds, False)
    mgt =  dataset.poses3d[inds][batch["valid_pose"]]
    _dgt[seq] = mgt

In [None]:
_dpred = {}
raw_preds = {}
losses = {}
with torch.no_grad():
    for i, (pose2d, valid) in enumerate(generator):
        seq = seqs[i]
        pred3d = (
            model(torch.from_numpy(pose2d).cuda()).detach().cpu().numpy()
        )
        raw_preds[seq] = pred3d.copy()  # .cpu().numpy()

        valid = valid[0]
#         pred_bo_np = pred3d[0][valid].reshape([-1, 2, 16])
#         if orient_norm is None:
#             pass
#         elif orient_norm == "_1_1":
#             pred_bo_np *= np.pi
#         elif orient_norm == "0_1":
#             pred_bo_np = (pred_bo_np * 2 * np.pi) - np.pi
#         else:
#             raise Exception(
#                 f"Not supported oreitation norm: {self.orient_norm}"
#             )
#         pred_bo = torch.from_numpy(pred_bo_np).to("cuda")
#         orient_pred3d = (
#             orient2pose(
#                 pred_bo,
#                 # torch.from_numpy(self.bo[seq]).to("cuda"),
#                 torch.from_numpy(bl[seq]).to("cuda"),
#                 torch.from_numpy(root[seq]).to("cuda"),
#             )
#             .cpu()
#             .numpy()
#         )
#         preds[seq] = orient_pred3d

        pred_real_pose = post_process_func(pred3d[0], seq)
    
        if augment:
            pred_real_pose_aug = post_process_func(pred3d[1], seq)
            pred_real_pose_aug[:, :, 0] *= -1
            pred_real_pose_aug = dataset.pose3d_jointset.flip(
                pred_real_pose_aug
            )
            pred_real_pose = (pred_real_pose + pred_real_pose_aug) / 2
            
        mpred = pred_real_pose[valid]
        _dpred[seq] = mpred

In [None]:
_ = eval_results(_dgt, _dpred, MuPoTSJoints())

In [None]:
dgt = {}
dpred = {}
for seq in seqs:
    mgt = _dgt[seq]
    mpred = _dpred[seq]
    gt_r = get_lengths(mgt)
    r, t, p = get_rtp(mpred)
    mpred = get_xyz(gt_r, t, p, mgt[:, 14, :])
    dgt[seq] = mgt
    dpred[seq] = mpred
_ = eval_results(dgt, dpred, MuPoTSJoints())

In [None]:
dgt = {}
dpred = {}
for seq in seqs:
    mgt = _dgt[seq]
    mpred = _dpred[seq]
    gt_r = get_lengths(mgt)
    diff = joint2bone(mpred)
    dx = diff[:, :, 0]
    dy = diff[:, :, 1]
    dz = diff[:, :, 2]
    adj_dz = np.sign(diff[:, :, 2]) * np.sqrt(np.maximum((gt_r**2) - (dx**2) - (dy**2), 0))
    mpred = bone2joint(dx, dy, adj_dz, mgt[:, 14, :])
    dgt[seq] = mgt
    dpred[seq] = mpred
_ = eval_results(dgt, dpred, MuPoTSJoints())

In [None]:
for i in range(3):
    i *= 10
    show3Dpose(np.array([mgt[i, ], mpred[i, ]]), MuPoTSJoints(), invert_vertical=True)
    plt.show()
    print(eval_results({'0': mgt[i:i+1, ]}, {'0': mpred[i:i+1]}, MuPoTSJoints(), verbose=False)[0]['0'])