In [None]:
import datetime as dt
import json, os
import pandas as pd
from matplotlib import pyplot as plt
from collections import defaultdict
import numpy as np
from itertools import combinations
from aquabyte.accuracy_metrics import AccuracyMetricsGenerator
from aquabyte.data_access_utils import S3AccessUtils, RDSAccessUtils
from aquabyte.visualize import Visualizer, _normalize_world_keypoints
from aquabyte.optics import euclidean_distance, pixel2world, depth_from_disp, convert_to_world_point

from aquabyte.data_loader import KeypointsDataset, NormalizeCentered2D, ToTensor, BODY_PARTS
from aquabyte.biomass_estimator import NormalizeCentered2D, NormalizedStabilityTransform, ToTensor, Network

import random
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from sklearn.model_selection import train_test_split
from copy import copy

pd.set_option('display.max_rows', 500)

<h1> Prepare GTSF dataset </h1>

In [None]:
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
query = """
    select * from research.fish_metadata a left join keypoint_annotations b
    on a.left_url = b.left_image_url 
    where b.keypoints -> 'leftCrop' is not null
    and b.keypoints -> 'rightCrop' is not null
    and b.is_qa = false 
    and b.captured_at < '2019-09-19';
"""
df = rds_access_utils.extract_from_database(query)

In [None]:
blacklisted_keypoint_annotation_ids = [
    606484, 
    635806, 
    637801, 
    508773, 
    640493, 
    639409, 
    648536, 
    507003,
    706002,
    507000,
    709298,
    714073,
    719239
]

df = df[~df.id.isin(blacklisted_keypoint_annotation_ids)]

def get_world_keypoints(row):
    if 'leftCrop' in row.keypoints and 'rightCrop' in row.keypoints:
        return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)
    else:
        return None
    
def is_well_behaved(wkps, cutoff_depth=10.0):
    if any([abs(wkp[1]) > cutoff_depth for wkp in wkps.values()]):
        return False
    return True

df['world_keypoints'] = df.apply(
    lambda x: get_world_keypoints(x), axis=1
)

is_well_behaved_mask = df.world_keypoints.apply(lambda x: is_well_behaved(x))
df = df[is_well_behaved_mask]

In [None]:
gtsf_fish_identifiers = list(df.fish_id.unique())
train_size = int(0.8 * len(gtsf_fish_identifiers))
fish_ids = random.sample(gtsf_fish_identifiers, train_size)
date_mask = (df.captured_at < '2019-09-10')
train_mask = date_mask & df.fish_id.isin(fish_ids)
test_mask = date_mask & ~df.fish_id.isin(fish_ids)

In [None]:
class NormalizeCentered2D(object):
    
    """
    Transforms the 2D left and right keypoints such that:
        (1) The center of the left image 2D keypoints is located at the center of the left image
            (i.e. 2D translation)
        (2) The left image keypoints are possibly flipped such that the upper-lip x-coordinate 
            is greater than the tail-notch coordinate. This is done to reduce the total number of 
            spatial orientations the network must learn from -> reduces the training size
        (3) The left image keypoints are then rotated such that upper-lip is located on the x-axis.
            As in (2), this is done to reduce the total number of spatial orientations the network 
            must learn from -> reduces the training size
        (4) Rescale all left image keypoints by some random number between 'lo' and 'hi' args
        (5) Apply Gaussian random noise "jitter" to each keypoint to mimic annotation error
        (5) For all transformations above, the right image keypoint coordinates are accordingly
            transformed such that the original disparity values are preserved for all keypoints
            (or adjusted during rescaling event)
    """


    def flip_center_kps(self, left_kps, right_kps):

        x_min_l = min([kp[0] for kp in left_kps.values()])
        x_max_l = max([kp[0] for kp in left_kps.values()])
        x_mid_l = np.mean([x_min_l, x_max_l])

        y_min_l = min([kp[1] for kp in left_kps.values()])
        y_max_l = max([kp[1] for kp in left_kps.values()])
        y_mid_l = np.mean([y_min_l, y_max_l])

        x_min_r = min([kp[0] for kp in right_kps.values()])
        x_max_r = max([kp[0] for kp in right_kps.values()])
        x_mid_r = np.mean([x_min_r, x_max_r])

        y_min_r = min([kp[1] for kp in right_kps.values()])
        y_max_r = max([kp[1] for kp in right_kps.values()])
        y_mid_r = np.mean([y_min_r, y_max_r])

        fc_left_kps, fc_right_kps = {}, {}
        flip_factor = 1 if left_kps['UPPER_LIP'][0] > left_kps['TAIL_NOTCH'][0] else -1
        for bp in BODY_PARTS:
            left_kp, right_kp = left_kps[bp], right_kps[bp]
            if flip_factor > 0:
                fc_left_kp = np.array([left_kp[0] - x_mid_l, left_kp[1] - y_mid_l])
                fc_right_kp = np.array([right_kp[0] - x_mid_l, right_kp[1] - y_mid_l])
            else:
                fc_right_kp = np.array([x_mid_r - left_kp[0], left_kp[1] - y_mid_r])
                fc_left_kp = np.array([x_mid_r - right_kp[0], right_kp[1] - y_mid_r])
            fc_left_kps[bp] = fc_left_kp
            fc_right_kps[bp] = fc_right_kp

        return fc_left_kps, fc_right_kps


    def _rotate_cc(self, p, theta):
        R = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])

        rotated_kp = np.dot(R, p)
        return rotated_kp


    def rotate_kps(self, left_kps, right_kps):
        upper_lip_x, upper_lip_y = left_kps['UPPER_LIP']
        theta = np.arctan(upper_lip_y / upper_lip_x)
        r_left_kps, r_right_kps = {}, {}
        for bp in BODY_PARTS:
            rotated_kp = self._rotate_cc(left_kps[bp], -theta)
            r_left_kps[bp] = rotated_kp
            disp = abs(left_kps[bp][0] - right_kps[bp][0])
            r_right_kps[bp] = np.array([rotated_kp[0] - disp, rotated_kp[1]])

        return r_left_kps, r_right_kps


    def scale_kps(self, left_kps, right_kps, factor):
        s_left_kps, s_right_kps = {}, {}
        for bp in BODY_PARTS:
            left_kp, right_kp = left_kps[bp], right_kps[bp]
            s_left_kps[bp] = factor * np.array(left_kps[bp])
            s_right_kps[bp] = factor * np.array(right_kps[bp])

        return s_left_kps, s_right_kps


    def jitter_kps(self, left_kps, right_kps, jitter):
        j_left_kps, j_right_kps = {}, {}
        for bp in BODY_PARTS:
            j_left_kps[bp] = np.array([left_kps[bp][0] + np.random.normal(0, jitter), 
                                       left_kps[bp][1] + np.random.normal(0, jitter)])
            j_right_kps[bp] = np.array([right_kps[bp][0] + np.random.normal(0, jitter), 
                                        right_kps[bp][1] + np.random.normal(0, jitter)])

        return j_left_kps, j_right_kps



    def modify_kps(self, left_kps, right_kps, factor, jitter, cm, rotate=True, center=False):
        fc_left_kps, fc_right_kps = self.flip_center_kps(left_kps, right_kps)
        if rotate:
            r_left_kps, r_right_kps = self.rotate_kps(fc_left_kps, fc_right_kps)
            s_left_kps, s_right_kps = self.scale_kps(r_left_kps, r_right_kps, factor)
        else:
            s_left_kps, s_right_kps = self.scale_kps(fc_left_kps, fc_right_kps, factor)
        j_left_kps, j_right_kps  = self.jitter_kps(s_left_kps, s_right_kps, jitter)
        j_left_kps_list, j_right_kps_list = [], []
        if not center:
            for bp in BODY_PARTS:
                l_item = {
                    'keypointType': bp,
                    'xFrame': j_left_kps[bp][0] + cm['pixelCountWidth'] / 2.0,
                    'yFrame': j_left_kps[bp][1] + cm['pixelCountHeight'] / 2.0
                }

                r_item = {
                    'keypointType': bp,
                    'xFrame': j_right_kps[bp][0] + cm['pixelCountWidth'] / 2.0,
                    'yFrame': j_right_kps[bp][1] + cm['pixelCountHeight'] / 2.0
                }

                j_left_kps_list.append(l_item)
                j_right_kps_list.append(r_item)
        else:
            for bp in BODY_PARTS:
                l_item = {
                    'keypointType': bp,
                    'xFrame': j_left_kps[bp][0],
                    'yFrame': j_left_kps[bp][1]
                }

                r_item = {
                    'keypointType': bp,
                    'xFrame': j_right_kps[bp][0],
                    'yFrame': j_right_kps[bp][1]
                }

                j_left_kps_list.append(l_item)
                j_right_kps_list.append(r_item)


        modified_kps = {
            'leftCrop': j_left_kps_list,
            'rightCrop': j_right_kps_list
        }

        return modified_kps

    
    def __init__(self, lo=None, hi=None, jitter=0.0, rotate=True, center=False):
        self.lo = lo
        self.hi = hi
        self.jitter = jitter
        self.rotate = rotate
        self.center = center
    

    def __call__(self, sample):
        keypoints, cm, stereo_pair_id, label = \
            sample['keypoints'], sample['cm'], sample.get('stereo_pair_id'), sample.get('label')
        left_keypoints_list = keypoints['leftCrop']
        right_keypoints_list = keypoints['rightCrop']
        left_kps = {item['keypointType']: np.array([item['xFrame'], item['yFrame']]) for item in left_keypoints_list}
        right_kps = {item['keypointType']: np.array([item['xFrame'], item['yFrame']]) for item in right_keypoints_list}
        
        factor = 1.0 
        if self.lo and self.hi:
            factor = np.random.uniform(low=self.lo, high=self.hi)
            
        jitter = np.random.uniform(high=self.jitter) * (1.0 / np.random.uniform(low=0.3, high=2.0))
        
        modified_kps = self.modify_kps(left_kps, right_kps, factor, jitter, cm, 
            rotate=self.rotate, center=self.center)

        kp_input = {}
        for idx, _ in enumerate(modified_kps['leftCrop']):
            left_item, right_item = modified_kps['leftCrop'][idx], modified_kps['rightCrop'][idx]
            bp = left_item['keypointType']
            kp_input[bp] = [left_item['xFrame'], left_item['yFrame'], right_item['xFrame'], right_item['yFrame']]


        transformed_sample = {
            'kp_input': kp_input,
            'modified_kps': modified_kps,
            'label': label,
            'stereo_pair_id': stereo_pair_id,
            'cm': cm,
            'single_point_inference': sample.get('single_point_inference')
        }
        
        return transformed_sample

In [None]:
train_dataset = KeypointsDataset(df[train_mask].head(10), transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=1.0, hi=1.0, jitter=0),
                                                  NormalizedStabilityTransform(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=1)

In [None]:
test_dataset = KeypointsDataset(df[test_mask].head(20), transform=transforms.Compose([
                                                      NormalizeCentered2D(lo=1.0, hi=1.0, jitter=0),
                                                      NormalizedStabilityTransform(),
                                                      ToTensor()
                                                  ]))

test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=True, num_workers=1)

In [None]:
# TODO: Define your network architecture here
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128, momentum=0.2)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
        



In [None]:
run_name = 'batch_25_no_rescaling_higher_bn_lr_v1'
write_outputs = False

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, run_name, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
    # run on test set
    else:
        test_running_loss = 0.0
        with torch.no_grad():
            network.eval()
            for i, data_batch in enumerate(test_dataloader):
                X_batch, y_batch, kpid_batch = \
                    data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
                y_pred = network(X_batch)
                loss = criterion(y_pred, y_batch)
                test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
    test_loss_for_epoch = test_running_loss / len(test_dataloader)
    train_losses.append(train_loss_for_epoch)
    test_losses.append(test_loss_for_epoch)
    
    # save current state of network
    if write_outputs:
        f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
        f_path = os.path.join(output_dir, f_name)
        torch.save(network, f_path)
    
    # print current loss values
    print('-'*20)
    print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
    print('Test Loss: {}'.format(test_loss_for_epoch))
    
    


In [None]:
plt.figure(figsize=(20, 10))
plt.plot(range(len(train_losses)), train_losses, color='blue', label='training loss')
plt.plot(range(len(test_losses)), test_losses, color='orange', label='validation loss')
plt.ylim([0, 0.01])
plt.xlabel('Epoch')
plt.ylabel('Loss value (MSE)')
plt.title('Loss curves (MSE - Adam optimizer)')
plt.legend()
plt.grid()
plt.show()

In [None]:
np.min(train_losses)

In [None]:
oos_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                      NormalizeCentered2D(lo=1.0, hi=1.0, jitter=10),
                                                      NormalizedStabilityTransform(),
                                                      ToTensor()
                                                  ]))

oos_dataloader = DataLoader(oos_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
test_running_loss = 0.0
with torch.no_grad():
    for i, data_batch in enumerate(test_dataloader):
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        test_running_loss += loss.item()

test_loss_for_epoch = test_running_loss / len(test_dataloader)
# print current loss values
print('Test Loss: {}'.format(test_loss_for_epoch))



In [None]:
oos_running_loss = 0.0
with torch.no_grad():
    for i, data_batch in enumerate(oos_dataloader):
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        oos_running_loss += loss.item()

oos_loss_for_epoch = oos_running_loss / len(oos_dataloader)
# print current loss values
print('Test Loss: {}'.format(oos_loss_for_epoch))



In [None]:
for data in train_dataset:
    X = data['kp_input'].numpy()
    plt.scatter(X[:, 0], X[:, 2])
    plt.show()
    break

In [None]:
train_dataset = KeypointsDataset(df[train_mask].head(10), transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=1.0, hi=1.0, jitter=0),
                                                  NormalizedStabilityTransform(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=1)

In [None]:
write_outputs = False

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
#     # run on test set
#     else:
#         test_running_loss = 0.0
#         with torch.no_grad():
#             network.eval()
#             for i, data_batch in enumerate(test_dataloader):
#                 X_batch, y_batch, kpid_batch = \
#                     data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
#                 y_pred = network(X_batch)
#                 loss = criterion(y_pred, y_batch)
#                 test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
#     test_loss_for_epoch = test_running_loss / len(test_dataloader)
#     train_losses.append(train_loss_for_epoch)
#     test_losses.append(test_loss_for_epoch)
    
#     # save current state of network
#     if write_outputs:
#         f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
#         f_path = os.path.join(output_dir, f_name)
#         torch.save(network, f_path)
    
#     # print current loss values
#     print('-'*20)
#     print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
#     print('Test Loss: {}'.format(test_loss_for_epoch))
    
    
