# 6 DoF Pose from Semantic Keypoints #

The following primarity comes from https://github.com/vaishak2future/posefromkeypoints.git

The code is an implementation of https://www.seas.upenn.edu/~pavlakos/projects/object3d/
Find the associated tutorial at:https://medium.com/@vaishakvk/geometric-deep-learning-for-pose-estimation-6af45da05922

This is an adaptation of code used in Prof.Kostas Daniilidis' course. 

### Imports
Change the directory to ws directiory. Must be run within dockercontainter

In [None]:
%cd ~/mines_ws

from __future__ import division
from __future__ import print_function

import torch
import json, random
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from src.keypoints_detection.hourglass import hg
from tqdm import tqdm
from PIL import Image, ImageFilter, ImageEnhance
import os
#from os import join
from time import time
import numpy as np
import cv2 

os.makedirs('./models', exist_ok=True)

## Keypoint Training ##

The following code is what trains the keypoint detection model

#### Dataset
The cell below defines the dataset used for this model

In [None]:
class Dataset(Dataset):
    """The following defines how the model will load the data"""

    def __init__(self, dataset_dir:str, is_train=True, transform=None):
        """Initialize Dataset

        Args:
            dataset_dir (str): path to the dataset.
            is_train (bool, optional): Tracks id the dataset is a train dataset. Defaults to True.
            transform (torch.transform, optional): pytorch transformations. Defaults to None.
        """
        self.dataset_dir = dataset_dir
        self.is_train = is_train
        self.transform = transform
        file_name = "frame_data_train.json" if is_train else "frame_data_test.json"

        path = f"{dataset_dir}/{file_name}"
        with open(path, 'r') as f:
            self.data = dict(json.load(f))

        self.img_names = list(self.data.keys())
        


    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        image_name = self.img_names[idx]

        image_dir = self.data[image_name]["img_dir"]

        image = cv2.imread(f"{image_dir}/Images/{image_name}.png")
        # bb = [0,0,image.shape[1], image.shape[0]]
        bb = self.data[image_name]["bbox"]
        # image = Image.fromarray((image[:,:,:3]*255).astype(np.uint8))
        image = Image.fromarray(image)
        # bb = self.data[image_name]["bbox"]

        keypoints = self.data[image_name]["keypoints"]
        keypoints = np.array(keypoints)
        # print(keypoints)
        item = {'image': image, 'bb': bb, 'keypoints': keypoints}
        # print(keypoints.shape)
        if self.transform is not None:
            item = self.transform(item)
        return item

#### Transformations
The following defines the necessary transformations, including
- Blurr: applies blur filter and dims image. Is necessary for training on simulated dataset
- Crop and Pad: Crops image to bounding box and bad to square image
- Locs to Heatmaps: Create a heatmap for each pixel
- To Tensor: converts necessary information into a tensor

In [None]:
#Transformations

from typing import Any


def generate_heatmap(heatmap, pt, sigma):
    heatmap[int(pt[1])][int(pt[0])] = 1
    heatmap = cv2.GaussianBlur(heatmap, sigma, 0)
    am = np.amax(heatmap)
    heatmap /= am
    return heatmap


class Blurr:
    def __init__(self, radius) -> None:
        import time
        self.radius = radius
        random.seed(time.time())

    def __call__(self, sample) -> Any:
        img = sample['image']
        img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(2,4)))
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(random.uniform(0.5, 1))
        sample['image'] = img
        return sample


