In [None]:
# %load_ext autoreload
# %autoreload 2

from collections import defaultdict
import json
import pandas as pd
import numpy as np
from keras.models import load_model
from research_lib.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.weight_estimation.keypoint_utils.optics import pixel2world
from dataset import prepare_gtsf_data, compute_akpd_score
from weight_estimation.weight_estimator import WeightEstimator
from weight_estimation.utils import CameraMetadata

from matplotlib import pyplot as plt
pd.set_option('display.max_rows', 500)


<h1> Prepare Augmented GTSF Dataset </h1>

<h2> Load raw data </h2>

In [None]:
s3 = S3AccessUtils('/root/data')
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
akpd_scorer_f, _, _ = s3.download_from_url(akpd_scorer_url)
df1 = prepare_gtsf_data('2019-03-01', '2019-09-20', akpd_scorer_f, 0.5, 1.0)

In [None]:
df2 = prepare_gtsf_data('2020-06-01', '2020-08-20', akpd_scorer_f, 0.5, 1.0)

In [None]:
df = pd.concat([df1, df2])

<h1> Generate weight estimates with production model </h1>

In [None]:
weight_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb')
kf_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb')
weight_estimator = WeightEstimator(weight_model_f, kf_model_f)

pred_weights = []
count = 0

for idx, row in df.iterrows():
    ann = row.keypoints
    camera_metadata = row.camera_metadata
    cm = CameraMetadata(
        focal_length=camera_metadata['focalLength'],
        focal_length_pixel=camera_metadata['focalLengthPixel'],
        baseline_m=camera_metadata['baseline'],
        pixel_count_width=camera_metadata['pixelCountWidth'],
        pixel_count_height=camera_metadata['pixelCountHeight'],
        image_sensor_width=camera_metadata['imageSensorWidth'],
        image_sensor_height=camera_metadata['imageSensorHeight']
    )
    
    weight, _, _ = weight_estimator.predict(ann, cm)
    pred_weights.append(weight)
    
    if count % 1000 == 0:
        print(count)
    count += 1
    
df['pred_weight'] = pred_weights

In [None]:
def add_spatial_attributes(df):
    yaws, pitches, rolls, depths = [], [], [], []
    for idx, row in df.iterrows():
        ann, cm = row.keypoints, camera_metadata
        try:
            world_keypoints = pixel2world(ann['leftCrop'], ann['rightCrop'], cm)
            depth = np.median([x[1] for x in world_keypoints.values()])
            u = world_keypoints['ADIPOSE_FIN'] - world_keypoints['ANAL_FIN']
            v = world_keypoints['UPPER_LIP'] - world_keypoints['TAIL_NOTCH']
            yaw = np.arctan(v[1] / abs(v[0])) * (180.0 / np.pi)
            pitch = np.arctan(v[2] / abs(v[0])) * (180.0 / np.pi)
            roll = np.arctan(u[1] / u[2]) * (180.0 / np.pi)
        except TypeError as err:
            yaw, pitch, roll, depth = None, None, None, None
        yaws.append(yaw)
        pitches.append(pitch)
        depths.append(depth)
        rolls.append(roll)
        
    df['yaw'] = yaws
    df['pitch'] = pitches
    df['roll'] = rolls
    df['depth'] = depths


In [None]:
add_spatial_attributes(df)

In [None]:
df['error_pct'] = (df.pred_weight - df.weight) / df.weight

In [None]:
plt.figure(figsize=(15, 8))
plt.scatter(df.weight.values, df.pred_weight.values)
plt.plot([0, 9000], [0, 9000], color='red')
plt.grid()
plt.show()

<h1> Single Fish Multi-Image Analysis </h1>

In [None]:
analysis_data = defaultdict(list)
fish_ids = list(df.fish_id.unique())
for fish_id in fish_ids:
    mask = df.fish_id == fish_id
    count = df[mask].shape[0]
    pred_weight = df[mask].pred_weight.mean()
    gt_weight = df[mask].weight.mean()
    pct_error = (pred_weight - gt_weight) / gt_weight
    pct_variation = df[mask].pred_weight.std() / pred_weight
    if count > 10:
        analysis_data['fish_id'].append(fish_id)
        analysis_data['count'].append(count)
        analysis_data['pred_weight'].append(pred_weight)
        analysis_data['gt_weight'].append(gt_weight)
        analysis_data['pct_error'].append(pct_error)
        analysis_data['pct_variation'].append(pct_variation)
        
analysis_df = pd.DataFrame(analysis_data)
        

In [None]:
mask = analysis_df.gt_weight > 5000
analysis_df[mask].sort_values('count', ascending=False)

In [None]:
mask = df.fish_id == '190711-c500494a-6c55-440e-8a90-cba094063c53'
plt.scatter(df[mask].pitch, df[mask].error_pct)
plt.grid()
plt.show()

In [None]:
def display_crops(left_image_f, right_image_f, ann, overlay_keypoints=True, show_labels=True, title=None):

    fig, axes = plt.subplots(2, 1, figsize=(10, 10))
    left_image = plt.imread(left_image_f)
    right_image = plt.imread(right_image_f)
    axes[0].imshow(left_image)
    axes[1].imshow(right_image)
    
    left_keypoints = {item['keypointType']: [item['xFrame'], item['yFrame']] for item in ann['leftCrop']}
    right_keypoints = {item['keypointType']: [item['xFrame'], item['yFrame']] for item in ann['rightCrop']}
    
    if overlay_keypoints:
        for bp, kp in left_keypoints.items():
            axes[0].scatter([kp[0]], [kp[1]], color='red', s=1)
            if show_labels:
                axes[0].annotate(bp, (kp[0], kp[1]), color='red')
        for bp, kp in right_keypoints.items():
            axes[1].scatter([kp[0]], [kp[1]], color='red', s=1)
            if show_labels:
                axes[1].annotate(bp, (kp[0], kp[1]), color='red')

    if title:
        axes[0].set_title(title)
    plt.show()
    


for idx, row in df[df.fish_id == '190730-8d4936bb-2de9-4379-8e36-1a0c3a3c600e'].sort_values('captured_at').iterrows():
    ann = row.keypoints
    left_image_url = row.left_url
    right_image_url = row.right_image_url
    left_image_f, _, _ = s3.download_from_url(left_image_url)
    right_image_f, _, _ = s3.download_from_url(right_image_url)
    
    pred_weight = round(row.pred_weight, 2)
    gt_weight = round(row.weight, 2)
    error_pct = round(100 * (pred_weight - gt_weight) / gt_weight, 2)
    title = 'Predicted weight: {}; GT weight: {}; Error: {}'.format(pred_weight, gt_weight, error_pct) 
    
    display_crops(left_image_f, right_image_f, ann, show_labels=False, title=title)
    
    
