In [None]:
from collections import defaultdict
import json
import os
import numpy as np
import cv2
from research.utils.data_access_utils import RDSAccessUtils
from research_lib.utils.data_access_utils import S3AccessUtils
from weight_estimation.body_parts import core_body_parts
from weight_estimation.utils import convert_to_world_point_arr, get_left_right_keypoint_arrs, normalize_left_right_keypoint_arrs, CameraMetadata, \
    stabilize_keypoints, convert_to_nn_input
from weight_estimation.weight_estimator import WeightEstimator, CameraMetadata
import torch
from research.weight_estimation.keypoint_utils.optics import pixel2world

from matplotlib import pyplot as plt



In [None]:
os.environ['PLALI_SQL_CREDENTIALS'] = '/run/secrets/plali_sql_credentials'
s3 = S3AccessUtils('/root/data')
rds = RDSAccessUtils(json.load(open(os.environ['PLALI_SQL_CREDENTIALS'])))

def get_annotated_data(workflow_id):
    query = """
        select * from plali.plali_annotations x
        inner join 
        ( select a.id as plali_image_id, a.images, a.metadata, b.id as workflow_id, b.name from plali.plali_images a
        inner join plali.plali_workflows b
        on a.workflow_id = b.id ) y
        on x.plali_image_id = y.plali_image_id
        where workflow_id = '{}';
    """.format(workflow_id)

    annotated_df = rds.extract_from_database(query)
    return annotated_df

class AnnotationFormatError(Exception):
    pass

def add_anns(annotated_df):
    anns = []
    for idx, row in annotated_df.iterrows():
        try:
            raw_ann = row.annotation
            if 'skipReasons' in raw_ann:
                raise AnnotationFormatError

            ann = {'leftCrop': [], 'rightCrop': []}

            for side in ['leftCrop', 'rightCrop']:
                for raw_item in row.annotation[side]['annotation']['annotations']:
                    if 'xCrop' not in raw_item or 'yCrop' not in raw_item:
                        raise AnnotationFormatError
                    item = {
                        'xCrop': raw_item['xCrop'],
                        'yCrop': raw_item['yCrop'],
                        'xFrame': raw_item['xCrop'],
                        'yFrame': raw_item['yCrop'],
                        'keypointType': raw_item['category']
                    }

                    ann[side].append(item)

            if any([len(ann[side]) != 11 for side in ['leftCrop', 'rightCrop']]):
                raise AnnotationFormatError

            anns.append(ann)

        except AnnotationFormatError as err:
            anns.append(None)

    annotated_df['ann'] = anns
    

def add_camera_metadata(df):
    
    cm_by_url = {}
    cms = []
    for idx, row in df.iterrows():
        metadata = row.metadata
        time = metadata['time']
        date = time[:10]
        if date < '2020-08-06':
            stereo_parameters_url = 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40029773_R40038903/2020-08-06T12%3A35%3A26.754586000Z_L40029773_R40038903_stereo-parameters.json'
        else:
            stereo_parameters_url = 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40029773_R40038903/2020-08-06T12%3A35%3A26.754586000Z_L40029773_R40038903_stereo-parameters.json'
            
        stereo_parameters_url = stereo_parameters_url.replace('%3A', ':')
        if stereo_parameters_url in cm_by_url.keys():
            cm = cm_by_url[stereo_parameters_url]
        else:
            stereo_parameters_f, _, _ = s3.download_from_url(stereo_parameters_url)
            stereo_parameters = json.load(open(stereo_parameters_f))

            cm = {
                'focalLengthPixel': stereo_parameters['CameraParameters1']['FocalLength'][0],
                'baseline': abs(stereo_parameters['TranslationOfCamera2'][0] / 1e3),
                'focalLength': stereo_parameters['CameraParameters1']['FocalLength'][0] * 3.45e-6,
                'pixelCountWidth': 4096,
                'pixelCountHeight': 3000,
                'imageSensorWidth': 0.01412,
                'imageSensorHeight': 0.01035
            }
            cm_by_url[stereo_parameters_url] = cm
        cms.append(json.dumps(cm))

    df['camera_metadata'] = cms
    


def add_weights(df):
    
    weight_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb')
    kf_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb')
    weight_estimator = WeightEstimator(weight_model_f, kf_model_f)

    pred_weights = []
    for idx, row in df.iterrows():
        ann = row.ann
        camera_metadata = json.loads(row.camera_metadata)
        if ann is not None:
            cm = CameraMetadata(
                focal_length=camera_metadata['focalLength'],
                focal_length_pixel=camera_metadata['focalLengthPixel'],
                baseline_m=camera_metadata['baseline'],
                pixel_count_width=camera_metadata['pixelCountWidth'],
                pixel_count_height=camera_metadata['pixelCountHeight'],
                image_sensor_width=camera_metadata['imageSensorWidth'],
                image_sensor_height=camera_metadata['imageSensorHeight']
            )

            weight, _, _ = weight_estimator.predict(ann, cm)
            pred_weights.append(weight)
        else:
            pred_weights.append(None)
    
    df['pred_weight'] = pred_weights    
    
    
def add_spatial_attributes(df):
    yaws, pitches, rolls, depths = [], [], [], []
    for idx, row in df.iterrows():
        ann, cm = row.ann, json.loads(row.camera_metadata)
        try:
            world_keypoints = pixel2world(ann['leftCrop'], ann['rightCrop'], cm)
            depth = np.median([x[1] for x in world_keypoints.values()])
            u = world_keypoints['ADIPOSE_FIN'] - world_keypoints['ANAL_FIN']
            v = world_keypoints['UPPER_LIP'] - world_keypoints['TAIL_NOTCH']
            yaw = np.arctan(v[1] / abs(v[0])) * (180.0 / np.pi)
            pitch = np.arctan(v[2] / abs(v[0])) * (180.0 / np.pi)
            roll = np.arctan(u[1] / u[2]) * (180.0 / np.pi)
        except TypeError as err:
            yaw, pitch, roll, depth = None, None, None, None
        yaws.append(yaw)
        pitches.append(pitch)
        depths.append(depth)
        rolls.append(roll)
        
    df['yaw'] = yaws
    df['pitch'] = pitches
    df['roll'] = rolls
    df['depth'] = depths




In [None]:
workflow_id = 'cb587143-2354-477e-998b-f06df33ffb45'
df = get_annotated_data(workflow_id)
add_anns(df)
add_camera_metadata(df)
add_weights(df)
add_spatial_attributes(df)

In [None]:
df['weight'] = df.metadata.apply(lambda x: x['data'].get('weightKgs'))
df['date'] = df.metadata.apply(lambda x: x.get('time')[:10])


In [None]:
plt.scatter(df.weight.values, df.pred_weight.values)
plt.plot([0, 2500], [0, 2500], color='red')
plt.xlim([0, 2500])
plt.ylim([0, 2500])
plt.show()

In [None]:
df.weight.mean() 