In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.weight_estimation.weight_estimator import WeightEstimator
from research.gtsf_data.gtsf_dataset import GTSFDataset

pd.set_option('display.max_rows', 500)

In [None]:
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
gtsf_dataset = GTSFDataset('2019-02-01', '2020-03-30', akpd_scorer_url)
df = gtsf_dataset.get_prepared_dataset()

In [None]:
tdf = df[(df.median_depth > 0.7) & (df.akpd_score > 0.5) & (df.captured_at < '2019-09-27')].copy(deep=True)

In [None]:
# instantiate weight estimator class
model_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-03-26T11-58-00/nn_8_keypoints_jitter_10.pb'
s3_access_utils = S3AccessUtils('/root/data')
model_f, _, _ = s3_access_utils.download_from_url(model_url)
weight_estimator = WeightEstimator(model_f)

# generate sample predictions
weights = []
for idx, row in tdf.iterrows():
    keypoints, camera_metadata = row.keypoints, row.camera_metadata
    weight_prediction = weight_estimator.predict(keypoints, camera_metadata)
    weights.append(weight_prediction)
    if len(weights) % 1000 == 0:
        print(len(weights))


In [None]:
tdf['pred_weight'] = weights

In [None]:
plt.figure(figsize=(20, 10))
plt.scatter(tdf.weight.values, tdf.pred_weight.values)
plt.plot([0, 10000], [0, 10000], color='red')
plt.xlim([0, 10000])
plt.ylim([0, 10000])
plt.grid()
plt.show()

In [None]:
tdf.index = pd.to_datetime(tdf.captured_at)