In [None]:
from collections import defaultdict
import random
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import torch
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.weight_estimation.weight_estimator import WeightEstimator
from research.gtsf_data.gtsf_dataset import GTSFDataset
from research.gtsf_data.body_parts import BodyParts
from research.utils.keypoint_transformations import get_keypoint_arr
from research.utils.optics import pixel2world
from research.weight_estimation.akpd_scorer import generate_confidence_score
from scipy.spatial.distance import pdist
from research.weight_estimation.data_loader import *
from research.weight_estimation.biomass_estimator import *

pd.set_option('display.max_rows', 500)

In [None]:
from collections import defaultdict
from copy import deepcopy
import json, os
import numpy as np
import pandas as pd

from keras.models import load_model
from scipy.spatial import Delaunay

from research.utils.optics import pixel2world
from research.weight_estimation.akpd_scorer import generate_confidence_score
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.gtsf_data.body_parts import BodyParts
from research.utils.keypoint_transformations import get_keypoint_arr, get_raw_3d_coordinates

BODY_PARTS = BodyParts().get_core_body_parts()


class GTSFDataset(object):

    def __init__(self, start_date, end_date, akpd_scorer_url, species='salmon', add_template_matching_keypoints=False):
        self.s3_access_utils = S3AccessUtils('/root/data')
        self.df = self.generate_raw_df(start_date, end_date)
        self.prepare_df(akpd_scorer_url, species, add_template_matching_keypoints)

    @staticmethod
    def generate_raw_df(start_date, end_date):
        rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
        query = """
            select * from research.fish_metadata a left join keypoint_annotations b
            on a.left_url = b.left_image_url 
            where b.keypoints -> 'leftCrop' is not null
            and b.keypoints -> 'rightCrop' is not null
            and b.captured_at between '{0}' and '{1}';
        """.format(start_date, end_date)
        df = rds_access_utils.extract_from_database(query)
        print('Raw dataframe loaded!')
        return df

    @staticmethod
    def get_world_keypoints(row):
        return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)

    def prepare_df(self, akpd_scorer_url, species, add_template_matching_keypoints):
        # use QA'ed entries, and only use Cogito entries when QA data is unavailable
        self.df = self.df[self.df.data.apply(lambda x: x['species'].lower()) == species].copy(deep=True)
        qa_df = self.df[self.df.is_qa == True]
        cogito_df = self.df[(self.df.is_qa != True) & ~(self.df.left_image_url.isin(qa_df.left_image_url))]
        self.df = pd.concat([qa_df, cogito_df], axis=0)
        print('Dataset preparation beginning...')

        # add 3D spatial information
        self.df['world_keypoints'] = self.df.apply(lambda x: self.get_world_keypoints(x), axis=1)
        self.df['median_depth'] = self.df.world_keypoints.apply(lambda x: np.median([wkp[1] for wkp in x.values()]))
        print('3D spatial information added!')

        # add k-factor
        self.df['k_factor'] = 1e5 * self.df.weight / self.df.data.apply(lambda x: x['lengthMms']**3).astype(float)
        
        # add AKPD scores and convert world keypoints to matrix form
        self.add_akpd_scores(akpd_scorer_url)
        if add_template_matching_keypoints:
            self.add_template_matching_keypoints()
        self.convert_wkps_to_matrix_form()
    
    @staticmethod
    def in_hull(p, hull):
        hull = Delaunay(hull)
        return hull.find_simplex(p)>=0

    def add_template_matching_keypoints(self):
        print('Adding template matching body keypoints...')

        # load data
        gen = self.s3_access_utils.get_matching_s3_keys(
            'aquabyte-research', 
            prefix='template-matching/2019-12-05T02:50:57', 
            suffixes=['.parquet']
        )

        keys = [key for key in gen]
        f = self.s3_access_utils.download_from_s3('aquabyte-research', keys[0])
        pdf = pd.read_parquet(f)
        pdf['homography'] = pdf.homography_and_matches.apply(lambda x: np.array(x[0].tolist(), dtype=np.float))
        pdf['matches'] = pdf.homography_and_matches.apply(lambda x: np.array(x[1].tolist(), dtype=np.int) if len(x) > 1 else None)

        # merge with existing dataframe
        self.df = pd.merge(self.df, pdf[['left_image_url', 'homography', 'matches']], how='inner', on='left_image_url')

        # generate list of modified keypoints
        modified_keypoints_list = []
        count = 0
        for idx, row in self.df.iterrows():
            if count % 100 == 0:
                print(count)
            count += 1
            X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['leftCrop']])
            X_body = np.array(row.matches)
            is_valid = self.in_hull(X_body[:, :2], X_keypoints)
            X_body = X_body[np.where(is_valid)]
            
            keypoints = deepcopy(row.keypoints)
            left_keypoints, right_keypoints = keypoints['leftCrop'], keypoints['rightCrop']
            left_item = {'keypointType': 'BODY', 'xFrame': X_body[:, 0], 'yFrame': X_body[:, 1]}
            right_item = {'keypointType': 'BODY', 'xFrame': X_body[:, 2], 'yFrame': X_body[:, 3]}
            
            left_keypoints.append(left_item)
            right_keypoints.append(right_item)
            modified_keypoints = {'leftCrop': left_keypoints, 'rightCrop': right_keypoints}
            modified_keypoints_list.append(modified_keypoints)

        # add modified keypoints information to dataframe
        self.df['old_keypoints'] = self.df.keypoints
        self.df['keypoints'] = modified_keypoints_list
        self.df = self.df[self.df.keypoints.apply(lambda x: x['leftCrop'][-1]['xFrame'].shape[0]) > 500]

    def add_akpd_scores(self, akpd_scorer_url):
        print('Adding AKPD scores...')
        # load neural network weights
        akpd_scorer_path, _, _ = self.s3_access_utils.download_from_url(akpd_scorer_url)
        akpd_scorer_network = load_model(akpd_scorer_path)

        akpd_scores = []
        for idx, row in self.df.iterrows():
            input_sample = {
                'keypoints': row.keypoints,
                'cm': row.camera_metadata,
                'stereo_pair_id': row.id,
                'single_point_inference': True
            }
            akpd_score = generate_confidence_score(input_sample, akpd_scorer_network)
            akpd_scores.append(akpd_score)
        self.df['akpd_score'] = akpd_scores

    def convert_wkps_to_matrix_form(self):
        print('Converting world keypoints to matrix form...')
        raw_keypoint_arr_list, norm_keypoint_arr_list = [], []
        for idx, row in self.df.iterrows():
            keypoints, cm = row.keypoints, row.camera_metadata
            try:
                raw_keypoint_arr, norm_keypoint_arr = get_keypoint_arr(keypoints, cm)
            except:
                print(row)
                raw_keypoint_arr, norm_keypoint_arr = None, None
                
            raw_keypoint_arr_list.append(raw_keypoint_arr)
            norm_keypoint_arr_list.append(norm_keypoint_arr)
        self.df['raw_keypoint_arr'] = raw_keypoint_arr_list
        self.df['norm_keypoint_arr'] = norm_keypoint_arr_list

    def get_prepared_dataset(self):
        return self.df

    @staticmethod
    def randomly_rotate_and_translate(wkps, random_x_addition, random_y_addition,
                                      random_z_addition, yaw, pitch, roll):

        # convert to radians
        yaw, pitch, roll = [theta * np.pi / 180.0 for theta in [yaw, pitch, roll]]

        # compute rotation matrix
        R_yaw = np.array([
            [np.cos(yaw), -np.sin(yaw), 0],
            [np.sin(yaw), np.cos(yaw), 0],
            [0, 0, 1]
        ])

        R_pitch = np.array([
            [np.cos(pitch), 0, np.sin(pitch)],
            [0, 1, 0],
            [-np.sin(pitch), 0, np.cos(pitch)]
        ])

        R_roll = np.array([
            [1, 0, 0],
            [0, np.cos(roll), -np.sin(roll)],
            [0, np.sin(roll), np.cos(roll)]
        ])

        R = np.dot(R_yaw, (np.dot(R_pitch, R_roll)))

        # apply rotation
        wkps = np.dot(R, wkps.T).T

        # perform translation
        wkps[:, 0] += random_x_addition
        wkps[:, 1] += random_y_addition
        wkps[:, 2] += random_z_addition
        return wkps

    @staticmethod
    def apply_jitter(wkps, cm, jitter):
        ann_left, ann_right = [], []
        for idx, body_part in enumerate(BODY_PARTS):
            # generate left item
            x, y, z = wkps[idx]
            x_frame_left = x * cm['focalLengthPixel'] / y + cm['pixelCountWidth'] / 2
            y_frame_left = -z * cm['focalLengthPixel'] / y + cm['pixelCountHeight'] / 2
            item_left = {
                'keypointType': body_part,
                'xFrame': x_frame_left + np.random.normal(0, jitter),
                'yFrame': y_frame_left
            }
            ann_left.append(item_left)

            # generate right item
            disparity = cm['focalLengthPixel'] * cm['baseline'] / y
            x_frame_right = x_frame_left - disparity
            y_frame_right = y_frame_left
            item_right = {
                'keypointType': body_part,
                'xFrame': x_frame_right + np.random.normal(0, jitter),
                'yFrame': y_frame_right
            }
            ann_right.append(item_right)

        jittered_kps = {'leftCrop': ann_left, 'rightCrop': ann_right}
        jittered_wkps_arr = get_raw_3d_coordinates(jittered_kps, cm)
        return jittered_kps, jittered_wkps_arr

    def generate_augmented_dataset(self, x_bounds, y_bounds, z_bounds,
                                   yaw_bounds, pitch_bounds, roll_bounds,
                                   jitter, trials):

        augmented_data = defaultdict(list)
        for idx, row in self.df.iterrows():
            cm = row.camera_metadata
            original_wkps = row.raw_keypoint_arr
            original_ann = row.keypoints
            norm_wkps = row.norm_keypoint_arr
            weight = row.weight
            for t in range(trials):

                # generate random position and orientation
                random_x_addition = np.random.uniform(*x_bounds)
                random_y_addition = np.random.uniform(*y_bounds)
                random_z_addition = np.random.uniform(*z_bounds)
                yaw = np.random.uniform(*yaw_bounds)
                pitch = np.random.uniform(*pitch_bounds)
                roll = np.random.uniform(*roll_bounds)

                # perform rotation and translation
                wkps = self.randomly_rotate_and_translate(deepcopy(norm_wkps), random_x_addition, random_y_addition,
                                                          random_z_addition, yaw, pitch, roll)

                # add jitter
                ann, wkps = self.apply_jitter(wkps, cm, jitter)

                # create output row
                output_row = {
                    'original_ann': original_ann,
                    'original_wkps': original_wkps,
                    'norm_wkps': norm_wkps,
                    'wkps': wkps,
                    'ann': ann,
                    'cm': cm,
                    'mean_x': random_x_addition,
                    'mean_y': random_y_addition,
                    'mean_z': random_z_addition,
                    'yaw': yaw,
                    'pitch': pitch,
                    'roll': roll,
                    'weight': weight,
                    'jitter': jitter,
                    'trial': t
                }
                for k, v in output_row.items():
                    augmented_data[k].append(v)

        augmented_df = pd.DataFrame(augmented_data)
        return augmented_df




