In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from utils import rotate_and_reposition, get_2D_coords_from_3D, get_3D_coords_from_2D, jitter_2D_coords, \
    deg_to_rad, rad_to_deg, CameraMetadata, center_3D_coordinates, rotation_matrix_to_euler_angles

<h1> Run simple fish experiment </h1>

<h2> Generate base model </h2>

In [None]:
# Establish base 3D coordinates


base_3D_coordinates = np.array([
    [ 0.35766454,  1.16197692, -0.01824493],
    [ 0.30937074,  1.20844312, -0.10601642],
    [ 0.13337256,  1.20080287,  0.06284855],
    [-0.10068471,  1.17529033, -0.01376725],
    [-0.04026903,  1.18897391, -0.06134045],
    [ 0.16120037,  1.22363095, -0.09758986],
    [ 0.51754564,  1.11791647, -0.11590627],
    [-0.16064665,  1.16861833, -0.00902898]
])

# base_3D_coordinates[:, 0] = -base_3D_coordinates[:, 0]

In [None]:
camera_metadata = CameraMetadata(
    focal_length=0.013842509663066934,
    focal_length_pixel=4012.3216414686767,
    baseline_m=0.10079791852561114,
    pixel_count_width=4096,
    pixel_count_height=3000,
    image_sensor_width=0.01412,
    image_sensor_height=0.01035
)

In [None]:
def compute_orthonormal_basis(coords):
    """Given a set of jittered fish coordinates, approximate the orthonormal basis 
    corresponding to the new coordinate system that is axis-aligned with the coordinates."""
    
    u = coords[7] - coords[6]
    w = coords[2] - coords[5]
    u = u / np.linalg.norm(u)
    w = w - (np.dot(u, w))*u
    w = w / np.linalg.norm(w)
    v = np.cross(w, u)
    v = v / np.linalg.norm(v)
    return np.vstack((u, v, w)).T

In [None]:
B = compute_orthonormal_basis(base_3D_coordinates)
local_yaw, local_pitch, local_roll = rotation_matrix_to_euler_angles(B)

<h2> Validate data and functions via 2D / 3D rendering </h2>

In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import plotly.express as px

In [None]:
def transform_into_df(coords):
    df = pd.DataFrame({
        'x': list(coords[:, 0]),
        'y': list(coords[:, 1]),
        'z': list(coords[:, 2])
    })

    return df

In [None]:
scaling_factor_x, scaling_factor_y, scaling_factor_z = 1.0, 1.0, 1.0
rescaled_3D_coordinates = np.zeros_like(base_3D_coordinates)
rescaled_3D_coordinates[:] = base_3D_coordinates
rescaled_3D_coordinates[:, 0] *= scaling_factor_x
rescaled_3D_coordinates[:, 1] *= scaling_factor_y
rescaled_3D_coordinates[:, 2] *= scaling_factor_z

volume = scaling_factor_x * scaling_factor_y * scaling_factor_z

yaw_deg, pitch_deg, roll_deg = 0, 30, 30
yaw, pitch, roll = [deg_to_rad(x) for x in (yaw_deg, pitch_deg, roll_deg)]
new_centroid_position = [0.5, 0.5, 0]
repositioned_3D_coords = rotate_and_reposition(rescaled_3D_coordinates, yaw, pitch, roll, new_centroid_position)

# get local orientation (i.e. orientation of jittered key-points relative to 
# rotated coordinate system where depth axis passes through fish medoid)
centered_3D_coords = center_3D_coordinates(repositioned_3D_coords)
B = compute_orthonormal_basis(centered_3D_coords)
local_yaw, local_pitch, local_roll = rotation_matrix_to_euler_angles(B)

# get global orientation (i.e. coordinate system's depth axis is aligned with 
# camera line of sight)
B = compute_orthonormal_basis(repositioned_3D_coords)
global_yaw, global_pitch, global_roll = rotation_matrix_to_euler_angles(B)

