# CPSC 532R/533R Visual AI - Assignment 2

This Jupyter notebook provides downloads and defines a pytorch dataset of egocentric images and corresponding 2D pose, a pre-defined neural network, and plotting utility functions. We also provide training code for regressing 2D pose directly from the image. All modules should seamlessly integrate into your Assignment 1 solution as they use dictionaries for storing the input images and output labels. You need to extend this or your Assignment 1 notebook with the tasks described in the Assignment2.pdf.

In [None]:
# download dataset from the web (400 MB file from https://www.cs.ubc.ca/~rhodin/20_CPSC_532R_533R/assignments/EgoCap_nth10.hdf5)
file_name = "EgoCap_nth10.hdf5"
import os.path
import urllib
if not os.path.exists(file_name):
    print("Downloading dataset, might take a while... its 400 MB")
    urllib.request.urlretrieve("https://www.cs.ubc.ca/~rhodin/20_CPSC_532R_533R/assignments/"+file_name,file_name)
    print("Done downloading")
else:
    print("Dataset already present, nothing to be done")

In [None]:
# utility dictionary that can move tensor values between devices via the 'to(device)' function
from collections import OrderedDict 
class DeviceDict(dict):
    # following https://stackoverflow.com/questions/3387691/how-to-perfectly-override-a-dict
    def __init__(self, *args):
      super(DeviceDict, self).__init__(*args)
    def to(self, device):
      dd = DeviceDict() # return shallow copy
      for k,v in self.items():
          if torch.is_tensor(v):
              dd[k] = v.to(device)
          else:
              dd[k] = v
      return dd

In [None]:
# Definition the EgoCap dataset (small version)
import torch
import torchvision
import torchvision.transforms as transforms
import h5py
import os

class EgoCapDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder):
        super(EgoCapDataset).__init__();
        data_file = 'EgoCap_nth10.hdf5'
        print("Loading dataset to memory, can take some seconds")
        with h5py.File(data_file, 'r') as hf:
            self.poses_2d = torch.from_numpy(hf['pose_2d'][...])
            self.poses_3d = torch.from_numpy(hf['pose_3d'][...])
            self.imgs  = torch.from_numpy(hf['img'][...])
        print(".. done loading")
        self.mean, self.std = torch.FloatTensor([0.485, 0.456, 0.406]), torch.FloatTensor([0.229, 0.224, 0.225])
        self.normalize = transforms.Normalize(self.mean, self.std)
        self.denormalize = transforms.Compose([transforms.Normalize(mean = [ 0., 0., 0. ], std = 1/self.std),
                                               transforms.Normalize(mean = -self.mean, std = [ 1., 1., 1. ])])

    def __len__(self):
        return self.poses_2d.shape[0]
    
    def __getitem__(self, idx):
        sample = DeviceDict(
                  {'img': self.normalize(self.imgs[idx].float()/255),
                  'pose_2d': self.poses_2d[idx],
                  'pose_3d': self.poses_3d[idx]})
        return sample

In [None]:
# skeleton pose definition
# Labels are 2D (y, x) coordinate vectors, zero-based starting from the top-left pixel. They appear in the following order: 
joint_names = ['head', 'neck', 'left-shoulder', 'left-elbow', 'left-wrist', 'left-finger', 'right-shoulder', 'right-elbow', 'right-wrist', 'right-finger', 'left-hip', 'left-knee', 'left-ankle', 'left-toe', 'right-hip', 'right-knee', 'right-ankle', 'right-toe']
# the skeleton is defined as a set of bones (pairs of skeleton joint indices):
bones_ego_str = [('head', 'neck'), ('neck', 'left-shoulder'), ('left-shoulder', 'left-elbow'), ('left-elbow', 'left-wrist'), ('left-wrist', 'left-finger'), ('neck', 'right-shoulder'), ('right-shoulder', 'right-elbow'), ('right-elbow', 'right-wrist'), ('right-wrist', 'right-finger'), 
                 ('left-shoulder', 'left-hip'), ('left-hip', 'left-knee'), ('left-knee', 'left-ankle'), ('left-ankle', 'left-toe'), ('right-shoulder', 'right-hip'), ('right-hip', 'right-knee'), ('right-knee', 'right-ankle'), ('right-ankle', 'right-toe'), ('right-shoulder', 'left-shoulder'), ('right-hip', 'left-hip')]
bones_ego_idx = [(joint_names.index(b[0]),joint_names.index(b[1])) for b in bones_ego_str]

