In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../q1_o2kr2_dataset_annotations/')

from collections import defaultdict
import json
import os
import cv2
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pandas as pd
from research_lib.utils.data_access_utils import S3AccessUtils
from research.utils.data_access_utils import RDSAccessUtils
import uuid
from construct_fish_detection_dataset_o2kr2 import establish_plali_connection, insert_into_plali
from rectification import rectify
from weight_estimation.weight_estimator import WeightEstimator, CameraMetadata

from PIL import Image, ImageDraw, ImageFont


<h1> Get raw images from toy fish experiment </h1>

In [None]:
s3 = S3AccessUtils('/root/data')

In [None]:
prefix = 'environment=production/site-id=55/pen-id=97/date=2021-02-24/hour=13'
suffixes = ['frame.jpg']
keygen = s3.get_matching_s3_keys('aquabyte-images-raw', prefix, suffixes=suffixes)
keys = []
for key in keygen:
    keys.append(key)


In [None]:
image_pair_dict = defaultdict(dict)
for key in keys:
    dirname = os.path.dirname(key)
    if 'left' in key:
        image_pair_dict[dirname]['left'] = key
    elif 'right' in key:
        image_pair_dict[dirname]['right'] = key
    else:
        raise Exception('Key not valid')
        
image_pairs = []
for dirname in sorted(list(image_pair_dict.keys())):
    keys = image_pair_dict[dirname]
    try:
        image_pairs.append((keys['left'], keys['right']))
    except KeyError as err:
        print(err)
        
    

<h1> Check frames </h1>

In [None]:
def download_image(image_url):
    image_s3_url = image_url
    url_components = image_s3_url.replace('s3://', '').split('/')
    bucket = url_components[0]
    key = os.path.join(*url_components[1:])
    image_f = s3.download_from_s3(bucket, key)
    return image_f


def plot_stereo_image(left_image_f, right_image_f):
    fig, axes = plt.subplots(1, 2)
    
    # show left image
    left_im = cv2.imread(left_image_f)
    left_im = cv2.cvtColor(left_im, cv2.COLOR_BGR2RGB)
    axes[0].imshow(left_im)
    
    # show right image
    right_im = cv2.imread(right_image_f)
    right_im = cv2.cvtColor(right_im, cv2.COLOR_BGR2RGB)
    axes[1].imshow(right_im)
    
    plt.show()
    


In [None]:
idx = 0
for left_key, right_key in image_pairs:
    print(idx)
    idx += 1
    
    left_full_res_frame_s3_url, right_full_res_frame_s3_url = [os.path.join('s3://', 'aquabyte-images-raw', key) for key in (left_key, right_key)]
    left_frame_s3_url, right_frame_s3_url = [x.replace('.jpg', '.resize_512_512.jpg') for x in (left_full_res_frame_s3_url, right_full_res_frame_s3_url)]
    
    # download left image
    left_image_f = download_image(left_frame_s3_url)
    right_image_f = download_image(right_frame_s3_url)
    
    # plot image
    plot_stereo_image(left_image_f, right_image_f)
    

In [None]:
left_image_f

<h1> Rectify raw images and upload to s3 </h1>

In [None]:
def download_from_s3_url(s3_url):
    url_components = s3_url.replace('s3://', '').split('/')
    bucket = url_components[0]
    key = os.path.join(*url_components[1:])
    f = s3.download_from_s3(bucket, key)
    return f, bucket, key

In [None]:
left_image_rectified_s3_urls, right_image_rectified_s3_urls = [], []
stereo_parameters_url = 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12%3A11%3A24.770071000Z_L40020313_R40013177_stereo-parameters.json'.replace('%3A', ':')
count = 0

for left_key, right_key in image_pairs:
    
    # get unrectified full resolution frames
    left_full_res_frame_s3_url, right_full_res_frame_s3_url = [os.path.join('s3://', 'aquabyte-images-raw', key) for key in (left_key, right_key)]
    left_full_res_frame_f, _, left_full_res_frame_key = download_from_s3_url(left_full_res_frame_s3_url)
    right_full_res_frame_f, _, right_full_res_frame_key = download_from_s3_url(right_full_res_frame_s3_url)
    stereo_parameters_f, _, _ = s3.download_from_url(stereo_parameters_url)
    
    # rectify into full resolution stereo frame pair and save to disk
    left_image_rectified, right_image_rectified = rectify(left_full_res_frame_f, right_full_res_frame_f, stereo_parameters_f)
    left_image_rectified_f = os.path.join(os.path.dirname(left_full_res_frame_f), 'left_frame.rectified.jpg')
    right_image_rectified_f = os.path.join(os.path.dirname(right_full_res_frame_f), 'right_frame.rectified.jpg')
    cv2.imwrite(left_image_rectified_f, left_image_rectified)
    cv2.imwrite(right_image_rectified_f, right_image_rectified)
    
    # upload rectified stereo frame pairs to s3
    left_rectified_full_res_frame_key = left_full_res_frame_key.replace('.jpg', '.rectified.jpg')
    right_rectified_full_res_frame_key = right_full_res_frame_key.replace('.jpg', '.rectified.jpg')
    s3.s3_client.upload_file(left_image_rectified_f, 'aquabyte-images-raw', left_rectified_full_res_frame_key)
    s3.s3_client.upload_file(right_image_rectified_f, 'aquabyte-images-raw', right_rectified_full_res_frame_key)
    
    # append to url lists
    left_image_rectified_s3_url = os.path.join('s3://', 'aquabyte-images-raw', left_rectified_full_res_frame_key)
    right_image_rectified_s3_url = os.path.join('s3://', 'aquabyte-images-raw', right_rectified_full_res_frame_key)
    left_image_rectified_s3_urls.append(left_image_rectified_s3_url)
    right_image_rectified_s3_urls.append(right_image_rectified_s3_url)
    
    print(count)
    count += 1




<h1> Insert into PLALI for key-point annotation </h1>

In [None]:
def process_into_plali_records(image_url_pairs, workflow_id):

    values_to_insert = []
    for idx, image_url_pair in enumerate(image_url_pairs):
        id = str(uuid.uuid4())
        images = set(image_url_pair)
        metadata = {}
        priority = float(idx) / len(image_url_pairs)

        values = {
            'id': id,
            'workflow_id': workflow_id,
            'images': images,
            'metadata': metadata,
            'priority': priority
        }

        values_to_insert.append(values)

    return values_to_insert


def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))



In [None]:
image_url_pairs = list(zip(left_image_rectified_s3_urls, right_image_rectified_s3_urls))

In [None]:
WORKFLOW_ID = '00000000-0000-0000-0000-000000000056'
values_to_insert = process_into_plali_records(image_url_pairs, WORKFLOW_ID)



In [None]:
os.environ['PLALI_SQL_CREDENTIALS'] = '/run/secrets/plali_sql_credentials.json'
engine, sql_metadata = establish_plali_connection()

n = 10
count = 0
for chunk in chunker(values_to_insert, n):
    insert_into_plali(chunk, engine, sql_metadata)
    
    count += 1
    print(count)

In [None]:
annotated_df.images.iloc[0]

<h1> Calculate weights </h1>

<h2> Parse annotations into standard form </h2>

In [None]:
os.environ['PLALI_SQL_CREDENTIALS'] = '/run/secrets/plali_sql_credentials.json'
rds = RDSAccessUtils(json.load(open(os.environ['PLALI_SQL_CREDENTIALS'])))

query = """
    select * from plali.plali_annotations x
    inner join 
    ( select a.id as plali_image_id, a.images, a.metadata, b.id as workflow_id, b.name from plali.plali_images a
    inner join plali.plali_workflows b
    on a.workflow_id = b.id ) y
    on x.plali_image_id = y.plali_image_id
    where workflow_id = '00000000-0000-0000-0000-000000000056';
"""

annotated_df = rds.extract_from_database(query)

In [None]:
class AnnotationFormatError(Exception):
    pass


anns = []
for idx, row in annotated_df.iterrows():
    try:
        raw_ann = row.annotation
        if 'skipReasons' in raw_ann:
            raise AnnotationFormatError

        ann = {'leftCrop': [], 'rightCrop': []}

        for side in ['leftCrop', 'rightCrop']:
            for raw_item in row.annotation[side]['annotation']['annotations']:
                if 'xCrop' not in raw_item or 'yCrop' not in raw_item:
                    raise AnnotationFormatError
                item = {
                    'xCrop': raw_item['xCrop'],
                    'yCrop': raw_item['yCrop'],
                    'xFrame': raw_item['xCrop'],
                    'yFrame': raw_item['yCrop'],
                    'keypointType': raw_item['category']
                }
                
                ann[side].append(item)

        if any([len(ann[side]) != 11 for side in ['leftCrop', 'rightCrop']]):
            raise AnnotationFormatError
        
        anns.append(ann)
        
    except AnnotationFormatError as err:
        anns.append(None)
    
    
    

In [None]:
annotated_df['ann'] = anns

<h2> Check annotations / disparity values </h2>

In [None]:
for idx, row in annotated_df.iterrows():
    ann = row.ann
    if ann is not None:
        left_mean_x = np.mean([item['xFrame'] for item in ann['leftCrop']])
        right_mean_x = np.mean([item['xFrame'] for item in ann['rightCrop']])
        print(left_mean_x - right_mean_x)


<h2> Compute weights </h2>

In [None]:
stereo_parameters_url = 'https://aquabyte-stereo-parameters.s3-eu-west-1.amazonaws.com/L40020313_R40013177/2021-02-25T12%3A11%3A24.770071000Z_L40020313_R40013177_stereo-parameters.json'.replace('%3A', ':')
stereo_parameters_f, _, _ = s3.download_from_url(stereo_parameters_url)

stereo_params = json.load(open(stereo_parameters_f))
camera_metadata = {
    'focalLengthPixel': stereo_params['CameraParameters1']['FocalLength'][0],
    'baseline': abs(stereo_params['TranslationOfCamera2'][0] / 1e3),
    'focalLength': stereo_params['CameraParameters1']['FocalLength'][0] * 3.45e-6,
    'pixelCountWidth': 4096,
    'pixelCountHeight': 3000,
    'imageSensorWidth': 0.01412,
    'imageSensorHeight': 0.01035
}

In [None]:
weight_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/biomass/trained_models/2020-11-27T00-00-00/weight_model_synthetic_data.pb')
kf_model_f, _, _ = s3.download_from_url('https://aquabyte-models.s3-us-west-1.amazonaws.com/k-factor/trained_models/2020-08-08T000000/kf_predictor_v2.pb')
weight_estimator = WeightEstimator(weight_model_f, kf_model_f)

pred_weights = []

count = 0
for idx, row in annotated_df.iterrows():
    ann = row.ann
    if ann is not None:
        cm = CameraMetadata(
            focal_length=camera_metadata['focalLength'],
            focal_length_pixel=camera_metadata['focalLengthPixel'],
            baseline_m=camera_metadata['baseline'],
            pixel_count_width=camera_metadata['pixelCountWidth'],
            pixel_count_height=camera_metadata['pixelCountHeight'],
            image_sensor_width=camera_metadata['imageSensorWidth'],
            image_sensor_height=camera_metadata['imageSensorHeight']
        )

        weight, _, _ = weight_estimator.predict(ann, cm)
        pred_weights.append(weight)
    else:
        pred_weights.append(None)
    
    if count % 1000 == 0:
        print(count)
    count += 1
    

In [None]:
annotated_df['weight'] = pred_weights

In [None]:
annotated_df.weight.mean()

In [None]:
def display_crops(left_image_f, right_image_f, ann, overlay_keypoints=True, show_labels=True):

    fig, axes = plt.subplots(2, 1, figsize=(20, 20))
    left_image = plt.imread(left_image_f)
    right_image = plt.imread(right_image_f)
    axes[0].imshow(left_image)
    axes[1].imshow(right_image)
    
    left_keypoints = {item['keypointType']: [item['xFrame'], item['yFrame']] for item in ann['leftCrop']}
    right_keypoints = {item['keypointType']: [item['xFrame'], item['yFrame']] for item in ann['rightCrop']}
    
    if overlay_keypoints:
        for bp, kp in left_keypoints.items():
            axes[0].scatter([kp[0]], [kp[1]], color='red', s=1)
            if show_labels:
                axes[0].annotate(bp, (kp[0], kp[1]), color='red')
        for bp, kp in right_keypoints.items():
            axes[1].scatter([kp[0]], [kp[1]], color='red', s=1)
            if show_labels:
                axes[1].annotate(bp, (kp[0], kp[1]), color='red')
    plt.show()

In [None]:
annotated_df.images.iloc[2]

In [None]:
count = 0
for idx, row in annotated_df.iterrows():
    ann = row.ann
    if ann is None:
        continue
    
    left_image_s3_url = row.images[0]
    right_image_s3_url = row.images[1]
    left_image_key = os.path.join(*left_image_s3_url.replace('s3://', '').split('/')[1:])
    right_image_key = os.path.join(*right_image_s3_url.replace('s3://', '').split('/')[1:])
    left_image_f = s3.download_from_s3('aquabyte-images-raw', left_image_key)
    right_image_f = s3.download_from_s3('aquabyte-images-raw', right_image_key)
    
    
    display_crops(left_image_f, right_image_f, ann)
    
    if count > 10:
        break
    count += 1

    

<h1> Generate video </h1>

In [None]:
S3_DIR = '/root/data/s3'
OUTPUT_BASE_DIR = '/root/data/alok/biomass_estimation/playground/toy_fish_video_second_enclosure'
WIDTH = 512

def stitch_frames(left_thumbnail_f, right_thumbnail_f, weight):

    # open images and metadata files
    left_im = Image.open(left_thumbnail_f)
    right_im = Image.open(right_thumbnail_f)

    # stitch images
    result = Image.new('RGB', (2 * WIDTH, WIDTH))
    result.paste(im=left_im, box=(0, 0))
    result.paste(im=right_im, box=(WIDTH, 0))

    # write timestamp on stitched image
    result_draw = ImageDraw.Draw(result)
#     selected_font = "arial.ttf"
#     font_size = 30
#     font = ImageFont.truetype(selected_font, font_size)
    result_draw.text((0, 0), '{} g'.format(str(weight)), (255, 255, 255))

    output_f = left_thumbnail_f.replace(S3_DIR, OUTPUT_BASE_DIR).replace('left_', 'stereo_')
    if not os.path.exists(os.path.dirname(output_f)):
        os.makedirs(os.path.dirname(output_f))
    result.save(output_f)
    return output_f
    

def stitch_frames_into_video(image_fs, video_f):
    im = cv2.imread(image_fs[0])
    height, width, layers = im.shape
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    video = cv2.VideoWriter(video_f, fourcc, 1, (width, height), True)
    for idx, image_f in enumerate(image_fs):
        if idx % 1000 == 0:
            print(idx)
        im = cv2.imread(image_f, cv2.IMREAD_COLOR)
        video.write(im)
    cv2.destroyAllWindows()
    video.release()

In [None]:
mask = ~annotated_df.ann.isnull()
output_fs = []
for idx, row in annotated_df[mask].iterrows():
    ann = row.ann
    left_image_s3_url = row.images[0]
    right_image_s3_url = row.images[1]
    weight = round(row.weight, 2)
    left_image_key = os.path.join(*left_image_s3_url.replace('s3://', '').split('/')[1:])
    right_image_key = os.path.join(*right_image_s3_url.replace('s3://', '').split('/')[1:])
    left_thumbnail_key, right_thumbnail_key = [x.replace('.rectified.jpg', '.resize_512_512.jpg') for x in (left_image_key, right_image_key)
                                              ]
    left_thumbnail_f = s3.download_from_s3('aquabyte-images-raw', left_thumbnail_key)
    right_thumbnail_f = s3.download_from_s3('aquabyte-images-raw', right_thumbnail_key)
    
    output_f = stitch_frames(left_thumbnail_f, right_thumbnail_f, weight)
    output_fs.append(output_f)
    

    

In [None]:
stitch_frames_into_video(sorted(output_fs), '/root/data/alok/biomass_estimation/playground/toy_fish_video_second_enclosure/video_second_enclosure.avi')



In [None]:
plt.figure(figsize=(15, 8))
plt.hist(annotated_df.weight.values, bins=20)
plt.grid()
plt.show()

In [None]:
annotated_df.weight.median()