In [None]:
print(np.array([local_yaw, local_pitch, local_roll]) * 180 / np.pi)
print(np.array([global_yaw, global_pitch, global_roll]) * 180 / np.pi)

In [None]:
df1 = transform_into_df(centered_3D_coords)
df2 = transform_into_df(repositioned_3D_coords)
fig = px.scatter_3d(df1, x='x', y='y', z='z')
fig.update_layout(scene=dict(xaxis=dict(),
           yaxis=dict(),
           zaxis=dict(),aspectmode='data'))
fig.show()

In [None]:
fig = px.scatter_3d(df2, x='x', y='y', z='z')
fig.update_layout(scene=dict(xaxis=dict(),
           yaxis=dict(),
           zaxis=dict(),aspectmode='data'))
fig.show()

In [None]:
X_left, X_right = get_2D_coords_from_3D(base_3D_coordinates, camera_metadata)
plt.scatter(X_left[:, 0], X_left[:, 1], color='blue')
plt.scatter(X_right[:, 0], X_right[:, 1], color='red')
plt.grid()
plt.show()

<h2> Generate large dataset </h2>

Notes on orientation: yaw, pitch, and roll are expressed in terms of Euler angles. Yaw, pitch, and roll are zero for a perfectly lateral fish facing right (i.e positive x-axis). 

Any given orientation can be described as a series of three rotations. The first rotation is yaw, which is the rotation of the fish about the z-axis (i.e. axis cutting vertically through fish from its top to bottom). Positive value means a counter-clockwise rotation looking from above down towards the upper-back of the fish i.e. dorsal fin. Second rotation is pitch, which is the axis cutting horizontally through the body of the fish from its right side to left side. Note that the axis of rotation here is defined AFTER the first yaw rotation. Positive value means a counter-clockwise rotation about the new positive y-axis i.e. fish is looking "up". Roll is the third rotation, which is rotation about the axis cutting longitudinally through the fish (i.e. tail notch through upper lip). A positive value means it is a counter-clockwise rotation about this axis, which is the new positive x-axis after the yaw and pitch rotations. 

In [None]:
from collections import defaultdict

In [None]:
volume_range = [0.5, 10.0]

yaw_range_deg = [-50, 50]
pitch_range_deg = [-50, 50]
roll_range_deg = [-50, 50]
centroid_range_x = [-0.5, 0.5]
centroid_range_y = [0.5, 1.5]
centroid_range_z = [-0.5, 0.5]

N = 5000
jitter_std = 10

dataset = defaultdict(list)
for t in range(N):
    
    volume = np.random.uniform(*volume_range)
    scaling_factor = volume**(1.0/3)
    rescaled_3D_coordinates = scaling_factor * base_3D_coordinates
    
    yaw = deg_to_rad(np.random.uniform(*yaw_range_deg))
    pitch = deg_to_rad(np.random.uniform(*pitch_range_deg))
    roll = deg_to_rad(np.random.uniform(*roll_range_deg))
    
    new_centroid_position = np.array([np.random.uniform(*x) for x in (centroid_range_x, centroid_range_y, centroid_range_z)])
    repositioned_3D_coords = rotate_and_reposition(rescaled_3D_coordinates, yaw, pitch, roll, new_centroid_position)
    repositioned_X_left, repositioned_X_right = get_2D_coords_from_3D(repositioned_3D_coords, camera_metadata)
    jittered_X_left, jittered_X_right = jitter_2D_coords(repositioned_X_left, repositioned_X_right, jitter_std)
    jittered_3D_coords = get_3D_coords_from_2D(jittered_X_left, jittered_X_right, camera_metadata)
    
    # get local orientation (i.e. orientation of jittered key-points relative to 
    # rotated coordinate system where depth axis passes through fish medoid)
    centered_3D_coords = center_3D_coordinates(jittered_3D_coords)
    B = compute_orthonormal_basis(centered_3D_coords)
    local_yaw, local_pitch, local_roll = rotation_matrix_to_euler_angles(B)
    
    # get global orientation (i.e. coordinate system's depth axis is aligned with 
    # camera line of sight)
    B = compute_orthonormal_basis(jittered_3D_coords)
    global_yaw, global_pitch, global_roll = rotation_matrix_to_euler_angles(B)
    
    dataset['X'].append(jittered_3D_coords.tolist())
    dataset['y'].append(volume)
    dataset['global_yaw'].append(rad_to_deg(yaw))
    dataset['global_pitch'].append(rad_to_deg(pitch))
    dataset['global_roll'].append(rad_to_deg(roll))
    dataset['local_yaw'].append(rad_to_deg(local_yaw))
    dataset['local_pitch'].append(rad_to_deg(local_pitch))
    dataset['local_roll'].append(rad_to_deg(local_roll))
    
    if t % 1000 == 0:
        print(t)
    
    
    

