In [None]:
import argparse
import json
import glob
from multiprocessing import Pool
import os
import shutil

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import cv2
from PIL import Image, ImageDraw
from urllib.parse import urlparse
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils

In [None]:
THUMBNAIL_WIDTH = 512
PIXEL_COUNT_WIDTH = 4096
PIXEL_COUNT_HEIGHT = 3000
X_PADDING_FULLRES = 190
Y_PADDING_FULLRES = 140
X_PADDING = X_PADDING_FULLRES * float(THUMBNAIL_WIDTH / PIXEL_COUNT_WIDTH)
Y_PADDING = Y_PADDING_FULLRES * float(THUMBNAIL_WIDTH / PIXEL_COUNT_HEIGHT)
ROOT_DIR = '/root/data/s3'
OUTPUT_BASE_DIR = 'generated_video'

s3_access_utils = S3AccessUtils('/root/data')
rds_access_utils = RDSAccessUtils(json.load(open(os.environ['DATA_WAREHOUSE_SQL_CREDENTIALS'])))

In [None]:
def _refresh_directory(dirname):
    if os.path.exists(dirname):
        shutil.rmtree(dirname)
    os.makedirs(dirname)

def _get_bucket_key(url):
    parsed_url = urlparse(url, allow_fragments=False)
    if parsed_url.netloc.startswith('s3'):
        url_components = parsed_url.path.lstrip('/').split('/')
        bucket, key = url_components[0], os.path.join(*url_components[1:])
    else:
        bucket = parsed_url.netloc.split('.')[0]
        key = parsed_url.path.lstrip('/')
    return bucket, key


def process_s3_key_dir(s3_key_dir, inbound_bucket='aquabyte-frames-resized-inbound'):
    try:
        left_f = s3_access_utils.download_from_s3(inbound_bucket, os.path.join(s3_key_dir,
                                                                               'left_frame.resize_512_512.jpg'))
        right_f = s3_access_utils.download_from_s3(inbound_bucket, os.path.join(s3_key_dir,
                                                                                'right_frame.resize_512_512.jpg'))
        crop_metadata_f = s3_access_utils.download_from_s3(inbound_bucket, os.path.join(s3_key_dir,
                                                                                        'crops.json'))
    except Exception as e:
        print(e)
        return

    # open images and metadata files
    left_im = Image.open(left_f)
    right_im = Image.open(right_f)
    crop_metadata = json.load(open(crop_metadata_f))
    try:
        depth = crop_metadata['capture']['sensors'].get('aquabyte_depth_meters')
    except Exception as e:
        depth = 'Depth not found'

    # draw boxes on images
    left_draw = ImageDraw.Draw(left_im)
    right_draw = ImageDraw.Draw(right_im)
    anns = crop_metadata['annotations']
    if anns:
        for ann in anns:
            c1 = max(ann['bbox'][0] - X_PADDING, 0)
            c2 = max(ann['bbox'][1] - Y_PADDING, 0)
            c3 = min(ann['bbox'][0] + ann['bbox'][2] + X_PADDING, THUMBNAIL_WIDTH)
            c4 = min(ann['bbox'][1] + ann['bbox'][3] + Y_PADDING, THUMBNAIL_WIDTH)
            if ann['image_id'] == 1:
                left_draw.rectangle([(c1, c2), (c3, c4)])
            elif ann['image_id'] == 2:
                right_draw.rectangle([(c1, c2), (c3, c4)])

    # stitch images
    result = Image.new('RGB', (2 * THUMBNAIL_WIDTH, THUMBNAIL_WIDTH))
    result.paste(im=left_im, box=(0, 0))
    result.paste(im=right_im, box=(THUMBNAIL_WIDTH, 0))

    # write timestamp on stitched image
    result_draw = ImageDraw.Draw(result)
    ts = [c for c in left_f.split('/') if c.startswith('at=')][0]
    display_ts = 'UTC Time: {}'.format(ts.replace('at=', ''))
    display_depth = 'Depth: {}m'.format(depth)
    result_draw.text((0, 0), display_ts, (255, 255, 255))
    result_draw.text((0, 10), display_depth, (255, 255, 255))

    output_f = left_f.replace(ROOT_DIR, OUTPUT_BASE_DIR).replace('left_', 'stereo_')
    if not os.path.exists(os.path.dirname(output_f)):
        os.makedirs(os.path.dirname(output_f))
    result.save(output_f)


def stitch_frames_into_video(image_fs, video_f):
    im = cv2.imread(image_fs[0])
    height, width, layers = im.shape
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    video = cv2.VideoWriter(video_f, fourcc, 4, (width, height), True)
    for idx, image_f in enumerate(image_fs):
        if idx % 1000 == 0:
            print(idx)
        im = cv2.imread(image_f, cv2.IMREAD_COLOR)
        video.write(im)
    cv2.destroyAllWindows()
    video.release()
    print('Video generation complete!')


def _captured_in_hour_range(key, start_hour, end_hour):
    hour = int([component for component in key.split('/') if component.startswith('hour=')][0].split('=')[-1])
    return start_hour <= hour <= end_hour


def extract_s3_keys(pen_id, date, start_hour, end_hour, inbound_bucket='aquabyte-frames-resized-inbound'):
    query = """
        SELECT captured_at, left_crop_url
        FROM prod.crop_annotation ca
        WHERE ca.pen_id={} AND ca.service_id = 2
        AND to_char(ca.captured_at, 'YYYY-MM-DD') IN ('{}')
        LIMIT 1;
    """.format(pen_id, date)

    df = rds_access_utils.extract_from_database(query)
    image_url = df.left_crop_url.iloc[0].replace("parallel", "production")
    bucket, key = _get_bucket_key(image_url)
    s3_folder = os.path.join(key[:key.index('date')], 'date={}'.format(date))
    generator = s3_access_utils.get_matching_s3_keys(inbound_bucket, s3_folder, suffixes=['capture.json'])
    keys = [key for key in generator if _captured_in_hour_range(key, start_hour, end_hour)]
    s3_key_dirs = sorted(list(set([os.path.dirname(f) for f in keys])))
    return s3_key_dirs


def generate_video(pen_id, date, start_hour, end_hour, upload_to_s3=True, video_bucket='aquabyte-images-adhoc',
                   num_processes=20):

    # refresh output directory (i.e. clean out its contents)
    _refresh_directory(OUTPUT_BASE_DIR)

    # extract s3 keys
    print('Extracting s3 keys...')
    s3_key_dirs = extract_s3_keys(pen_id, date, start_hour, end_hour)
    print('S3 keys extraction complete!')

    print('Generating frames...')
    pool = Pool(num_processes)
    pool.map(process_s3_key_dir, s3_key_dirs)
    print('Frame generation complete!')

    print('Generating video...')
    image_fs = sorted(
        filter(lambda path: 'stereo' in path, glob.glob(os.path.join(OUTPUT_BASE_DIR, '**', '*.jpg'), recursive=True)))
    video_f = os.path.join(OUTPUT_BASE_DIR, 'pen_id_{}_date_{}_video.avi'.format(str(pen_id), date))
    stitch_frames_into_video(image_fs, video_f)
    print('Video generation complete!')

    if upload_to_s3:
        print('Uploading to S3...')
        video_key = os.path.join('videos', str(pen_id), os.path.basename(video_f))
        s3_access_utils.s3_client.upload_file(video_f, video_bucket, video_key)
        print('Upload complete! Result available here: {}'.format(os.path.join(video_bucket, video_key)))


def main():
    pen_id_switch = '--pen_id'
    date_switch = '--date'
    start_hour_switch = '--start_hour'
    end_hour_switch = '--end_hour'
    parser = argparse.ArgumentParser()
    parser.add_argument(pen_id_switch, type=int, help='Pen ID for this video generation', required=True)
    parser.add_argument(date_switch, type=str, help='Date for this video generation', required=True)
    parser.add_argument(start_hour_switch, type=int, help='Start hour (full hour included)', required=False, default=24)
    parser.add_argument(end_hour_switch, type=int, help='End hour (full hour included)', required=False, default=24)
    args = parser.parse_args()
    pen_id, date, start_hour, end_hour = args.pen_id, args.date, args.start_hour, args.end_hour
    print(f'Building video for {pen_id} on {date} from {start_hour} to {end_hour} (inclusive)')
    generate_video(pen_id, date, start_hour, end_hour)


if __name__ == '__main__':
    main()


In [None]:
pen_id, date, start_hour, end_hour, upload_to_s3 = 56, '2020-04-22', 10, 12, False
num_processes = 12
inbound_bucket='aquabyte-frames-resized-inbound'

In [None]:
query = """
    SELECT captured_at, left_crop_url
    FROM prod.crop_annotation ca
    WHERE ca.pen_id={} AND ca.service_id = 2
    AND to_char(ca.captured_at, 'YYYY-MM-DD') IN ('{}')
    LIMIT 1;
""".format(pen_id, date)

df = rds_access_utils.extract_from_database(query)
image_url = df.left_crop_url.iloc[0]
print(image_url)

In [None]:
image_url = image_url.replace("parallel", "production")
print(image_url)

In [None]:
bucket, key = _get_bucket_key(image_url)
s3_folder = os.path.join(key[:key.index('date')], 'date={}'.format(date))
generator = s3_access_utils.get_matching_s3_keys(inbound_bucket, s3_folder, suffixes=['capture.json'])
keys = [key for key in generator if _captured_in_hour_range(key, start_hour, end_hour)]
s3_key_dirs = sorted(list(set([os.path.dirname(f) for f in keys])))

In [None]:
# # refresh output directory (i.e. clean out its contents)
# _refresh_directory(OUTPUT_BASE_DIR)

# # extract s3 keys
# print('Extracting s3 keys...')
# s3_key_dirs = extract_s3_keys(pen_id, date, start_hour, end_hour)
# print('S3 keys extraction complete!')

In [None]:
print('Generating frames...')
pool = Pool(num_processes)
pool.map(process_s3_key_dir, s3_key_dirs)
print('Frame generation complete!')

print('Generating video...')
image_fs = sorted(
    filter(lambda path: 'stereo' in path, glob.glob(os.path.join(OUTPUT_BASE_DIR, '**', '*.jpg'), recursive=True)))
video_f = os.path.join(OUTPUT_BASE_DIR, 'pen_id_{}_date_{}_video.avi'.format(str(pen_id), date))
stitch_frames_into_video(image_fs, video_f)
print('Video generation complete!')

if upload_to_s3:
    print('Uploading to S3...')
    video_key = os.path.join('videos', str(pen_id), os.path.basename(video_f))
    s3_access_utils.s3_client.upload_file(video_f, video_bucket, video_key)
    print('Upload complete! Result available here: {}'.format(os.path.join(video_bucket, video_key)))

In [None]:
process_s3_key_dir(s3_key_dirs[0])

In [None]:
os.path.join(key[:key.index('date')], 'date={}'.format(date))