In [None]:
from research.utils.optics import pixel2world
from research.gtsf_data.body_parts import BodyParts
import random
import numpy as np

BODY_PARTS = BodyParts().get_core_body_parts()


def get_raw_3d_coordinates(keypoints, cm):
    wkps = pixel2world([item for item in keypoints['leftCrop'] if item['keypointType'] != 'BODY'],
                       [item for item in keypoints['rightCrop'] if item['keypointType'] != 'BODY'],
                       cm)
    all_wkps = [list(wkps[bp]) for bp in BODY_PARTS]

    # compute BODY world keypoint coordinates
    if 'BODY' in [item['keypointType'] for item in keypoints['leftCrop']]:
        left_item = [item for item in keypoints['leftCrop'] if item['keypointType'] == 'BODY'][0]
        right_item = [item for item in keypoints['rightCrop'] if item['keypointType'] == 'BODY'][0]
        disps = np.abs(left_item['xFrame'] - right_item['xFrame'])
        focal_length_pixel = cm["focalLengthPixel"]
        baseline = cm["baseline"]
        depths = focal_length_pixel * baseline / np.array(disps)

        pixel_count_width = cm["pixelCountWidth"]
        pixel_count_height = cm["pixelCountHeight"]
        sensor_width = cm["imageSensorWidth"]
        sensor_height = cm["imageSensorHeight"]
        focal_length = cm["focalLength"]

        image_center_x = pixel_count_width / 2.0
        image_center_y = pixel_count_height / 2.0
        x = left_item['xFrame']
        y = left_item['yFrame']
        px_x = x - image_center_x
        px_z = image_center_y - y

        sensor_x = px_x * (sensor_width / pixel_count_width)
        sensor_z = px_z * (sensor_height / pixel_count_height)

        world_y = depths
        world_x = (world_y * sensor_x) / focal_length
        world_z = (world_y * sensor_z) / focal_length
        wkps['BODY'] = np.column_stack([world_x, world_y, world_z])

        body_wkps = random.sample([list(wkp) for wkp in list(wkps['BODY'])], 500)
        all_wkps.extend(body_wkps)
    return np.array(all_wkps)


