In [None]:
import datetime as dt
import json, os
import pandas as pd
from matplotlib import pyplot as plt
from collections import defaultdict
import numpy as np
from itertools import combinations
from aquabyte.accuracy_metrics import AccuracyMetricsGenerator
from aquabyte.data_access_utils import S3AccessUtils, RDSAccessUtils
from aquabyte.visualize import Visualizer, _normalize_world_keypoints
from aquabyte.optics import euclidean_distance, pixel2world, depth_from_disp, convert_to_world_point
import random
import torch
from aquabyte.data_loader import KeypointsDataset, NormalizeCentered2D, ToTensor, BODY_PARTS
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from sklearn.model_selection import train_test_split
from copy import copy, deepcopy
import pyarrow.parquet as pq
from scipy.spatial import Delaunay
from pyobb.obb import OBB
from mpl_toolkits.mplot3d import Axes3D

pd.set_option('display.max_rows', 500)

In [None]:
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
query = """
    select * from research.fish_metadata a left join keypoint_annotations b
    on a.left_url = b.left_image_url 
    where b.keypoints -> 'leftCrop' is not null
    and b.keypoints -> 'rightCrop' is not null
    and b.is_qa = false
    and b.captured_at < '2019-09-20';
"""
df = rds_access_utils.extract_from_database(query)

In [None]:
blacklisted_keypoint_annotation_ids = [
    606484, 
    635806, 
    637801, 
    508773, 
    640493, 
    639409, 
    648536, 
    507003,
    706002,
    507000,
    709298,
    714073,
    719239
]

df = df[~df.id.isin(blacklisted_keypoint_annotation_ids)]

<h1> Append World Keypoints to this Data </h1>

In [None]:
def get_world_keypoints(row):
    if 'leftCrop' in row.keypoints and 'rightCrop' in row.keypoints:
        return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)
    else:
        return None
    
def is_well_behaved(wkps, cutoff_depth=10.0):
    if any([abs(wkp[1]) > cutoff_depth for wkp in wkps.values()]):
        return False
    return True

df['world_keypoints'] = df.apply(
    lambda x: get_world_keypoints(x), axis=1
)

is_well_behaved_mask = df.world_keypoints.apply(lambda x: is_well_behaved(x))
df = df[is_well_behaved_mask]

<h1> Add template matchingr results to this base dataset </h1>

In [None]:
s3_access_utils = S3AccessUtils('/root/data')

gen = s3_access_utils.get_matching_s3_keys('aquabyte-research', prefix='template-matching/2019-12-05T02:50:57', suffixes=['.parquet'])
keys = []
for key in gen:
    keys.append(key)

f = s3_access_utils.download_from_s3('aquabyte-research', keys[0])
pdf = pd.read_parquet(f)

In [None]:
pdf['homography'] = pdf.homography_and_matches.apply(lambda x: np.array(x[0].tolist(), dtype=np.float))
pdf['matches'] = pdf.homography_and_matches.apply(lambda x: np.array(x[1].tolist(), dtype=np.int) if len(x) > 1 else None)


In [None]:
df = pd.merge(df, pdf[['left_image_url', 'homography', 'matches']], how='inner', on='left_image_url')


<h1> OBB Class </h1>

In [None]:
from numpy import ndarray, array, asarray, dot, cross, cov, array, finfo, min as npmin, max as npmax
from numpy.linalg import eigh, norm


