In [None]:
from matplotlib import pyplot as plt
import pandas as pd


from weight_estimation.weight_estimator import WeightEstimator, CameraMetadata
from research_lib.utils.data_access_utils import S3AccessUtils, RDSAccessUtils

In [None]:
s3 = S3AccessUtils('/root/data')
rds = RDSAccessUtils()

<h1> Load sample data </h1>

In [None]:
query = """
    SELECT *
    FROM 
        prod.biomass_computations
    WHERE
        pen_id = 173 AND
        akpd_score > 0.95 AND
        captured_at BETWEEN '2021-02-10' and '2021-02-15'
"""

df = rds.extract_from_database(query)

In [None]:
weight_model_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb'
small_weight_model_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/playground/small_fish_weight_model.pb'
kf_model_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb'

weight_model_f, _, _ = s3.download_from_url(weight_model_url)
small_weight_model_f, _, _ = s3.download_from_url(small_weight_model_url)
kf_model_f, _, _ = s3.download_from_url(kf_model_url)
weight_estimator = WeightEstimator(weight_model_f, small_weight_model_f, kf_model_f)

weights = []
count = 0
for idx, row in df.iterrows():
    annotation = row.annotation
    cm = row.camera_metadata
    camera_metadata = CameraMetadata(
        focal_length=cm['focalLength'],
        focal_length_pixel=cm['focalLengthPixel'],
        baseline_m=cm['baseline'],
        pixel_count_width=cm['pixelCountWidth'],
        pixel_count_height=cm['pixelCountHeight'],
        image_sensor_width=cm['imageSensorWidth'],
        image_sensor_height=cm['imageSensorHeight']
    )
    
    weight, length, kf = weight_estimator.predict(annotation, camera_metadata)
    weights.append(weight)
    
    if count % 100 == 0:
        print(count)
    count += 1
    
    


In [None]:
def display_crops(left_image_f, right_image_f, ann, show_labels=False):
    
    fig, axes = plt.subplots(2, 1, figsize=(20, 20))
    left_image = plt.imread(left_image_f)
    right_image = plt.imread(right_image_f)
    
    left_keypoints = {item['keypointType']: [item['xCrop'], item['yCrop']] for item in ann['leftCrop']}
    right_keypoints = {item['keypointType']: [item['xCrop'], item['yCrop']] for item in ann['rightCrop']}
    
    axes[0].imshow(left_image)
    axes[1].imshow(right_image)
    
    for bp, kp in left_keypoints.items():
        axes[0].scatter([kp[0]], [kp[1]], color='red', s=1)
        if show_labels:
            axes[0].annotate(bp, (kp[0], kp[1]), color='red')
    for bp, kp in right_keypoints.items():
        axes[1].scatter([kp[0]], [kp[1]], color='red', s=1)
        if show_labels:
            axes[1].annotate(bp, (kp[0], kp[1]), color='red')
    plt.show()


In [None]:
mask = df.estimated_weight_g > 2000
for idx, row in df[mask].iterrows():
    ann = row.annotation
    left_crop_url = row.left_crop_url
    right_crop_url = row.right_crop_url
    
    left_crop_f, _, _ = s3.download_from_url(left_crop_url)
    right_crop_f, _, _ = s3.download_from_url(right_crop_url)
    
    display_crops(left_crop_f, right_crop_f, ann, show_labels=True)
    

In [None]:
df.shape

In [None]:
df[df.estimated_weight_g > 2000].estimated_weight_g

In [None]:
df[df.estimated_weight_g > 2000].estimated_length_mm