# Keypoint detection with Pytorch

## Imports and installs

In [None]:
# Necessary imports

import os
import cv2
import sys
import json
import copy
import time
import datetime
import warnings
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

plt.rcParams['figure.figsize'] = [20, 10]

import torch
from torch.utils.data import Dataset, DataLoader
from pycocotools.cocoeval import COCOeval

In [None]:
# Clone the updated Torchvision repo and install Trochvision from source
!git clone https://github.com/karolyartur/vision
!pip uninstall torchvision -y
!cd vision && python setup.py install

# Add the built torchvision package to sys.path
sys.path.append('vision/build/lib.linux-x86_64-3.8')

In [None]:
# Import the newly built torchvision and necessary modules
import torchvision

sys.path.append('vision/references/detection')

from torchvision.transforms import functional as F
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor

# Import modules from vision/references
import vision.references.detection.utils as utils
import vision.references.detection.engine as engine
from vision.references.detection.utils import collate_fn
from vision.references.detection.engine import train_one_epoch, _get_iou_types
from vision.references.detection.coco_utils import get_coco_api_from_dataset
from vision.references.detection.coco_eval import CocoEvaluator

##Constants and function definitions

In [None]:
# Constants and functions

DATASET_PATH = 'drive/MyDrive/KK/SlideBot/data/synthetic_data_keypoints'
SEGMENTATION_PATH = os.path.join(DATASET_PATH, 'Segmentation_annotations')
TRAIN_IMGS_PATH = os.path.join(DATASET_PATH, 'datasets', 'images', 'train')
VAL_IMGS_PATH = os.path.join(DATASET_PATH, 'datasets', 'images', 'val')
KEYPOINT_ANNOT_PATH = os.path.join(DATASET_PATH, 'keypoint_annotations_coco')
MODEL_OUTPUT_PATH = 'drive/MyDrive/KK/SlideBot/data/training_results/keypoint_rcnn'
BOX_COLOR = [255,0,0,255]
BG_COLOR = [0,0,0,255]

NUM_KEYPOINTS = 48

train_imgs = os.listdir(TRAIN_IMGS_PATH)
val_imgs = os.listdir(VAL_IMGS_PATH)
train_imgs.sort()
val_imgs.sort()
all_image_files = copy.copy(val_imgs)
all_image_files.extend(train_imgs)

def img_name_to_annot_name(img_name):
  '''Return the name of the annotation image given the name of the rendered image
  '''
  return img_name.replace('.', '_annotation.')

def visualize_annot(img, bboxes, keypoints, threshold=0.5):
  '''Visualize bounding box and keypoint annotations
  '''
  fig, ax = plt.subplots()
  ax.imshow(img)
  for index,bbox in enumerate(bboxes):
    rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='g', facecolor='none')
    ax.add_patch(rect)
    for id, keypoint in enumerate(keypoints[index]):
      if keypoint[2] >= threshold:
        circ = patches.Circle((keypoint[0], keypoint[1]), radius=2, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(circ)
        ax.text(keypoint[0], keypoint[1], str(id), color='r', fontsize=10)
  plt.show()

## Create COCO-style annotations from segmentation masks

In [None]:
# Test data availability

img = np.array(Image.open(os.path.join(TRAIN_IMGS_PATH, train_imgs[0])))
plt.imshow(img)
plt.show()

img = np.array(Image.open(os.path.join(SEGMENTATION_PATH, img_name_to_annot_name(train_imgs[0]))))
plt.imshow(img)
plt.show()

img = np.array(Image.open(os.path.join(VAL_IMGS_PATH, val_imgs[0])))
plt.imshow(img)
plt.show()

img = np.array(Image.open(os.path.join(SEGMENTATION_PATH, img_name_to_annot_name(val_imgs[0]))))
plt.imshow(img)
plt.show()

In [None]:
# Get all unique colors (This is separated because it takes a lot of time)

expected_unique_color_num = 2 + NUM_KEYPOINTS/2  # Background + Box + 24 slide

segment_imgs = os.listdir(SEGMENTATION_PATH)
segment_imgs.sort()

unique_colors = []

for filename in segment_imgs:
  img = np.array(Image.open(os.path.join(SEGMENTATION_PATH, filename)))
  img = np.reshape(img, (img.shape[0]*img.shape[1], img.shape[2]))
  for color in np.unique(img, axis=0):
    if color.tolist() not in unique_colors:
      unique_colors.append(color.tolist())
  if len(unique_colors) >= expected_unique_color_num:
    break

In [None]:
# Generate bbox and keypoint annotation JSON files for the train and valid images

unique_colors.sort()
slide_colors = [color for color in unique_colors if not color in [BOX_COLOR, BG_COLOR]]

def mask_img_to_bbox(mask_img, color):
  '''Return a bounding box [x_min, y_min, x_max, y_max] given an input segmentation mask and a color ID
  '''
  xs, ys = np.where(np.all(mask_img==color, axis=2))
  if len(xs) > 0 and len(ys) > 0:
    return([int(val) for val in [min(xs), min(ys), max(xs), max(ys)]])
  else:
    return(None)


def mask_image_to_keypoints(img, color_ids, bbox):
  '''Return list of keypoints given a segmentation image, the color ID-s for the keypoints and the bounding box for the containing instance

  For each color is color_id two keypoints will be created one for the left side of the bounding box and one for the right side
  '''
  bbox_middle = (bbox[1]+bbox[3])/2
  keypoints = []
  for slide_color in color_ids:
    selected_pixels = np.where(np.all(img==slide_color, axis=2))
    missing = len(selected_pixels[0])==0 and len(selected_pixels[1])==0
    if missing:
      keypoints.append([0,0,0])  # Left keypoint
      keypoints.append([0,0,0])  # Right keypoint
    else:
      if not (np.all(selected_pixels[1] < bbox_middle) or np.all(selected_pixels[1] > bbox_middle)) and max(selected_pixels[1])-min(selected_pixels[1]) > 50:
        # Both keypoints for the slide are visible
        left_pixels = np.where(selected_pixels[1] < bbox_middle)
        right_pixels = np.where(selected_pixels[1] > bbox_middle)
        keypoints.append([int(np.mean(selected_pixels[0][left_pixels])),int(np.mean(selected_pixels[1][left_pixels])),1])  # Left keypoint
        keypoints.append([int(np.mean(selected_pixels[0][right_pixels])),int(np.mean(selected_pixels[1][right_pixels])),1])  # Right keypoint
      elif np.all(selected_pixels[1] < bbox_middle) or bbox[3]==img.shape[1]-1:
        # Only left side of the slide is visible
        keypoints.append([int(np.mean(selected_pixels[0])),int(np.mean(selected_pixels[1])),1])  # Left keypoint
        keypoints.append([0,0,0])  # Right keypoint
      elif np.all(selected_pixels[1] > bbox_middle) or bbox[1] == 0:
        # Only right side of the slide is visible
        keypoints.append([0,0,0])  # Left keypoint
        keypoints.append([int(np.mean(selected_pixels[0])),int(np.mean(selected_pixels[1])),1])  # Right keypoint
  return keypoints

for img_file_name in all_image_files:
  print(img_file_name)
  mask_img = np.array(Image.open(os.path.join(SEGMENTATION_PATH, img_name_to_annot_name(img_file_name))))
  bbox = mask_img_to_bbox(mask_img, BOX_COLOR)
  if bbox:
    keypoints = mask_image_to_keypoints(mask_img, slide_colors, bbox)

  json_name = img_file_name.split('.')[0] + '.json'
  with open(os.path.join(KEYPOINT_ANNOT_PATH, json_name), 'w') as f:
    if bbox:
      bbox = [bbox[1],bbox[0],bbox[3],bbox[2]]
      keypoints_copy = copy.deepcopy(keypoints)
      for i in range(len(keypoints_copy)):
        keypoints[i][0] = keypoints_copy[i][1]
        keypoints[i][1] = keypoints_copy[i][0]
      f.write(json.dumps({'bboxes':[bbox], 'keypoints':[keypoints]}))
    else:
      f.write(json.dumps({'bboxes':[], 'keypoints':[]}))

## Keypoint detector training

### Dataset definition

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_path, annot_path, num=None, size=(720,1280), num_keypoints=48):                
        self.size=size
        self.annotations_files = [os.path.join(annot_path, e) for e in sorted(os.listdir(annot_path))]
        self.imgs_files = [os.path.join(img_path, e) for e in sorted(os.listdir(img_path))]
        if num:
          if not isinstance(num, int):
            raise TypeError(f'"num" argument must be integer, instead got {type(num)}')
          if num < len(self.imgs_files):
            self.imgs_files = self.imgs_files[0:num]
            self.annotations_files = [e for e in self.annotations_files if any(os.path.split(e)[-1].split('.')[0] in p for p in self.imgs_files)]
          else:
            warnings.warn(f'"num" is greater than the number of images in the dataset ({len(self.imgs_files)}) all images will be used!') 
    
    def __getitem__(self, idx):
        img_path = self.imgs_files[idx]
        annotations_path = self.annotations_files[idx]

        img = np.array(Image.open(img_path))[:,:,:3]
        img = cv2.resize(img, (self.size[1], self.size[0]), interpolation=cv2.INTER_LINEAR)
        
        bbox_exists = False
        with open(annotations_path) as f:
            data = json.loads(f.read())
            bboxes = data['bboxes']
            if bboxes:
              bboxes[0][0] *= self.size[1]/1920
              bboxes[0][2] *= self.size[1]/1920
              bboxes[0][1] *= self.size[0]/1080
              bboxes[0][3] *= self.size[0]/1080
              bbox_exists = True
            keypoints = data['keypoints']
            if keypoints:
              for i in range(len(keypoints[0])):
                keypoints[0][i][0] *= self.size[1]/1920
                keypoints[0][i][1] *= self.size[0]/1080
        
        # Convert everything into a torch tensor        
        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)       
        target = {}
        target["labels"] = torch.as_tensor([1 for _ in bboxes], dtype=torch.int64) # all objects are boxes
        target["image_id"] = torch.tensor([idx])
        if bbox_exists:
          target["boxes"] = bboxes
          target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
          target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
          target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
        else:
          target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
          target["area"] = torch.tensor([0])
          target["iscrowd"] = torch.ones(len(bboxes), dtype=torch.int64)
          target["keypoints"] = torch.zeros((1, num_keypoints, 3), dtype=torch.float32)        
        img = F.to_tensor(img)        
        return img, target
    
    def __len__(self):
        return len(self.imgs_files)

