In [None]:
import datetime as dt
import json, os
import pandas as pd
from matplotlib import pyplot as plt
from collections import defaultdict
import numpy as np
from itertools import combinations
from aquabyte.accuracy_metrics import AccuracyMetricsGenerator
from aquabyte.data_access_utils import S3AccessUtils, RDSAccessUtils
from aquabyte.visualize import Visualizer, _normalize_world_keypoints
from aquabyte.optics import euclidean_distance, pixel2world, depth_from_disp, convert_to_world_point
import random
import torch
from aquabyte.data_loader import KeypointsDataset, NormalizeCentered2D, ToTensor, BODY_PARTS
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from sklearn.model_selection import train_test_split
from copy import copy, deepcopy
import pyarrow.parquet as pq
from scipy.spatial import Delaunay
from pyobb.obb import OBB
from mpl_toolkits.mplot3d import Axes3D

pd.set_option('display.max_rows', 500)

In [None]:
BODY_PARTS = sorted([
    'UPPER_LIP',
    'ADIPOSE_FIN',
    'TAIL_NOTCH',
    'EYE',
    'PELVIC_FIN',
    'PECTORAL_FIN',
    'UPPER_PRECAUDAL_PIT',
    'LOWER_PRECAUDAL_PIT',
    'HYPURAL_PLATE',
    'DORSAL_FIN',
    'ANAL_FIN'
])

In [None]:
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
query = """
    select * from research.fish_metadata a left join keypoint_annotations b
    on a.left_url = b.left_image_url 
    where b.keypoints -> 'leftCrop' is not null
    and b.keypoints -> 'rightCrop' is not null
    and b.is_qa = false
    and b.captured_at < '2019-09-20';
"""
df = rds_access_utils.extract_from_database(query)

In [None]:
blacklisted_keypoint_annotation_ids = [
    606484, 
    635806, 
    637801, 
    508773, 
    640493, 
    639409, 
    648536, 
    507003,
    706002,
    507000,
    709298,
    714073,
    719239
]

df = df[~df.id.isin(blacklisted_keypoint_annotation_ids)]

<h1> Append World Keypoints to this Data </h1>

In [None]:
def get_world_keypoints(row):
    if 'leftCrop' in row.keypoints and 'rightCrop' in row.keypoints:
        return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)
    else:
        return None
    
def is_well_behaved(wkps, cutoff_depth=10.0):
    if any([abs(wkp[1]) > cutoff_depth for wkp in wkps.values()]):
        return False
    return True

df['world_keypoints'] = df.apply(
    lambda x: get_world_keypoints(x), axis=1
)

is_well_behaved_mask = df.world_keypoints.apply(lambda x: is_well_behaved(x))
df = df[is_well_behaved_mask]

<h1> Add template matching results to this base dataset </h1>

In [None]:
s3_access_utils = S3AccessUtils('/root/data')

gen = s3_access_utils.get_matching_s3_keys('aquabyte-research', prefix='template-matching/2019-12-05T02:50:57', suffixes=['.parquet'])
keys = []
for key in gen:
    keys.append(key)

f = s3_access_utils.download_from_s3('aquabyte-research', keys[0])
pdf = pd.read_parquet(f)

In [None]:
pdf['homography'] = pdf.homography_and_matches.apply(lambda x: np.array(x[0].tolist(), dtype=np.float))
pdf['matches'] = pdf.homography_and_matches.apply(lambda x: np.array(x[1].tolist(), dtype=np.int) if len(x) > 1 else None)
df = pd.merge(df, pdf[['left_image_url', 'homography', 'matches']], how='inner', on='left_image_url')


<h1> Add Body Keypoints </h1>

In [None]:
def in_hull(p, hull):
    hull = Delaunay(hull)
    return hull.find_simplex(p)>=0