def normalize_3d_coordinates(wkps):

    # translate fish to origin
    v = np.mean(wkps[:8], axis=0)
    wkps -= v

    # perform PCA decomposition and rotate with respect to new axes
    _, eigen_vectors = np.linalg.eig(np.dot(wkps.T, wkps))
    wkps = np.dot(eigen_vectors.T, wkps.T).T

    return wkps


def get_keypoint_arr(keypoints, cm):
    dorsal_fin_idx, pelvic_fin_idx = BODY_PARTS.index('DORSAL_FIN'), BODY_PARTS.index('PELVIC_FIN')
    raw_wkps = get_raw_3d_coordinates(keypoints, cm)
    norm_wkps = normalize_3d_coordinates(deepcopy(raw_wkps))
    if any([item['keypointType'] == 'BODY' for item in keypoints['leftCrop']]):
        body_norm_wkps = norm_wkps[8:, :]
        mid_point = 0.5 * (norm_wkps[dorsal_fin_idx] + norm_wkps[pelvic_fin_idx])
        idx = np.argmin(np.linalg.norm(body_norm_wkps[:, [0, 2]] - np.array([mid_point[0], mid_point[2]]), axis=1))
        body_wkp = body_norm_wkps[idx]
        keypoint_arr = np.vstack([norm_wkps[:8, :], body_wkp])
        return raw_wkps, keypoint_arr
    else:
        return raw_wkps, norm_wkps


In [None]:
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
gtsf_dataset = GTSFDataset('2019-02-01', '2019-09-20', akpd_scorer_url)
df = gtsf_dataset.get_prepared_dataset()

In [None]:
augmented_df = gtsf_dataset.generate_augmented_dataset((-0.1, 0.1), (0.8, 1.3), (-0.1, 0.1), 
                                                    (-45, 45), (-45, 45), (-5, 5), 0, 3)

In [None]:
from research.utils.optics import pixel2world
from research.gtsf_data.body_parts import BodyParts
import random
import numpy as np

BODY_PARTS = BodyParts().get_core_body_parts()


def get_raw_3D_coordinates(keypoints, cm):
    wkps = pixel2world([item for item in keypoints['leftCrop'] if item['keypointType'] != 'BODY'],
                       [item for item in keypoints['rightCrop'] if item['keypointType'] != 'BODY'],
                       cm)
    all_wkps = [list(wkps[bp]) for bp in BODY_PARTS]

    # compute BODY world keypoint coordinates
    if 'BODY' in [item['keypointType'] for item in keypoints['leftCrop']]:
        left_item = [item for item in keypoints['leftCrop'] if item['keypointType'] == 'BODY'][0]
        right_item = [item for item in keypoints['rightCrop'] if item['keypointType'] == 'BODY'][0]
        disps = np.abs(left_item['xFrame'] - right_item['xFrame'])
        focal_length_pixel = cm["focalLengthPixel"]
        baseline = cm["baseline"]
        depths = focal_length_pixel * baseline / np.array(disps)

        pixel_count_width = cm["pixelCountWidth"]
        pixel_count_height = cm["pixelCountHeight"]
        sensor_width = cm["imageSensorWidth"]
        sensor_height = cm["imageSensorHeight"]
        focal_length = cm["focalLength"]

        image_center_x = pixel_count_width / 2.0
        image_center_y = pixel_count_height / 2.0
        x = left_item['xFrame']
        y = left_item['yFrame']
        px_x = x - image_center_x
        px_z = image_center_y - y

        sensor_x = px_x * (sensor_width / pixel_count_width)
        sensor_z = px_z * (sensor_height / pixel_count_height)

        world_y = depths
        world_x = (world_y * sensor_x) / focal_length
        world_z = (world_y * sensor_z) / focal_length
        wkps['BODY'] = np.column_stack([world_x, world_y, world_z])

        body_wkps = random.sample([list(wkp) for wkp in list(wkps['BODY'])], 500)
        all_wkps.extend(body_wkps)
    return np.array(all_wkps)


