- https://www.adrianbulat.com/downloads/FG20/fast_human_pose.pdf
- https://arxiv.org/pdf/1603.06937.pdf (Figure 4)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline

In [3]:
import os
import sys
import torch
sys.path.append("/home/michael/CascadedPoseEstimation/lib")
from core.config import config
from core.config import update_config

In [None]:
import models.pose_stacked_hg
import torch.nn as nn
from typing import Dict, Iterable, Callable
from utils.utils import create_logger

In [None]:
def count_parameters(model, trainable=False):
  if trainable:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
  else:
    return sum(p.numel() for p in model.parameters())

In [None]:
def get_state_dict(output_dir, config, logger, use_best=False):
  if config.TEST.MODEL_FILE:
    logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
    state_dict = torch.load(config.TEST.MODEL_FILE)
  else:
    ckpt_path = os.path.join(output_dir, f"final_state.pth.tar")
    
    if os.path.exists(ckpt_path) and not use_best:
      logger.info('=> loading model from {}'.format(ckpt_path))
      state_dict = torch.load(ckpt_path)
    else:
      ckpt_path = os.path.join(output_dir, f"model_best.pth.tar")
      logger.info('=> loading model from {}'.format(ckpt_path))
      state_dict = torch.load(ckpt_path)
  
  if "state_dict" in state_dict:
    state_dict = state_dict["state_dict"]
#       # Fix
#       state_dict = OrderedDict()
#       for k, v in state_dict_src.items():
#         k = k.replace("module.", "")
#         state_dict[k] = v
        
  return state_dict

In [None]:
def load_model(config, output_dir, logger, load=True, load_best_ckpt=True):
  # Setup model
  model = models.pose_stacked_hg.get_pose_net(config, is_train=False)
  if load:
    # Load state dict
    state_dict = get_state_dict(output_dir, 
                                config, 
                                logger, 
                                use_best=load_best_ckpt)

    # Load previous model
    model.load_state_dict(state_dict)
  return model

In [74]:
cfg_paths = [
  "/home/michael/CascadedPoseEstimation/experiments/mpii/hourglass/hourglass_4__td_1__double.yaml",
  "/home/michael/CascadedPoseEstimation/experiments/mpii/hourglass/hourglass_4__td_1.yaml",
  "/home/michael/CascadedPoseEstimation/experiments/mpii/hourglass/hourglass_4__td_1__shared_weights.yaml",
]

In [84]:
cfg_path = cfg_paths[0]
cfg_path

'/home/michael/CascadedPoseEstimation/experiments/mpii/hourglass/hourglass_4__td_1__double.yaml'

In [85]:
config["MODEL"]["EXTRA"]

{'DOUBLE_STACK': True,
 'NUM_DOUBLE_CHANNELS': 144,
 'SHARE_HG_WEIGHTS': False,
 'TARGET_TYPE': 'gaussian',
 'SIGMA': 2,
 'HEATMAP_SIZE': array([64, 64]),
 'N_HG_STACKS': 4}

In [91]:
update_config(cfg_path)

In [139]:
# Setup logger
logger, output_dir, tb_log_dir = create_logger(config, cfg_path, 'valid')
output_dir = "../" + output_dir
output_dir

=> creating output/mpii/hourglass_x4__TD_1.0__double
=> creating log/mpii/hourglass_x4/hourglass_4__td_1__double_2021-09-25-16-21


'../output/mpii/hourglass_x4__TD_1.0__double'

In [143]:
model = load_model(config, output_dir, logger, load=False, load_best_ckpt=True)
n_params = count_parameters(model, trainable=True)
print(f"n_params: {n_params:,}")

n_params: 13,180,288


In [144]:
X = torch.zeros((1, 3, 256, 256))

In [145]:
out = model(X)

In [146]:
out.shape

torch.Size([8, 1, 16, 64, 64])

# Dataset

In [None]:
import torchvision.transforms as transforms
import dataset
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import pandas as pd
import seaborn as sns

In [None]:
# Data loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

In [None]:
gpus = [0]

valid_dataset = eval('dataset.'+config.DATASET.DATASET)(
    config,
    "../" + config.DATASET.ROOT,
    config.DATASET.TEST_SET,
    False,
    transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=config.TEST.BATCH_SIZE*len(gpus),
    shuffle=False,
    num_workers=config.WORKERS,
    pin_memory=True
)

In [None]:
for X, y, z, h in valid_loader:
  break

# Model

In [None]:
# config['MODEL']['EXTRA']['CASCADED_SCHEME'] = 'parallel'