In [None]:
# Check dataset and annotations

demo = False

train_dataset = CustomDataset(TRAIN_IMGS_PATH, KEYPOINT_ANNOT_PATH, num=1, num_keypoints=NUM_KEYPOINTS)
data_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

iterator = iter(data_loader)
batch = next(iterator)
print("Targets:\n", batch[1])

image = (batch[0][0].permute(1,2,0).numpy() * 255).astype(np.uint8)
bboxes = batch[1][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()
keypoints = batch[1][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist()

visualize_annot(image, bboxes, keypoints)

### Evaluation definition

In [None]:
@torch.inference_mode()
def evaluate(model, data_loader, device, keypointnum=48, print_freq=100):
    n_threads = torch.get_num_threads()
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = "Test:"

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types, keypointnum=keypointnum)

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(img.to(device) for img in images)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(images)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator

### Training

In [None]:
def get_model(num_keypoints, weights_path=None):
  anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))
  model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
      pretrained=False,
      pretrained_backbone=True,
      num_keypoints=num_keypoints,
      num_classes = 2, # Background is the first class, box is the second class
      rpn_anchor_generator=anchor_generator)
  model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512,num_keypoints)
  if weights_path:
    state_dict = torch.load(weights_path)
    model.load_state_dict(state_dict)
  return model


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


# Create datasets and data loaders
dataset_train = CustomDataset(TRAIN_IMGS_PATH, KEYPOINT_ANNOT_PATH, num=4, num_keypoints=NUM_KEYPOINTS)
dataset_val = CustomDataset(VAL_IMGS_PATH, KEYPOINT_ANNOT_PATH, num=4, num_keypoints=NUM_KEYPOINTS)

data_loader_train = DataLoader(dataset_train, batch_size=2, shuffle=True, collate_fn=collate_fn)
data_loader_val = DataLoader(dataset_val, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Create model
model = get_model(num_keypoints = NUM_KEYPOINTS)
model.to(device)

# Set otpimizer and hyperparams
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
num_epochs = 100

#Trainig loop
print_freq=50
for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=print_freq)
    evaluate(model, data_loader_val, device, keypointnum=NUM_KEYPOINTS, print_freq=print_freq)
    
# Save model weights after training
torch.save(model.state_dict(), os.path.join(MODEL_OUTPUT_PATH, f"keypointrcnn_weights{str(datetime.datetime.now()).split('.')[0].replace(' ','_')}.pth"))

### Predictions

In [None]:
# Get image from validation set
iterator = iter(data_loader_val)
images, targets = next(iterator)
images_copy = copy.copy(images)
images = [image.to(device) for image in images]

# Predict
with torch.no_grad():
    model.to(device)
    model.eval()
    output = model(images)

# Visualize predicitions
print("Predictions: \n", output)
visualize_annot(images_copy[0].permute(1,2,0), [output[0]['boxes'].cpu()[0]], output[0]['keypoints'].cpu())