########################################################################################################################
# adapted from: http://jamesgregson.blogspot.com/2011/03/latex-test.html
########################################################################################################################
class OBB:
    def __init__(self):
        self.rotation = None
        self.min = None
        self.max = None

    def transform(self, point):
        return dot(array(point), self.rotation)

    @property
    def centroid(self):
        return self.transform((self.min + self.max) / 2.0)

    @property
    def extents(self):
        return abs(self.transform((self.max - self.min) / 2.0))

    @property
    def points(self):
        return [
            # upper cap: ccw order in a right-hand system
            # rightmost, topmost, farthest
            self.transform((self.max[0], self.max[1], self.min[2])),
            # leftmost, topmost, farthest
            self.transform((self.min[0], self.max[1], self.min[2])),
            # leftmost, topmost, closest
            self.transform((self.min[0], self.max[1], self.max[2])),
            # rightmost, topmost, closest
            self.transform(self.max),
            # lower cap: cw order in a right-hand system
            # leftmost, bottommost, farthest
            self.transform(self.min),
            # rightmost, bottommost, farthest
            self.transform((self.max[0], self.min[1], self.min[2])),
            # rightmost, bottommost, closest
            self.transform((self.max[0], self.min[1], self.max[2])),
            # leftmost, bottommost, closest
            self.transform((self.min[0], self.min[1], self.max[2])),
        ]

    @classmethod
    def build_from_covariance_matrix(cls, covariance_matrix, points):
        if not isinstance(points, ndarray):
            points = array(points, dtype=float)
        assert points.shape[1] == 3

        obb = OBB()

        _, eigen_vectors = eigh(covariance_matrix)

        def try_to_normalize(v):
            n = norm(v)
            if n < finfo(float).resolution:
                raise ZeroDivisionError
            return v / n

        r = try_to_normalize(eigen_vectors[:, 0])
        u = try_to_normalize(eigen_vectors[:, 1])
        f = try_to_normalize(eigen_vectors[:, 2])

        obb.rotation = array((r, u, f)).T

        # apply the rotation to all the position vectors of the array
        # TODO : this operation could be vectorized with tensordot
        p_primes = asarray([obb.rotation.dot(p) for p in points])
        obb.min = npmin(p_primes, axis=0)
        obb.max = npmax(p_primes, axis=0)

        return obb, eigen_vectors

    @classmethod
    def build_from_points(cls, points):
        if not isinstance(points, ndarray):
            points = array(points, dtype=float)
        assert points.shape[1] == 3, 'points have to have 3-elements'
        # no need to store the covariance matrix
        return OBB.build_from_covariance_matrix(cov(points, y=None, rowvar=0, bias=1), points)

<h1> Add Convex Hull Filtration and Volume Computation </h1>

In [None]:
%matplotlib inline
idx = 30
row = df[~(df.left_image_url.str.contains('aquabyte-crops'))].iloc[idx]
X_body = np.array(row.matches)
X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['rightCrop']])
plt.figure(figsize=(20, 10))
plt.scatter(X_body[:, 2], X_body[:, 3], color='blue')
plt.scatter(X_keypoints[:, 0], X_keypoints[:, 1], color='red')
plt.grid()
plt.show()

In [None]:
def in_hull(p, hull):
    hull = Delaunay(hull)
    return hull.find_simplex(p)>=0

In [None]:
row = df.iloc[3]
X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['leftCrop']])
X_body = np.array(row.matches)
is_valid = in_hull(X_body[:, :2], X_keypoints)
X_body = X_body[np.where(is_valid)]

# generate 3D point cloud
cm = row.camera_metadata
all_wkps = []
for i in range(X_body.shape[0]):
    d = depth_from_disp(abs(X_body[i, 2] - X_body[i, 0]), cm)
    wkp = convert_to_world_point(X_body[i, 0], X_body[i, 1], d, cm)
    all_wkps.append(list(wkp))
    
additional_wkps = pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], cm)
all_wkps.extend([list(additional_wkps[bp]) for bp in additional_wkps.keys()])

obb, eigen_vectors = OBB.build_from_points(all_wkps)
obb_points = np.array(obb.points)


In [None]:
%matplotlib notebook
fig = plt.figure()
ax = Axes3D(fig)

# get x, y, and z lists
x_values = list(obb_points[:,0])
y_values = list(obb_points[:,1])
z_values = list(obb_points[:,2])

x_values.extend(list(np.array(all_wkps)[:,0]))
y_values.extend(list(np.array(all_wkps)[:,1]))
z_values.extend(list(np.array(all_wkps)[:,2]))

x_values = np.array(x_values)
y_values = np.array(y_values)
z_values = np.array(z_values)


ax.scatter(x_values, y_values, z_values)
for point_pair in [(0, 1), (1, 2), (2, 3), (3, 0), \
                   (4, 5), (5, 6), (6, 7), (7, 4), \
                   (0, 5), (1, 4), (2, 7), (3, 6)]:
    i, j = point_pair
    edge_x_values = [obb_points[i][0], obb_points[j][0]]
    edge_y_values = [obb_points[i][1], obb_points[j][1]]
    edge_z_values = [obb_points[i][2], obb_points[j][2]]
    ax.plot(edge_x_values, edge_y_values, edge_z_values, color='red')
    