In [None]:
class FeatureExtractor(nn.Module):
  def __init__(self, model, layer_suffix="identity_mapping"):
    super().__init__()
    self.model = model
    self.layer_suffix = layer_suffix
    self._activation = {}
    self.setup()
    
  def _hook_fxn(self, name):
    # the hook signature
    # Inputs: model, input, output
    def hook(model, input, output):
      if self._log_active:
        self._activation[name] = {
            "input": input[0].detach(),
            "output": output.detach(),
        }
    return hook

  def setup(self):
    self.hooks = []
    for name, module in self.model.named_modules():
      if name.endswith(self.layer_suffix) and "hg" in name:
        print(f"Hooking into:\t{name}")
        hook_i = module.register_forward_hook(self._hook_fxn(name))
        self.hooks.append(hook_i)
  
  def forward(self, x, t=0):
    for t_i in range(t+1):
      self._log_active = t_i == t
      _ = self.model(x, t_i)
      
    for hook in self.hooks:
      hook.remove()
    return self._activation

In [None]:
def avg_results(alpha_features):
  alpha_avg = {}
  for k, v in alpha_features.items():
    in_v = v["input"]
    out_v = v["output"]
    alpha_avg[k] = {
      "input": in_v.mean(dim=(0,2,3)).numpy(),
      "output": out_v.mean(dim=(0,2,3)).numpy(),
    }
  return alpha_avg

In [None]:
def plot_alpha_features(avg_alpha_features, t=0):
  n_plots = len(avg_alpha_features)
  fig, axes = plt.subplots(n_plots, 1, figsize=(12,4*n_plots))
  for i, (k, v) in enumerate(avg_alpha_features.items()):
    ax_i = axes[i]
    flat_in_v = v["input"]
    flat_out_v = v["output"]
    df_dict = defaultdict(list)
    for v in flat_in_v:
      df_dict["key"].append("in")
      df_dict["val"].append(v)
    for v in flat_out_v:
      df_dict["key"].append("out")
      df_dict["val"].append(v)
    layer_df = pd.DataFrame(df_dict)
    g = sns.histplot(x="val", 
                     hue="key", 
                     binwidth=0.05,
                     data=layer_df, 
                     stat="probability",
                     ax=ax_i)
    ax_i.set_xlim((-1.5, 1.5))
    
    title = f"{k} (t={t})"
    ax_i.set_title(title)
  plt.tight_layout()

In [None]:
# skip_conv_feature_extractor = FeatureExtractor(model, layer_suffix="identity_mapping.skip_conv")
# skip_conv_features = skip_conv_feature_extractor(X, t=0)

In [None]:
def check_same(prev_avg_alpha, prev_vals):
  for k, prev_vals in prev_avg_alpha.items():
    for key, prev_val in prev_vals.items():
      v = avg_alpha[k][key]
      if np.all(v == prev_val):
        print(f"{k} {key} All same!")

In [None]:
df_dict = defaultdict(list)
n_timesteps = model.timesteps
prev_avg_alpha = None
for t in range(n_timesteps):
  model = load_model(config, output_dir, logger, load_best_ckpt=True)
  print(f"t={t}/{n_timesteps}...")
  alpha_feature_extractor = FeatureExtractor(model, layer_suffix="identity_mapping")
  alpha_features = alpha_feature_extractor(X, t=t)
  avg_alpha = avg_results(alpha_features)
  if prev_avg_alpha is not None:
    check_same(prev_avg_alpha, prev_vals)
  prev_avg_alpha = avg_alpha
  # plot_alpha_features(avg_alpha, t=t)
  for layer_key, layer_vals in avg_alpha.items():
    for in_val, out_val in zip(layer_vals["input"], layer_vals["output"]):
      df_dict["in_val"].append(in_val)
      df_dict["out_val"].append(out_val)
      df_dict["layer"].append(layer_key)
      df_dict["t"].append(t)
    
  del model
  del alpha_feature_extractor
df = pd.DataFrame(df_dict)

In [None]:
prev_val

In [None]:
prev_v

In [None]:
v

In [None]:
avg_alpha

In [None]:
for layer, layer_df in df.groupby("layer"):
  break

In [None]:
sns.histplot(x="in_val", hue="t", binwidth=0.1, alpha=0.3, data=layer_df)

In [None]:
sns.histplot(x="out_val", hue="t", binwidth=0.05, alpha=0.3, data=layer_df)

In [None]:
for t, t_df in layer_df.groupby("t"):
  sns.histplot(x="out_val", data=t_df, binwidth=0.05, alpha=0.3)
  plt.show()
  plt.clf()

In [None]:
layer_df.head()

In [None]:
# net = pose_resnet.get_pose_net(config, is_train=False)

In [None]:
net.timesteps

In [None]:
n_params = count_parameters(net, trainable=False)
print(f"n_params: {n_params:,}")

In [None]:
# X = torch.zeros((4, 3, 256, 256))

In [None]:
X = torch.randint(0, 255, (4, 3, 256, 256)) / 255 * 2 - 1
X.min(), X.max()

In [None]:
o = net(X, 3)

In [None]:
for t in range(net.timesteps):
  print("T: ", t)
  o = net(X, t)
  print("\n")

In [None]:
outs = []
for t in range(net.timesteps):
  print(t)
  out = net(X, t=t)
  outs.append(out)

In [None]:
for t, out in enumerate(outs):
  x1 = out.sum()
  print(t, x1)