In [147]:
from __future__ import print_function, division
import torch

# Utils

In [148]:
%matplotlib inline

import matplotlib.pyplot as plt
import cv2

def show_image(img):
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

# Dataset

In [334]:
from os import listdir

import scipy.io

from torch.utils.data.dataset import Dataset


joint_to_name = {0:'r ankle', 
                 1:'r knee',
                 2:'r hip',
                 3:'l hip', 
                 4:'l knee',
                 5:'l ankle',
                 6:'pelvis', 
                 7:'thorax',
                 8:'upper neck',
                 9:'head top',
                 10:'r wrist',
                 11:'r elbow',
                 12:'r shoulder',
                 13:'l shoulder',
                 14:'l elbow',
                 15:'l wrist'}

def onePersonToDict(p, joint_id_to_name):
    res = {}
    if hasattr(p, 'x1'):
        res['head_rect'] = (p.x1, p.y1, p.x2, p.y2)
    if hasattr(p, 'scale'):
        res['scale'] = p.scale
    if hasattr(p, 'objpos') and hasattr(p.objpos, 'x'):
        res['position'] = (p.objpos.x, p.objpos.y)

    try:
        for joint in p.annopoints.point if hasattr(p, 'annopoints') and hasattr(p.annopoints, 'point') else []:
            if not isinstance(joint.is_visible, np.ndarray):
                res[joint_id_to_name[joint.id]] = (joint.x, joint.y, True if joint.is_visible == u'1' else False)
            else:
                res[joint_id_to_name[joint.id]] = (joint.x, joint.y, True)
    except TypeError:
        pass
        
    return res

def transformMatlabToList(matlab):
    img_anno_transformed = []
    for i in range(matlab.shape[0]):
        img_anno = matlab[i]
        people_anno = []
        if isinstance(img_anno.annorect, np.ndarray):
            for j in range(img_anno.annorect.shape[0]):
                people_anno += [onePersonToDict(img_anno.annorect[j], joint_to_name)]
        else:
            people_anno = [onePersonToDict(img_anno.annorect, joint_to_name)]
        img_anno_transformed += [(img_anno.image.name, people_anno)]
        
    return img_anno_transformed


class MpiiHumanPoseDataset(Dataset):
    def __init__(self, imgs_path, annotations_path):
        self.img_base_path = imgs_path
        self.images_names = listdir(imgs_path)
        self.images_num = len(self.images_names)
        self.annotations = \
            transformMatlabToList(scipy.io.loadmat(annotations_path, struct_as_record=False, squeeze_me=True)["RELEASE"].annolist)
        
        self.file_name_to_anno_id = {}
        for i in range(len(self.annotations)):
            self.file_name_to_anno_id[self.annotations[i][0]] = i
        
    def __len__(self):
        return self.images_num
    
    def __getitem__(self, index):
        img_path = self.img_base_path + '/' + self.images_names[index]
        img = torch.IntTensor(cv2.imread(img_path))
        anno_data = self.annotations[self.file_name_to_anno_id[self.images_names[index]]]
        return img, anno_data

# Model

In [335]:
import torch.nn as nn

class PoseFittingNetwork(nn.Module):
    def __init__(self):
        super(PoseFittingNetwork, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    
    def forward(self, input):
        return input

# Training

In [336]:
def train(network, optimizer, data_loader, num_epoch, loss_fun, log_interval):
    for epoch in range(num_epoch):
        for batch_idx, data in enumerate(data_loader):
            img, annotations = data
            
            optimizer.zero_grad()
            output = network(img)
            loss = loss_fun(output, annotations)
            loss.backward()
            optimizer.step()
            
            if batch_idx % log_interval == 0:
                print(batch_idx)    

# Main

In [337]:
from  torch.utils.data import DataLoader
import torch.optim as optim

# GLOBALS #
data_loader_globals = { 'batch_size':1, 
           'shuffle':True, 
           'num_workers':1, 
           'pin_memory':False}
###########

data_set = MpiiHumanPoseDataset('images', 'mpii_human_pose_v1_u12_2/mpii_human_pose_v1_u12_1.mat')
data_loader = DataLoader(data_set, **data_loader_globals)
net = PoseFittingNetwork()
optimizer = optim.SGD(net.parameters(), lr=0.01)

train(net, optimizer, data_loader, 1, lambda x, y: torch.IntTensor([0]), 10)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn