In [None]:
%load_ext autoreload
%autoreload 2

from collections import defaultdict
import json
from matplotlib import pyplot as plt
import pandas as pd
from utils import *

<h1> Load lateral, SIFT corrected GTSF data </h1>

In [None]:
f = '/root/data/alok/biomass_estimation/playground/filtered_lateral_sift_corrected_gtsf.csv'
df = pd.read_csv(f)

In [None]:
X_list = []
for idx, row in df.iterrows():
#     ann = json.loads(row.modified_ann.replace("'", '"'))
    ann = json.loads(row.keypoints.replace("'", '"'))
    camera_metadata = json.loads(row.camera_metadata.replace("'", '"'))
    cm = CameraMetadata(
        focal_length=camera_metadata['focalLength'],
        focal_length_pixel=camera_metadata['focalLengthPixel'],
        baseline_m=camera_metadata['baseline'],
        pixel_count_width=camera_metadata['pixelCountWidth'],
        pixel_count_height=camera_metadata['pixelCountHeight'],
        image_sensor_width=camera_metadata['imageSensorWidth'],
        image_sensor_height=camera_metadata['imageSensorHeight']
    )
    
    X_left, X_right = get_2D_coords_from_ann(ann)
    X = get_3D_coords_from_2D(X_left, X_right, cm)
    X_list.append(X)
    
df['modified_X'] = X_list

In [None]:
plt.scatter(df.modified_X.apply(lambda x: np.linalg.norm(x[6]-x[7])), df.weight)

<h1> Generate Training Dataset </h1>

In [None]:
def compute_orthonormal_basis(coords):
    """Given a set of jittered fish coordinates, approximate the orthonormal basis 
    corresponding to the new coordinate system that is axis-aligned with the coordinates."""
    
    u = coords[7] - coords[6] # vector from TAIL_NOTCH to UPPER_LIP
    w = coords[2] - coords[5] # vector from PELVIC_FIN to DORSAL_FIN
    u = u / np.linalg.norm(u)
    w = w - (np.dot(u, w))*u
    w = w / np.linalg.norm(w)
    v = np.cross(w, u)
    v = v / np.linalg.norm(v)
    return np.vstack((u, v, w)).T

In [None]:
N = 500
yaw_range_deg = [-30, 30]
pitch_range_deg = [-50, 50]
roll_range_deg = [-30, 30]
centroid_range_x = [-0.0, 0.0]
centroid_range_y = [0.5, 1.5]
centroid_range_z = [-0.0, 0.0]

jitter_std = 5

dataset = defaultdict(list)

count = 0
for idx, row in df.iterrows():
    camera_metadata_dict = json.loads(row.camera_metadata.replace("'", '"'))
    camera_metadata = CameraMetadata(
        focal_length=camera_metadata_dict['focalLength'],
        focal_length_pixel=camera_metadata_dict['focalLengthPixel'],
        baseline_m=camera_metadata_dict['baseline'],
        pixel_count_width=camera_metadata_dict['pixelCountWidth'],
        pixel_count_height=camera_metadata_dict['pixelCountHeight'],
        image_sensor_width=camera_metadata_dict['imageSensorWidth'],
        image_sensor_height=camera_metadata_dict['imageSensorHeight']
    )
    for n in range(N):
        volume_factor = np.random.uniform(1.0, 1.0)
        scaling_factor = volume_factor**(1.0/3)
        X = scaling_factor * row.modified_X
        y = row.weight * volume_factor * 1e-4
        
        yaw = deg_to_rad(np.random.uniform(*yaw_range_deg))
        pitch = deg_to_rad(np.random.uniform(*pitch_range_deg))
        roll = deg_to_rad(np.random.uniform(*roll_range_deg))
        
        new_centroid_position = np.array([np.random.uniform(*x) for x in (centroid_range_x, centroid_range_y, centroid_range_z)])
        repositioned_3D_coords = rotate_and_reposition(X, yaw, pitch, roll, new_centroid_position)
        repositioned_X_left, repositioned_X_right = get_2D_coords_from_3D(repositioned_3D_coords, camera_metadata)
        jittered_X_left, jittered_X_right = jitter_2D_coords(repositioned_X_left, repositioned_X_right, jitter_std)
        jittered_3D_coords = get_3D_coords_from_2D(jittered_X_left, jittered_X_right, camera_metadata)

        # get local orientation (i.e. orientation of jittered key-points relative to 
        # rotated coordinate system where depth axis passes through fish medoid)
        centered_3D_coords = center_3D_coordinates(jittered_3D_coords)
        B = compute_orthonormal_basis(centered_3D_coords)
        local_yaw, local_pitch, local_roll = rotation_matrix_to_euler_angles(B)

        # get global orientation (i.e. coordinate system's depth axis is aligned with 
        # camera line of sight)
        B = compute_orthonormal_basis(jittered_3D_coords)
        global_yaw, global_pitch, global_roll = rotation_matrix_to_euler_angles(B)
        
#         dataset['X'].append(jittered_3D_coords.tolist())
        dataset['X'].append(jittered_3D_coords)
        dataset['y'].append(y)
        dataset['global_yaw'].append(rad_to_deg(yaw))
        dataset['global_pitch'].append(rad_to_deg(pitch))
        dataset['global_roll'].append(rad_to_deg(roll))
        dataset['local_yaw'].append(rad_to_deg(local_yaw))
        dataset['local_pitch'].append(rad_to_deg(local_pitch))
        dataset['local_roll'].append(rad_to_deg(local_roll))
        
        if count % 10000 == 0:
            print(count)
        count += 1

In [None]:
df = pd.DataFrame(dataset)

<h1> Train neural network </h1>

In [None]:
from keras.layers import Input, Dense, Flatten
from keras.models import Model
import keras

In [None]:
def get_model():
    inputs = Input(shape=(24,))
#     x = Dense(64, activation='relu')(inputs)
#     x = Dense(128, activation='relu')(inputs)
    x = Dense(256, activation='relu')(inputs)
    x = Dense(128, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    pred = Dense(1)(x)
    model = Model(inputs, pred)
    return model


def train_model(model, X_train, y_train, X_val, y_val, train_config):
    epochs = train_config['epochs']
    batch_size = train_config['batch_size']
    lr = train_config['learning_rate']
    patience = train_config['patience']

    callbacks = [keras.callbacks.EarlyStopping(monitor='val_loss',
                                               min_delta=0,
                                               patience=patience,
                                               verbose=0,
                                               mode='auto')]

    optimizer = keras.optimizers.Adam(learning_rate=lr)
    model.compile(optimizer=optimizer,
                  loss='mean_squared_error',
                  metrics=['accuracy'])
    model.fit(X_train, y_train, validation_data=(X_val, y_val), callbacks=callbacks,
              batch_size=batch_size, epochs=epochs)

    return model

In [None]:
train_pct, val_pct, test_pct = 0.6, 0.2, 0.2
train_idx = int(train_pct * df.shape[0])
val_idx = int((train_pct + val_pct) * df.shape[0])
train_mask = df.index < train_idx
val_mask = (df.index >= train_idx) & (df.index < val_idx)
test_mask = (df.index >= val_idx)

In [None]:
X_train = np.array(list(df[train_mask].X.values))
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1]*X_train.shape[2], -1)
y_train = df[train_mask].y.values

X_val = np.array(list(df[val_mask].X.values))
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1]*X_val.shape[2], -1)
y_val = df[val_mask].y.values

X_test = np.array(list(df[test_mask].X.values))
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1]*X_test.shape[2], -1)
y_test = df[test_mask].y.values


In [None]:
model = get_model()

train_config = dict(
    epochs=1000,
    batch_size=64, 
    learning_rate=1e-4,
    patience=15
)

model = train_model(model, X_train, y_train, X_val, y_val, train_config)

In [None]:
import torch
from torch import nn
from sklearn.linear_model import LinearRegression

class Network(nn.Module):
    """Network class defines neural-network architecture for both weight and k-factor estimation
    (currently both neural networks share identical architecture)."""

    def __init__(self):
        super().__init__()
#         self.fc1 = nn.Linear(24, 64)
#         self.fc2 = nn.Linear(64, 128)
#         self.fc3 = nn.Linear(128, 256)
#         self.fc4 = nn.Linear(256, 128)
#         self.fc5 = nn.Linear(128, 64)
#         self.output = nn.Linear(64, 1)
#         self.relu = nn.ReLU()
        
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        """Run inference on input keypoint tensor."""
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
#         x = self.fc4(x)
#         x = self.relu(x)
#         x = self.fc5(x)
#         x = self.relu(x)
        x = self.output(x)
        return x
    
    def forward_intermediate(self, x):
        """Run inference on input keypoint tensor and get final hiddel layer weights."""
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
#         x = self.fc4(x)
#         x = self.relu(x)
#         x = self.fc5(x)
#         x = self.relu(x)
        
        return x
        