class CropAndPad:

    def __init__(self, out_size=(256,256)):
        self.out_size = out_size[::-1]

    def __call__(self, sample):
        image, bb = sample['image'], sample['bb']
        img_size = image.size
        min_x, min_y = bb[:2]
        max_x, max_y = bb[2:]
        center_x = (min_x + max_x) / 2
        center_y = (min_y + max_y) / 2
        width = max([max_x-min_x, max_y-min_y])
        min_x = int(center_x) - int(width)//2
        min_y = int(center_y) - int(width)//2
        max_x = int(center_x) + int(width)//2
        max_y = int(center_y) + int(width)//2
        sample['image'] = image.crop(box=(min_x,min_y,max_x,max_y))
        sample['orig_image'] = image
        sample['center'] = np.array([center_x, center_y], dtype=np.float32)
        sample['min'] = np.array([min_x, min_y], dtype=np.float32)
        sample['scale'] = np.array([width/self.out_size[0]], dtype=np.float32)
        sample['width'] = width
        if width != self.out_size[0]:
            sample['image'] = sample['image'].resize(self.out_size)
        if 'mask' in sample:
            sample['mask'] = sample['mask'].crop(box=(min_x,min_y,max_x,max_y)).resize(self.out_size)
        if 'keypoints' in sample:
            keypoints = sample['keypoints']
            for i in range(keypoints.shape[0]):
                # if keypoints[i,2] != 0:
                if keypoints[i,0] < min_x or keypoints[i,0] > max_x or keypoints[i,1] < min_y or keypoints[i,1] > max_y:
                    keypoints[i,:] = [0,0,0]
                else:
                    keypoints[i,:2] = (keypoints[i,:2]-np.array([min_x, min_y]))*self.out_size/width
        sample['keypoints'] = keypoints
        sample.pop('bb')
        return sample

# Convert keypoint locations to heatmaps
class LocsToHeatmaps:

    def __init__(self, img_size=(256,256), out_size=(64,64), sigma=1):
        self.img_size = img_size
        self.out_size = out_size
        self.x_scale = 1.0 * out_size[0]/img_size[0]
        self.y_scale = 1.0 * out_size[1]/img_size[1]
        self.sigma=sigma
        x = np.arange(0, out_size[1], dtype=np.float32)
        y = np.arange(0, out_size[0], dtype=np.float32)
        self.yg, self.xg = np.meshgrid(y,x, indexing='ij')
        return

    def __call__(self, sample):
        sigma = 7
        gaussian_hm = np.zeros((self.out_size[0], self.out_size[1], sample['keypoints'].shape[0]))
        for i,keypoint in enumerate(sample['keypoints']):
            if keypoint[2] != 0:
                gaussian_hm[:,:,i] = generate_heatmap(gaussian_hm[:,:,i], tuple(keypoint[:-1].astype(np.int64) * self.x_scale), (sigma, sigma))
        sample['keypoint_locs'] = sample['keypoints'][:,:2]
        sample['visible_keypoints'] = sample['keypoints'][:,2]
        sample['keypoint_heatmaps'] = gaussian_hm
        return sample

# Convert numpy arrays to Tensor objects
# Permute the image dimensions
class ToTensor:

    def __init__(self, downsample_mask=False):
        self.tt = transforms.ToTensor()
        self.downsample_mask=downsample_mask

    def __call__(self, sample):
        sample['image'] = self.tt(sample['image'])
        if 'orig_image' in sample:
            sample['orig_image'] = self.tt(sample['orig_image'])
        if 'mask' in sample:
            if self.downsample_mask:
                sample['mask'] = self.tt(sample['mask'].resize((64,64), Image.ANTIALIAS))
            else:
                sample['mask'] = self.tt(sample['mask'])
        if 'in_mask' in sample:
            sample['in_mask'] = self.tt(sample['in_mask'])
            # sample['in_mask'] = sample['in_mask'].unsqueeze(0)
        if 'keypoint_heatmaps' in sample:
            sample['keypoint_heatmaps'] =\
                torch.from_numpy(sample['keypoint_heatmaps'].astype(np.float32).transpose(2,0,1))
            sample['keypoint_locs'] =\
                torch.from_numpy(sample['keypoint_locs'].astype(np.float32))
            sample['visible_keypoints'] =\
                torch.from_numpy(sample['visible_keypoints'].astype(np.float32))
        return sample

