In [None]:
from collections import defaultdict
import json
import os
import numpy as np
import cv2
from research.utils.data_access_utils import RDSAccessUtils
from research_lib.utils.data_access_utils import S3AccessUtils
from weight_estimation.body_parts import core_body_parts
from weight_estimation.utils import convert_to_world_point_arr, get_left_right_keypoint_arrs, normalize_left_right_keypoint_arrs, CameraMetadata, \
    stabilize_keypoints, convert_to_nn_input
from weight_estimation.weight_estimator import WeightEstimator, CameraMetadata
import torch
from research.weight_estimation.keypoint_utils.optics import pixel2world


In [None]:
os.environ['PLALI_SQL_CREDENTIALS'] = '/run/secrets/plali_sql_credentials'

In [None]:
toy_fish_experiments = [
    {
        'name': 'B2_fish_moving_around_v1',
        'stereo_parameters_url': 'https://aquabyte-abc.s3-eu-west-1.amazonaws.com/rook/2021-03-10T13:57:48Z-pfe-1421920048928-187-4bd8/cal_output/2021-03-10T14-07-03.821272000Z/stereo_params.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'Dale P3 post-swap enclosure'
    },
    {
        'name': 'C_fish_moving_around_v1',
        'stereo_parameters_url': 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12:11:24.770071000Z_L40020313_R40013177_stereo-parameters.json',
        'workflow_id': '00000000-0000-0000-0000-000000000056',
        'metadata_type': 'None'
    },
    {
        'name': 'A2_fish_moving_around_v2',
        'stereo_parameters_url': 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40029797_R40020184/2021-02-25T11:30:42.149694000Z_L40029797_R40020184_stereo-parameters.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'Dale P3 pre-swap -- Round 2'
    },
    {
        'name': 'B2_fish_moving_around_v2',
        'stereo_parameters_url': 'https://aquabyte-abc.s3-eu-west-1.amazonaws.com/rook/2021-03-10T13:57:48Z-pfe-1421920048928-187-4bd8/cal_output/2021-03-10T14-07-03.821272000Z/stereo_params.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'Dale P3 post-swap -- Round 2'
    },
    {
        'name': 'C_fish_moving_around_v2',
        'stereo_parameters_url': 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12:11:24.770071000Z_L40020313_R40013177_stereo-parameters.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'other enclosure -- Round 2'
    },
    {
        'name': 'A2_fish_static_v1',
        'stereo_parameters_url': 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12:11:24.770071000Z_L40020313_R40013177_stereo-parameters.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'Dale P3 pre-swap -- static - Round 1'
    },
    {
        'name': 'B2_fish_static_v1',
        'stereo_parameters_url': 'https://aquabyte-abc.s3-eu-west-1.amazonaws.com/rook/2021-03-10T13:57:48Z-pfe-1421920048928-187-4bd8/cal_output/2021-03-10T14-07-03.821272000Z/stereo_params.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'Dale P3 post-swap -- static - Round 1'
    },
    {
        'name': 'C_fish_static_v1',
        'stereo_parameters_url': 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12:11:24.770071000Z_L40020313_R40013177_stereo-parameters.json',
        'workflow_id': '00000000-0000-0000-0000-000000000055',
        'metadata_type': 'other enclosure -- static - Round 1'
    }
]

In [None]:
rds = RDSAccessUtils(json.load(open(os.environ['PLALI_SQL_CREDENTIALS'])))

def get_annotated_data(workflow_id, metadata_type):
    query = """
        select * from plali.plali_annotations x
        inner join 
        ( select a.id as plali_image_id, a.images, a.metadata, b.id as workflow_id, b.name from plali.plali_images a
        inner join plali.plali_workflows b
        on a.workflow_id = b.id ) y
        on x.plali_image_id = y.plali_image_id
        where workflow_id = '{}';
    """.format(workflow_id)

    annotated_df = rds.extract_from_database(query)
    annotated_df = annotated_df[annotated_df.metadata.apply(lambda x: str(x.get('type')) == metadata_type)]
    return annotated_df



class AnnotationFormatError(Exception):
    pass