modified_keypoints_list = []
count = 0
for idx, row in df.iterrows():
    if count % 100 == 0:
        print(count)
    count += 1
    X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['leftCrop']])
    X_body = np.array(row.matches)
    is_valid = in_hull(X_body[:, :2], X_keypoints)
    X_body = X_body[np.where(is_valid)]
    
    keypoints = deepcopy(row.keypoints)
    left_keypoints, right_keypoints = keypoints['leftCrop'], keypoints['rightCrop']
    left_item = {
        'keypointType': 'BODY',
        'xFrame': X_body[:, 0],
        'yFrame': X_body[:, 1]
    }
    
    right_item = {
        'keypointType': 'BODY',
        'xFrame': X_body[:, 2],
        'yFrame': X_body[:, 3]
    }
    
    left_keypoints.append(left_item)
    right_keypoints.append(right_item)
    modified_keypoints = {
        'leftCrop': left_keypoints,
        'rightCrop': right_keypoints
    }

    modified_keypoints_list.append(modified_keypoints)

df['old_keypoints'] = df.keypoints
df['keypoints'] = modified_keypoints_list


df = df[df.keypoints.apply(lambda x: x['leftCrop'][-1]['xFrame'].shape[0]) > 500]

<h1> Construct Point Cloud Data Transform </h1>

In [None]:
def in_hull(p, hull):
    hull = Delaunay(hull)
    return hull.find_simplex(p)>=0

def get_raw_3D_coordinates(keypoints, cm):
    wkps = pixel2world([item for item in keypoints['leftCrop'] if item['keypointType'] != 'BODY'], 
                       [item for item in keypoints['rightCrop'] if item['keypointType'] != 'BODY'],
                       cm)
    
    # compute BODY world keypoint coordinates
    if 'BODY' in [item['keypointType'] for item in keypoints['leftCrop']]:
        left_item = [item for item in keypoints['leftCrop'] if item['keypointType'] == 'BODY'][0]
        right_item = [item for item in keypoints['rightCrop'] if item['keypointType'] == 'BODY'][0]
        disps = np.abs(left_item['xFrame'] - right_item['xFrame'])
        focal_length_pixel = cm["focalLengthPixel"]
        baseline = cm["baseline"]
        depths = focal_length_pixel * baseline / np.array(disps)

        pixel_count_width = cm["pixelCountWidth"]
        pixel_count_height = cm["pixelCountHeight"]
        sensor_width = cm["imageSensorWidth"]
        sensor_height = cm["imageSensorHeight"]
        focal_length = cm["focalLength"]

        image_center_x = pixel_count_width / 2.0
        image_center_y = pixel_count_height / 2.0
        x = left_item['xFrame']
        y = left_item['yFrame']
        px_x = x - image_center_x
        px_z = image_center_y - y

        sensor_x = px_x * (sensor_width / pixel_count_width)
        sensor_z = px_z * (sensor_height / pixel_count_height)

        world_y = depths
        world_x = (world_y * sensor_x) / focal_length
        world_z = (world_y * sensor_z) / focal_length
        wkps['BODY'] = np.column_stack([world_x, world_y, world_z])
        
    
    all_wkps = [list(wkps[bp]) for bp in BODY_PARTS]
    if 'BODY' in wkps.keys():
        random.seed(0)
        body_wkps = random.sample([list(wkp) for wkp in list(wkps['BODY'])], 500)
        all_wkps.extend(body_wkps)
    return np.array(all_wkps)
    

def _generate_rotation_matrix(n, theta):

    R = np.array([[
        np.cos(theta) + n[0]**2*(1-np.cos(theta)), 
        n[0]*n[1]*(1-np.cos(theta)) - n[2]*np.sin(theta),
        n[0]*n[2]*(1-np.cos(theta)) + n[1]*np.sin(theta)
    ], [
        n[1]*n[0]*(1-np.cos(theta)) + n[2]*np.sin(theta),
        np.cos(theta) + n[1]**2*(1-np.cos(theta)),
        n[1]*n[2]*(1-np.cos(theta)) - n[0]*np.sin(theta),
    ], [
        n[2]*n[0]*(1-np.cos(theta)) - n[1]*np.sin(theta),
        n[2]*n[1]*(1-np.cos(theta)) + n[0]*np.sin(theta),
        np.cos(theta) + n[2]**2*(1-np.cos(theta))
    ]])
    
    return R