def convert_to_pytorch(model):
    pytorch_model = Network()
    weights = model.get_weights()

    pytorch_model.fc1.weight.data = torch.from_numpy(np.transpose(weights[0]))
    pytorch_model.fc1.bias.data = torch.from_numpy(np.transpose(weights[1]))
    pytorch_model.fc2.weight.data = torch.from_numpy(np.transpose(weights[2]))
    pytorch_model.fc2.bias.data = torch.from_numpy(np.transpose(weights[3]))
    pytorch_model.fc3.weight.data = torch.from_numpy(np.transpose(weights[4]))
    pytorch_model.fc3.bias.data = torch.from_numpy(np.transpose(weights[5]))
#     pytorch_model.fc4.weight.data = torch.from_numpy(np.transpose(weights[6]))
#     pytorch_model.fc4.bias.data = torch.from_numpy(np.transpose(weights[7]))
#     pytorch_model.fc5.weight.data = torch.from_numpy(np.transpose(weights[8]))
#     pytorch_model.fc5.bias.data = torch.from_numpy(np.transpose(weights[9]))
    pytorch_model.output.weight.data = torch.from_numpy(np.transpose(weights[6]))
    pytorch_model.output.bias.data = torch.from_numpy(np.transpose(weights[7]))
    
    return pytorch_model


def apply_final_layer_ols(pytorch_model):
    X_ols = pytorch_model.forward_intermediate(torch.from_numpy(X_train).float()).detach().numpy()
    lr = LinearRegression().fit(X_ols, y_train)
    pytorch_model.output.weight.data = torch.from_numpy(np.array(lr.coef_).reshape(1, -1))
    pytorch_model.output.bias.data = torch.from_numpy(np.array([lr.intercept_]))


In [None]:
pytorch_model = convert_to_pytorch(model)
apply_final_layer_ols(pytorch_model)

In [None]:
len(model.get_weights())

<h1> Accuracy Reporting </h1>

In [None]:
import seaborn as sns

In [None]:
X = np.array(list(df.X.values))
X = X.reshape(X.shape[0], X.shape[1]*X.shape[2], -1)

# y_pred = model.predict(X)
y_pred = (pytorch_model(torch.from_numpy(X).float())).detach().numpy().squeeze()
df['y_pred'] = y_pred
df['pct_error'] = (df.y_pred - df.y) / df.y

In [None]:
plt.figure(figsize=(20, 10))
plt.scatter(df[train_mask].y, df[train_mask].y_pred)
plt.grid()
plt.show()

In [None]:
def produce_heatmap(df, angle_1, angle_2, bucket_cutoffs_1, bucket_cutoffs_2):
    heatmap_arr = np.zeros([len(bucket_cutoffs_1) - 1, len(bucket_cutoffs_2) - 1])

    for i, angle_1_cutoffs in enumerate(zip(bucket_cutoffs_1, bucket_cutoffs_1[1:])):
        for j, angle_2_cutoffs in enumerate(zip(bucket_cutoffs_2, bucket_cutoffs_2[1:])):
            angle_1_low, angle_1_high = angle_1_cutoffs
            angle_2_low, angle_2_high = angle_2_cutoffs
            angle_1_mask = (df[angle_1] > angle_1_low) & (df[angle_1] < angle_1_high)
            angle_2_mask = (df[angle_2] > angle_2_low) & (df[angle_2] < angle_2_high)
            orientation_mask = angle_1_mask & angle_2_mask
            mean_error_pct = (df[orientation_mask].y_pred.mean() - df[orientation_mask].y.mean()) / df[orientation_mask].y.mean()
            heatmap_arr[i][j] = round(100 * mean_error_pct, 2)

    angle_1_buckets = []
    for i, angle_1_cutoffs in enumerate(zip(bucket_cutoffs_1, bucket_cutoffs_1[1:])):
        angle_1_low, angle_1_high = angle_1_cutoffs
        angle_1_bucket = '{} <-> {}'.format(angle_1_low, angle_1_high)
        angle_1_buckets.append(angle_1_bucket)

    angle_2_buckets = []
    for i, angle_2_cutoffs in enumerate(zip(bucket_cutoffs_2, bucket_cutoffs_2[1:])):
        angle_2_low, angle_2_high = angle_2_cutoffs
        angle_2_bucket = '{} <-> {}'.format(angle_2_low, angle_2_high)
        angle_2_buckets.append(angle_2_bucket)

    plt.figure(figsize=(15, 10))
    sns.heatmap(heatmap_arr, xticklabels=angle_1_buckets, yticklabels=angle_2_buckets, annot=True)
    plt.xlabel('{} range (degrees)'.format(angle_1))
    plt.ylabel('{} range (degrees)'.format(angle_2))
    plt.title('Error percentage (%) broken down by Orientation Bucket')
    plt.show()