# Create cubic bounding box to simulate equal aspect ratio
max_range = np.array([x_values.max()-x_values.min(), y_values.max()-y_values.min(), z_values.max()-z_values.min()]).max()
Xb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][0].flatten() + 0.5*(x_values.max()+x_values.min())
Yb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][1].flatten() + 0.5*(y_values.max()+y_values.min())
Zb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][2].flatten() + 0.5*(z_values.max()+z_values.min())
# Comment or uncomment following both lines to test the fake bounding box:
for xb, yb, zb in zip(Xb, Yb, Zb):
    ax.plot([xb], [yb], [zb], 'w')
plt.show()

<h1> Train Neural Network </h1>

In [None]:
def convert_to_world_point(x, y, d, parameters):
    """ from pixel coordinates to world coordinates """
    # get relevant parameters
    pixel_count_width = parameters["pixelCountWidth"]
    pixel_count_height = parameters["pixelCountHeight"]
    sensor_width = parameters["imageSensorWidth"]
    sensor_height = parameters["imageSensorHeight"]
    focal_length = parameters["focalLength"]

    image_center_x = pixel_count_width / 2.0
    image_center_y = pixel_count_height / 2.0
    px_x = x - image_center_x
    px_z = image_center_y - y

    sensor_x = px_x * (sensor_width / pixel_count_width)
    sensor_z = px_z * (sensor_height / pixel_count_height)

    # now move to world coordinates
    world_y = d
    world_x = (world_y * sensor_x) / focal_length
    world_z = (world_y * sensor_z) / focal_length
    return np.array([world_x, world_y, world_z])


def depth_from_disp(disp, parameters):
    """ calculate the depth of the point based on the disparity value """
    focal_length_pixel = parameters["focalLengthPixel"]

    baseline = parameters["baseline"]
    depth = focal_length_pixel * baseline / np.array(disp)
    return depth


