In [None]:
import datetime as dt
import os, random

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from research.gtsf_data.gtsf_dataset import GTSFDataset, BODY_PARTS


<h1> Load GTSF data </h1>

In [None]:
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
gtsf_dataset = GTSFDataset('2019-02-01', '2020-03-30', akpd_scorer_url)
df = gtsf_dataset.get_prepared_dataset()

<h1> Define Augmentation Classes </h1>

In [None]:
def jitter_wkps(wkps, cm, base_jitter):
    wkps_jittered = []
    for idx in range(len(BODY_PARTS)):
        x_p_left = wkps[idx, 0] * cm['focalLengthPixel'] / wkps[idx, 1]
        y_p_left = wkps[idx, 2] * cm['focalLengthPixel'] / wkps[idx, 1]
        disparity = cm['focalLengthPixel'] * cm['baseline'] / wkps[idx, 1]
        x_p_left_jitter = np.random.normal(0, base_jitter)
        x_p_right_jitter = np.random.normal(0, base_jitter)
        disparity_jitter = x_p_left_jitter + x_p_right_jitter
        
        x_p_left_jittered = x_p_left + x_p_left_jitter
        disparity_jittered = disparity + disparity_jitter
        depth_jittered = cm['focalLengthPixel'] * cm['baseline'] / disparity_jittered
        x_jittered = x_p_left_jittered * depth_jittered / cm['focalLengthPixel']
        y_jittered = wkps[idx, 2]
        wkp_jittered = [x_jittered, depth_jittered, y_jittered]
        wkps_jittered.append(wkp_jittered)
    if wkps.shape[0] == len(BODY_PARTS) + 1:
        wkps_jittered.append(wkps[len(BODY_PARTS), :].tolist())
    wkps_jittered = np.array(wkps_jittered)
    return wkps_jittered


def get_jittered_keypoints(wkps, cm, base_jitter=5):    
    # put at random depth and apply jitter
    depth = np.random.uniform(low=0.5, high=2.0)
    wkps[:, 1] = wkps[:, 1] + depth
    
    # apply jitter
    jittered_wkps = jitter_wkps(wkps, cm, base_jitter)
    return jittered_wkps
    
#     # normalize
#     final_wkps = np.column_stack([0.5 * jittered_wkps[:, 0] / jittered_wkps[:, 1], 
#                             0.5 * jittered_wkps[:, 2] / jittered_wkps[:, 1], 
#                             0.05 / jittered_wkps[:, 1]])
#     return final_wkps

    

class KeypointsDataset(Dataset):
    """Keypoints dataset
    This is the base version of the dataset that is used to map 3D keypoints to a
    biomass estimate. The label is the weight, and the input is the 3D workd keypoints
    obtained during triangulation
    """

    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform


    def __len__(self):
        return self.df.shape[0]


    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.transform:
            input_sample = {
                'kp_input': row.keypoint_arr,
                'cm': row.camera_metadata,
                'stereo_pair_id': row.id,
            }
            if 'weight' in dict(row).keys():
                input_sample['label'] = row.weight
            sample = self.transform(input_sample)
            return sample

        world_keypoints = row.world_keypoints
        weight = row.weight

        sample = {'kp_input': world_keypoints, 'label': weight, 'stereo_pair_id': row.id}

        return sample

class NormalizedCentered3D(object):
    
    def __init__(self, base_jitter):
        self.base_jitter = base_jitter

    def __call__(self, sample):
        keypoint_arr, cm, stereo_pair_id, label = \
            sample['kp_input'], sample['cm'], sample.get('stereo_pair_id'), sample.get('label')
    
        jittered_wkps = get_jittered_keypoints(keypoint_arr, cm, base_jitter=self.base_jitter)
        normalized_label = label * 1e-4
        
        transformed_sample = {
            'kp_input': jittered_wkps,
            'label': normalized_label,
            'stereo_pair_id': stereo_pair_id,
            'cm': cm,
            'single_point_inference': sample.get('single_point_inference')
        }

        return transformed_sample
    
class ToTensor(object):
    
    def __call__(self, sample):
        x, label, stereo_pair_id = \
            sample['kp_input'], sample.get('label'), sample.get('stereo_pair_id')
        
        if sample.get('single_point_inference'):
            x = np.array([x])
        else:
            x = np.array(x)
        
        kp_input_tensor = torch.from_numpy(x).float()
        
        tensorized_sample = {
            'kp_input': kp_input_tensor
        }

        if label:
            label_tensor = torch.from_numpy(np.array([label])).float() if label else None
            tensorized_sample['label'] = label_tensor

        if stereo_pair_id:
            tensorized_sample['stereo_pair_id'] = stereo_pair_id

        
        return tensorized_sample
        

<h1> Define train and test data loaders </h1>

