In [None]:
import json, os
import random
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from weight_estimation.utils import CameraMetadata, get_ann_from_keypoint_arrs, get_left_right_keypoint_arrs, normalize_left_right_keypoint_arrs
from weight_estimation.body_parts import core_body_parts


<h1> Load base dataset </h1>

In [None]:
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_SQL_CREDENTIALS'])))
query = """
    select * from keypoint_annotations
    where keypoints is not null
    and keypoints -> 'leftCrop' is not null
    and keypoints -> 'rightCrop' is not null
    limit 10000;
"""
df = rds_access_utils.extract_from_database(query)

<h1> Construct "good" and "bad" class </h1>

In [None]:
PIXEL_COUNT_WIDTH = 4096


def convert_to_akpd_nn_input(ann):
    X_left, X_right = get_left_right_keypoint_arrs(ann)
    X_left_norm, X_right_norm = normalize_left_right_keypoint_arrs(X_left, X_right)
    X = np.hstack([X_left_norm, X_right_norm]) / PIXEL_COUNT_WIDTH
    return X


def perturb_ann(ann, p_perturbation=0.2, min_magnitude=30, max_magnitude=200):
    
    left_keypoints, right_keypoints = ann['leftCrop'], ann['rightCrop']
    perturbed_left_keypoints = []
    
    # pick body parts to perturb (at least one)
    indices = []
    while len(indices) == 0:
        indices = [x for x in range(len(core_body_parts)) if (random.random() < p_perturbation)]
    
    # apply perturbation
    perturbed_left_keypoints, perturbed_right_keypoints = [], []
    for idx, _ in enumerate(left_keypoints):
        left_item, right_item = left_keypoints[idx], right_keypoints[idx]
        left_perturbation_x, right_perturbation_x, left_perturbation_y, right_perturbation_y = \
            0.0, 0.0, 0.0, 0.0
        if idx in indices:
            case = np.random.choice([0, 1, 2], 1).item()
            if case == 0:
                left_perturbation_x = np.random.normal(0, np.random.uniform(low=min_magnitude, high=max_magnitude))
                right_perturbation_x = np.random.normal(0, np.random.uniform(low=min_magnitude, high=max_magnitude))
                left_perturbation_y = np.random.normal(0, np.random.uniform(low=min_magnitude, high=max_magnitude))
                right_perturbation_y = np.random.normal(0, np.random.uniform(low=min_magnitude, high=max_magnitude))
            elif case == 1:
                x_magnitude = np.random.uniform(low=min_magnitude, high=max_magnitude)
                y_magnitude = np.random.uniform(low=min_magnitude, high=max_magnitude)
                left_perturbation_x = np.random.normal(0, x_magnitude)
                right_perturbation_x = np.random.normal(0, abs(x_magnitude + np.random.normal(0, 20)))
                left_perturbation_y = np.random.normal(0, y_magnitude)
                right_perturbation_y = np.random.normal(0, abs(y_magnitude + np.random.normal(0, 20)))
            else:
                k = list(range(len(core_body_parts)))
                k.remove(idx)
                random_idx = np.random.choice(k, 1).item()
                left_perturbation_x = left_keypoints[random_idx]['xFrame'] - left_item['xFrame'] + np.random.normal(0, 20)
                left_perturbation_y = left_keypoints[random_idx]['yFrame'] - left_item['yFrame'] + np.random.normal(0, 20)
                right_perturbation_x = right_keypoints[random_idx]['xFrame'] - right_item['xFrame'] + np.random.normal(0, 20)
                right_perturbation_y = right_keypoints[random_idx]['yFrame'] - right_item['yFrame'] + np.random.normal(0, 20)

        perturbed_left_item = {
            'keypointType': left_item['keypointType'],
            'xFrame': left_item['xFrame'] + left_perturbation_x,
            'yFrame': left_item['yFrame'] + left_perturbation_y
        }

        perturbed_right_item = {
            'keypointType': right_item['keypointType'],
            'xFrame': right_item['xFrame'] + right_perturbation_x,
            'yFrame': right_item['yFrame'] + right_perturbation_y
        }
        
        perturbed_left_keypoints.append(perturbed_left_item)
        perturbed_right_keypoints.append(perturbed_right_item)

    perturbed_keypoints = {
        'leftCrop': perturbed_left_keypoints,
        'rightCrop': perturbed_right_keypoints
    }
    
    return perturbed_keypoints
        


X_good_arr, X_bad_arr = [], []
count = 0
for idx, row in df.iterrows():
    
    # construct "good" class
    ann = row.keypoints
    
    X_good = convert_to_akpd_nn_input(ann)
    X_good_arr.append(X_good)
    
    # construct "bad" class
    ann_bad = perturb_ann(ann)
    X_bad = convert_to_akpd_nn_input(ann_bad)
    X_bad_arr.append(X_bad)
    

X_good_arr = np.array(X_good_arr)
X_bad_arr = np.array(X_bad_arr)

In [None]:
x = X_good_arr[6]
plt.scatter(x[:, 0], x[:, 1], color='blue')
plt.scatter(x[:, 2], x[:, 3], color='red')
plt.grid()
plt.show()

<h1> Train Model </h1>

<h2> Create train / val split </h2>

In [None]:
train_pct = 0.8

X_good_train = X_good_arr[:int(len(X_good_arr) * train_pct)]
X_good_val = X_good_arr[int(len(X_good_arr) * train_pct):]
X_bad_train = X_bad_arr[:int(len(X_bad_arr) * train_pct)]
X_bad_val = X_bad_arr[int(len(X_bad_arr) * train_pct):]