class NormalizeCentered2D(object):
    
    """
    Transforms the 2D left and right keypoints such that:
        (1) The center of the left image 2D keypoints is located at the center of the left image
            (i.e. 2D translation)
        (2) The left image keypoints are possibly flipped such that the upper-lip x-coordinate 
            is greater than the tail-notch coordinate. This is done to reduce the total number of 
            spatial orientations the network must learn from -> reduces the training size
        (3) The left image keypoints are then rotated such that upper-lip is located on the x-axis.
            As in (2), this is done to reduce the total number of spatial orientations the network 
            must learn from -> reduces the training size
        (4) Rescale all left image keypoints by some random number between 'lo' and 'hi' args
        (5) Apply Gaussian random noise "jitter" to each keypoint to mimic annotation error
        (5) For all transformations above, the right image keypoint coordinates are accordingly
            transformed such that the original disparity values are preserved for all keypoints
            (or adjusted during rescaling event)
    """


    def flip_center_kps(self, left_kps, right_kps):

        x_min_l = min([left_kps[bp][0] for bp in BODY_PARTS])
        x_max_l = max([left_kps[bp][0] for bp in BODY_PARTS])
        x_mid_l = np.mean([x_min_l, x_max_l])

        y_min_l = min([left_kps[bp][1] for bp in BODY_PARTS])
        y_max_l = max([left_kps[bp][1] for bp in BODY_PARTS])
        y_mid_l = np.mean([y_min_l, y_max_l])

        x_min_r = min([right_kps[bp][0] for bp in BODY_PARTS])
        x_max_r = max([right_kps[bp][0] for bp in BODY_PARTS])
        x_mid_r = np.mean([x_min_r, x_max_r])

        y_min_r = min([right_kps[bp][1] for bp in BODY_PARTS])
        y_max_r = max([right_kps[bp][1] for bp in BODY_PARTS])
        y_mid_r = np.mean([y_min_r, y_max_r])

        fc_left_kps, fc_right_kps = {}, {}
        flip_factor = 1 if left_kps['UPPER_LIP'][0] > left_kps['TAIL_NOTCH'][0] else -1
        for bp in BODY_PARTS:
            left_kp, right_kp = left_kps[bp], right_kps[bp]
            if flip_factor > 0:
                fc_left_kp = np.array([left_kp[0] - x_mid_l, left_kp[1] - y_mid_l])
                fc_right_kp = np.array([right_kp[0] - x_mid_l, right_kp[1] - y_mid_l])
            else:
                fc_right_kp = np.array([x_mid_r - left_kp[0], left_kp[1] - y_mid_r])
                fc_left_kp = np.array([x_mid_r - right_kp[0], right_kp[1] - y_mid_r])
            fc_left_kps[bp] = fc_left_kp
            fc_right_kps[bp] = fc_right_kp

        if 'BODY' in left_kps.keys():
            left_body_kps, right_body_kps = np.array(left_kps['BODY']), np.array(right_kps['BODY'])
            if flip_factor > 0:
                fc_left_body_kps = left_body_kps - np.array([x_mid_l, y_mid_l])
                fc_right_body_kps = right_body_kps - np.array([x_mid_l, y_mid_l])
            else:
                fc_left_body_kps = np.dot(left_body_kps - np.array([x_mid_r, y_mid_r]), np.array([[-1, 0], [0, 1]]))
                fc_right_body_kps = np.dot(right_body_kps - np.array([x_mid_r, y_mid_r]), np.array([[-1, 0], [0, 1]]))
            fc_left_kps['BODY'] = fc_left_body_kps
            fc_right_kps['BODY'] = fc_right_body_kps
        
        return fc_left_kps, fc_right_kps


    def _rotate_cc(self, p, theta):
        R = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])

        rotated_kp = np.dot(R, p)
        return rotated_kp


    def rotate_kps(self, left_kps, right_kps):
        upper_lip_x, upper_lip_y = left_kps['UPPER_LIP']
        theta = np.arctan(upper_lip_y / upper_lip_x)
        r_left_kps, r_right_kps = {}, {}
        for bp in BODY_PARTS:
            rotated_kp = self._rotate_cc(left_kps[bp], -theta)
            r_left_kps[bp] = rotated_kp
            disp = abs(left_kps[bp][0] - right_kps[bp][0])
            r_right_kps[bp] = np.array([rotated_kp[0] - disp, rotated_kp[1]])
            
        if 'BODY' in left_kps.keys():
            left_body_kps, right_body_kps = np.array(left_kps['BODY']), np.array(right_kps['BODY'])
            r_left_body_kps = self._rotate_cc(left_body_kps.T, -theta).T
            disp = np.abs(left_body_kps[:, 0] - right_body_kps[:, 0])
            r_right_body_kps = np.column_stack([r_left_body_kps[:, 0] - disp, r_left_body_kps[:, 1]])
            r_left_kps['BODY'] = r_left_body_kps
            r_right_kps['BODY'] = r_right_body_kps

        return r_left_kps, r_right_kps


    def scale_kps(self, left_kps, right_kps, factor):
        s_left_kps, s_right_kps = {}, {}
        for bp in BODY_PARTS:
            left_kp, right_kp = left_kps[bp], right_kps[bp]
            s_left_kps[bp] = factor * np.array(left_kps[bp])
            s_right_kps[bp] = factor * np.array(right_kps[bp])
        
        if 'BODY' in left_kps.keys():
            left_body_kps, right_body_kps = np.array(left_kps['BODY']), np.array(right_kps['BODY'])
            s_left_body_kps = factor * left_body_kps
            s_right_body_kps = factor * right_body_kps
            s_left_kps['BODY'] = s_left_body_kps
            s_right_kps['BODY'] = s_right_body_kps
            
        return s_left_kps, s_right_kps


    def jitter_kps(self, left_kps, right_kps, jitter):
        j_left_kps, j_right_kps = {}, {}
        for bp in BODY_PARTS:
            j_left_kps[bp] = np.array([left_kps[bp][0] + np.random.normal(0, jitter), 
                                       left_kps[bp][1] + np.random.normal(0, jitter)])
            j_right_kps[bp] = np.array([right_kps[bp][0] + np.random.normal(0, jitter), 
                                        right_kps[bp][1] + np.random.normal(0, jitter)])
            
        if 'BODY' in left_kps.keys():
            j_left_kps['BODY'] = left_kps['BODY']
            j_right_kps['BODY'] = right_kps['BODY']

        return j_left_kps, j_right_kps



    def modify_kps(self, left_kps, right_kps, factor, jitter, cm, rotate=True, center=False):
        fc_left_kps, fc_right_kps = self.flip_center_kps(left_kps, right_kps)
        if rotate:
            r_left_kps, r_right_kps = self.rotate_kps(fc_left_kps, fc_right_kps)
            s_left_kps, s_right_kps = self.scale_kps(r_left_kps, r_right_kps, factor)
        else:
            s_left_kps, s_right_kps = self.scale_kps(fc_left_kps, fc_right_kps, factor)
        j_left_kps, j_right_kps  = self.jitter_kps(s_left_kps, s_right_kps, jitter)
        j_left_kps_list, j_right_kps_list = [], []
        if not center:
            for bp in BODY_PARTS:
                l_item = {
                    'keypointType': bp,
                    'xFrame': j_left_kps[bp][0] + cm['pixelCountWidth'] / 2.0,
                    'yFrame': j_left_kps[bp][1] + cm['pixelCountHeight'] / 2.0
                }

                r_item = {
                    'keypointType': bp,
                    'xFrame': j_right_kps[bp][0] + cm['pixelCountWidth'] / 2.0,
                    'yFrame': j_right_kps[bp][1] + cm['pixelCountHeight'] / 2.0
                }

                j_left_kps_list.append(l_item)
                j_right_kps_list.append(r_item)
                
            if 'BODY' in left_kps.keys():
                l_item = {
                    'keypointType': 'BODY',
                    'xFrame': j_left_kps['BODY'][:, 0] + cm['pixelCountWidth'] / 2.0,
                    'yFrame': j_left_kps['BODY'][:, 1] + cm['pixelCountHeight'] / 2.0
                }
                
                r_item = {
                    'keypointType': 'BODY',
                    'xFrame': j_right_kps['BODY'][:, 0] + cm['pixelCountWidth'] / 2.0,
                    'yFrame': j_right_kps['BODY'][:, 1] + cm['pixelCountHeight'] / 2.0
                }
                
                j_left_kps_list.append(l_item)
                j_right_kps_list.append(r_item)
            
        else:
            for bp in BODY_PARTS:
                l_item = {
                    'keypointType': bp,
                    'xFrame': j_left_kps[bp][0],
                    'yFrame': j_left_kps[bp][1]
                }

                r_item = {
                    'keypointType': bp,
                    'xFrame': j_right_kps[bp][0],
                    'yFrame': j_right_kps[bp][1]
                }

                j_left_kps_list.append(l_item)
                j_right_kps_list.append(r_item)
                
            if 'BODY' in left_kps.keys():
                l_item = {
                    'keypointType': 'BODY',
                    'xFrame': j_left_kps['BODY'][:, 0],
                    'yFrame': j_left_kps['BODY'][:, 1]
                }
                
                r_item = {
                    'keypointType': 'BODY',
                    'xFrame': j_right_kps['BODY'][:, 0],
                    'yFrame': j_right_kps['BODY'][:, 1]
                }
                
                j_left_kps_list.append(l_item)
                j_right_kps_list.append(r_item)


        modified_kps = {
            'leftCrop': j_left_kps_list,
            'rightCrop': j_right_kps_list
        }

        return modified_kps
    
    def __init__(self, lo=None, hi=None, jitter=0.0, rotate=True, center=False):
        self.lo = lo
        self.hi = hi
        self.jitter = jitter
        self.rotate = rotate
        self.center = center


    def __call__(self, sample):
        keypoints, cm, stereo_pair_id, label = \
            sample['keypoints'], sample['cm'], sample.get('stereo_pair_id'), sample.get('label')
        left_keypoints_list = keypoints['leftCrop']
        right_keypoints_list = keypoints['rightCrop']
        left_kps = {item['keypointType']: np.column_stack([item['xFrame'], item['yFrame']]) \
                    for item in left_keypoints_list if item['keypointType'] == 'BODY'}
        right_kps = {item['keypointType']: np.column_stack([item['xFrame'], item['yFrame']]) \
                    for item in right_keypoints_list if item['keypointType'] == 'BODY'}
        left_kps.update({item['keypointType']: np.array([item['xFrame'], item['yFrame']]) for item in left_keypoints_list if item['keypointType'] != 'BODY'})
        right_kps.update({item['keypointType']: np.array([item['xFrame'], item['yFrame']]) for item in right_keypoints_list if item['keypointType'] != 'BODY'})
        factor = 1.0 
        if self.lo and self.hi:
            factor = np.random.uniform(low=self.lo, high=self.hi)

        jitter = np.random.uniform(high=self.jitter)

        modified_kps = self.modify_kps(left_kps, right_kps, factor, jitter, cm, 
            rotate=self.rotate, center=self.center)

        kp_input = {}
        for idx, _ in enumerate(modified_kps['leftCrop']):
            left_item, right_item = modified_kps['leftCrop'][idx], modified_kps['rightCrop'][idx]
            bp = left_item['keypointType']
            kp_input[bp] = [left_item['xFrame'], left_item['yFrame'], right_item['xFrame'], right_item['yFrame']]


        transformed_sample = {
            'kp_input': kp_input,
            'modified_kps': modified_kps,
            'label': label,
            'stereo_pair_id': stereo_pair_id,
            'cm': cm,
            'single_point_inference': sample.get('single_point_inference')
        }

        return transformed_sample
    


