In [250]:
%load_ext autoreload
%autoreload 2

Could not find coco.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [251]:
%matplotlib inline

In [259]:
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import pose_estimation._init_paths

from collections import defaultdict
import pandas as pd
pd.set_option('mode.chained_assignment',None)
import glob
import torch
import torchvision.transforms as transforms
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator
from IPython.display import display
import numpy as np
from core.evaluate import accuracy
from matplotlib.ticker import MaxNLocator
import dataset
from core.config import config
# from core.config import update_config
import yaml
from easydict import EasyDict as edict
import utils.flops_benchmarker
import argparse
import models.pose_stacked_hg
from utils.utils import create_experiment_directory
%matplotlib inline

In [260]:

def _update_dict(k, v):
    if k == "DATASET":
        if "MEAN" in v and v["MEAN"]:
            v["MEAN"] = np.array([eval(x) if isinstance(x, str) else x
                                  for x in v["MEAN"]])
        if "STD" in v and v["STD"]:
            v["STD"] = np.array([eval(x) if isinstance(x, str) else x
                                 for x in v["STD"]])
    if k == "MODEL":
        if "EXTRA" in v and "HEATMAP_SIZE" in v["EXTRA"]:
            if isinstance(v["EXTRA"]["HEATMAP_SIZE"], int):
                v["EXTRA"]["HEATMAP_SIZE"] = np.array(
                    [v["EXTRA"]["HEATMAP_SIZE"], v["EXTRA"]["HEATMAP_SIZE"]])
            else:
                v["EXTRA"]["HEATMAP_SIZE"] = np.array(
                    v["EXTRA"]["HEATMAP_SIZE"])
        if "IMAGE_SIZE" in v:
            if isinstance(v["IMAGE_SIZE"], int):
                v["IMAGE_SIZE"] = np.array([v["IMAGE_SIZE"], v["IMAGE_SIZE"]])
            else:
                v["IMAGE_SIZE"] = np.array(v["IMAGE_SIZE"])
    for vk, vv in v.items():
        if vk in config[k]:
            config[k][vk] = vv
        else:
            raise ValueError(f"{k}.{vk} not exist in config.py")

In [261]:
def update_config(config, config_file):
    exp_config = None
    with open(config_file) as f:
        exp_config = edict(yaml.load(f))
        for k, v in exp_config.items():
            if k in config:
                if isinstance(v, dict):
                    _update_dict(k, v)
                else:
                    if k == "SCALES":
                        config[k][0] = (tuple(v))
                    else:
                        config[k] = v
            else:
                raise ValueError(f"{k} not exist in config.py")
    return config

In [265]:
def parse_args(cfg_path=""):
    parser = argparse.ArgumentParser(description='Train keypoints network')
#     # general
#     parser.add_argument('--cfg',
#                         help='experiment configure file name',
#                         required=True,
#                         type=str)
    if cfg_path:
      default_cfg = cfg_path
    else:
      default_cfg = "experiments/mpii/hourglass_8__td_1.yaml"
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default=default_cfg,
                        type=str)

    args, rest = parser.parse_known_args()
    # update config
#     update_config(args.cfg)

    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--max_batch_logs',
                        help='Max # of batches to save data from',
                        default=5,
                        type=int)
    parser.add_argument('--gpus',
                        help='gpus',
                        type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)
    parser.add_argument('--model-file',
                        help='model state file',
                        type=str)
    parser.add_argument('--result_root',
                        default="/hdd/mliuzzolino/TDPoseEstimation/results/",
                        help='Root for results',
                        type=str)
    parser.add_argument('--threshold',
                        type=float,
                        default=0.5,
                        help='Accuracy threshold [default=0.5]')
    parser.add_argument('--use-detect-bbox',
                        help='use detect bbox',
                        action='store_true')
    parser.add_argument('--flip-test',
                        help='use flip test',
                        action='store_true')
    parser.add_argument('--load_best_ckpt',
                        help='Load best checkpoint [default: load final]',
                        action='store_true')
    parser.add_argument('--post-process',
                        help='use post process',
                        action='store_true')
    parser.add_argument('--shift-heatmap',
                        help='shift heatmap',
                        action='store_true')
    parser.add_argument('--force_overwrite',
                        help='Force overwrite',
                        action='store_true')
    parser.add_argument('--vis_output_only',
                        help='Visualize output only; dont save results',
                        action='store_true')
    parser.add_argument('--save_all_data',
                        help='Save all data',
                        action='store_true')
    parser.add_argument('--coco-bbox-file',
                        help='coco detection bbox file',
                        type=str)

    args = parser.parse_args("")

    return args

In [266]:
def reset_config(config, args):
    if args.gpus:
        config.GPUS = args.gpus
    if args.workers:
        config.WORKERS = args.workers
    if args.use_detect_bbox:
        config.TEST.USE_GT_BBOX = not args.use_detect_bbox
    if args.flip_test:
        config.TEST.FLIP_TEST = args.flip_test
    if args.post_process:
        config.TEST.POST_PROCESS = args.post_process
    if args.shift_heatmap:
        config.TEST.SHIFT_HEATMAP = args.shift_heatmap
    if args.model_file:
        config.TEST.MODEL_FILE = args.model_file
    if args.coco_bbox_file:
        config.TEST.COCO_BBOX_FILE = args.coco_bbox_file

In [267]:
args = parse_args()
reset_config(config, args)
config = update_config(config, args.cfg)

  exp_config = edict(yaml.load(f))


## Load Dataset

In [268]:
normalize = transforms.Normalize(
  mean=[0.485, 0.456, 0.406],
  std=[0.229, 0.224, 0.225],
)
valid_dataset = eval('dataset.coco')(
    config,
    config.DATASET.ROOT,
    config.DATASET.TEST_SET,
    False,
    transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

AttributeError: module 'dataset' has no attribute 'coco'

In [269]:
for i, (x_data, target, target_weight, meta) in enumerate(valid_loader):
  break

## Load Model

In [270]:
root = f"experiments/mpii"

In [271]:
[print(ele) for ele in np.sort(os.listdir(root)) if ele.endswith(".yaml")];

hourglass_4__td_0.yaml
hourglass_4__td_0_25.yaml
hourglass_4__td_0_25__distill_td_0.yaml
hourglass_4__td_0_25__distill_td_0_25.yaml
hourglass_4__td_0_25__distill_td_0_5.yaml
hourglass_4__td_0_25__distill_td_0_9.yaml
hourglass_4__td_0_25__distill_td_1.yaml
hourglass_4__td_0_25__double.yaml
hourglass_4__td_0_5.yaml
hourglass_4__td_0_5__distill_td_0.yaml
hourglass_4__td_0_5__distill_td_0_25.yaml
hourglass_4__td_0_5__distill_td_0_5.yaml
hourglass_4__td_0_5__distill_td_0_9.yaml
hourglass_4__td_0_5__distill_td_1.yaml
hourglass_4__td_0_5__double.yaml
hourglass_4__td_0_9.yaml
hourglass_4__td_0_9__distill_td_0.yaml
hourglass_4__td_0_9__distill_td_0_25.yaml
hourglass_4__td_0_9__distill_td_0_5.yaml
hourglass_4__td_0_9__distill_td_0_9.yaml
hourglass_4__td_0_9__distill_td_1.yaml
hourglass_4__td_0_9__double.yaml
hourglass_4__td_0__distill_td_0.yaml
hourglass_4__td_0__distill_td_0_25.yaml
hourglass_4__td_0__distill_td_0_5.yaml
hourglass_4__td_0__distill_td_0_9.yaml
hourglass_4__td_0__distill_td_1.yam

In [288]:
# cfg_path = os.path.join(root, "hourglass_8__teacher.yaml")
# cfg_path = os.path.join(root, "hourglass_8__td_1__no_shared__distill.yaml")
cfg_path = os.path.join(root, "hourglass_8__td_1__distill_td_1_singlehead.yaml")


In [289]:
args = parse_args(cfg_path)
reset_config(config, args)
config = update_config(config, args.cfg)

  exp_config = edict(yaml.load(f))


In [290]:
output_dir = create_experiment_directory(
    config,
    args.cfg,
    distillation=False,
    make_dir=False,
)
print(output_dir)

output/mpii/hourglass_x8__TD_1.0__single_head


In [291]:
model = models.pose_stacked_hg.get_pose_net(config)

Sharing weights!


In [292]:
n_params = 0
for key, param in model.named_parameters():
  if param.requires_grad:
    n_params += param.numel() 
print(f"n_params: {n_params:,}") 

n_params: 660,496


In [247]:
# model

In [248]:
out = model(x_data, log_flops=True)

In [249]:
model.total_flops

[0.518339072,
 1.036678144,
 1.555017216,
 2.073356288,
 2.59169536,
 3.110034432,
 3.628373504,
 4.146712576]

In [222]:
# torch.save(model.total_flops, "student_gflops.pt")

In [223]:
np.cumsum(model.total_flops)

array([  8.18285158,  24.54855475,  49.0971095 ,  81.82851584,
       122.74277376, 171.83988326, 229.11984435, 294.58265702])

In [224]:
model.total_flops

[8.182851584,
 16.365703168,
 24.548554752,
 32.731406336,
 40.91425792,
 49.097109504,
 57.279961088,
 65.462812672]

### Model GPU memory consumption

In [225]:
def format_bytes(size):
    # 2**10 = 1024
    power = 2**10
    n = 0
    power_labels = {0 : '', 1: 'kilo', 2: 'mega', 3: 'giga', 4: 'tera'}
    while size > power:
        size /= power
        n += 1
    return size, power_labels[n]+'bytes'

In [226]:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
format_bytes(mem)

(219.00628280639648, 'megabytes')

14857660

In [None]:
model.total_flops

In [None]:
out.shape

In [None]:
utils.flops_benchmarker