class Normalize:

    def __call__(self, sample):
        sample['image'] = 2*(sample['image']-0.5)
        if 'in_mask' in sample:
            sample['in_mask'] = 2*(sample['in_mask']-0.5)
        return sample

#### Trainer
Defines the model trainer

In [None]:
class Trainer(object):

    def __init__(self, dataset_dir, learning_rate, model_path: str = None):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        train_transform_list = [Blurr(2.2), CropAndPad(out_size=(256, 256)),LocsToHeatmaps(out_size=(64, 64)),ToTensor(),Normalize()]
        test_transform_list = [CropAndPad(out_size=(256, 256)),LocsToHeatmaps(out_size=(64, 64)),ToTensor(),Normalize()]
        self.train_ds = Dataset(is_train=True, transform=transforms.Compose(train_transform_list), dataset_dir=dataset_dir)
        self.test_ds = Dataset(is_train=False, transform=transforms.Compose(test_transform_list), dataset_dir=dataset_dir)

        self.model = hg(num_stacks=1, num_blocks=1, num_classes=10).to(self.device)

        if model_path is not None:
            print("loading model from checkpoint")
            try:
                checkpoint = torch.load(model_path)
                self.model.load_state_dict(checkpoint['model'])
            except:
                print("No checkpoint loaded")
        else:
            print("Not using checkpointed model")
        # define loss function and optimizer
        self.heatmap_loss = torch.nn.MSELoss().to(self.device) # for Global loss
        self.optimizer = torch.optim.RMSprop(self.model.parameters(),
                                             lr = learning_rate)#2.5e-4)
        self.train_data_loader = DataLoader(self.train_ds, batch_size=16,
                                            num_workers=8,
                                            pin_memory=True,
                                            shuffle=True)
        self.test_data_loader = DataLoader(self.test_ds, batch_size=32,
                                           num_workers=8,
                                           pin_memory=True,
                                           shuffle=True)

        self.summary_iters = []
        self.losses = []
        self.pcks = []
        self.running_losses = []

    def train(self, num_epochs):
        self.total_step_count = 0
        start_time = time()
        for epoch in range(1,num_epochs+1):

            print("Epoch%d/%d"%
                    (epoch,num_epochs), end="\r")

            running_loss = 0

            for step, batch in enumerate(self.train_data_loader):
                self.model.train()
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in batch.items()}
                self.optimizer.zero_grad()
                pred_heatmap_list = self.model(batch['image'])
                loss = self.heatmap_loss(pred_heatmap_list[-1], batch['keypoint_heatmaps'])
                loss.backward()
                self.optimizer.step()                                          
                
                self.total_step_count += 1

                running_loss += loss.detach()

                if step % 10 == 0:
                    print(f"running loss at step {step} = \t{running_loss}", end="\r")
            
            self.running_losses.append(running_loss)
            print(f"Epoch {epoch} / {num_epochs} total Loss\t")

            if epoch % 10 == 0:
                print(f"Savig Checkpoint at epoch {epoch}")
                checkpoint = {'model': self.model.state_dict()}
                torch.save(checkpoint, './output/model_checkpoint.pt')
                
        checkpoint = {'model': self.model.state_dict()}
        torch.save(checkpoint, './models/kpt_checkpoint.pt')

    def save(self, checkpoint_name):
        checkpoint = {'model': self.model.state_dict()}
        print("saving checkpoint")
        torch.save(checkpoint, f'./models/{checkpoint_name}.pt')

### Training the Model

In [None]:
data_directory = "./data/NewData"

# if the simulated data is given in an improper format
%run ~/mines_ws/src/data_collection/preprocessing/fix_json.py {data_directory} DescriptiveFrameData.json
data_directory += "/accumulated_data"

model_path = "./models/model_checkpoint.pt"

In [None]:
trainer = Trainer(dataset_dir=data_directory, learning_rate=0.00005, model_path=model_path)
trainer.train(num_epochs=50)

#### Save the model, does not need to be run sequentially

In [None]:
trainer.save("kpt_checkpoint")