In [None]:
df = pd.DataFrame(dataset)

<h1> Train neural network architecture </h1>

In [None]:
from keras.layers import Input, Dense, Flatten
from keras.models import Model
import keras

In [None]:

def get_model():
    inputs = Input(shape=(24,))
    x = Dense(256, activation='relu')(inputs)
    x = Dense(128, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    pred = Dense(1)(x)
    model = Model(inputs, pred)
    return model


def train_model(model, X_train, y_train, X_val, y_val, train_config):
    epochs = train_config['epochs']
    batch_size = train_config['batch_size']
    lr = train_config['learning_rate']
    patience = train_config['patience']

    callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss',
                                               min_delta=0,
                                               patience=patience,
                                               verbose=0,
                                               mode='auto')]

    optimizer = keras.optimizers.Adam(learning_rate=lr)
    model.compile(optimizer=optimizer,
                  loss='mean_squared_error',
                  metrics=['accuracy'])
    model.fit(X_train, y_train, validation_data=(X_val, y_val), callbacks=callbacks,
              batch_size=batch_size, epochs=epochs)

    return model

In [None]:
train_pct, val_pct, test_pct = 0.6, 0.2, 0.2
train_idx = int(train_pct * df.shape[0])
val_idx = int((train_pct + val_pct) * df.shape[0])
train_mask = df.index < train_idx
val_mask = (df.index >= train_idx) & (df.index < val_idx)
test_mask = (df.index >= val_idx)

In [None]:
X_train = np.array(list(df[train_mask].X.values))
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1]*X_train.shape[2], -1)
y_train = df[train_mask].y.values

X_val = np.array(list(df[val_mask].X.values))
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1]*X_val.shape[2], -1)
y_val = df[val_mask].y.values

X_test = np.array(list(df[test_mask].X.values))
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1]*X_test.shape[2], -1)
y_test = df[test_mask].y.values


In [None]:
model = get_model()

train_config = dict(
    epochs=1000,
    batch_size=64, 
    learning_rate=1e-4,
    patience=30
)

train_model(model, X_train, y_train, X_val, y_val, train_config)

In [None]:
import torch
from torch import nn
from sklearn.linear_model import LinearRegression

class Network(nn.Module):
    """Network class defines neural-network architecture for both weight and k-factor estimation
    (currently both neural networks share identical architecture)."""

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        """Run inference on input keypoint tensor."""
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
    
    def forward_intermediate(self, x):
        """Run inference on input keypoint tensor and get final hiddel layer weights."""
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        return x
        

