In [None]:
%load_ext autoreload
%autoreload 2

import json
import pandas as pd
import numpy as np
from keras.models import load_model
from research_lib.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from weight_estimation.dataset import prepare_gtsf_data, compute_akpd_score, generate_akpd_scores
from weight_estimation.train import train, augment, normalize, get_data_split, train_model
from typing import Dict, Tuple


In [None]:
def augment(df: pd.DataFrame, augmentation_config: Dict) -> pd.DataFrame:
    print('hello')
    
    counts, edges = np.histogram(df.weight, bins=np.arange(0, 10000, 1000))
    trial_values = (5.0 / (counts / np.max(counts))).astype(int)
    max_jitter_std = augmentation_config['max_jitter_std']
    min_depth = augmentation_config['min_depth']
    max_depth = augmentation_config['max_depth']

    augmented_data = defaultdict(list)
    for idx, row in df.iterrows():
        
        camera_metadata = row.camera_metadata
        cm = CameraMetadata(
            focal_length=camera_metadata['focalLength'],
            focal_length_pixel=camera_metadata['focalLengthPixel'],
            baseline_m=camera_metadata['baseline'],
            pixel_count_width=camera_metadata['pixelCountWidth'],
            pixel_count_height=camera_metadata['pixelCountHeight'],
            image_sensor_width=camera_metadata['imageSensorWidth'],
            image_sensor_height=camera_metadata['imageSensorHeight']
        )
        
        weight = row.weight
        trials = trial_values[min(int(weight / 1000), len(trial_values) - 1)]
        for _ in range(trials):
            
            ann = row.keypoints
            X_left, X_right = get_left_right_keypoint_arrs(ann)
            wkps = convert_to_world_point_arr(X_left, X_right, cm)
            original_depth = np.median(wkps[:, 1])
            
            depth = np.random.uniform(min_depth, max_depth)
            scaling_factor = float(original_depth) / depth
            jitter_std = np.random.uniform(0, max_jitter_std)
            

            # rescale
            X_left = X_left * scaling_factor
            X_right = X_right * scaling_factor

            # add jitter
            X_left[:, 0] += np.random.normal(0, jitter_std, X_left.shape[0])
            X_right[:, 0] += np.random.normal(0, jitter_std, X_right.shape[0])

            # reconstruct annotation
            ann = get_ann_from_keypoint_arrs(X_left, X_right)
            augmented_data['annotation'].append(ann)
            augmented_data['fish_id'].append(row.fish_id)
            augmented_data['weight'].append(row.weight)
            augmented_data['kf'].append(row.k_factor)
            augmented_data['camera_metadata'].append(row.camera_metadata)

    augmented_df = pd.DataFrame(augmented_data)
    return augmented_df

In [None]:
s3 = S3AccessUtils('/root/data')
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
akpd_scorer_f, _, _ = s3.download_from_url(akpd_scorer_url)
df1 = prepare_gtsf_data('2019-03-01', '2019-09-20', akpd_scorer_f, 0.5, 1.0)

df2 = prepare_gtsf_data('2020-06-01', '2020-08-20', akpd_scorer_f, 0.5, 1.0)
df = pd.concat([df1, df2])

In [None]:
# augmentation_config = dict(
#     trials=10,
#     max_jitter_std=10,
#     min_depth=0.5,
#     max_depth=2.5
# )

# augmented_df = augment(df, augmentation_config)

In [None]:
import json
import os
import cv2
import numpy as np
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
from research.weight_estimation.keypoint_utils.body_parts import core_body_parts
from research.utils.image_utils import Picture
from scipy.spatial import Delaunay
from itertools import compress

def in_hull(p, hull):
    hull = Delaunay(hull)
    return hull.find_simplex(p) >= 0


def apply_convex_hull_filter(kp, des, canonical_kps, bbox):
    X_canon_kps = np.array(list(canonical_kps.values()))
    X_kp = np.array([x.pt for x in kp]).reshape(-1, 2) + np.array([bbox['x_min'], bbox['y_min']])
    is_valid = in_hull(X_kp, X_canon_kps)
    kp = list(compress(kp, is_valid))
    des = des[is_valid]
    return kp, des


def get_homography_and_matches(sift, left_patch, right_patch,
                               left_kps, right_kps,
                               left_bbox, right_bbox,
                               good_perc=0.7, min_match_count=3):

    kp1, des1 = sift.detectAndCompute(left_patch, None)
    kp2, des2 = sift.detectAndCompute(right_patch, None)
    try:
        if not (des1.any() and des2.any()):
            return None, kp1, kp2, None, [0]
    except AttributeError:
        print("None type for detectAndComputer descriptor")
        return None, kp1, kp2, None, [0]
    # apply convex hull filter
    kp1, des1 = apply_convex_hull_filter(kp1, des1, left_kps, left_bbox)
    kp2, des2 = apply_convex_hull_filter(kp2, des2, right_kps, right_bbox)

    bf = cv2.BFMatcher()
    matches = bf.knnMatch(des1, des2, k=2)

    H, matches_mask = np.eye(3), []
    good = []

    # check that matches list contains actual pairs
    if len(matches) > 0:
        if len(matches[0]) != 2:
            print('Aborting: matches list does not contain pairs')
            return H, kp1, kp2, good, matches_mask

    for m, n in matches:
        if m.distance < good_perc * n.distance:
            good.append(m)

    if len(good) >= min_match_count:
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
        H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        matches_mask = mask.ravel().tolist()
    return H, kp1, kp2, good, matches_mask


