In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path as osp
import sys
import os

import torch
import torchvision.transforms as transforms

from visualize import update_config, add_path

lib_path = osp.join('lib')
add_path(lib_path)

import dataset as dataset
from config import cfg
import models
import os
import torchvision.transforms as T


file_name = 'experiments/TP_H_w48_256x192_stage3_1_4_d96_h192_relu_enc6_mh1.yaml' # choose a yaml file
# file_name = 'experiments/TP_R_256x192_d256_h1024_enc4_mh8.yaml'
f = open(file_name, 'r')
update_config(cfg, file_name)

model_name = 'T-H-A6'
assert model_name in ['T-R', 'T-H','T-H-L','T-R-A4', 'T-H-A6', 'T-H-A5', 'T-H-A4' ,'T-R-A4-DirectAttention']

normalize = T.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

dataset = eval('dataset.' + cfg.DATASET.DATASET)(
    cfg, cfg.DATASET.ROOT, 'val', True,
    transforms.Compose([
    transforms.ToTensor(),
    normalize,
    ])
)

device = torch.device('cuda:1')
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
    cfg, is_train=True
)

if cfg.TEST.MODEL_FILE:
    print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))

    model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE, map_location='cuda:1'), strict=True)
else:
    raise ValueError("please choose one ckpt in cfg.TEST.MODEL_FILE")

model.to(device)
print("model params:{:.3f}M".format(sum([p.numel() for p in model.parameters()])/1000**2))

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0

In [None]:
from core.inference import get_max_preds

def get_pred_target(output, target):
  pred, _ = get_max_preds(output)
  target, _ = get_max_preds(target)

  return pred, target
  

In [None]:
import numpy as np

def compute_error_accel(joints_gt, joints_pred, vis=None):
    """
    Computes acceleration error:
        1/(n-2) \sum_{i=1}^{n-1} X_{i-1} - 2X_i + X_{i+1}
    Note that for each frame that is not visible, three entries in the
    acceleration error should be zero'd out.
    Args:
        joints_gt (Nx14x3).
        joints_pred (Nx14x3).
        vis (N).
    Returns:
        error_accel (N-2).
    """
    # (N-2)x14x3

    accel_gt = joints_gt[:-2] - 2 * joints_gt[1:-1] + joints_gt[2:]
    accel_pred = joints_pred[:-2] - 2 * joints_pred[1:-1] + joints_pred[2:]

    normed = np.linalg.norm(accel_pred - accel_gt, axis=2)

    if vis is None:
        new_vis = np.ones(len(normed), dtype=bool)
    else:
        invis = np.logical_not(vis)
        invis1 = np.roll(invis, -1)
        invis2 = np.roll(invis, -2)
        new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
        new_vis = np.logical_not(new_invis)

    return np.mean(normed[new_vis], axis=1)

In [None]:
from core.evaluate import accuracy

with torch.no_grad():
  model.eval()
  
  test_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=cfg.WORKERS,
    pin_memory=cfg.PIN_MEMORY
  )

  acc = AverageMeter()

  pred_list = []
  target_list = []

  for i, (input, target, target_weight, meta) in enumerate(test_loader):
    input = input.to(device)
    output = model(input)

    target = target
    target_weight = target_weight

    _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
                                      target.detach().cpu().numpy())
    acc.update(avg_acc, cnt)

    pred, target = get_pred_target(output.detach().cpu().numpy(), target.detach().cpu().numpy())

    pred_list.append(pred[0])
    target_list.append(target[0])

  print('acc: ', acc.avg)

  accel_error = compute_error_accel(np.array(pred_list), np.array(target_list))
  print('accel_error: ', accel_error)
  print('average accel_error: ', np.average(accel_error))