def convert_to_pytorch(model):
    pytorch_model = Network()
    weights = model.get_weights()

    pytorch_model.fc1.weight.data = torch.from_numpy(np.transpose(weights[0]))
    pytorch_model.fc1.bias.data = torch.from_numpy(np.transpose(weights[1]))
    pytorch_model.fc2.weight.data = torch.from_numpy(np.transpose(weights[2]))
    pytorch_model.fc2.bias.data = torch.from_numpy(np.transpose(weights[3]))
    pytorch_model.fc3.weight.data = torch.from_numpy(np.transpose(weights[4]))
    pytorch_model.fc3.bias.data = torch.from_numpy(np.transpose(weights[5]))
    pytorch_model.output.weight.data = torch.from_numpy(np.transpose(weights[6]))
    pytorch_model.output.bias.data = torch.from_numpy(np.transpose(weights[7]))
    
    return pytorch_model


def apply_final_layer_ols(pytorch_model):
    X_ols = pytorch_model.forward_intermediate(torch.from_numpy(X_train).float()).detach().numpy()
    lr = LinearRegression().fit(X_ols, y_train)
    pytorch_model.output.weight.data = torch.from_numpy(np.array(lr.coef_).reshape(1, -1))
    pytorch_model.output.bias.data = torch.from_numpy(np.array([lr.intercept_]))


In [None]:
pytorch_model = convert_to_pytorch(model)
apply_final_layer_ols(pytorch_model)

<h1> Accuracy Reporting </h1>

In [None]:
import seaborn as sns

In [None]:
X = np.array(list(df.X.values))
X = X.reshape(X.shape[0], X.shape[1]*X.shape[2], -1)

# y_pred = model.predict(X)
y_pred = (pytorch_model(torch.from_numpy(X).float())).detach().numpy().squeeze()
df['y_pred'] = y_pred
df['pct_error'] = (df.y_pred - df.y) / df.y

In [None]:
yaw_bucket_cutoffs = np.arange(-50, 55, 5)
pitch_bucket_cutoffs = np.arange(-50, 55, 5)
roll_bucket_cutoffs = np.arange(-50, 55, 5)

analysis_data = defaultdict(list)
for yaw_low, yaw_high in zip(yaw_bucket_cutoffs, yaw_bucket_cutoffs[1:]):
    for pitch_low, pitch_high in zip(pitch_bucket_cutoffs, pitch_bucket_cutoffs[1:]):
        for roll_low, roll_high in zip(roll_bucket_cutoffs, roll_bucket_cutoffs[1:]):
            
            yaw_bucket = '{} <-> {}'.format(yaw_low, yaw_high)
            pitch_bucket = '{} <-> {}'.format(pitch_low, pitch_high)
            roll_bucket = '{} <-> {}'.format(roll_low, roll_high)
            
            yaw_mask = (df.yaw > yaw_low) & (df.yaw < yaw_high)
            pitch_mask = (df.pitch > pitch_low) & (df.pitch < pitch_high)
            roll_mask = (df.roll > roll_low) & (df.roll < roll_high)
            orientation_mask = yaw_mask & pitch_mask & roll_mask
            mean_error_pct = (df[orientation_mask].y_pred.mean() - df[orientation_mask].y.mean()) / df[orientation_mask].y.mean()
            mean_abs_error_pct = np.mean(np.abs((df[orientation_mask].y_pred - df[orientation_mask].y) / df[orientation_mask].y))
            
            analysis_data['yaw_bucket'].append(yaw_bucket)
            analysis_data['pitch_bucket'].append(pitch_bucket)
            analysis_data['roll_bucket'].append(roll_bucket)
            analysis_data['mean_error_pct'].append(mean_error_pct)
            analysis_data['mean_abs_error_pct'].append(mean_abs_error_pct)
            
            
        

In [None]:
analysis_df = pd.DataFrame(analysis_data)

In [None]:
plt.hist(analysis_df.mean_error_pct)