def generate_sift_adjustment(bp, left_crop_metadata, left_fish_picture, left_kps, right_crop_metadata,
                             right_fish_picture, right_kps, sift):
    left_kp, right_kp = left_kps[bp], right_kps[bp]
    left_crop, left_bbox = left_fish_picture.generate_crop_given_center(left_kp[0], left_kp[1], 600, 200)
    right_crop, right_bbox = right_fish_picture.generate_crop_given_center(right_kp[0], right_kp[1], 600, 200)

    H, _, _, _, matches_mask = get_homography_and_matches(sift, left_crop, right_crop,
                                                          left_kps, right_kps,
                                                          left_bbox, right_bbox)
    num_matches = sum(matches_mask)
    if H is not None:
        local_left_kp = [left_kp[0] - left_bbox['x_min'], left_kp[1] - left_bbox['y_min']]
        local_right_kp = cv2.perspectiveTransform(
            np.array([local_left_kp[0], local_left_kp[1]]).reshape(-1, 1, 2).astype(float), H).squeeze()
        right_kp = [local_right_kp[0] + right_bbox['x_min'], local_right_kp[1] + right_bbox['y_min']]
    left_item = {
        'keypointType': bp,
        'xCrop': left_kp[0],
        'yCrop': left_kp[1],
        'xFrame': left_crop_metadata['x_coord'] + left_kp[0],
        'yFrame': left_crop_metadata['y_coord'] + left_kp[1]
    }
    right_item = {
        'keypointType': bp,
        'xCrop': right_kp[0],
        'yCrop': right_kp[1],
        'xFrame': right_crop_metadata['x_coord'] + right_kp[0],
        'yFrame': right_crop_metadata['y_coord'] + right_kp[1]
    }
    return left_item, right_item, num_matches


def generate_refined_keypoints(ann, left_crop_url, right_crop_url):

    left_kps = {item['keypointType']: [item['xCrop'], item['yCrop']] for item in ann['leftCrop']}
    right_kps = {item['keypointType']: [item['xCrop'], item['yCrop']] for item in ann['rightCrop']}

    left_crop_metadata = {
        'x_coord': ann['leftCrop'][0]['xFrame'] - ann['leftCrop'][0]['xCrop'],
        'y_coord': ann['leftCrop'][0]['yFrame'] - ann['leftCrop'][0]['yCrop']
    }
    right_crop_metadata = {
        'x_coord': ann['rightCrop'][0]['xFrame'] - ann['rightCrop'][0]['xCrop'],
        'y_coord': ann['rightCrop'][0]['yFrame'] - ann['rightCrop'][0]['yCrop']
    }

    left_fish_picture = Picture(image_url=left_crop_url)
    right_fish_picture = Picture(image_url=right_crop_url)
    left_fish_picture.enhance(in_place=True)
    right_fish_picture.enhance(in_place=True)
    sift = cv2.KAZE_create()
    left_items, right_items = [], []
    for bp in core_body_parts:
        left_item, right_item, num_matches = generate_sift_adjustment(bp, left_crop_metadata, left_fish_picture,
                                                                      left_kps, right_crop_metadata,
                                                                      right_fish_picture, right_kps, sift)
        left_items.append(left_item)
        right_items.append(right_item)
    modified_ann = {
        'leftCrop': left_items,
        'rightCrop': right_items
    }
    return modified_ann

In [None]:
len(df)

In [None]:
modified_keypoints = []

count = 0

for idx, row in df.iterrows():
    count = count + 1
    
    if count > 10:
        modified_keypoints.append(None)
    if count % 1000 == 0:
        print(count, len(df))

    ann, cm = (row.keypoints), (row.camera_metadata)
    left_crop_url, right_crop_url = row.left_image_url, row.right_image_url
    
    modified_ann = generate_refined_keypoints(ann, left_crop_url, right_crop_url)
    modified_keypoints.append(modified_ann)
    
df['modified_keypoints'] = modified_keypoints

In [None]:
modified_keypoints[30722]

In [None]:
new_modified = []

for i in range(len(modified_keypoints)):
    if i >= 10 and i % 2 == 0:
        pass
    else:
        new_modified.append(modified_keypoints[i])

In [None]:
df['modified_keypoints'] = new_modified

In [None]:
df.to_csv(r'/root/data/alok/biomass_estimation/playground/gtsf_akpr.csv', header = True)

In [None]:
df = pd.read_csv(r'/root/data/alok/biomass_estimation/playground/gtsf_akpr.csv')


In [None]:
df.loc[0]

In [None]:
from research_lib.utils.data_access_utils import S3AccessUtils
s3 = S3AccessUtils('/root/data')
akpd_scorer_url = 'https://aquabyte-models.s3-us-west-1.amazonaws.com/keypoint-detection-scorer/akpd_scorer_model_TF.h5'
akpd_scorer_f, _, _ = s3.download_from_url(akpd_scorer_url)

In [None]:
modified_keypoints = []
camera_metadatas = []

for idx, row in df.iterrows():
    modified_keypoints.append(eval(row['modified_keypoints']))
    camera_metadatas.append(eval(row['camera_metadata']))
    
df['keypoints'] = modified_keypoints
df['camera_metadata'] = camera_metadatas

In [None]:
akpd_scores = generate_akpd_scores(df, akpd_scorer_f)

df['modified_akpd_scores'] = akpd_scores


In [None]:
df['modified_akpd_score'] = akpd_scores

In [None]:
import matplotlib.pyplot as plt
plt.scatter(df['akpd_score'], df['modified_akpd_score'])

In [None]:
df.to_csv(r'/root/data/alok/biomass_estimation/playground/gtsf_akpr2.csv', header = True)