def add_anns(annotated_df):
    anns = []
    for idx, row in annotated_df.iterrows():
        try:
            raw_ann = row.annotation
            if 'skipReasons' in raw_ann:
                raise AnnotationFormatError

            ann = {'leftCrop': [], 'rightCrop': []}

            for side in ['leftCrop', 'rightCrop']:
                for raw_item in row.annotation[side]['annotation']['annotations']:
                    if 'xCrop' not in raw_item or 'yCrop' not in raw_item:
                        raise AnnotationFormatError
                    item = {
                        'xCrop': raw_item['xCrop'],
                        'yCrop': raw_item['yCrop'],
                        'xFrame': raw_item['xCrop'],
                        'yFrame': raw_item['yCrop'],
                        'keypointType': raw_item['category']
                    }

                    ann[side].append(item)

            if any([len(ann[side]) != 11 for side in ['leftCrop', 'rightCrop']]):
                raise AnnotationFormatError

            anns.append(ann)

        except AnnotationFormatError as err:
            anns.append(None)

    annotated_df['ann'] = anns


def add_camera_metadata(df, stereo_parameters_url):
    stereo_parameters_f, _, _ = s3.download_from_url(stereo_parameters_url)
    stereo_parameters = json.load(open(stereo_parameters_f))
    
    camera_metadata = {
        'focalLengthPixel': stereo_parameters['CameraParameters1']['FocalLength'][0],
        'baseline': abs(stereo_parameters['TranslationOfCamera2'][0] / 1e3),
        'focalLength': stereo_parameters['CameraParameters1']['FocalLength'][0] * 3.45e-6,
        'pixelCountWidth': 4096,
        'pixelCountHeight': 3000,
        'imageSensorWidth': 0.01412,
        'imageSensorHeight': 0.01035
    }

    df['camera_metadata'] = json.dumps(camera_metadata)
    
    
def add_weights(df):
    
    weight_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb')
    kf_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb')
    weight_estimator = WeightEstimator(weight_model_f, kf_model_f)

    pred_weights = []
    for idx, row in df.iterrows():
        ann = row.ann
        camera_metadata = json.loads(row.camera_metadata)
        if ann is not None:
            cm = CameraMetadata(
                focal_length=camera_metadata['focalLength'],
                focal_length_pixel=camera_metadata['focalLengthPixel'],
                baseline_m=camera_metadata['baseline'],
                pixel_count_width=camera_metadata['pixelCountWidth'],
                pixel_count_height=camera_metadata['pixelCountHeight'],
                image_sensor_width=camera_metadata['imageSensorWidth'],
                image_sensor_height=camera_metadata['imageSensorHeight']
            )

            weight, _, _ = weight_estimator.predict(ann, cm)
            pred_weights.append(weight)
        else:
            pred_weights.append(None)
    
    df['weight'] = pred_weights
    
    

def add_spatial_attributes(df):
    yaws, pitches, rolls, depths = [], [], [], []
    for idx, row in df.iterrows():
        ann, cm = row.ann, json.loads(row.camera_metadata)
        try:
            world_keypoints = pixel2world(ann['leftCrop'], ann['rightCrop'], cm)
            depth = np.median([x[1] for x in world_keypoints.values()])
            u = world_keypoints['ADIPOSE_FIN'] - world_keypoints['ANAL_FIN']
            v = world_keypoints['UPPER_LIP'] - world_keypoints['TAIL_NOTCH']
            yaw = np.arctan(v[1] / abs(v[0])) * (180.0 / np.pi)
            pitch = np.arctan(v[2] / abs(v[0])) * (180.0 / np.pi)
            roll = np.arctan(u[1] / u[2]) * (180.0 / np.pi)
        except TypeError as err:
            yaw, pitch, roll, depth = None, None, None, None
        yaws.append(yaw)
        pitches.append(pitch)
        depths.append(depth)
        rolls.append(roll)
        
    df['yaw'] = yaws
    df['pitch'] = pitches
    df['roll'] = rolls
    df['depth'] = depths


    
    


In [None]:
dfs = {}
for experiment in toy_fish_experiments:
    name = experiment['name']
    stereo_parameters_url = experiment['stereo_parameters_url']
    annotated_df = get_annotated_data(experiment['workflow_id'], experiment['metadata_type'])
    add_anns(annotated_df)
    add_camera_metadata(annotated_df, stereo_parameters_url)
    add_weights(annotated_df)
    add_spatial_attributes(annotated_df)
    dfs[name] = annotated_df.copy(deep=True)
    
    

In [None]:
result_data = defaultdict(list)
for experiment_name, df in dfs.items():
    result_data['experiment_name'].append(experiment_name)
    result_data['average_weight'].append(df.weight.mean())
    result_data['average_yaw'].append(df.yaw.mean())
    result_data['average_pitch'].append(df.pitch.mean())
    result_data['average_roll'].append(df.roll.mean())