def normalize_3D_coordinates(wkps):
    
    # translate keypoints such that medoid is at origin
    wkps = wkps - 0.5*(np.max(wkps[:8], axis=0) + np.min(wkps[:8], axis=0))

    # perform rotation
    upper_lip_idx = BODY_PARTS.index('UPPER_LIP')
    
    n = np.array([0, 1, 0])
    theta = np.arctan(wkps[upper_lip_idx][2] / wkps[upper_lip_idx][0])
    R = _generate_rotation_matrix(n, theta)
    wkps = np.dot(R, wkps.T).T
    
    # perform reflecton if necessary
    tail_notch_idx = BODY_PARTS.index('TAIL_NOTCH')
    if wkps[upper_lip_idx][0] < wkps[tail_notch_idx][0]:
        R = np.array([
            [-1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ])
        wkps = np.dot(R, wkps.T).T
    
    return wkps
    

def jitter_wkps(wkps, cm):
    wkps_jittered = []
    for idx, body_part in enumerate(BODY_PARTS):
        jitter = 0
        if body_part in ['TAIL_NOTCH', 'HYPURAL_PLATE', 'UPPER_PRECAUDAL_PIT', 'UPPER_LIP', 'DORSAL_FIN']:
            jitter = 10
        x_p_left = wkps[idx, 0] * cm['focalLengthPixel'] / wkps[idx, 1]
        y_p_left = wkps[idx, 2] * cm['focalLengthPixel'] / wkps[idx, 1]
        disparity = cm['focalLengthPixel'] * cm['baseline'] / wkps[idx, 1]
        x_p_left_jitter = np.random.normal(0, jitter)
        x_p_right_jitter = np.random.normal(0, jitter)
        disparity_jitter = x_p_left_jitter + x_p_right_jitter
        
        x_p_left_jittered = x_p_left + x_p_left_jitter
        disparity_jittered = disparity + disparity_jitter
        depth_jittered = cm['focalLengthPixel'] * cm['baseline'] / disparity_jittered
        x_jittered = x_p_left_jittered * depth_jittered / cm['focalLengthPixel']
        y_jittered = wkps[idx, 2]
        wkp_jittered = [x_jittered, depth_jittered, y_jittered]
        wkps_jittered.append(wkp_jittered)
    wkps_jittered.append(wkps[len(BODY_PARTS), :].tolist())
    wkps_jittered = np.array(wkps_jittered)
    return wkps_jittered


# def jitter_wkps(wkps, cm, base_jitter=5):
#     wkps_jittered = []
#     for idx in range(len(BODY_PARTS)):
#         x_p_left = wkps[idx, 0] * cm['focalLengthPixel'] / wkps[idx, 1]
#         y_p_left = wkps[idx, 2] * cm['focalLengthPixel'] / wkps[idx, 1]
#         disparity = cm['focalLengthPixel'] * cm['baseline'] / wkps[idx, 1]
#         x_p_left_jitter = np.random.normal(0, base_jitter)
#         x_p_right_jitter = np.random.normal(0, base_jitter)
#         disparity_jitter = x_p_left_jitter + x_p_right_jitter
        
#         x_p_left_jittered = x_p_left + x_p_left_jitter
#         disparity_jittered = disparity + disparity_jitter
#         depth_jittered = cm['focalLengthPixel'] * cm['baseline'] / disparity_jittered
#         x_jittered = x_p_left_jittered * depth_jittered / cm['focalLengthPixel']
#         y_jittered = wkps[idx, 2]
#         wkp_jittered = [x_jittered, depth_jittered, y_jittered]
#         wkps_jittered.append(wkp_jittered)
#     wkps_jittered.append(wkps[len(BODY_PARTS), :].tolist())
#     wkps_jittered = np.array(wkps_jittered)
#     return wkps_jittered


def get_augmented_keypoints(keypoints, cm, base_depth=0.5):
#     dorsal_fin_idx, pelvic_fin_idx = BODY_PARTS.index('DORSAL_FIN'), BODY_PARTS.index('PELVIC_FIN')
    dorsal_fin_idx, pelvic_fin_idx = BODY_PARTS.index('UPPER_LIP'), BODY_PARTS.index('HYPURAL_PLATE')
    wkps = get_raw_3D_coordinates(keypoints, cm)
    norm_wkps = normalize_3D_coordinates(wkps)
    body_norm_wkps = norm_wkps[len(BODY_PARTS):, :]
    mid_point = 0.5*(norm_wkps[dorsal_fin_idx] + norm_wkps[pelvic_fin_idx])
    idx = np.argmin(np.linalg.norm(body_norm_wkps[:, [0, 2]] - np.array([mid_point[0], mid_point[2]]), axis=1))
    body_wkp = body_norm_wkps[idx]
    
    augmented_wkps = np.vstack([norm_wkps[:len(BODY_PARTS), :], body_wkp])
    return augmented_wkps



def get_jittered_keypoints(wkps, cm):    
    # put at random depth and apply jitter
    depth = np.random.uniform(low=0.5, high=2.0)
    wkps[:, 1] = wkps[:, 1] + depth
    
    # apply jitter
    jittered_wkps = jitter_wkps(wkps, cm)
    return jittered_wkps
    
    # normalize
    final_wkps = np.column_stack([0.5 * jittered_wkps[:, 0] / jittered_wkps[:, 1], 
                            0.5 * jittered_wkps[:, 2] / jittered_wkps[:, 1], 
                            0.05 / jittered_wkps[:, 1]])
    return final_wkps


    

In [None]:
augmented_keypoints_list = []
for idx, row in df.iterrows():
    keypoints, cm = row.keypoints, row.camera_metadata
    augmented_keypoints = get_augmented_keypoints(keypoints, cm)
    augmented_keypoints_list.append(augmented_keypoints)
df['augmented_keypoints'] = augmented_keypoints_list

In [None]:
wkps = df.world_keypoints.iloc[1]
augmented_wkps = df.augmented_keypoints.iloc[1]
cm = df.camera_metadata.iloc[1]
jittered_wkps = get_jittered_keypoints(deepcopy(augmented_wkps), cm)
idx_0, idx_1 = 2, 7
print(euclidean_distance(wkps[BODY_PARTS[idx_0]], wkps[BODY_PARTS[idx_1]]))
print(euclidean_distance(augmented_wkps[idx_0], augmented_wkps[idx_1]))
print(euclidean_distance(jittered_wkps[idx_0], jittered_wkps[idx_1]))

<h1> Train Neural Network </h1>

In [None]:
class KeypointsDataset(Dataset):
    """Keypoints dataset
    This is the base version of the dataset that is used to map 3D keypoints to a
    biomass estimate. The label is the weight, and the input is the 3D workd keypoints
    obtained during triangulation
    """

    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform


    def __len__(self):
        return self.df.shape[0]


    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.transform:
            input_sample = {
                'kp_input': row.augmented_keypoints,
                'cm': row.camera_metadata,
                'stereo_pair_id': row.id,
            }
            if 'weight' in dict(row).keys():
                input_sample['label'] = row.weight
            sample = self.transform(input_sample)
            return sample

        world_keypoints = row.world_keypoints
        weight = row.weight

        sample = {'kp_input': world_keypoints, 'label': weight, 'stereo_pair_id': row.id}

        return sample

class NormalizedCentered3D(object):
    
    def __init__(self):
        pass

    def __call__(self, sample):
        augmented_wkps, cm, stereo_pair_id, label = \
            sample['kp_input'], sample['cm'], sample.get('stereo_pair_id'), sample.get('label')
    
        jittered_wkps = get_jittered_keypoints(augmented_wkps, cm)
        normalized_label = label * 1e-4
        
        transformed_sample = {
            'kp_input': jittered_wkps,
            'label': normalized_label,
            'stereo_pair_id': stereo_pair_id,
            'cm': cm,
            'single_point_inference': sample.get('single_point_inference')
        }

        return transformed_sample
    
class ToTensor(object):
    
    def __call__(self, sample):
        x, label, stereo_pair_id = \
            sample['kp_input'], sample.get('label'), sample.get('stereo_pair_id')
        
        if sample.get('single_point_inference'):
            x = np.array([x])
        else:
            x = np.array(x)
        
        kp_input_tensor = torch.from_numpy(x).float()
        
        tensorized_sample = {
            'kp_input': kp_input_tensor
        }

        if label:
            label_tensor = torch.from_numpy(np.array([label])).float() if label else None
            tensorized_sample['label'] = label_tensor

        if stereo_pair_id:
            tensorized_sample['stereo_pair_id'] = stereo_pair_id

        
        return tensorized_sample
        

<h1> Define train and test data loaders </h1>

In [None]:
gtsf_fish_identifiers = list(df.fish_id.unique())
train_size = int(0.8 * len(gtsf_fish_identifiers))
fish_ids = random.sample(gtsf_fish_identifiers, train_size)
date_mask = (df.captured_at < '2019-09-10')
train_mask = date_mask & df.fish_id.isin(fish_ids)
test_mask = date_mask & ~df.fish_id.isin(fish_ids)

train_dataset = KeypointsDataset(df[train_mask], transform=transforms.Compose([
                                                  NormalizedCentered3D(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=1)

test_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                      NormalizedCentered3D(),
                                                      ToTensor()
                                                  ]))

test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
for data in train_dataloader:
    new_wkps = data['kp_input']
    break

In [None]:
%matplotlib inline
plt.figure(figsize=(20, 10))
plt.scatter(new_wkps[1][:, 0], new_wkps[1][:, 2], color='red')
plt.show()

In [None]:
# TODO: Define your network architecture here
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(36, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
        


In [None]:
run_name = 'batch_25_with_scaled_jitter_v1'
write_outputs = False

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, run_name, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
    # run on test set
    else:
        test_running_loss = 0.0
        with torch.no_grad():
            network.eval()
            for i, data_batch in enumerate(test_dataloader):
                X_batch, y_batch, kpid_batch = \
                    data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
                y_pred = network(X_batch)
                loss = criterion(y_pred, y_batch)
                test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
    test_loss_for_epoch = test_running_loss / len(test_dataloader)
    train_losses.append(train_loss_for_epoch)
    test_losses.append(test_loss_for_epoch)
    
    # save current state of network
    if write_outputs:
        f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
        f_path = os.path.join(output_dir, f_name)
        torch.save(network, f_path)
    
    # print current loss values
    print('-'*20)
    print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
    print('Test Loss: {}'.format(test_loss_for_epoch))
    
    


In [None]:
plt.figure(figsize=(20, 10))
plt.plot(range(len(train_losses)), train_losses, color='blue', label='training loss')
plt.plot(range(len(test_losses)), test_losses, color='orange', label='validation loss')
plt.ylim([0, 0.01])
plt.xlabel('Epoch')
plt.ylabel('Loss value (MSE)')
plt.title('Loss curves (MSE - Adam optimizer)')
plt.legend()
plt.grid()
plt.show()

In [None]:
predictions = []
weights = []
for i, data_batch in enumerate(test_dataloader):
    X_batch, y_batch, kpid_batch = \
        data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
    y_pred = network(X_batch)
    p = [1e4*x[0]**0.5 for x in y_pred.tolist()]
    w = [1e4*x[0]**0.5 for x in y_batch.tolist()]
    predictions.extend(p)
    weights.extend(w)

In [None]:
plt.figure(figsize=(20, 10))
plt.scatter(weights, predictions)
plt.show()

In [None]:
amg = AccuracyMetricsGenerator()
amg.generate_accuracy_metrics(np.array(predictions), np.array(weights), w=np.ones(len(predictions)))

In [None]:
np.percentile(np.abs((np.array(predictions) - np.array(weights))/np.array(weights)), 95)

In [None]:
for data in train_dataloader:
    x = data['kp_input']
    break
    

In [None]:
all_wkps = data['kp_input'][1].numpy()

In [None]:
%matplotlib notebook
fig = plt.figure()
ax = Axes3D(fig)

# get x, y, and z lists

x_values = all_wkps[:,0]
y_values = all_wkps[:,1]
z_values = all_wkps[:,2]

ax.scatter(x_values, y_values, z_values)

# Create cubic bounding box to simulate equal aspect ratio
max_range = np.array([x_values.max()-x_values.min(), y_values.max()-y_values.min(), z_values.max()-z_values.min()]).max()
Xb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][0].flatten() + 0.5*(x_values.max()+x_values.min())
Yb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][1].flatten() + 0.5*(y_values.max()+y_values.min())
Zb = 0.5*max_range*np.mgrid[-1:2:2,-1:2:2,-1:2:2][2].flatten() + 0.5*(z_values.max()+z_values.min())
# Comment or uncomment following both lines to test the fake bounding box:
for xb, yb, zb in zip(Xb, Yb, Zb):
    ax.plot([xb], [yb], [zb], 'w')