In [None]:
class WorldKeypointTransform(object):
    """
        Transforms into world keypoints
    """
    
    def __call__(self, sample):
        modified_kps, label, stereo_pair_id, cm = \
            sample['modified_kps'], sample['label'], sample['stereo_pair_id'], sample['cm']
        
        modified_wkps = pixel2world([item for item in modified_kps['leftCrop'] if item['keypointType'] != 'BODY'], 
                                    [item for item in modified_kps['rightCrop'] if item['keypointType'] != 'BODY'],
                                    cm)
        
        # compute BODY world keypoint coordinates
        if 'BODY' in modified_wkps.keys():
            disps = np.abs(modified_kps['leftCrop']['BODY'][:, 0] - modified_kps['rightCrop']['BODY'][:, 0])
            focal_length_pixel = cm["focalLengthPixel"]
            baseline = cm["baseline"]
            depths = focal_length_pixel * baseline / np.array(disps)

            pixel_count_width = parameters["pixelCountWidth"]
            pixel_count_height = parameters["pixelCountHeight"]
            sensor_width = parameters["imageSensorWidth"]
            sensor_height = parameters["imageSensorHeight"]
            focal_length = parameters["focalLength"]

            image_center_x = pixel_count_width / 2.0
            image_center_y = pixel_count_height / 2.0
            x = modified_kps['leftCrop']['BODY'][:, 0]
            y = modified_kps['leftCrop']['BODY'][:, 1]
            px_x = x - image_center_x
            px_z = image_center_y - y

            sensor_x = px_x * (sensor_width / pixel_count_width)
            sensor_z = px_z * (sensor_height / pixel_count_height)

            world_y = depths
            world_x = (world_y * sensor_x) / focal_length
            world_z = (world_y * sensor_z) / focal_length
            modified_wkps['BODY'] = np.column_stack([world_x, world_y, world_z])
        
        
        transformed_sample = {
            'modified_wkps': modified_wkps,
            'label': label,
            'stereo_pair_id': stereo_pair_id,
            'single_point_inference': sample.get('single_point_inference')
        }
        
        return transformed_sample
    
