In [None]:
from akpr.src.akpr import generate_refined_keypoints

In [None]:

from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
import json
import pandas as pd
import os
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import numpy as np
s3_access_utils = S3AccessUtils('/root/data', json.load(open(os.environ['AWS_CREDENTIALS'])))


In [None]:
f = s3_access_utils.download_from_s3("aquabyte-images-adhoc", "jane/weight_estimation/pen_37_2020-06-13_2020-06-20_3950_-1.csv")
dat = pd.read_csv(f)


In [None]:
for ind in range(10):
    ann = json.loads(dat.iloc[ind]['annotation'].replace("'", '"'))
    left_crop_url, right_crop_url  = dat.iloc[ind]['left_crop_url'], dat.iloc[ind]['right_crop_url']
    crop_url = {'left_crop_url': left_crop_url, 
                'right_crop_url' : right_crop_url}
    modified_ann = generate_refined_keypoints(ann, left_crop_url, right_crop_url)

    with open('tests/test_case/t{}_crop_url.json'.format(ind), "w") as json_file:
        json.dump(crop_url, json_file)
    with open('tests/test_case/t{}_ann.json'.format(ind), 'w') as json_file:
        json.dump(ann, json_file)
    with open('tests/test_case/t{}_modified_ann.json'.format(ind), 'w') as json_file:
        json.dump(modified_ann, json_file)

In [None]:
str2dict = lambda x: json.loads(x.replace("'", "\"")) if x is not np.nan else None

def ann2dict(kps):
    """
    Parameters: 
    ----------
    kps : either annotation['leftCrop'] or annotation ['rightCrop']
        
    Returns:
    ----------
    dictionary
    """
    return {item['keypointType']: [item['xCrop'], item['yCrop']] for item in kps}

def plot_image_url(url):

    image_f, bucket, image_key = s3_access_utils.download_from_url(url)

    img = Image.open(image_f)

    alpha = 2 # Contrast control (1.0-3.0)
    beta = 20 # Brightness control (0-100)

    img = np.asarray(img)
    adjusted = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    return adjusted


def display_crops(left_url, right_url, ann, overlay_keypoints=True, show_labels=False):
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 20))
    left_image, right_image = plot_image_url(left_url), plot_image_url(right_url)
    axes[0].imshow(left_image)
    axes[1].imshow(right_image)

    left_keypoints = ann2dict(ann['leftCrop'])
    right_keypoints = ann2dict(ann['rightCrop'])
    
    if overlay_keypoints:
        for bp, kp in left_keypoints.items():
            axes[0].scatter([kp[0]], [kp[1]], color='red', s=5)
            if show_labels:
                axes[0].annotate(bp, (kp[0], kp[1]), color='red')
        for bp, kp in right_keypoints.items():
            axes[1].scatter([kp[0]], [kp[1]], color='red', s=5)
            if show_labels:
                axes[1].annotate(bp, (kp[0], kp[1]), color='red')
    plt.show()

def display_refinement(right_url, ann, modified_ann, overlay_keypoints=True, show_labels=False):
    
    fig, axes = plt.subplots(figsize=(20, 20))
    right_image = plot_image_url(right_url)

    axes.imshow(right_image)

    right_keypoints = ann2dict(ann['rightCrop'])
    modified_right_keypoints = ann2dict(modified_ann['rightCrop'])


    for bp, kp in right_keypoints.items():
        axes.scatter([kp[0]], [kp[1]], color='red', s=5)
        if show_labels:
            axes.annotate(bp, (kp[0], kp[1]), color='red')
    for bp, kp in modified_right_keypoints.items():
        axes.scatter([kp[0]], [kp[1]], color='blue', s=5)
        if show_labels:
            axes.annotate(bp, (kp[0], kp[1]), color='blue')
    plt.show()

In [None]:
display_refinement(right_crop_url, ann, modified_ann)