plt.show()

In [None]:
kps = df[df.id == 710764].keypoints.iloc[0]
cm = df[df.id == 710764].camera_metadata.iloc[0]
wkps = pixel2world(kps['leftCrop'], kps['rightCrop'], cm)

In [None]:
euclidean_distance(wkps['UPPER_LIP'], wkps['EYE'])

In [None]:
euclidean_distance(all_wkps[3], all_wkps[7])

In [None]:
BODY_PARTS

In [None]:
sample = {
    'keypoints': df.keypoints.iloc[0],
    'stereo_pair_id': 0,
    'cm': df.camera_metadata.iloc[0],
}

In [None]:
np.mean(np.array([[1, 2, 3], [4, 5, 6]]), axis=0)

In [None]:
modified_keypoints_list = []
count = 0
for idx, row in df.iterrows():
    if count % 100 == 0:
        print(count)
    count += 1
    X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['leftCrop']])
    X_body = np.array(row.matches)
    is_valid = in_hull(X_body[:, :2], X_keypoints)
    X_body = X_body[np.where(is_valid)]
    
    keypoints = deepcopy(row.keypoints)
    left_keypoints, right_keypoints = keypoints['leftCrop'], keypoints['rightCrop']
    left_item = {
        'keypointType': 'BODY',
        'xFrame': X_body[:, 0],
        'yFrame': X_body[:, 1]
    }
    
    right_item = {
        'keypointType': 'BODY',
        'xFrame': X_body[:, 2],
        'yFrame': X_body[:, 3]
    }
    
    left_keypoints.append(left_item)
    right_keypoints.append(right_item)
    modified_keypoints = {
        'leftCrop': left_keypoints,
        'rightCrop': right_keypoints
    }

    modified_keypoints_list.append(modified_keypoints)