X_train = np.vstack([X_good_train, X_bad_train])
y_train = np.array([1] * len(X_good_train) + [0] * len(X_bad_train))
shuffle_idx = np.array(range(len(X_train)))
np.random.shuffle(shuffle_idx)
X_train = X_train[shuffle_idx]
y_train = y_train[shuffle_idx]

X_val = np.vstack([X_good_val, X_bad_val])
y_val = np.array([1] * len(X_good_val) + [0] * len(X_bad_val))
shuffle_idx = np.array(range(len(X_val)))
np.random.shuffle(shuffle_idx)
X_val = X_val[shuffle_idx]
y_val = y_val[shuffle_idx]


X_train = X_train.reshape(len(X_train), -1)
X_val = X_val.reshape(len(X_val), -1)


<h2> Train model </h2>

In [None]:
import os
import math
import keras
from keras.layers import Input, Dense, Flatten
from keras.models import Model
from keras.optimizers import RMSprop
from keras.models import load_model


In [None]:
inputs = Input(shape=(32,))

x = Dense(256, activation='relu')(inputs)
x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)


In [None]:
model = Model(inputs=inputs, outputs=predictions)

In [None]:
model.summary()

In [None]:
optimizer = RMSprop(lr=0.0001)
model.compile(optimizer=optimizer,
              loss='binary_crossentropy',
              metrics=['accuracy'])

callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss',
                                               min_delta=0,
                                               patience=30,
                                               verbose=0,
                                               mode='auto')]

model.fit(X_train, y_train, validation_data=(X_val, y_val), callbacks=callbacks, batch_size=32, epochs=100)

<h1> Test on Real Examples </h1>

In [None]:
s3 = S3AccessUtils('/root/data', json.load(open(os.environ['AWS_CREDENTIALS'])))

In [None]:
def display_crops(left_image_f, right_image_f, left_keypoints, right_keypoints, side='both', overlay_keypoints=True, show_labels=False):
    assert side == 'left' or side == 'right' or side == 'both', \
        'Invalid side value: {}'.format(side)

    if side == 'left' or side == 'right':
        fig, ax = plt.subplots(figsize=(20, 10))
        image_f = left_image_f if side == 'left' else right_image_f
        keypoints = left_keypoints if side == 'left' else right_keypoints
        image = plt.imread(image_f)
        ax.imshow(image)

        if overlay_keypoints:
            for bp, kp in keypoints.items():
                ax.scatter([kp[0]], [kp[1]], color='red', s=1)
                if show_labels:
                    ax.annotate(bp, (kp[0], kp[1]), color='red')
    else:
        fig, axes = plt.subplots(2, 1, figsize=(20, 20))
        left_image = plt.imread(left_image_f)
        right_image = plt.imread(right_image_f)
        axes[0].imshow(left_image)
        axes[1].imshow(right_image)
        if overlay_keypoints:
            for bp, kp in left_keypoints.items():
                axes[0].scatter([kp[0]], [kp[1]], color='red', s=1)
                if show_labels:
                    axes[0].annotate(bp, (kp[0], kp[1]), color='red')
            for bp, kp in right_keypoints.items():
                axes[1].scatter([kp[0]], [kp[1]], color='red', s=1)
                if show_labels:
                    axes[1].annotate(bp, (kp[0], kp[1]), color='red')
    plt.show()

In [None]:
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['DATA_WAREHOUSE_SQL_CREDENTIALS'])))
query = """
        select * from prod.biomass_computations
        where pen_id=144 and captured_at >= '2020-12-27' and captured_at <= '2021-01-12';
    """
tdf = rds_access_utils.extract_from_database(query)

In [None]:
akpd_scores = []
count = 0
for idx, row in tdf.iterrows():
    
    ann = row.annotation
    X = convert_to_akpd_nn_input(ann)
    score = model.predict(np.array(X.reshape(-1, 32)))[0][0]
    akpd_scores.append(score)
    
    

In [None]:
tdf = tdf[(tdf.estimated_weight_g > 10000) & (tdf.akpd_score > 0.1) & (tdf.akpd_score < 0.95)]

In [None]:
idx = 3
left_image_url = tdf.left_crop_url.iloc[idx]
right_image_url = tdf.right_crop_url.iloc[idx]
left_image_f, _, _ = s3.download_from_url(left_image_url)
right_image_f, _, _ = s3.download_from_url(right_image_url)

ann = tdf.annotation.iloc[idx]
left_keypoints = {item['keypointType']: [item['xCrop'], item['yCrop']] for item in ann['leftCrop']}
right_keypoints = {item['keypointType']: [item['xCrop'], item['yCrop']] for item in ann['rightCrop']}

display_crops(left_image_f, right_image_f, left_keypoints, right_keypoints)

In [None]:
tdf.akpd_score.iloc[idx]

In [None]:
left_keypoints

In [None]:
right_keypoints

In [None]:
tdf['left_crop_area'] = tdf.left_crop_metadata.apply(lambda x: x['width'] * x['height'])

In [None]:
tdf.left_crop_area.describe()

In [None]:
weights = np.arange(7000, 16000, 500)
for low_lca, high_lca in zip(weights, weights[1:]):
    mask = (tdf.estimated_weight_g > low_lca) & (tdf.estimated_weight_g < high_lca) & (tdf.akpd_score > 0.01)
    print(tdf[mask].akpd_score.mean())