In [None]:
def produce_heatmap(df, angle_1, angle_2, bucket_cutoffs_1, bucket_cutoffs_2):
    heatmap_arr = np.zeros([len(bucket_cutoffs_1) - 1, len(bucket_cutoffs_2) - 1])

    for i, angle_1_cutoffs in enumerate(zip(bucket_cutoffs_1, bucket_cutoffs_1[1:])):
        for j, angle_2_cutoffs in enumerate(zip(bucket_cutoffs_2, bucket_cutoffs_2[1:])):
            angle_1_low, angle_1_high = angle_1_cutoffs
            angle_2_low, angle_2_high = angle_2_cutoffs
            angle_1_mask = (df[angle_1] > angle_1_low) & (df[angle_1] < angle_1_high)
            angle_2_mask = (df[angle_2] > angle_2_low) & (df[angle_2] < angle_2_high)
            orientation_mask = angle_1_mask & angle_2_mask
            mean_error_pct = (df[orientation_mask].y_pred.mean() - df[orientation_mask].y.mean()) / df[orientation_mask].y.mean()
            heatmap_arr[i][j] = round(100 * mean_error_pct, 2)

    angle_1_buckets = []
    for i, angle_1_cutoffs in enumerate(zip(bucket_cutoffs_1, bucket_cutoffs_1[1:])):
        angle_1_low, angle_1_high = angle_1_cutoffs
        angle_1_bucket = '{} <-> {}'.format(angle_1_low, angle_1_high)
        angle_1_buckets.append(angle_1_bucket)

    angle_2_buckets = []
    for i, angle_2_cutoffs in enumerate(zip(bucket_cutoffs_2, bucket_cutoffs_2[1:])):
        angle_2_low, angle_2_high = angle_2_cutoffs
        angle_2_bucket = '{} <-> {}'.format(angle_2_low, angle_2_high)
        angle_2_buckets.append(angle_2_bucket)

    plt.figure(figsize=(15, 10))
    sns.heatmap(heatmap_arr, xticklabels=angle_1_buckets, yticklabels=angle_2_buckets, annot=True)
    plt.xlabel('{} range (degrees)'.format(angle_1))
    plt.ylabel('{} range (degrees)'.format(angle_2))
    plt.title('Error percentage (%) broken down by Orientation Bucket')
    plt.show()