df['old_keypoints'] = df.keypoints
df['keypoints'] = modified_keypoints_list

In [None]:
gtsf_fish_identifiers = list(df.fish_id.unique())
train_size = int(0.8 * len(gtsf_fish_identifiers))
fish_ids = random.sample(gtsf_fish_identifiers, train_size)
date_mask = (df.captured_at < '2019-09-10')
train_mask = date_mask & df.fish_id.isin(fish_ids)
test_mask = date_mask & ~df.fish_id.isin(fish_ids)

In [None]:
train_dataset = KeypointsDataset(df[train_mask], transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=0.3, hi=2.0, jitter=10),
                                                  WorldKeypointTransform(),
                                                  PrismTransform(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=1)

test_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=0.3, hi=2.0, jitter=10),
                                                  WorldKeypointTransform(),
                                                  PrismTransform(),
                                                  ToTensor()
                                              ]))

test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
# TODO: Define your network architecture here
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
        



In [None]:
write_outputs = False

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=0.00001)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
#     # run on test set
#     else:
#         test_running_loss = 0.0
#         with torch.no_grad():
#             network.eval()
#             for i, data_batch in enumerate(test_dataloader):
#                 X_batch, y_batch, kpid_batch = \
#                     data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
#                 y_pred = network(X_batch)
#                 loss = criterion(y_pred, y_batch)
#                 test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
#     test_loss_for_epoch = test_running_loss / len(test_dataloader)
#     train_losses.append(train_loss_for_epoch)
#     test_losses.append(test_loss_for_epoch)
    
#     # save current state of network
#     if write_outputs:
#         f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
#         f_path = os.path.join(output_dir, f_name)
#         torch.save(network, f_path)
    
#     # print current loss values
#     print('-'*20)
#     print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
#     print('Test Loss: {}'.format(test_loss_for_epoch))
    
    