In [None]:
gtsf_fish_identifiers = list(df.fish_id.unique())
train_size = int(0.8 * len(gtsf_fish_identifiers))
fish_ids = random.sample(gtsf_fish_identifiers, train_size)
filter_mask = df.median_depth < 1.0
train_mask = filter_mask & df.fish_id.isin(fish_ids)
test_mask = filter_mask & ~df.fish_id.isin(fish_ids)

train_dataset = KeypointsDataset(df[train_mask], transform=transforms.Compose([
                                                  NormalizedCentered3D(10),
                                                  ToTensor()
                                              ]))

train_dataloader = DataLoader(train_dataset, batch_size=25, shuffle=True, num_workers=1)

test_dataset = KeypointsDataset(df[test_mask], transform=transforms.Compose([
                                                      NormalizedCentered3D(10),
                                                      ToTensor()
                                                  ]))

test_dataloader = DataLoader(test_dataset, batch_size=25, shuffle=True, num_workers=1)

<h1> Visual Check for Correctness </h1>

In [None]:
%matplotlib inline

count = 0
for data in train_dataloader:
    new_wkps = data['kp_input']
    count +=1 
    if count > 2:
        break

plt.figure(figsize=(20, 10))
plt.scatter(new_wkps[0][:, 0], new_wkps[0][:, 2], color='red')

stereo_pair_id = data['stereo_pair_id'][0].item()
keypoint_arr = df[df.id == stereo_pair_id].keypoint_arr.iloc[0]
plt.scatter(keypoint_arr[:, 0], keypoint_arr[:, 2], color='blue')
plt.show()

In [None]:
# TODO: Define your network architecture here
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(27, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.output(x)
        return x
        


In [None]:
run_name = 'batch_25_jitter_10_body_kp_lr_1e-4_v1'
write_outputs = True

# establish output directory where model .pb files will be written
if write_outputs:
    dt_now = dt.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
    output_base = '/root/data/alok/biomass_estimation/results/neural_network'
    output_dir = os.path.join(output_base, run_name, dt_now)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

# instantiate neural network
network = Network()
epochs = 1000
optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

# track train and test losses
train_losses, test_losses = [], []

seed = 0
for epoch in range(epochs):
    network.train()
    np.random.seed(seed)
    seed += 1
    running_loss = 0.0
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:
            print(running_loss / i)
            
    # run on test set
    else:
        test_running_loss = 0.0
        with torch.no_grad():
            network.eval()
            for i, data_batch in enumerate(test_dataloader):
                X_batch, y_batch, kpid_batch = \
                    data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
                y_pred = network(X_batch)
                loss = criterion(y_pred, y_batch)
                test_running_loss += loss.item()

    train_loss_for_epoch = running_loss / len(train_dataloader)
    test_loss_for_epoch = test_running_loss / len(test_dataloader)
    train_losses.append(train_loss_for_epoch)
    test_losses.append(test_loss_for_epoch)
    
    # save current state of network
    if write_outputs:
        f_name = 'nn_epoch_{}.pb'.format(str(epoch).zfill(3))
        f_path = os.path.join(output_dir, f_name)
        torch.save(network, f_path)
    
    # print current loss values
    print('-'*20)
    print('Epoch: {}'.format(epoch))
    print('Train Loss: {}'.format(train_loss_for_epoch))
    print('Test Loss: {}'.format(test_loss_for_epoch))
    
    

<h1> Show Train / Test Results </h1>

In [None]:
plt.figure(figsize=(20, 10))
plt.plot(range(len(train_losses)), train_losses, color='blue', label='training loss')
plt.plot(range(len(test_losses)), test_losses, color='orange', label='validation loss')
plt.ylim([0, 0.01])
plt.xlabel('Epoch')
plt.ylabel('Loss value (MSE)')
plt.title('Loss curves (MSE - Adam optimizer)')
plt.legend()
plt.grid()
plt.show()

In [None]:
print('Minimum test loss: {}; Occurred at epoch {}'.format(np.min(test_losses) np.argmin(test_losses))

In [None]:
with torch.no_grad():
    network.eval()
    y_preds_train, y_gt_train, y_preds_test, y_gt_test = [], [], [], []
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        y_preds_train.extend(list(y_pred.numpy().flatten()))
        y_gt_train.extend(list(y_batch.numpy().flatten()))
    
    for i, data_batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X_batch, y_batch, kpid_batch = \
            data_batch['kp_input'], data_batch['label'], data_batch['stereo_pair_id']
        y_pred = network(X_batch)
        y_preds_test.extend(list(y_pred.numpy().flatten()))
        y_gt_test.extend(list(y_batch.numpy().flatten()))

y_preds_test = 1e4 * np.array(y_preds_test)
y_gt_test = 1e4 * np.array(y_gt_test)
print('Median absolute error percentage: {}'.format(np.median(np.abs((y_preds_test - y_gt_test) / y_gt_test)))) 

In [None]:
plt.scatter(y_gt_test, y_preds_test)
plt.grid()
plt.show()