In [None]:
import datetime as dt
import json, os
import pandas as pd
from matplotlib import pyplot as plt
from collections import defaultdict
import numpy as np
from itertools import combinations
from aquabyte.accuracy_metrics import AccuracyMetricsGenerator
from aquabyte.data_access_utils import S3AccessUtils, RDSAccessUtils
from aquabyte.visualize import Visualizer, _normalize_world_keypoints
from aquabyte.optics import euclidean_distance, pixel2world, depth_from_disp, convert_to_world_point

from aquabyte.data_loader import KeypointsDataset, NormalizeCentered2D, ToTensor, BODY_PARTS
from aquabyte.biomass_estimator import NormalizeCentered2D, NormalizedStabilityTransform, ToTensor, Network

import random
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from sklearn.model_selection import train_test_split
from copy import copy

pd.set_option('display.max_rows', 500)

<h1> Prepare GTSF dataset </h1>

In [None]:
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
query = """
    select * from research.fish_metadata a left join keypoint_annotations b
    on a.left_url = b.left_image_url 
    where b.keypoints -> 'leftCrop' is not null
    and b.keypoints -> 'rightCrop' is not null
    and b.is_qa = false 
    and b.captured_at < '2019-09-19';
"""
df = rds_access_utils.extract_from_database(query)

In [None]:
blacklisted_keypoint_annotation_ids = [
    606484, 
    635806, 
    637801, 
    508773, 
    640493, 
    639409, 
    648536, 
    507003,
    706002,
    507000,
    709298,
    714073,
    719239
]

df = df[~df.id.isin(blacklisted_keypoint_annotation_ids)]

def get_world_keypoints(row):
    if 'leftCrop' in row.keypoints and 'rightCrop' in row.keypoints:
        return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)
    else:
        return None
    
def is_well_behaved(wkps, cutoff_depth=10.0):
    if any([abs(wkp[1]) > cutoff_depth for wkp in wkps.values()]):
        return False
    return True

df['world_keypoints'] = df.apply(
    lambda x: get_world_keypoints(x), axis=1
)

is_well_behaved_mask = df.world_keypoints.apply(lambda x: is_well_behaved(x))
df = df[is_well_behaved_mask]

In [None]:
gtsf_fish_identifiers = list(df.fish_id.unique())
train_size = int(0.8 * len(gtsf_fish_identifiers))
fish_ids = random.sample(gtsf_fish_identifiers, train_size)
date_mask = (df.captured_at < '2019-09-10')
train_mask = date_mask & df.fish_id.isin(fish_ids)
test_mask = date_mask & ~df.fish_id.isin(fish_ids)

In [None]:
train_dataset = KeypointsDataset(df[train_mask], transform=transforms.Compose([
                                                  NormalizeCentered2D(lo=0.3, hi=2.0, jitter=0),
                                                  NormalizedStabilityTransform(),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
test_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                      NormalizeCentered2D(lo=0.3, hi=2.0, jitter=0),
                                                      NormalizedStabilityTransform(),
                                                      ToTensor()
                                                  ]))

test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
# TODO: Define your network architecture here
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(24, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
        



In [None]:
run_name = 'batch_25_rescaling_no_jitter_v1'
write_outputs = True

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, run_name, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
    # run on test set
    else:
        test_running_loss = 0.0
        with torch.no_grad():
            network.eval()
            for i, data_batch in enumerate(test_dataloader):
                X_batch, y_batch, kpid_batch = \
                    data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
                y_pred = network(X_batch)
                loss = criterion(y_pred, y_batch)
                test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
    test_loss_for_epoch = test_running_loss / len(test_dataloader)
    train_losses.append(train_loss_for_epoch)
    test_losses.append(test_loss_for_epoch)
    
    # save current state of network
    if write_outputs:
        f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
        f_path = os.path.join(output_dir, f_name)
        torch.save(network, f_path)
    
    # print current loss values
    print('-'*20)
    print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
    print('Test Loss: {}'.format(test_loss_for_epoch))
    
    


In [None]:
plt.figure(figsize=(20, 10))
plt.plot(range(len(train_losses)), train_losses, color='blue', label='training loss')
plt.plot(range(len(test_losses)), test_losses, color='orange', label='validation loss')
plt.ylim([0, 0.01])
plt.xlabel('Epoch')
plt.ylabel('Loss value (MSE)')
plt.title('Loss curves (MSE - Adam optimizer)')
plt.legend()
plt.grid()
plt.show()

In [None]:
np.min(test_losses)

In [None]:
oos_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                      NormalizeCentered2D(lo=0.3, hi=2.0, jitter=10),
                                                      NormalizedStabilityTransform(),
                                                      ToTensor()
                                                  ]))

oos_dataloader = DataLoader(oos_dataset, batch_size=25, shuffle=True, num_workers=1)

In [None]:
test_running_loss = 0.0
with torch.no_grad():
    for i, data_batch in enumerate(test_dataloader):
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        test_running_loss += loss.item()

test_loss_for_epoch = test_running_loss / len(test_dataloader)
# print current loss values
print('Test Loss: {}'.format(test_loss_for_epoch))



In [None]:
oos_running_loss = 0.0
with torch.no_grad():
    for i, data_batch in enumerate(oos_dataloader):
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        oos_running_loss += loss.item()

oos_loss_for_epoch = oos_running_loss / len(oos_dataloader)
# print current loss values
print('Test Loss: {}'.format(oos_loss_for_epoch))

