In [None]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import imageio
import cv2
from torch.utils.data import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter
from vectormath import Vector2

import H36M
import MPII
import model
from util import config

In [None]:
hourglass, _, _, _ = model.hourglass.load(config.hourglass.parameter_dir, config.hourglass.device)
hourglass.eval()

In [None]:
data = DataLoader(
    H36M.Dataset(
        data_dir=config.bilinear.data_dir,
        task=H36M.Task.Train,
        position_only=False,
    ),
    batch_size=config.hourglass.batch_size * 2,
    shuffle=False,
    pin_memory=True,
    num_workers=config.hourglass.num_workers,
)

In [None]:
# One of duplicated values, 9, will be removed at H36M/data.py
from_MPII_to_H36M = [6, 3, 4, 5, 2, 1, 0, 7, 8, 9, 9, 13, 14, 15, 12, 11, 10]
from_MPII_to_H36M = torch.Tensor(from_MPII_to_H36M).long().to(config.bilinear.device)

In [None]:
part = list()
step = 0
with tqdm(total=len(data), desc='SH preprocessing') as progress:
    with torch.set_grad_enabled(False):
        for subset, image, heatmap, action in data:

            in_camera_space = subset[H36M.Annotation.S]
            center = subset[H36M.Annotation.Center]
            scale = subset[H36M.Annotation.Scale]
            mean = subset[H36M.Annotation.Mean_Of + H36M.Annotation.Part]
            stddev = subset[H36M.Annotation.Stddev_Of + H36M.Annotation.Part]

            in_camera_space = in_camera_space.to(config.bilinear.device)
            image = image.to(config.hourglass.device)
            center = center.to(config.hourglass.device)
            scale = scale.to(config.hourglass.device)
            mean = mean.to(config.bilinear.device)
            stddev = stddev.to(config.bilinear.device)

            output = hourglass(image)
            output = output[-1]  # Heatmaps from the last stack in batch-channel-height-width shape.

            n_batch = output.shape[0]

            pose = torch.argmax(output.view(n_batch, 16, -1), dim=-1)
            pose = torch.stack([
                pose % 64,
                pose // 64,
            ], dim=-1).float()
            pose = pose - 32
            pose = center.view(n_batch, 1, 2) + pose / 64 * scale.view(n_batch, 1, 1) * 200

            pose = pose.to(config.bilinear.device)
            pose = torch.index_select(pose, dim=1, index=from_MPII_to_H36M)
            
            in_image_space = np.asarray(pose.cpu())
            part.append(in_image_space)

            progress.update(1)
            step = step + 1

In [None]:
train = np.concatenate(part, axis=0)

In [None]:
data = pickle.load(open('data/Human3.6M/train_GT.bin', 'rb'))
data[H36M.Annotation.Part] = train
pickle.dump(data, open('data/Human3.6M/train_SH.bin', 'wb'))

In [None]:
# Run above cells again with H36M.Task.Valid