class PrismTransform(object):
    
    def __call__(self, sample):
        modified_wkps, label, stereo_pair_id = \
            sample['modified_wkps'], sample['label'], sample['stereo_pair_id']
        
        all_wkps = [list(modified_wkps[bp]) for bp in BODY_PARTS]
        if 'BODY' in modified_wkps.keys():
            all_wkps.extend([list(wkp) for wkp in list(modified_wkps['BODY'])])
        obb, eigen_vectors = OBB.build_from_points(all_wkps)
        obb_points = np.array(obb.points)
        obb_points_dict = {'BP{}'.format(idx): p for idx, p in enumerate(obb_points)}
        
        normalized_label = label * 1e-4 if label else None
        
        transformed_sample = {
            'kp_input': obb_points_dict,
            'modified_wkps': modified_wkps,
            'label': normalized_label,
            'stereo_pair_id': stereo_pair_id,
            'single_point_inference': sample.get('single_point_inference')
        }
        
        return transformed_sample
        
        
class ToTensor(object):
    
    def __call__(self, sample):
        kp_input, label, stereo_pair_id = \
            sample['kp_input'], sample.get('label'), sample.get('stereo_pair_id')
        
        x = []
        for bp in kp_input.keys():
            kp_data = kp_input[bp]
            x.append(kp_data)
        if sample.get('single_point_inference'):
            x = np.array([x])
        else:
            x = np.array(x)
        
        kp_input_tensor = torch.from_numpy(x).float()
        
        tensorized_sample = {
            'kp_input': kp_input_tensor
        }

        if label:
            label_tensor = torch.from_numpy(np.array([label])).float() if label else None
            tensorized_sample['label'] = label_tensor

        if stereo_pair_id:
            tensorized_sample['stereo_pair_id'] = stereo_pair_id


        
        return tensorized_sample
        