result_df = pd.DataFrame(result_data)

In [None]:
result_df

In [None]:
plt.figure(figsize=(15, 8))
mask = dfs['C_fish_static_v1'].weight < 6000
plt.hist(dfs['C_fish_static_v1'][mask].weight.values)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(15, 8))
plt.hist(dfs['A2_fish_moving_around_v2'].weight.values)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(15, 8))
plt.hist(dfs['B2_fish_static_v1'].weight.values)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['B2_fish_moving_around_v1'].roll.values, dfs['B2_fish_moving_around_v1'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['B2_fish_moving_around_v2'].roll.values, dfs['B2_fish_moving_around_v2'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['C_fish_moving_around_v1'].roll.values, dfs['C_fish_moving_around_v1'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['C_fish_moving_around_v2'].roll.values, dfs['C_fish_moving_around_v2'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['A2_fish_moving_around_v2'].roll.values, dfs['A2_fish_moving_around_v2'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['A2_fish_moving_around_v2'].yaw.values, dfs['A2_fish_moving_around_v2'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['A2_fish_moving_around_v2'].pitch.values, dfs['A2_fish_moving_around_v2'].weight.values)
plt.xlabel('Roll')
plt.ylabel('Weight')
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['B2_fish_moving_around_v2'].roll.values, dfs['B2_fish_moving_around_v2'].weight.values)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['C_fish_moving_around_v2'].roll.values, dfs['C_fish_moving_around_v2'].weight.values)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.scatter(dfs['B2_fish_static_v1'].roll.values, dfs['B2_fish_static_v1'].weight.values)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(dfs['B2_fish_static_v1'].roll.values)
plt.grid()
plt.show()

In [None]:
from typing import Tuple
s3 = S3AccessUtils('/root/data')

def load_params(params):
    print("Loading params...")
    cameraMatrix1 = np.array(params['CameraParameters1']['IntrinsicMatrix']).transpose()
    cameraMatrix2 = np.array(params['CameraParameters2']['IntrinsicMatrix']).transpose()

    distCoeffs1 = params['CameraParameters1']['RadialDistortion'][0:2] + \
                   params['CameraParameters1']['TangentialDistortion'] + \
                   [params['CameraParameters1']['RadialDistortion'][2]]
    distCoeffs1 = np.array(distCoeffs1)

    distCoeffs2 = params['CameraParameters2']['RadialDistortion'][0:2] + \
                   params['CameraParameters2']['TangentialDistortion'] + \
                   [params['CameraParameters2']['RadialDistortion'][2]]
    distCoeffs2 = np.array(distCoeffs2)

    R = np.array(params['RotationOfCamera2']).transpose()
    T = np.array(params['TranslationOfCamera2']).transpose()

    imageSize = (4096, 3000)

    # perform rectification
    (R1, R2, P1, P2, Q, leftROI, rightROI) = cv2.stereoRectify(cameraMatrix1, distCoeffs1, cameraMatrix2, distCoeffs2, imageSize, R, T, None, None, None, None, None, cv2.CALIB_ZERO_DISPARITY, 0)

    left_maps = cv2.initUndistortRectifyMap(cameraMatrix1, distCoeffs1, R1, P1, imageSize, cv2.CV_16SC2)
    right_maps = cv2.initUndistortRectifyMap(cameraMatrix2, distCoeffs2, R2, P2, imageSize, cv2.CV_16SC2)

    print("Params loaded.")
    return left_maps, right_maps

IMAGE_WIDTH = 4096
IMAGE_HEIGHT = 3000

def get_camera_parameters(params: dict) -> Tuple:
    """Return individual camera parameters from JSON stereo parameters contents."""
    
    cameraMatrix1 = np.array(params['CameraParameters1']['IntrinsicMatrix']).transpose()
    cameraMatrix2 = np.array(params['CameraParameters2']['IntrinsicMatrix']).transpose()

    distCoeffs1 = params['CameraParameters1']['RadialDistortion'][0:2] + \
                   params['CameraParameters1']['TangentialDistortion'] + \
                   [params['CameraParameters1']['RadialDistortion'][2]]
    distCoeffs1 = np.array(distCoeffs1)

    distCoeffs2 = params['CameraParameters2']['RadialDistortion'][0:2] + \
                   params['CameraParameters2']['TangentialDistortion'] + \
                   [params['CameraParameters2']['RadialDistortion'][2]]
    distCoeffs2 = np.array(distCoeffs2)

    R = np.array(params['RotationOfCamera2']).transpose()
    T = np.array(params['TranslationOfCamera2']).transpose()
    
    imageSize = (IMAGE_WIDTH, IMAGE_HEIGHT)
    (R1, R2, P1, P2, Q, leftROI, rightROI) = cv2.stereoRectify(cameraMatrix1, distCoeffs1, cameraMatrix2, 
                                                               distCoeffs2, imageSize, R, T, None, None, 
                                                               None, None, None, cv2.CALIB_ZERO_DISPARITY, 0)
    left_maps = cv2.initUndistortRectifyMap(cameraMatrix1, distCoeffs1, R1, P1, imageSize, cv2.CV_16SC2)
    right_maps = cv2.initUndistortRectifyMap(cameraMatrix2, distCoeffs2, R2, P2, imageSize, cv2.CV_16SC2)
    
    return left_maps, right_maps, cameraMatrix1, distCoeffs1, R1, P1, cameraMatrix2, distCoeffs2, R2, P2



def un_re_rectify(annotated_df, stereo_parameters_o_url, stereo_parameters_n_url):
    stereo_parameters_o_f, _, _ = s3.download_from_url(stereo_parameters_o_url)
    stereo_parameters_n_f, _, _ = s3.download_from_url(stereo_parameters_n_url)

    stereo_params_o = json.load(open(stereo_parameters_o_f))
    stereo_params_n = json.load(open(stereo_parameters_n_f))
    left_maps_o, right_maps_o, cameraMatrix1_o, distCoeffs1_o, R1_o, P1_o, cameraMatrix2_o, distCoeffs2_o, R2_o, P2_o = get_camera_parameters(stereo_params_o)
    left_maps_n, right_maps_n, cameraMatrix1_n, distCoeffs1_n, R1_n, P1_n, cameraMatrix2_n, distCoeffs2_n, R2_n, P2_n = get_camera_parameters(stereo_params_n)

    ann_u_rs = []
    for idx, row in annotated_df.iterrows():
        ann = row.ann
        if ann is None:
            ann_u_rs.append(None)
            continue

        # un-rectify with matlab params, re-rectify with circular params
        ann_u_r = {'leftCrop': [], 'rightCrop': []}
        for side in ['leftCrop', 'rightCrop']:
            for item in ann[side]:
                bp = item['keypointType']
                x = item['xFrame']
                y = item['yFrame']
                if side == 'leftCrop':
                    x_new, y_new = cv2.undistortPoints(np.array([[left_maps_o[0][y, x]]]).astype(float), 
                                        cameraMatrix1_n, distCoeffs1_n, R=R1_n, P=P1_n)[0][0]
                elif side == 'rightCrop':
                    x_new, y_new = cv2.undistortPoints(np.array([[right_maps_o[0][y, x]]]).astype(float), 
                                        cameraMatrix2_n, distCoeffs2_n, R=R2_n, P=P2_n)[0][0]

                x_new, y_new = int(round(x_new)), int(round(y_new))
                ann_u_r[side].append({
                    'keypointType': bp,
                    'xFrame': x_new,
                    'yFrame': y_new,
                })

        ann_u_rs.append(ann_u_r)
        

    camera_metadata = {
        'focalLengthPixel': stereo_params_n['CameraParameters1']['FocalLength'][0],
        'baseline': abs(stereo_params_n['TranslationOfCamera2'][0] / 1e3),
        'focalLength': stereo_params_n['CameraParameters1']['FocalLength'][0] * 3.45e-6,
        'pixelCountWidth': 4096,
        'pixelCountHeight': 3000,
        'imageSensorWidth': 0.01412,
        'imageSensorHeight': 0.01035
    }



    annotated_df['ann_u_r'] = ann_u_rs
    

In [None]:
stereo_param_dict

In [None]:
stereo_param_dict = {
    'pre-swap': ('https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40029797_R40020184/2021-02-25T11:30:42.149694000Z_L40029797_R40020184_stereo-parameters.json', 
                 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40029797_R40020184/2021-02-25T11:30:42.149694000Z_L40029797_R40020184_stereo-parameters.json'),
    'post-swap': ('https://aquabyte-abc.s3-eu-west-1.amazonaws.com/rook/2021-03-10T13%3A57%3A48Z-pfe-1421920048928-187-4bd8/cal_output/2021-03-10T14-07-03.821272000Z/stereo_params.json'.replace('%3A', ':'),
                  'https://aquabyte-abc.s3-eu-west-1.amazonaws.com/rook/2021-03-10T13%3A57%3A48Z-pfe-1421920048928-187-4bd8/cal_output/2021-03-10T14-07-03.821272000Z/stereo_params.json'.replace('%3A', ':')),
    'other': ('https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12%3A11%3A24.770071000Z_L40020313_R40013177_stereo-parameters.json'.replace('%3A', ':'), 
              'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12%3A11%3A24.770071000Z_L40020313_R40013177_stereo-parameters.json'.replace('%3A', ':'))
}


for k, v in df_dict.items():
    add_anns(v)
    _add_depth(v)
    stereo_parameters_o_url, stereo_parameters_n_url = stereo_param_dict[k]
    un_re_rectify(v, stereo_parameters_o_url, stereo_parameters_n_url)


<h1> Conduct pairwise distance comparison </h1>

In [None]:
def convert_to_world_point_arr(X_left: np.ndarray, X_right: np.ndarray,
                               camera_metadata: CameraMetadata) -> np.ndarray:
    """Converts input left and right normalized keypoint arrays into world coordinate array."""

    y_world = camera_metadata.focal_length_pixel * camera_metadata.baseline_m / \
              (X_left[:, 0] - X_right[:, 0])
    
#     x_world = (((X_left[:, 0] - camera_metadata.pixel_count_width / 2.0) * camera_metadata.image_sensor_width / camera_metadata.pixel_count_width) * y_world) / (camera_metadata.focal_length)
#     z_world = (-((X_left[:, 1] - camera_metadata.pixel_count_height / 2.0) * camera_metadata.image_sensor_height / camera_metadata.pixel_count_height) * y_world) / (camera_metadata.focal_length)
    x_world = (((X_left[:, 0]) * camera_metadata.image_sensor_width / camera_metadata.pixel_count_width) * y_world) / (camera_metadata.focal_length)
    z_world = (-((X_left[:, 1]) * camera_metadata.image_sensor_height / camera_metadata.pixel_count_height) * y_world) / (camera_metadata.focal_length)
    X_world = np.vstack([x_world, y_world, z_world]).T
    return X_world

In [None]:
def convert_to_world_point_arr_2(X_left: np.ndarray, X_right: np.ndarray,
                               camera_metadata: CameraMetadata) -> np.ndarray:
    """Converts input left and right normalized keypoint arrays into world coordinate array."""

    y_world = camera_metadata.focal_length_pixel * camera_metadata.baseline_m / \
              (X_left[:, 0] - X_right[:, 0])
    
    x_world = (((X_left[:, 0] - camera_metadata.pixel_count_width / 2.0) * camera_metadata.image_sensor_width / camera_metadata.pixel_count_width) * y_world) / (camera_metadata.focal_length)
    z_world = (-((X_left[:, 1] - camera_metadata.pixel_count_height / 2.0) * camera_metadata.image_sensor_height / camera_metadata.pixel_count_height) * y_world) / (camera_metadata.focal_length)
#     x_world = (((X_left[:, 0]) * camera_metadata.image_sensor_width / camera_metadata.pixel_count_width) * y_world) / (camera_metadata.focal_length)
#     z_world = (-((X_left[:, 1]) * camera_metadata.image_sensor_height / camera_metadata.pixel_count_height) * y_world) / (camera_metadata.focal_length)
    X_world = np.vstack([x_world, y_world, z_world]).T
    return X_world

In [None]:
X_w_dict = {}
for key, df in df_dict.items():
    tdf = df[~df.ann.isnull()]
    
    X_w = []
    for idx, row in tdf.iterrows():
        cm = json.loads(row.camera_metadata)
        camera_metadata = CameraMetadata(
            focal_length=cm['focalLength'],
            focal_length_pixel=cm['focalLengthPixel'],
            baseline_m=cm['baseline'],
            pixel_count_width=cm['pixelCountWidth'],
            pixel_count_height=cm['pixelCountHeight'],
            image_sensor_width=cm['imageSensorWidth'],
            image_sensor_height=cm['imageSensorHeight']
        )
        
        X_left, X_right = get_left_right_keypoint_arrs(row.ann)
        X_left, X_right = normalize_left_right_keypoint_arrs(X_left, X_right)
        X_world = convert_to_world_point_arr(X_left, X_right, camera_metadata)
        X_w.append(X_world)
        
    X_w_dict[key] = np.array(X_w)
        

In [None]:
from scipy.spatial.distance import pdist

dist_arr_dict = {}
for key, _ in df_dict.items():
    dist_arr = []
    for idx in range(X_w_dict[key].shape[0]):
        dist_arr.append(pdist(X_w_dict[key][idx]))
    
    dist_arr = np.array(dist_arr)
    dist_arr_dict[key] = dist_arr

In [None]:
np.median(dist_arr_dict['pre-swap'], axis=0)

In [None]:
np.median(dist_arr_dict['post-swap'], axis=0)

In [None]:
np.median(dist_arr_dict['other'], axis=0)

In [None]:
lengths, angles = [], []
for X_w in X_w_dict['pre-swap']:
    v = X_w[0] - X_w[1]
    length = np.linalg.norm(v)
    angle = np.arctan(v[1] / v[2]) * 180.0 / np.pi
    lengths.append(length)
    angles.append(angle)
    

In [None]:
plt.scatter(angles, lengths)

In [None]:
lengths, angles = [], []
for X_w in X_w_dict['other']:
    v = X_w[0] - X_w[1]
    length = np.linalg.norm(v)
    angle = np.arctan(v[1] / v[2]) * 180.0 / np.pi
    lengths.append(length)
    angles.append(angle)
    

In [None]:
plt.scatter(angles, lengths)

In [None]:
import pandas as pd
kdf = pd.DataFrame({'angle': angles, 'length': lengths})

In [None]:
kdf[kdf.angle < 20].length.median()

In [None]:
X_w_dict_2 = {}
for key, df in df_dict.items():
    tdf = df[~df.ann.isnull()]
    
    X_w = []
    for idx, row in tdf.iterrows():
        cm = json.loads(row.camera_metadata)
        camera_metadata = CameraMetadata(
            focal_length=cm['focalLength'],
            focal_length_pixel=cm['focalLengthPixel'],
            baseline_m=cm['baseline'],
            pixel_count_width=cm['pixelCountWidth'],
            pixel_count_height=cm['pixelCountHeight'],
            image_sensor_width=cm['imageSensorWidth'],
            image_sensor_height=cm['imageSensorHeight']
        )
        
        X_left, X_right = get_left_right_keypoint_arrs(row.ann)
        X_world = convert_to_world_point_arr_2(X_left, X_right, camera_metadata)
        X_w.append(X_world)
        
    X_w_dict_2[key] = np.array(X_w)
        

In [None]:
from scipy.spatial.distance import pdist

dist_arr_dict_2 = {}
for key, _ in df_dict.items():
    dist_arr = []
    for idx in range(X_w_dict_2[key].shape[0]):
        dist_arr.append(pdist(X_w_dict_2[key][idx]))
    
    dist_arr = np.array(dist_arr)
    dist_arr_dict_2[key] = dist_arr

In [None]:
np.mean(dist_arr_dict_2['pre-swap'], axis=0)

In [None]:
np.mean(dist_arr_dict_2['post-swap'], axis=0)

In [None]:
np.mean(dist_arr_dict_2['other'], axis=0)

<h1> Hone in on single case </h1>

In [None]:
from typing import Dict, List, Tuple

def get_camera_metadata_from_cm(cm):
    camera_metadata = CameraMetadata(
        focal_length=cm['focalLength'],
        focal_length_pixel=cm['focalLengthPixel'],
        baseline_m=cm['baseline'],
        pixel_count_width=cm['pixelCountWidth'],
        pixel_count_height=cm['pixelCountHeight'],
        image_sensor_width=cm['imageSensorWidth'],
        image_sensor_height=cm['imageSensorHeight']
    )
    
    return camera_metadata
    

def normalize_left_right_keypoint_arrs(X_left: np.ndarray, X_right: np.ndarray) -> Tuple:
    """Normalizes input left and right key-point arrays. The normalization involves (1) 2D
    translation of all keypoints such that they are centered, (2) rotation of the 2D coordiantes
    about the center such that the line passing through UPPER_LIP and fish center is horizontal.
    """

    # translate key-points, perform reflection if necessary
    upper_lip_idx = body_parts.core_body_parts.index(body_parts.UPPER_LIP)
    tail_notch_idx = body_parts.core_body_parts.index(body_parts.TAIL_NOTCH)
    if X_left[upper_lip_idx, 0] > X_left[tail_notch_idx, 0]:
        X_center = 0.5 * (np.max(X_left, axis=0) + np.min(X_left, axis=0))
        X_left_centered = X_left - X_center
        X_right_centered = X_right - X_center
    else:
        X_center = 0.5 * (np.max(X_right, axis=0) + np.min(X_right, axis=0))
        X_left_centered = X_right - X_center
        X_right_centered = X_left - X_center
        X_left_centered[:, 0] = -X_left_centered[:, 0]
        X_right_centered[:, 0] = -X_right_centered[:, 0]
        
    return X_left_centered, X_right_centered

#     # rotate key-points
#     upper_lip_x, upper_lip_y = tuple(X_left_centered[upper_lip_idx])
#     theta = np.arctan(upper_lip_y / upper_lip_x)
#     R = np.array([
#         [np.cos(theta), -np.sin(theta)],
#         [np.sin(theta), np.cos(theta)]
#     ])

#     D = X_left_centered - X_right_centered
#     X_left_rot = np.dot(X_left_centered, R)
#     X_right_rot = X_left_rot - D
#     return X_left_rot, X_right_rot

    
    

In [None]:
idx = 0
ann = df_dict['pre-swap'].ann.dropna().iloc[idx]
cm = json.loads(df_dict['pre-swap'].camera_metadata.dropna().iloc[idx])
camera_metadata = get_camera_metadata_from_cm(cm)
X_left, X_right = get_left_right_keypoint_arrs(ann)
X_w = convert_to_world_point_arr(X_left, X_right, camera_metadata)
X_left_norm, X_right_norm = normalize_left_right_keypoint_arrs(X_left, X_right)
X_w_norm = convert_to_world_point_arr(X_left_norm, X_right_norm, camera_metadata)

print(pdist(X_w_norm) - pdist(X_w))


In [None]:
X_left, X_right

In [None]:
X_left_norm, X_right_norm

In [None]:
X_w_dict['pre-swap'][0]

In [None]:
X_w_dict_2['pre-swap'][0]

In [None]:
np.linalg.norm(X_w_dict['pre-swap'][0] - X_w_dict['pre-swap'][4])

In [None]:
np.linalg.norm(X_w_dict_2['pre-swap'][0] - X_w_dict_2['pre-swap'][4])

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(dist_arr_dict['other'][:, 0])
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(dist_arr_dict['other'][:, -1])
plt.grid()
plt.show()

In [None]:
core_body_parts

In [None]:
np.mean(dist_arr_dict['post-swap'], axis=0) - np.mean(dist_arr_dict['pre-swap'], axis=0)

In [None]:
np.mean(dist_arr_dict['other'], axis=0) - np.mean(dist_arr_dict['pre-swap'], axis=0)

In [None]:
from weight_estimation.weight_estimator import WeightEstimator, CameraMetadata

weight_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb')
kf_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb')
weight_estimator = WeightEstimator(weight_model_f, kf_model_f)

pred_weights = []

count = 0
for idx, row in df_dict['pre-swap'].iterrows():
    ann = row.ann
    camera_metadata = json.loads(row.camera_metadata)
    if ann is not None:
        cm = CameraMetadata(
            focal_length=camera_metadata['focalLength'],
            focal_length_pixel=camera_metadata['focalLengthPixel'],
            baseline_m=camera_metadata['baseline'],
            pixel_count_width=camera_metadata['pixelCountWidth'],
            pixel_count_height=camera_metadata['pixelCountHeight'],
            image_sensor_width=camera_metadata['imageSensorWidth'],
            image_sensor_height=camera_metadata['imageSensorHeight']
        )

        weight, _, _ = weight_estimator.predict(ann, cm)
        pred_weights.append(weight)
    else:
        pred_weights.append(None)
    
    if count % 1000 == 0:
        print(count)
    count += 1
    

In [None]:
df_dict['pre-swap']['weight'] = pred_weights

In [None]:
plt.scatter(df_dict['other'].roll, df_dict['other'].weight)

In [None]:
plt.scatter(df_dict['pre-swap'].roll, df_dict['pre-swap'].weight)

In [None]:
mask = df_dict['other'].roll > 20
df_dict['other'][mask].weight.mean()

In [None]:
plt.scatter(df_dict['other'].roll, df_dict['other'].weight)