In [None]:
yaw_bucket_cutoffs = np.arange(-30, 35, 5)
pitch_bucket_cutoffs = np.arange(-50, 55, 5)
roll_bucket_cutoffs = np.arange(-30, 35, 5)

for low_roll, high_roll in zip(roll_bucket_cutoffs, roll_bucket_cutoffs[1:]):
    print('Roll range: {} <-> {}'.format(low_roll, high_roll))
    roll_mask = (df.global_roll > low_roll) & (df.global_roll < high_roll)
    produce_heatmap(df[roll_mask].copy(deep=True), 'global_yaw', 'global_pitch', yaw_bucket_cutoffs, pitch_bucket_cutoffs)

<h1> Toy fish - orientation bias experiment </h1>

In [None]:
f = '/root/data/alok/biomass_estimation/playground/sample_toy_fish_dataset.csv'
tdf = pd.read_csv(f)


In [None]:
X_list = []
local_yaws, local_pitches, local_rolls = [], [], []

for idx, row in tdf.iterrows():
    ann = json.loads(row.ann.replace("'", '"'))
    camera_metadata_dict = json.loads(row.camera_metadata.replace("'", '"'))
    camera_metadata = CameraMetadata(
        focal_length=camera_metadata_dict['focalLength'],
        focal_length_pixel=camera_metadata_dict['focalLengthPixel'],
        baseline_m=camera_metadata_dict['baseline'],
        pixel_count_width=camera_metadata_dict['pixelCountWidth'],
        pixel_count_height=camera_metadata_dict['pixelCountHeight'],
        image_sensor_width=camera_metadata_dict['imageSensorWidth'],
        image_sensor_height=camera_metadata_dict['imageSensorHeight']
    )
    
    X_left, X_right = get_2D_coords_from_ann(ann)
    X = get_3D_coords_from_2D(X_left, X_right, camera_metadata)
    
    centered_3D_coords = center_3D_coordinates(X)
    B = compute_orthonormal_basis(centered_3D_coords)
    local_yaw, local_pitch, local_roll = rotation_matrix_to_euler_angles(B)
    local_yaws.append(local_yaw)
    local_pitches.append(local_pitch)
    local_rolls.append(local_roll)
    
    X_list.append(centered_3D_coords)
    
tdf['X'] = X_list
tdf['local_yaw'] = local_yaws
tdf['local_pitch'] = local_pitches
tdf['local_roll'] = local_rolls

X = np.array(list(tdf.X.values))
X = X.reshape(X.shape[0], X.shape[1]*X.shape[2], -1)

# y_pred = model.predict(X)
y_pred = (pytorch_model(torch.from_numpy(X).float())).detach().numpy().squeeze()
tdf['pred_weight'] = y_pred * 1e4


<h1> Visualize weight predictions vs. local yaw, local pitch, and local roll </h1>

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 20))
axes[0].scatter(tdf.local_yaw.values * 180 / np.pi, tdf.pred_weight.values)
axes[1].scatter(tdf.local_pitch.values * 180 / np.pi, tdf.pred_weight.values)
axes[2].scatter(tdf.local_roll.values * 180 / np.pi, tdf.pred_weight.values)
[axes[i].grid() for i in range(3)]
plt.show()



In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 20))
axes[0].scatter(tdf.local_yaw.values * 180 / np.pi, tdf.weight.values)
axes[1].scatter(tdf.local_pitch.values * 180 / np.pi, tdf.weight.values)
axes[2].scatter(tdf.local_roll.values * 180 / np.pi, tdf.weight.values)
[axes[i].grid() for i in range(3)]
plt.show()