In [None]:
train_dataset = KeypointsDataset(df[train_mask], transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=0.3, hi=2.0, jitter=10),
                                                  WorldKeypointTransform(),
                                                  PrismTransform(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=1)

In [None]:
sample = {
    'keypoints': df.keypoints.iloc[0],
    'stereo_pair_id': 0,
    'cm': df.camera_metadata.iloc[0],
}

In [None]:
modified_keypoints_list = []
count = 0
for idx, row in df.iterrows():
    if count % 100 == 0:
        print(count)
    count += 1
    X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['leftCrop']])
    X_body = np.array(row.matches)
    is_valid = in_hull(X_body[:, :2], X_keypoints)
    X_body = X_body[np.where(is_valid)]
    
    keypoints = deepcopy(row.keypoints)
    left_keypoints, right_keypoints = keypoints['leftCrop'], keypoints['rightCrop']
    left_item = {
        'keypointType': 'BODY',
        'xFrame': X_body[:, 0],
        'yFrame': X_body[:, 1]
    }
    
    right_item = {
        'keypointType': 'BODY',
        'xFrame': X_body[:, 2],
        'yFrame': X_body[:, 3]
    }
    
    left_keypoints.append(left_item)
    right_keypoints.append(right_item)
    modified_keypoints = {
        'leftCrop': left_keypoints,
        'rightCrop': right_keypoints
    }

    modified_keypoints_list.append(modified_keypoints)

df['old_keypoints'] = df.keypoints
df['keypoints'] = modified_keypoints_list

In [None]:
gtsf_fish_identifiers = list(df.fish_id.unique())
train_size = int(0.8 * len(gtsf_fish_identifiers))
fish_ids = random.sample(gtsf_fish_identifiers, train_size)
date_mask = (df.captured_at < '2019-09-10')
train_mask = date_mask & df.fish_id.isin(fish_ids)
test_mask = date_mask & ~df.fish_id.isin(fish_ids)

In [None]:
train_dataset = KeypointsDataset(df[train_mask], transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=0.3, hi=2.0, jitter=10),
                                                  WorldKeypointTransform(),
                                                  PrismTransform(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=1)

test_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=0.3, hi=2.0, jitter=10),
                                                  WorldKeypointTransform(),
                                                  PrismTransform(),
                                                  ToTensor()
                                              ]))

test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
# TODO: Define your network architecture here
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
        



In [None]:
write_outputs = False

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=0.00001)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
#     # run on test set
#     else:
#         test_running_loss = 0.0
#         with torch.no_grad():
#             network.eval()
#             for i, data_batch in enumerate(test_dataloader):
#                 X_batch, y_batch, kpid_batch = \
#                     data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
#                 y_pred = network(X_batch)
#                 loss = criterion(y_pred, y_batch)
#                 test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
#     test_loss_for_epoch = test_running_loss / len(test_dataloader)
#     train_losses.append(train_loss_for_epoch)
#     test_losses.append(test_loss_for_epoch)
    
#     # save current state of network
#     if write_outputs:
#         f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
#         f_path = os.path.join(output_dir, f_name)
#         torch.save(network, f_path)
    
#     # print current loss values
#     print('-'*20)
#     print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
#     print('Test Loss: {}'.format(test_loss_for_epoch))
    
    
