In [1]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import imageio
import cv2
from torch.utils.data import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter
from vectormath import Vector2

import H36M
import MPII
import model
from util import config
from util.log import get_logger

In [2]:
logger, log_dir, comment = get_logger(comment='old')

In [15]:
task = H36M.Task.Valid

In [4]:
hourglass, optimizer, step, train_epoch = model.hourglass.load(
    device=config.hourglass.device,
    parameter_dir='{log_dir}/parameter'.format(log_dir=log_dir),
)

In [5]:
logger.info('===========================================================')
logger.info('Convert from GT to SH' + '                                 ')
logger.info('    -paramter: ' + comment + '                             ')
logger.info('    -epoch: ' + str(train_epoch) + '                       ')
logger.info('    -task: ' + task + '                                    ')
logger.info('===========================================================')

[INFO|<ipython-input-5-ef50c8b5b93e>:2] 2018-09-10 20:36:11,051 > Convert from GT to SH                                 
[INFO|<ipython-input-5-ef50c8b5b93e>:3] 2018-09-10 20:36:11,052 >     -paramter: old                             
[INFO|<ipython-input-5-ef50c8b5b93e>:4] 2018-09-10 20:36:11,053 >     -epoch: 178                       
[INFO|<ipython-input-5-ef50c8b5b93e>:5] 2018-09-10 20:36:11,055 >     -task: train                                    


In [16]:
del data

In [17]:
data = DataLoader(
    H36M.Dataset(
        data_dir=config.bilinear.data_dir,
        task=task,
        position_only=False,
    ),
    batch_size=config.hourglass.batch_size * 2,
    shuffle=False,
    pin_memory=True,
    num_workers=config.hourglass.num_workers,
)

In [8]:
# One of duplicated values, 9, will be removed at H36M/data.py
from_MPII_to_H36M = [6, 3, 4, 5, 2, 1, 0, 7, 8, 9, 9, 13, 14, 15, 12, 11, 10]
from_MPII_to_H36M = torch.Tensor(from_MPII_to_H36M).long().to(config.bilinear.device)

In [18]:
part = list()
step = 0
with tqdm(total=len(data), desc='SH preprocessing') as progress:
    with torch.set_grad_enabled(False):
        for subset, image, heatmap, action in data:

            in_camera_space = subset[H36M.Annotation.S]
            center = subset[H36M.Annotation.Center]
            scale = subset[H36M.Annotation.Scale]

            in_camera_space = in_camera_space.to(config.bilinear.device)
            image = image.to(config.hourglass.device)
            center = center.to(config.hourglass.device)
            scale = scale.to(config.hourglass.device)

            output = hourglass(image)
            output = output[-1]  # Heatmaps from the last stack in batch-channel-height-width shape.

            n_batch = output.shape[0]

            pose = torch.argmax(output.view(n_batch, 16, -1), dim=-1)
            pose = torch.stack([
                pose % 64,
                pose // 64,
            ], dim=-1).float()
            pose = pose - 32
            pose = center.view(n_batch, 1, 2) + pose / 64 * scale.view(n_batch, 1, 1) * 200

            pose = pose.to(config.bilinear.device)
            pose = torch.index_select(pose, dim=1, index=from_MPII_to_H36M)
            
            in_image_space = np.asarray(pose.cpu())
            part.append(in_image_space)

            progress.update(1)
            step = step + 1

SH preprocessing: 100%|██████████| 6867/6867 [39:50<00:00,  3.22it/s]


In [11]:
train = np.concatenate(part, axis=0)

In [13]:
data = pickle.load(open('data/Human3.6M/train_GT.bin', 'rb'))
data[H36M.Annotation.Part] = train
pickle.dump(data, open('data/Human3.6M/train_SH.bin', 'wb'))

In [19]:
valid = np.concatenate(part, axis=0)

In [20]:
valid.shape

(109867, 17, 2)

In [None]:
data = pickle.load(open('data/Human3.6M/valid_GT.bin', 'rb'))
data[H36M.Annotation.Part] = valid
pickle.dump(data, open('data/Human3.6M/valid_SH.bin', 'wb'))

In [14]:
logger.info('Saved to ' + 'data/Human3.6M/train_SH.bin' + '             ')
logger.info('===========================================================')

[INFO|<ipython-input-14-2028c9167912>:1] 2018-09-11 00:52:45,178 > Saved to data/Human3.6M/train_SH.bin             


In [None]:
# Run above cells again with H36M.Task.Valid