In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline

In [3]:
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import pose_estimation._init_paths

from collections import defaultdict
import pandas as pd
pd.set_option('mode.chained_assignment',None)
import glob
import torch
import torchvision.transforms as transforms
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator
from IPython.display import display
import numpy as np
from core.evaluate import accuracy
from matplotlib.ticker import MaxNLocator
import dataset
from core.config import config
# from core.config import update_config
import yaml
from easydict import EasyDict as edict
import utils.flops_benchmarker
import argparse
import models.pose_stacked_hg
%matplotlib inline

this_dir:  /home/michael/TDPoseEstimation/pose_estimation
lib_path:  /home/michael/TDPoseEstimation/lib
Adding /home/michael/TDPoseEstimation/lib to sys.path
Could not find coco.


In [4]:

def _update_dict(k, v):
    if k == "DATASET":
        if "MEAN" in v and v["MEAN"]:
            v["MEAN"] = np.array([eval(x) if isinstance(x, str) else x
                                  for x in v["MEAN"]])
        if "STD" in v and v["STD"]:
            v["STD"] = np.array([eval(x) if isinstance(x, str) else x
                                 for x in v["STD"]])
    if k == "MODEL":
        if "EXTRA" in v and "HEATMAP_SIZE" in v["EXTRA"]:
            if isinstance(v["EXTRA"]["HEATMAP_SIZE"], int):
                v["EXTRA"]["HEATMAP_SIZE"] = np.array(
                    [v["EXTRA"]["HEATMAP_SIZE"], v["EXTRA"]["HEATMAP_SIZE"]])
            else:
                v["EXTRA"]["HEATMAP_SIZE"] = np.array(
                    v["EXTRA"]["HEATMAP_SIZE"])
        if "IMAGE_SIZE" in v:
            if isinstance(v["IMAGE_SIZE"], int):
                v["IMAGE_SIZE"] = np.array([v["IMAGE_SIZE"], v["IMAGE_SIZE"]])
            else:
                v["IMAGE_SIZE"] = np.array(v["IMAGE_SIZE"])
    for vk, vv in v.items():
        if vk in config[k]:
            config[k][vk] = vv
        else:
            raise ValueError(f"{k}.{vk} not exist in config.py")

In [5]:
def update_config(config, config_file):
    exp_config = None
    with open(config_file) as f:
        exp_config = edict(yaml.load(f))
        for k, v in exp_config.items():
            if k in config:
                if isinstance(v, dict):
                    _update_dict(k, v)
                else:
                    if k == "SCALES":
                        config[k][0] = (tuple(v))
                    else:
                        config[k] = v
            else:
                raise ValueError(f"{k} not exist in config.py")
    return config

In [6]:
def parse_args(cfg_path=""):
    parser = argparse.ArgumentParser(description='Train keypoints network')
#     # general
#     parser.add_argument('--cfg',
#                         help='experiment configure file name',
#                         required=True,
#                         type=str)
    if cfg_path:
      default_cfg = cfg_path
    else:
      default_cfg = "experiments/mpii/hourglass_8__teacher.yaml"
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default=default_cfg,
                        type=str)

    args, rest = parser.parse_known_args()
    # update config
#     update_config(args.cfg)

    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--max_batch_logs',
                        help='Max # of batches to save data from',
                        default=5,
                        type=int)
    parser.add_argument('--gpus',
                        help='gpus',
                        type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)
    parser.add_argument('--model-file',
                        help='model state file',
                        type=str)
    parser.add_argument('--result_root',
                        default="/hdd/mliuzzolino/TDPoseEstimation/results/",
                        help='Root for results',
                        type=str)
    parser.add_argument('--threshold',
                        type=float,
                        default=0.5,
                        help='Accuracy threshold [default=0.5]')
    parser.add_argument('--use-detect-bbox',
                        help='use detect bbox',
                        action='store_true')
    parser.add_argument('--flip-test',
                        help='use flip test',
                        action='store_true')
    parser.add_argument('--load_best_ckpt',
                        help='Load best checkpoint [default: load final]',
                        action='store_true')
    parser.add_argument('--post-process',
                        help='use post process',
                        action='store_true')
    parser.add_argument('--shift-heatmap',
                        help='shift heatmap',
                        action='store_true')
    parser.add_argument('--force_overwrite',
                        help='Force overwrite',
                        action='store_true')
    parser.add_argument('--vis_output_only',
                        help='Visualize output only; dont save results',
                        action='store_true')
    parser.add_argument('--save_all_data',
                        help='Save all data',
                        action='store_true')
    parser.add_argument('--coco-bbox-file',
                        help='coco detection bbox file',
                        type=str)

    args = parser.parse_args("")

    return args