In [None]:
# plotting utility functions
import matplotlib.pyplot as plt

r"""Plots skeleton pose on a matplotlib axis.

        Args:
            ax (Axis): plt axis to plot
            pose_2d (FloatTensor): tensor of keypoints, of shape K x 2
            bones (list): list of tuples, each tuple defining the keypoint indices to be connected by a bone 
        Returns:
            Module: self
"""            
def plot_skeleton(ax, pose_2d, bones=bones_ego_idx, linewidth=2, linestyle='-'):
    cmap = plt.get_cmap('hsv')
    for bone in bones:
        color = cmap(bone[1] * cmap.N // len(joint_names)) # color according to second joint index
        ax.plot(pose_2d[bone,0], pose_2d[bone,1], linestyle, color=color, linewidth=linewidth)

r"""Plots list of skeleton poses and image.

        Args:
            poses (list): list of pose tensors to be plotted
            ax (Axis): plt axis to plot
            bones (list): list of tuples, each tuple defining the keypoint indices to be connected by a bone 
        Returns:
            Module: self
"""       
def plotPoseOnImage(poses, img, ax=plt):
    img_pil = torchvision.transforms.ToPILImage()(img)
    img_size = torch.FloatTensor(img_pil.size)
    if type(poses) is not list:
      poses = [poses]
    linestyles = ['-', '--', '-.', ':']
    for i, p in enumerate(poses):
      pose_px = p*img_size
      plot_skeleton(ax, pose_px, linestyle=linestyles[i%len(linestyles)])
    ax.imshow(img_pil)

r"""Converts a multi channel heatmap to an RGB color representation for display.

        Args:
            heatmap (tensor): of size C X H x W
        Returns:
            image (tensor): of size 3 X H x W
"""       
def heatmap2image(heatmap):
    C,H,W = heatmap.shape
    cmap = plt.get_cmap('hsv')
    img = torch.zeros(3,H,W).to(heatmap.device)
    for i in range(C):
        color = torch.FloatTensor(cmap(i * cmap.N // C)[:3]).reshape([-1,1,1]).to(heatmap.device)
        img = torch.max(img, color * heatmap[i]) # max in case of overlapping position of joints
    # heatmap and probability maps might have small maximum value. Normalize per channel to make each of them visible
    img_max, indices = torch.max(img,dim=-1,keepdim=True)
    img_max, indices = torch.max(img_max,dim=-2,keepdim=True)
    return img/img_max

In [None]:
# setting up the dataset and train/val splits
path='./'
ecds = EgoCapDataset(data_folder=path)

val_ratio = 0.2
val_size = int(len(ecds)*val_ratio)
indices_val = list(range(0, val_size))
indices_train = list(range(val_size, len(ecds)))

val_set   = torch.utils.data.Subset(ecds, indices_val)
train_set = torch.utils.data.Subset(ecds, indices_train)

In [None]:
# playing with data and plotting functions
sample_train = train_set[100]
sample_val = val_set[100]
plotPoseOnImage(sample_train['pose_2d'], ecds.denormalize(sample_train['img']))
plt.show()
plotPoseOnImage(sample_val['pose_2d'], ecds.denormalize(sample_val['img']))
plt.show()
print('dataset length', len(ecds))
print('train_set length', len(train_set))
print('val_set length', len(val_set))
print('pose shape',sample_train['pose_2d'].shape)
print('img shape',sample_train['img'].shape)

In [None]:
# define the dataset loader (batch size, shuffling, ...)
collate_fn_device = lambda batch : DeviceDict(torch.utils.data.dataloader.default_collate(batch)) # collate_fn_device is necessary to preserve our custom dictionary during the collection of samples fetched from the dataset into a Tensor batch. 
# Hopefully, one day, pytorch might change the default collate to pretain the mapping type. Currently all Mapping objects are converted to dict. Anyone wants to create a pull request? Would need to be changed in 
# pytorch/torch/utils/data/_utils/collate.py :     elif isinstance(data, container_abcs.Mapping): return {key: default_convert(data[key]) for key in data}
# pytorch/torch/utils/data/_utils/pin_memory.py : if isinstance(data, container_abcs.Mapping): return {k: pin_memory(sample) for k, sample in data.items()}
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 2, num_workers=0, pin_memory=False, shuffle=True, drop_last=True, collate_fn=collate_fn_device) # Note, setting pin_memory=False to avoid the pin_memory call
val_loader = torch.utils.data.DataLoader(val_set, batch_size = 2, num_workers=0, pin_memory=False, shuffle=False, drop_last=True, collate_fn=collate_fn_device)

## Regression-based pose inference

We provide a baseline method that regresses 2D pose straight from the image. Make sure that it runs on your hardware and configuration.

In [None]:
# define a regression network that works on dictionaries
class RegressionNet(torch.nn.Module):
    def __init__(self, num_joints):
        super().__init__()
        self.num_joints = num_joints
        self.net = torchvision.models.resnet50(num_classes=num_joints*2)

    def forward(self, dictionary):
        return DeviceDict({'pose_2d' : self.net(dictionary['img']).reshape(-1,self.num_joints,2)})
num_joints = len(joint_names)
regression_network = RegressionNet(num_joints=num_joints).cuda()

In [None]:
# training loop for regression
%matplotlib inline
from IPython import display
optimizer = torch.optim.Adam(regression_network.parameters(), lr=0.001)
fig=plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
axes=fig.subplots(1,2)
num_epochs = 10
losses = []
for e in range(num_epochs):
  train_iter = iter(train_loader)
  for i in range(len(train_loader)):
      batch_cpu = next(train_iter)
      batch_gpu = batch_cpu.to('cuda')
      pred = regression_network(batch_gpu)
      pred_cpu = pred.to('cpu')

      loss = torch.nn.functional.mse_loss(pred['pose_2d'], batch_gpu['pose_2d'])
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      losses.append(loss.item())

      if i%10==0:
          # clear figures for a new update
          for ax in axes:
            ax.cla()
          # plot the predicted pose and ground truth pose on the image
          plotPoseOnImage([pred_cpu['pose_2d'][0].detach(), batch_cpu['pose_2d'][0]], 
                          ecds.denormalize(batch_cpu['img'][0]), ax=axes[0])
          # plot the training error on a log plot
          axes[1].plot(losses)
          axes[1].set_yscale('log')
          # clear output window and diplay updated figure
          display.clear_output(wait=True)
          display.display(plt.gcf())
          print("Epoch {}, iteration {} of {} ({} %), loss={}".format(e, i, len(train_loader), 100*i//len(train_loader), losses[-1]))
plt.close('all')

## Heatmap-based pose classification

In [None]:
# Detection network that handles dictionaries as input and output
class HeatNetWrapper(torch.nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net

    def forward(self, dictionary):
        return DeviceDict({'heatmap':(self.net(dictionary['img'])['out'])})
num_joints = len(joint_names)
det_network = HeatNetWrapper(torchvision.models.segmentation.deeplabv3_resnet50(num_classes=num_joints)).cuda()

In [None]:
# Function that takes an NxKx2 pose vector (N: batch dimension, K: number of keypoints) to create stacks of heatmaps that have Gaussian distribution with the mean at the keypoint and standard deviation equal to 3.
# The second argument specifies the output dimensions of the map. Note that the keypoints are defined in normalized coordinates, ranging from 0..1 irrespectively of the image resolution.
import math

r"""Creates a heatmap stack, with each channel having Gaussian form with mean at the pose keypoint locations

        Args:
            pose_2d (tensor): tensor of size N x K x 2, with K the number of keypoints. Keypoint locations store relative keypoint locations, i.e. both x and y coordinated in the range 0..1
            map_size (tuple): height and width of the heatmap to be generated
        Returns:
            heatmap (tensor): tensor of size N x K x H x W, with K the number of keypoints
"""       
def pose2heatmap(pose_2d, map_size):
  # TODO: Task I
  pass

r"""Takes a heatmap and returns the location of the maximum value in the heatmap

        Args:
            heatmap (tensor): tensor of size N x K x H x W, with K the number of keypoints
        Returns:
            pose (tensor): tensor of size N x K x 2, the 2D pose for each image in the batch
"""       
def heatmap2pose(heatmap):
  max_alongx, _ = torch.max(heatmap, dim=-1)
  max_alongy, _ = torch.max(heatmap, dim=-2)
  _, max_y_index= torch.max(max_alongx, dim=-1)
  _, max_x_index = torch.max(max_alongy, dim=-1)
  res_y, res_x = heatmap.shape[-2:]
  return torch.stack([max_x_index/float(res_x), max_y_index/float(res_y)],dim=-1)

In [None]:
# training loop for heatmap prediction
%matplotlib inline
from IPython import display
optimizer = torch.optim.Adam(det_network.parameters(), lr=0.001)
fig=plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
axes=fig.subplots(1,4)
losses = []
num_epochs = 10
for e in range(num_epochs):
  train_iter = iter(train_loader)
  for i in range(len(train_loader)):
      batch_cpu = next(train_iter)
      batch_gpu = batch_cpu.to('cuda')
      pred_gpu = det_network(batch_gpu)
      pred_cpu = pred_gpu.to('cpu')
      
      # convert between representations
      img_shape = batch_gpu['img'].shape
      gt_heatmap_gpu = pose2heatmap(batch_gpu['pose_2d'], img_shape[-2:])
      pred_pose = heatmap2pose(pred_cpu['heatmap']).cpu() # note, not differentiable
      gt_pose_max = heatmap2pose(gt_heatmap_gpu).cpu()

      # optimize network
      loss = torch.nn.functional.mse_loss(pred_gpu['heatmap'], gt_heatmap_gpu)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      losses.append(loss.item())
        
      # display progress
      if i%10==0:
          # clear figure for a new update
          for ax in axes: 
              ax.cla()
          # plot the ground truth and the predicted pose on top of the image
          plotPoseOnImage([pred_pose[0], batch_cpu['pose_2d'][0]], ecds.denormalize(batch_cpu['img'][0]), ax=axes[0])
          # plot the predicted heatmap map and the predicted pose on top
          plotPoseOnImage([pred_pose[0]], heatmap2image(pred_cpu['heatmap'][0]), ax=axes[1])
          # plot the reference heatmap map and the GT pose on top
          plotPoseOnImage([gt_pose_max[0]], heatmap2image(gt_heatmap_gpu[0].cpu()), ax=axes[2])
          # plot the current training error on a logplot
          axes[3].plot(losses); axes[3].set_yscale('log')
          # clear output window and diplay updated figure
          display.clear_output(wait=True)
          display.display(plt.gcf())
          print("Epoch {}, iteration {} of {} ({} %), loss={}".format(e, i, len(train_loader), 100*i//len(train_loader), losses[-1]))
plt.close('all')

## Heatmap-based pose regression

In [None]:
def integral_heatmap_layer(dict):
    # compute coordinate matrix
    heatmap = dict['heatmap']

    # TODO: Task II

    return DeviceDict({'probabilitymap': h_norm, 'pose_2d': pose})

In [None]:
int_network = HeatNetWrapper(torchvision.models.segmentation.deeplabv3_resnet50(num_classes=num_joints)).cuda()

In [None]:
%matplotlib inline
from IPython import display
import time
optimizer = torch.optim.Adam(int_network.parameters(), lr=0.001)
fig=plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
axes=fig.subplots(1,3)
losses = []
num_epochs = 100
for e in range(num_epochs):
  train_iter = iter(train_loader)
  for i in range(len(train_loader)):
      batch_cpu = next(train_iter)
      batch_gpu = batch_cpu.to('cuda')
      pred_raw = int_network(batch_gpu)
      pred_integral = integral_heatmap_layer(pred_raw) # note, this function must be differentiable

      # optimize network
      loss = torch.nn.functional.mse_loss(pred_integral['pose_2d'], batch_gpu['pose_2d'])
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      losses.append(loss.item())
        
      # plot progress
      if i%10==0:
          # clear figures for a new update
          for ax in axes: 
              ax.cla()
          pred_cpu = pred_integral.to('cpu')
          # plot the ground truth and the predicted pose on top of the image
          plotPoseOnImage([pred_cpu['pose_2d'][0].detach(), batch_cpu['pose_2d'][0]], ecds.denormalize(batch_cpu['img'][0]), ax=axes[0])
          # plot the predicted probability map and the predicted pose on top
          plotPoseOnImage([pred_cpu['pose_2d'][0].detach()], heatmap2image(pred_cpu['probabilitymap'][0]).detach(), ax=axes[1])
          # plot the current training error on a logplot
          axes[2].plot(losses)
          axes[2].set_yscale('log')
          # clear output window and diplay updated figure
          display.clear_output(wait=True)
          display.display(plt.gcf())
          print("Epoch {}, iteration {} of {} ({} %), loss={}".format(e, i, len(train_loader), 100*i//len(train_loader), losses[-1]))
plt.close('all')

In [None]:
# TODO: Task III, validation, which approach is the best?