In [None]:
for low_roll, high_roll in zip(roll_bucket_cutoffs, roll_bucket_cutoffs[1:]):
    print('Roll range: {} <-> {}'.format(low_roll, high_roll))
    roll_mask = (df.roll > low_roll) & (df.roll < high_roll)
    produce_heatmap(df[roll_mask].copy(deep=True), 'yaw', 'pitch', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

In [None]:
produce_heatmap(df[test_mask].copy(deep=True), 'yaw', 'roll', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

In [None]:
produce_heatmap(df, 'yaw', 'roll', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

In [None]:
produce_heatmap(df, 'pitch', 'roll', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde


# Generate fake data
x = df[test_mask].y.values
y = df[test_mask].y_pred.values

# Calculate the point density
xy = np.vstack([x,y])
z = gaussian_kde(xy)(xy)

fig, ax = plt.subplots(figsize=(15, 8))
ax.scatter(x, y, c=z, s=100, edgecolor='')
plt.grid()
plt.show()

<h1> Production Model Orientation Bias Assessment </h1>

In [None]:
from research_lib.utils.data_access_utils import S3AccessUtils
from weight_estimation.utils import normalize_left_right_keypoint_arrs, stabilize_keypoints, convert_to_world_point_arr
s3 = S3AccessUtils('/root/data')

In [None]:
weight_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb')
weight_model = Network()
weight_model.load_state_dict(torch.load(weight_model_f))

In [None]:
X = np.array(list(df.X.values))
X = X.reshape(X.shape[0], X.shape[1]*X.shape[2], -1)

y_pred = (weight_model(torch.from_numpy(X).float())).detach().numpy().squeeze()
df['y_pred'] = y_pred
df['pct_error'] = (df.y_pred - df.y) / df.y

In [None]:
base_y = 4680
preds, ys = [], []
count = 0
mask = (df.yaw.abs() < 5) & (df.pitch.abs() < 5) & (df.roll.abs() < 5)
for idx, row in df[mask].iterrows():
    X_left, X_right = get_2D_coords_from_3D(np.array(row.X), camera_metadata)
    X_left_norm, X_right_norm = normalize_left_right_keypoint_arrs(X_left, X_right)
    X_world = convert_to_world_point_arr(X_left_norm, X_right_norm, camera_metadata)
    X = stabilize_keypoints(X_world)
    nn_input = torch.from_numpy(np.array([X])).float()
    pred = 1e4 * weight_model(nn_input).item() / base_y
    preds.append(pred)
    ys.append(row.y)
    
    
    if count % 100 == 0:
        print(count)
    count += 1
    


In [None]:
yaws, pitches, rolls, preds = [], [], [], []

yaw_range_deg = [-50, 50]
pitch_range_deg = [-50, 50]
roll_range_deg = [-5, 5]

for i in range(100000):
    volume = 1.0
    scaling_factor = volume**(1.0/3)
    rescaled_3D_coordinates = scaling_factor * base_3D_coordinates
    
    yaw = deg_to_rad(np.random.uniform(*yaw_range_deg))
    pitch = deg_to_rad(np.random.uniform(*pitch_range_deg))
    roll = deg_to_rad(np.random.uniform(*roll_range_deg))
    
    new_centroid_position = np.array([np.random.uniform(*x) for x in (centroid_range_x, centroid_range_y, centroid_range_z)])
    repositioned_3D_coords = rotate_and_reposition(rescaled_3D_coordinates, yaw, pitch, roll, new_centroid_position)
    repositioned_X_left, repositioned_X_right = get_2D_coords_from_3D(repositioned_3D_coords, camera_metadata)
    jittered_X_left, jittered_X_right = jitter_2D_coords(repositioned_X_left, repositioned_X_right, 10)
    jittered_3D_coords = get_3D_coords_from_2D(jittered_X_left, jittered_X_right, camera_metadata)
    
    X_left, X_right = get_2D_coords_from_3D(jittered_3D_coords, camera_metadata)
    X_left_norm, X_right_norm = normalize_left_right_keypoint_arrs(X_left, X_right)
    X_world = convert_to_world_point_arr(X_left_norm, X_right_norm, camera_metadata)
    X = stabilize_keypoints(X_world)
    nn_input = torch.from_numpy(np.array([X])).float()
    pred = 1e4 * weight_model(nn_input).item() / base_y
    preds.append(pred)
    yaws.append(rad_to_deg(yaw))
    pitches.append(rad_to_deg(pitch))
    rolls.append(rad_to_deg(roll))
    

In [None]:
kdf = pd.DataFrame({
    'y_pred': preds,
    'yaw': yaws,
    'pitch': pitches,
    'roll': rolls
})

kdf['y'] = 1

In [None]:
yaw_bucket_cutoffs = np.arange(-50, 55, 5)
pitch_bucket_cutoffs = np.arange(-50, 55, 5)
roll_bucket_cutoffs = np.arange(-5, 5, 5)


In [None]:
produce_heatmap(kdf, 'yaw', 'pitch', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

<h1> Repeat training - reproduce plots above to see how picture changes when only random seed is altered </h1>

In [None]:
model = get_model()

train_config = dict(
    epochs=1000,
    batch_size=64, 
    learning_rate=1e-4,
    patience=30
)

train_model(model, X_train, y_train, X_val, y_val, train_config)

In [None]:
X = np.array(list(df.X.values))
X = X.reshape(X.shape[0], X.shape[1]*X.shape[2], -1)

y_pred = model.predict(X)
df['y_pred'] = y_pred
df['pct_error'] = (df.y_pred - df.y) / df.y

In [None]:
produce_heatmap(df, 'yaw', 'pitch', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

In [None]:
produce_heatmap(df, 'yaw', 'roll', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

In [None]:
produce_heatmap(df, 'pitch', 'roll', yaw_bucket_cutoffs, pitch_bucket_cutoffs)