In [7]:
def reset_config(config, args):
    if args.gpus:
        config.GPUS = args.gpus
    if args.workers:
        config.WORKERS = args.workers
    if args.use_detect_bbox:
        config.TEST.USE_GT_BBOX = not args.use_detect_bbox
    if args.flip_test:
        config.TEST.FLIP_TEST = args.flip_test
    if args.post_process:
        config.TEST.POST_PROCESS = args.post_process
    if args.shift_heatmap:
        config.TEST.SHIFT_HEATMAP = args.shift_heatmap
    if args.model_file:
        config.TEST.MODEL_FILE = args.model_file
    if args.coco_bbox_file:
        config.TEST.COCO_BBOX_FILE = args.coco_bbox_file

In [8]:
args = parse_args()
reset_config(config, args)
config = update_config(config, args.cfg)

  exp_config = edict(yaml.load(f))


## Load Dataset

In [9]:
normalize = transforms.Normalize(
  mean=[0.485, 0.456, 0.406],
  std=[0.229, 0.224, 0.225],
)
valid_dataset = eval('dataset.mpii')(
    config,
    config.DATASET.ROOT,
    config.DATASET.TEST_SET,
    False,
    transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=config.TEST.BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

/home/michael/anaconda3/envs/ml/lib/python3.8/site-packages/json_tricks/nonp.py:221: JsonTricksDeprecation: `json_tricks.load(s)` stripped some comments, but `ignore_comments` was not passed; in the next major release, the behaviour when `ignore_comments` is not passed will change; it is recommended to explicitly pass `ignore_comments=True` if you want to strip comments; see https://github.com/mverleg/pyjson_tricks/issues/74


In [10]:
for i, (x_data, target, target_weight, meta) in enumerate(valid_loader):
  break

## Load Model

In [11]:
root = f"experiments/mpii"

In [12]:
[print(ele) for ele in np.sort(os.listdir(root)) if ele.endswith(".yaml")];

hourglass_8__td_0__no_shared__distill.yaml
hourglass_8__td_0__no_shared__no_distill.yaml
hourglass_8__td_0__shared__distill.yaml
hourglass_8__td_0__shared__no_distill.yaml
hourglass_8__teacher.yaml


In [35]:
cfg_path = os.path.join(root, "hourglass_8__teacher.yaml")
# cfg_path = os.path.join(root, "hourglass_8__td_0__shared__no_distill.yaml")
# cfg_path = os.path.join(root, "hourglass_8__td_0__no_shared__no_distill.yaml")

In [36]:
args = parse_args(cfg_path)
reset_config(config, args)
config = update_config(config, args.cfg)

  exp_config = edict(yaml.load(f))


In [37]:
model = models.pose_stacked_hg.get_pose_net(config)

In [38]:
n_params = 0
for key, param in model.named_parameters():
  if param.requires_grad:
    n_params += param.numel() 
print(f"n_params: {n_params:,}") 

n_params: 71,529,259


In [39]:
# model

In [None]:
out = model(x_data, log_flops=True)

In [28]:
model.total_flops

[18.172297216,
 36.344594432,
 54.516891648,
 72.689188864,
 90.86148608,
 109.033783296,
 127.206080512,
 145.378377728]

In [216]:
model.total_flops

[91.167379456,
 182.334758912,
 273.502138368,
 364.669517824,
 455.83689728,
 547.004276736,
 638.171656192,
 729.339035648]

In [189]:
out.shape

torch.Size([8, 32, 16, 64, 64])

In [None]:
utils.flops_benchmarker