In [None]:
import datetime as dt
import json, os
import pandas as pd
from matplotlib import pyplot as plt
from collections import defaultdict
import numpy as np
from itertools import combinations
from research.weight_estimation.akpd_scorer import generate_confidence_score
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.weight_estimation.visualize import Visualizer, _normalize_world_keypoints
from research.weight_estimation.optics import euclidean_distance, pixel2world, depth_from_disp, convert_to_world_point
from research.weight_estimation.biomass_estimator import NormalizeCentered2D, NormalizedStabilityTransform, ToTensor, Network
from keras.models import load_model

import random
import torch
from research.weight_estimation.data_loader import KeypointsDataset, NormalizeCentered2D, ToTensor, BODY_PARTS
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from sklearn.model_selection import train_test_split
from copy import copy, deepcopy
from scipy.spatial import Delaunay
from mpl_toolkits.mplot3d import Axes3D

pd.set_option('display.max_rows', 500)

In [None]:
from copy import copy, deepcopy
import json, os, random
import numpy as np
import pandas as pd

from keras.models import load_model

from research.weight_estimation.optics import pixel2world
from research.weight_estimation.akpd_scorer import generate_confidence_score
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.gtsf_data.body_parts import BodyParts
import pyarrow.parquet as pq
from scipy.spatial import Delaunay

BODY_PARTS = BodyParts().get_core_body_parts()

class GTSFDataset(object):

    def __init__(self, start_date, end_date, akpd_scorer_url):
        self.s3_access_utils = S3AccessUtils('/root/data')
        self.df = self.generate_raw_df(start_date, end_date)
        self.prepare_df(akpd_scorer_url)

    @staticmethod
    def generate_raw_df(start_date, end_date):
        rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
        query = f"""
            select * from research.fish_metadata a left join keypoint_annotations b
            on a.left_url = b.left_image_url 
            where b.keypoints -> 'leftCrop' is not null
            and b.keypoints -> 'rightCrop' is not null
            and b.captured_at between '{start_date}' and '{end_date}';
        """
        df = rds_access_utils.extract_from_database(query)
        print('Raw dataframe loaded!')
        return df

    @staticmethod
    def get_world_keypoints(row):
        return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)

    def prepare_df(self, akpd_scorer_path, add_template_matching_keypoints=True):
        # use QA'ed entries, and only use Cogito entries when QA data is unavailable
        qa_df = self.df[self.df.is_qa == True]
        cogito_df = self.df[(self.df.is_qa != True) & ~(self.df.left_image_url.isin(qa_df.left_image_url))]
        self.df = pd.concat([qa_df, cogito_df], axis=0)
        print('Dataset preparation beginning...')

        # add 3D spatial information
        self.df['world_keypoints'] = self.df.apply(lambda x: self.get_world_keypoints(x), axis=1)
        self.df['median_depth'] = self.df.world_keypoints.apply(lambda x: np.median([wkp[1] for wkp in x.values()]))
        print('3D spatial information added!')
        
        self.add_akpd_scores(akpd_scorer_path)
        if add_template_matching_keypoints:
            self.add_template_matching_keypoints()
        self.convert_wkps_to_matrix_form()

    
    @staticmethod
    def in_hull(p, hull):
        hull = Delaunay(hull)
        return hull.find_simplex(p)>=0


    def add_template_matching_keypoints(self):
        print('Adding template matching body keypoints...')

        # load data
        gen = self.s3_access_utils.get_matching_s3_keys(
            'aquabyte-research', 
            prefix='template-matching/2019-12-05T02:50:57', 
            suffixes=['.parquet']
        )

        keys = [key for key in gen]
        f = self.s3_access_utils.download_from_s3('aquabyte-research', keys[0])
        pdf = pd.read_parquet(f)
        pdf['homography'] = pdf.homography_and_matches.apply(lambda x: np.array(x[0].tolist(), dtype=np.float))
        pdf['matches'] = pdf.homography_and_matches.apply(lambda x: np.array(x[1].tolist(), dtype=np.int) if len(x) > 1 else None)

        # merge with existing dataframe
        self.df = pd.merge(self.df, pdf[['left_image_url', 'homography', 'matches']], how='inner', on='left_image_url')

        # generate list of modified keypoints
        modified_keypoints_list = []
        count = 0
        for idx, row in self.df.iterrows():
            if count % 100 == 0:
                print(count)
            count += 1
            X_keypoints = np.array([[item['xFrame'], item['yFrame']] for item in row.keypoints['leftCrop']])
            X_body = np.array(row.matches)
            is_valid = self.in_hull(X_body[:, :2], X_keypoints)
            X_body = X_body[np.where(is_valid)]
            
            keypoints = deepcopy(row.keypoints)
            left_keypoints, right_keypoints = keypoints['leftCrop'], keypoints['rightCrop']
            left_item = {'keypointType': 'BODY', 'xFrame': X_body[:, 0], 'yFrame': X_body[:, 1]}
            right_item = {'keypointType': 'BODY', 'xFrame': X_body[:, 2], 'yFrame': X_body[:, 3]}
            
            left_keypoints.append(left_item)
            right_keypoints.append(right_item)
            modified_keypoints = {'leftCrop': left_keypoints, 'rightCrop': right_keypoints}
            modified_keypoints_list.append(modified_keypoints)

        # add modified keypoints information to dataframe
        self.df['old_keypoints'] = self.df.keypoints
        self.df['keypoints'] = modified_keypoints_list
        self.df = self.df[self.df.keypoints.apply(lambda x: x['leftCrop'][-1]['xFrame'].shape[0]) > 500]


    def add_akpd_scores(self, akpd_scorer_url):
        print('Adding AKPD scores...')
        # load neural network weights
        akpd_scorer_path, _, _ = self.s3_access_utils.download_from_url(akpd_scorer_url)
        akpd_scorer_network = load_model(akpd_scorer_path)

        akpd_scores = []
        for idx, row in self.df.iterrows():
            input_sample = {
                'keypoints': row.keypoints,
                'cm': row.camera_metadata,
                'stereo_pair_id': row.id,
                'single_point_inference': True
            }
            akpd_score = generate_confidence_score(input_sample, akpd_scorer_network)
            akpd_scores.append(akpd_score)
        self.df['akpd_score'] = akpd_scores

    def convert_wkps_to_matrix_form(self):
        print('Converting world keypoints to matrix form...')
        keypoint_arr_list = []
        for idx, row in self.df.iterrows():
            keypoints, cm = row.keypoints, row.camera_metadata
            keypoint_arr = self.get_keypoint_arr(keypoints, cm)
            keypoint_arr_list.append(keypoint_arr)
        self.df['keypoint_arr'] = keypoint_arr_list

    @staticmethod
    def get_raw_3D_coordinates(keypoints, cm):
        wkps = pixel2world([item for item in keypoints['leftCrop'] if item['keypointType'] != 'BODY'],
                           [item for item in keypoints['rightCrop'] if item['keypointType'] != 'BODY'],
                           cm)
        all_wkps = [list(wkps[bp]) for bp in BODY_PARTS]

        # compute BODY world keypoint coordinates
        if 'BODY' in [item['keypointType'] for item in keypoints['leftCrop']]:
            left_item = [item for item in keypoints['leftCrop'] if item['keypointType'] == 'BODY'][0]
            right_item = [item for item in keypoints['rightCrop'] if item['keypointType'] == 'BODY'][0]
            disps = np.abs(left_item['xFrame'] - right_item['xFrame'])
            focal_length_pixel = cm["focalLengthPixel"]
            baseline = cm["baseline"]
            depths = focal_length_pixel * baseline / np.array(disps)

            pixel_count_width = cm["pixelCountWidth"]
            pixel_count_height = cm["pixelCountHeight"]
            sensor_width = cm["imageSensorWidth"]
            sensor_height = cm["imageSensorHeight"]
            focal_length = cm["focalLength"]

            image_center_x = pixel_count_width / 2.0
            image_center_y = pixel_count_height / 2.0
            x = left_item['xFrame']
            y = left_item['yFrame']
            px_x = x - image_center_x
            px_z = image_center_y - y

            sensor_x = px_x * (sensor_width / pixel_count_width)
            sensor_z = px_z * (sensor_height / pixel_count_height)

            world_y = depths
            world_x = (world_y * sensor_x) / focal_length
            world_z = (world_y * sensor_z) / focal_length
            wkps['BODY'] = np.column_stack([world_x, world_y, world_z])

            body_wkps = random.sample([list(wkp) for wkp in list(wkps['BODY'])], 500)
            all_wkps.extend(body_wkps)
        return np.array(all_wkps)

    @staticmethod
    def _generate_rotation_matrix(n, theta):

        R = np.array([[
            np.cos(theta) + n[0] ** 2 * (1 - np.cos(theta)),
            n[0] * n[1] * (1 - np.cos(theta)) - n[2] * np.sin(theta),
            n[0] * n[2] * (1 - np.cos(theta)) + n[1] * np.sin(theta)
        ], [
            n[1] * n[0] * (1 - np.cos(theta)) + n[2] * np.sin(theta),
            np.cos(theta) + n[1] ** 2 * (1 - np.cos(theta)),
            n[1] * n[2] * (1 - np.cos(theta)) - n[0] * np.sin(theta),
        ], [
            n[2] * n[0] * (1 - np.cos(theta)) - n[1] * np.sin(theta),
            n[2] * n[1] * (1 - np.cos(theta)) + n[0] * np.sin(theta),
            np.cos(theta) + n[2] ** 2 * (1 - np.cos(theta))
        ]])

        return R

    def normalize_3D_coordinates(self, wkps):

        # translate keypoints such that medoid is at origin
        wkps = wkps - 0.5 * (np.max(wkps[:8], axis=0) + np.min(wkps[:8], axis=0))

        # perform rotation
        upper_lip_idx = BODY_PARTS.index('UPPER_LIP')

        n = np.array([0, 1, 0])
        theta = np.arctan(wkps[upper_lip_idx][2] / wkps[upper_lip_idx][0])
        R = self._generate_rotation_matrix(n, theta)
        wkps = np.dot(R, wkps.T).T

        # perform reflecton if necessary
        tail_notch_idx = BODY_PARTS.index('TAIL_NOTCH')
        if wkps[upper_lip_idx][0] < wkps[tail_notch_idx][0]:
            R = np.array([
                [-1, 0, 0],
                [0, 1, 0],
                [0, 0, 1]
            ])
            wkps = np.dot(R, wkps.T).T

        return wkps

    def get_keypoint_arr(self, keypoints, cm):
        dorsal_fin_idx, pelvic_fin_idx = BODY_PARTS.index('DORSAL_FIN'), BODY_PARTS.index('PELVIC_FIN')
        wkps = self.get_raw_3D_coordinates(keypoints, cm)
        norm_wkps = self.normalize_3D_coordinates(wkps)
        if any([item['keypointType'] == 'BODY' for item in keypoints['leftCrop']]):
            body_norm_wkps = norm_wkps[8:, :]
            mid_point = 0.5 * (norm_wkps[dorsal_fin_idx] + norm_wkps[pelvic_fin_idx])
            idx = np.argmin(np.linalg.norm(body_norm_wkps[:, [0, 2]] - np.array([mid_point[0], mid_point[2]]), axis=1))
            body_wkp = body_norm_wkps[idx]
            keypoint_arr = np.vstack([norm_wkps[:8, :], body_wkp])
            return keypoint_arr
        else:
            return norm_wkps

    def get_prepared_dataset(self):
        return self.df


akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
gtsf_dataset = GTSFDataset('2019-03-01', '2020-03-30', akpd_scorer_url)
df = gtsf_dataset.get_prepared_dataset()
print(df.shape)



In [None]:
wkps.max(axis=0), wkps.min(axis=0)

In [None]:
for i in range(wkps.shape[0] - 1):
    for j in range(i+1, wkps.shape[0]):
        print(euclidean_distance(wkps[i], wkps[j]))

In [None]:
for idx_1 in range(len(BODY_PARTS) - 1):
    for idx_2 in range(idx_1+1, len(BODY_PARTS)):
        wkps_dict = df.world_keypoints.iloc[0]
        bp_1, bp_2 = BODY_PARTS[idx_1], BODY_PARTS[idx_2]
        d = euclidean_distance(wkps_dict[bp_1], wkps_dict[bp_2])
        print('{}-{}: {}'.format(bp_1, bp_2, d))

In [None]:
def get_world_keypoints(row):
    return pixel2world(row.keypoints['leftCrop'], row.keypoints['rightCrop'], row.camera_metadata)

def prepare_df(aggregate_df):
    
    # use QA'ed entries, and only use Cogito entries when QA data is unavailable
    qa_df = aggregate_df[aggregate_df.is_qa == True]
    cogito_df = aggregate_df[(aggregate_df.is_qa != True) & \
                             ~(aggregate_df.left_image_url.isin(qa_df.left_image_url))]
    df = pd.concat([qa_df, cogito_df], axis=0)
    
    # add world keypoints
    df['world_keypoints'] = df.apply(lambda x: get_world_keypoints(x), axis=1)
    return df


rds_access_utils = RDSAccessUtils(json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS'])))
query = """
    select * from research.fish_metadata a left join keypoint_annotations b
    on a.left_url = b.left_image_url 
    where b.keypoints -> 'leftCrop' is not null
    and b.keypoints -> 'rightCrop' is not null
    and b.is_qa = false
    and b.captured_at < '2019-09-20';
"""
aggregate_df = rds_access_utils.extract_from_database(query)
df = prepare_df(aggregate_df)


In [None]:
plt.figure(figsize=(20, 10))
plt.hist(df.world_keypoints.apply(lambda x: np.median([wkp[1] for wkp in x.values()])))
plt.grid()
plt.show()

In [None]:
df['median_depth'] = df.world_keypoints.apply(lambda x: np.median([wkp[1] for wkp in x.values()]))

In [None]:
s3_access_utils = S3AccessUtils('/root/data')


# initialize data transforms so that we can run inference with biomass neural network
normalize_centered_2D_transform_biomass = NormalizeCentered2D()
normalized_stability_transform = NormalizedStabilityTransform()
to_tensor_transform = ToTensor()

# load neural network weights
akpd_scorer_network = load_model('/root/data/alok/biomass_estimation/playground/akpd_scorer_model_TF.h5') # make this better

In [None]:
def generate_akpd_score(row_id, ann, cm):
    
    # run AKPD scoring network
    input_sample = {
        'keypoints': ann,
        'cm': cm,
        'stereo_pair_id': row_id,
        'single_point_inference': True
    }
    akpd_score = generate_confidence_score(input_sample, akpd_scorer_network)
    return akpd_score


In [None]:
akpd_scores = []
for idx, row in df.iterrows():
    akpd_score = generate_akpd_score(row.id, row.keypoints, row.camera_metadata)
    akpd_scores.append(akpd_score)
df['akpd_score'] = akpd_scores


In [None]:
where_clause = ''
for idx, row in df.loc[df.akpd_score < 1e-4, ['id', 'akpd_score']].iterrows():
    kpid = row.id
    where_clause += f' OR id = {int(kpid)}'


In [None]:
ids = []
for idx, row in df.loc[df.akpd_score < 1e-5, ['id', 'akpd_score']].iterrows():
    kpid = row.id
    ids.append(kpid)


In [None]:
df[df.fish_id=='190808-d20dc94e-fc76-4ffb-a4f5-f296d9ac368d'].id

In [None]:
prod_research_sql_credentials = json.load(open(os.environ['PROD_RESEARCH_SQL_CREDENTIALS']))
rds_access_utils = RDSAccessUtils(prod_research_sql_credentials)
s3_access_utils = S3AccessUtils('/root/data')
visualizer = Visualizer(s3_access_utils, rds_access_utils)


In [None]:
keypoint_annotation_id = 507806
visualizer.load_data(keypoint_annotation_id)
visualizer.display_crops(overlay_keypoints=True, show_labels=True)

In [None]:
keypoint_annotation_id = 648822
visualizer.load_data(keypoint_annotation_id)
visualizer.display_crops(overlay_keypoints=True, show_labels=False)

In [None]:
{item['keypointType']: [item['xFrame'], item['yFrame']] for item in df[df.id == 635713].keypoints.iloc[0]['leftCrop']}

In [None]:
{item['keypointType']: [item['xFrame'], item['yFrame']] for item in df[df.id == 635713].keypoints.iloc[0]['rightCrop']}

In [None]:
diffs = []
count = 0
for idx, row in df.iterrows():
    if count % 10000 == 0:
        print(count)
    count += 1
    ann_c = row.keypoints
    ann_dict_left_kps_c = {item['keypointType']: [item['xFrame'], item['yFrame']] for item in ann_c['leftCrop']}
    ann_dict_right_kps_c = {item['keypointType']: [item['xFrame'], item['yFrame']] for item in ann_c['rightCrop']}
    these_diffs = []
    for bp in BODY_PARTS:
        diff = ann_dict_left_kps_c[bp][1] - ann_dict_right_kps_c[bp][1]
        these_diffs.append(diff)
    diffs.append(np.mean(these_diffs))

In [None]:
df['diffs'] = diffs
df.index = pd.to_datetime(df.captured_at)
df.diffs.resample('D', how=lambda x: x.median())