In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import numpy as np
import glob
import torch
import yaml

import _init_paths
import models.pose_resnet
import models.unet
import models.cascaded_pose_resnet

this_dir:  /home/michael/CascadedPoseEstimation/pose_estimation
lib_path:  /home/michael/CascadedPoseEstimation/lib


In [3]:
roots = glob.glob(f"../output/mpii/*resnet_18*")
for root in roots:
  final_path = os.path.join(root, "final_state.pth.tar")
  best_path = os.path.join(root, "model_best.pth.tar")
  break

In [4]:
root

'../output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel'

In [5]:
config_paths = glob.glob(f"../experiments/mpii/resnet18/*")
config_paths

['../experiments/mpii/resnet18/cascaded__td_0_parallel.yaml',
 '../experiments/mpii/resnet18/cascaded__td_1_serial.yaml',
 '../experiments/mpii/resnet18/cascaded__td_0_5_parallel.yaml',
 '../experiments/mpii/resnet18/baseline.yaml',
 '../experiments/mpii/resnet18/cascaded__td_1_parallel.yaml']

In [6]:
config_path = config_paths[-1]
config_path

'../experiments/mpii/resnet18/cascaded__td_1_parallel.yaml'

In [7]:
class dotdict(dict):
    """
    a dictionary that supports dot notation 
    as well as dictionary access notation 
    usage: d = DotDict() or d = DotDict({'val1':'first'})
    set attributes: d.val2 = 'second' or d['val2'] = 'second'
    get attributes: d.val2 or d['val2']
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __init__(self, dct):
        for key, value in dct.items():
            if hasattr(value, 'keys'):
                value = dotdict(value)
            self[key] = value
    
with open(config_path, "r") as infile:
  config = dotdict(yaml.load(infile))



In [8]:
final = torch.load(final_path)
best = torch.load(best_path)

In [9]:
# Setup model
if config.MODEL.NAME == "pose_resnet":
  if config.MODEL.CASCADED:
    model = models.cascaded_pose_resnet.get_pose_net(config, is_train=True)
  else:
    model = models.pose_resnet.get_pose_net(config, is_train=True)
elif config.MODEL.NAME == "unet":
    model = models.unet.get_pose_net(config, is_train=True)

if config.MODEL.CASCADED:
    config.MODEL.N_TIMESTEPS = model.timesteps

In [15]:
from collections import OrderedDict

In [18]:
new_state_dict = OrderedDict()
for k, v in best.items():
  k = k.replace("module.", "")
  new_state_dict[k] = v

In [19]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'

In [17]:
final.keys()

odict_keys(['layer0.conv1.weight', 'layer0.bn1.weight', 'layer0.bn1.bias', 'layer0.bn1.running_mean', 'layer0.bn1.running_var', 'layer0.bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.downsample.0.weight', 'layer2.0.conv1.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracke

In [12]:
glob.glob(f"{root}/*")

['output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/model_best.pth.tar',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/cascaded__td_1_2021-08-18-15-29_valid.log',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/pred.mat',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/cascaded__td_1_2021-08-18-12-12_valid.log',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/checkpoint.pth.tar',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/cascaded__td_1_2021-08-16-19-31_train.log',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/pose_resnet.py',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/final_state.pth.tar',
 'output/mpii/pose_resnet_18__cascaded_td(1.0)__parallel/cascaded__td_1_parallel_2021-08-23-10-46_valid.log']