def generate_rotation_matrix(n, theta):

    R = np.array([[
        np.cos(theta) + n[0] ** 2 * (1 - np.cos(theta)),
        n[0] * n[1] * (1 - np.cos(theta)) - n[2] * np.sin(theta),
        n[0] * n[2] * (1 - np.cos(theta)) + n[1] * np.sin(theta)
    ], [
        n[1] * n[0] * (1 - np.cos(theta)) + n[2] * np.sin(theta),
        np.cos(theta) + n[1] ** 2 * (1 - np.cos(theta)),
        n[1] * n[2] * (1 - np.cos(theta)) - n[0] * np.sin(theta),
    ], [
        n[2] * n[0] * (1 - np.cos(theta)) - n[1] * np.sin(theta),
        n[2] * n[1] * (1 - np.cos(theta)) + n[0] * np.sin(theta),
        np.cos(theta) + n[2] ** 2 * (1 - np.cos(theta))
    ]])

    return R


def normalize_3D_coordinates(wkps):

    v = np.median(wkps[:8], axis=0)
    v /= np.linalg.norm(v)
    y = np.array([0, 1, 0])
    n = np.cross(y, v)
    n /= np.linalg.norm(n)
    theta = -np.arccos(np.dot(y, v))
    R = generate_rotation_matrix(n, theta)
    wkps = np.dot(R, wkps.T).T

    # rotate about y-axis so that fish is straight
    upper_lip_idx = BODY_PARTS.index('UPPER_LIP')
    tail_notch_idx = BODY_PARTS.index('TAIL_NOTCH')
    v = wkps[upper_lip_idx] - wkps[tail_notch_idx]

    n = np.array([0, 1, 0])
    theta = np.arctan(v[2] / v[0])
    R = generate_rotation_matrix(n, theta)
    wkps = np.dot(R, wkps.T).T

    # perform reflection if necessary
    if wkps[upper_lip_idx][0] < wkps[tail_notch_idx][0]:
        R = np.array([
            [-1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ])
        wkps = np.dot(R, wkps.T).T

    return wkps


def get_keypoint_arr(keypoints, cm):
    dorsal_fin_idx, pelvic_fin_idx = BODY_PARTS.index('DORSAL_FIN'), BODY_PARTS.index('PELVIC_FIN')
    wkps = get_raw_3D_coordinates(keypoints, cm)
    norm_wkps = normalize_3D_coordinates(wkps)
    if any([item['keypointType'] == 'BODY' for item in keypoints['leftCrop']]):
        body_norm_wkps = norm_wkps[8:, :]
        mid_point = 0.5 * (norm_wkps[dorsal_fin_idx] + norm_wkps[pelvic_fin_idx])
        idx = np.argmin(np.linalg.norm(body_norm_wkps[:, [0, 2]] - np.array([mid_point[0], mid_point[2]]), axis=1))
        body_wkp = body_norm_wkps[idx]
        keypoint_arr = np.vstack([norm_wkps[:8, :], body_wkp])
        return keypoint_arr
    else:
        return norm_wkps


import json, os
import numpy as np
import torch
from torch import nn
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils


# network architecture
class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x


class WeightEstimator(object):

    def __init__(self, model_f):
        self.model = Network()
        self.model.load_state_dict(torch.load(model_f))

    def predict(self, keypoints, camera_metadata):
        norm_keypoint_arr = get_keypoint_arr(keypoints, camera_metadata)
        keypoint_tensor = torch.from_numpy(np.array([norm_keypoint_arr])).float()
        weight_prediction = 1e4 * self.model(keypoint_tensor).item()
        return weight_prediction


In [None]:
model_f = '/root/data/alok/biomass_estimation/playground/nn_epoch_253.pb'
weight_estimator = WeightEstimator(model_f)

original_preds, preds = [], []
for idx, row in augmented_df.iterrows():
    original_ann, ann, cm = row.original_ann, row.ann, row.cm
    original_pred = weight_estimator.predict(original_ann, cm)
    pred = weight_estimator.predict(ann, cm)
    preds.append(pred)
    original_preds.append(original_pred)
    
    
    

In [None]:
augmented_df['original_pred'] = original_preds
augmented_df['pred'] = preds

In [None]:
x = augmented_df.original_wkps.iloc[0]
x[7] - x[6]

In [None]:
y = augmented_df.wkps.iloc[0]

In [None]:
x

In [None]:
y_vals = np.arange(0.7, 1.4, 0.05)
errs = []
for idx in range(len(y_vals) - 1):
    y_mask = (augmented_df.mean_y > y_vals[idx]) & (augmented_df.mean_y < y_vals[idx + 1])
    mask = y_mask & (augmented_df.wkps.apply(lambda x: x[7][0] > x[6][0])) & (augmented_df.wkps.apply(lambda x: x[2][2] > x[5][2]))
    error_pct = \
        (augmented_df[mask].pred.mean() - augmented_df[mask].weight.mean()) / (augmented_df[mask].weight.mean())
    print('Error at depth {}: {}'.format(round(y_vals[idx], 2), round(100*error_pct, 2)))
    

In [None]:
def generate_error_breakdown(df, vals, field):
    for idx in range(len(vals) - 1):
        mask = (df[field] > vals[idx]) & (df[field] < vals[idx + 1])
        error_pct = (df[mask].pred.mean() - df[mask].weight.mean()) / (df[mask].weight.mean())
        print('Error for {} in range {} <-> {}: {}'.format(
            field,
            round(vals[idx], 2), 
            round(vals[idx + 1], 2),
            round(100*error_pct, 2))
        )



In [None]:
generate_error_breakdown(augmented_df, np.arange(-45, 50, 5), 'pitch')

In [None]:
mask = (augmented_df.wkps.apply(lambda x: x[7][0] > x[6][0])) & (augmented_df.wkps.apply(lambda x: x[2][2] > x[5][2]))
(augmented_df[mask].pred.mean() - augmented_df[mask].weight.mean()) / (augmented_df[mask].weight.mean())

In [None]:
plt.hist(augmented_df.wkps.apply(lambda x: x[:, 1].mean()))
plt.show()

In [None]:
s3_access_utils = S3AccessUtils('/root/data')
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
akpd_scorer_path, _, _ = s3_access_utils.download_from_url(akpd_scorer_url)
akpd_scorer_network = load_model(akpd_scorer_path)
akpd_scores = []

for idx, row in augmented_df.iterrows():
    input_sample = {
        'keypoints': row.ann,
        'cm': row.cm,
        'stereo_pair_id': 0,
        'single_point_inference': True
    }
    akpd_score = generate_confidence_score(input_sample, akpd_scorer_network)
    akpd_scores.append(akpd_score)

In [None]:
augmented_df['akpd_score'] = akpd_scores

In [None]:
def generate_rotation_matrix(n, theta):
    theta = theta * np.pi / 180.0

    R = np.array([[
        np.cos(theta) + n[0] ** 2 * (1 - np.cos(theta)),
        n[0] * n[1] * (1 - np.cos(theta)) - n[2] * np.sin(theta),
        n[0] * n[2] * (1 - np.cos(theta)) + n[1] * np.sin(theta)
    ], [
        n[1] * n[0] * (1 - np.cos(theta)) + n[2] * np.sin(theta),
        np.cos(theta) + n[1] ** 2 * (1 - np.cos(theta)),
        n[1] * n[2] * (1 - np.cos(theta)) - n[0] * np.sin(theta),
    ], [
        n[2] * n[0] * (1 - np.cos(theta)) - n[1] * np.sin(theta),
        n[2] * n[1] * (1 - np.cos(theta)) + n[0] * np.sin(theta),
        np.cos(theta) + n[2] ** 2 * (1 - np.cos(theta))
    ]])

    return R


In [None]:
df[df.akpd_score > 0.9].world_keypoints.iloc[0]

In [None]:
n = [0, 0, 1]
akpd_scores = []
ids, preds, preds_original, gts, thetas, depths = [], [], [], [], [], []
analysis_df = pd.DataFrame()
for idx, row in df[df.akpd_score > 0.9].iterrows():
    wkps = row.raw_keypoint_arr
    for t in range(3):
        new_wkps = normalize_3D_coordinates(wkps)
        depth = np.random.uniform(0.5, 2.0)
        new_wkps[:, 1] = new_wkps[:, 1] - new_wkps[:, 1].mean() + depth
        theta = np.random.uniform(-30, 30)
        R = generate_rotation_matrix([0, 0, 1], theta)
        new_wkps = np.dot(R, new_wkps.T).T
        
        ann_left, ann_right = [], []
        for idx, body_part in enumerate(BODY_PARTS):
            # generate left item
            x, y, z = new_wkps[idx]
            x_frame_left = x * cm['focalLengthPixel'] / y + cm['pixelCountWidth'] / 2
            y_frame_left = -z * cm['focalLengthPixel'] / y + cm['pixelCountHeight'] / 2
            item_left = {
                'keypointType': body_part,
                'xFrame': x_frame_left + np.random.normal(0, 10),
                'yFrame': y_frame_left
            }
            ann_left.append(item_left)

            # generate right item
            disparity = cm['focalLengthPixel'] * cm['baseline'] / y
            x_frame_right = x_frame_left - disparity
            y_frame_right = y_frame_left
            item_right = {
                'keypointType': body_part,
                'xFrame': x_frame_right + np.random.normal(0, 10),
                'yFrame': y_frame_right
            }
            ann_right.append(item_right)

        new_kps = {'leftCrop': ann_left, 'rightCrop': ann_right}
        input_sample = {
            'keypoints': new_kps,
            'cm': row.camera_metadata,
            'stereo_pair_id': row.id,
            'single_point_inference': True
        }
        akpd_score = generate_confidence_score(input_sample, akpd_scorer_network)
        akpd_scores.append(akpd_score)
        
        nomralized_centered_2D_kps = \
            normalize_centered_2D_transform.__call__(input_sample)

        normalized_stability_kps = normalized_stability_transform.__call__(nomralized_centered_2D_kps)
        tensorized_kps = to_tensor_transform.__call__(normalized_stability_kps)
        pred_original = network(tensorized_kps['kp_input']).item() * 1e4

        pred = weight_estimator.predict(new_kps, row.camera_metadata)
        preds.append(pred)
        preds_original.append(pred_original)
        gts.append(row.weight)
        ids.append(row.id)
        thetas.append(theta)
        depths.append(depth)

In [None]:
kdf = pd.DataFrame({
    'pred': preds,
    'weight': gts,
    'id': ids,
    'theta': thetas,
    'depth': depths
})

In [None]:
kdf['error'] = (kdf.pred - kdf.weight) / kdf.weight

In [None]:
df.keypoints.iloc[0]

In [None]:
generate_error_breakdown(kdf, np.arange(-30, 35, 5), 'theta')

In [None]:
generate_error_breakdown(kdf, np.arange(0.5, 2.1, 0.1), 'depth')