In [None]:
from research.utils.data_access_utils import RDSAccessUtils, S3AccessUtils
import json
import os
from utils import utils, data_prep, visualize
import matplotlib.pyplot as plt

from tqdm import tqdm
import  pandas as pd
from random import seed


from PIL import Image
import numpy as np
import matplotlib.patches as patches
import cv2


In [None]:
SEED = 33
CROP_WIDTH = 512
CROP_HEIGHT = 512

LICE_BBOX_COLOR = ['b', 'r'] # bbox edge color
LICE_CATEGORY = ['ADULT_FEMALE', 'MOVING']

In [None]:
# docker-compose up 
#credentials = json.load(open(os.environ["DATA_WAREHOUSE_SQL_CREDENTIALS"]))

# docker run
credentials = json.load(open("/root/jane/cv_research/jane/deploy/data_warehouse_sql_credentials.json"))

rds_access_utils = RDSAccessUtils(credentials)
#s3_access_utils = S3AccessUtils('/root/data')

In [None]:
get_annotation_data = """
    SELECT 
        group_id,
        left_crop_metadata,
        left_crop_url,
        annotation 
    FROM 
        prod.crop_annotation 
    WHERE 
        (captured_at BETWEEN '2020-01-01' AND '2020-02-01') AND 
        (group_id IN ('56', '65', '37')) AND
        (annotation_state_id IN (7)) AND
        (service_id = 1);
"""
annotation_data = rds_access_utils.extract_from_database(get_annotation_data)

In [None]:
for idx, sf in tqdm(annotation_data.iterrows()):
    if not sf['annotation']:
        continue
    sf['annotation'] = sorted(sf['annotation'], key = lambda i: (i['position']['left'], i['position']['top'])) 


In [None]:
annotation_data.head()

In [None]:
for idx, sf in tqdm(annotation_data.iterrows()):
    if not sf['annotation']:
        continue
    if idx > 100:
        break
    print("_____")
    for lice in sf['annotation']:
        print(lice['position'])

In [None]:
seed(SEED)
import importlib
importlib.reload(visualize)
importlib.reload(data_prep)

crop_len, lice_len, cp_avg = [], [], []
for idx, sf in tqdm(annotation_data.iterrows()):
    if not sf['annotation'] or idx < 0:
        continue
    if idx > 50:
        break
    visualizer = visualize.Visualizer(s3_access_utils, rds_access_utils)

    visualizer.load_image(sf)
    iw = sf['left_crop_metadata']['width']
    ih = sf['left_crop_metadata']['height']
    crops = data_prep.generate_crops_smart(sf["annotation"], [iw, ih], [512, 512])
    crop_len.append(len(crops))
    for crop in crops:
        cp_avg.append(1 / len(crops[crop]))
        crop_left, crop_top = crop
        visualizer.display_crop(crop_left, crop_top, 512, 512, "TOP")
        for lice in crops[crop]:
            visualizer.display_lice(lice, lice['position']['left'], lice['position']['top'])
    visualizer.show()
print(sum(crop_len)/len(crop_len))
print(sum(cp_avg)/len(cp_avg))

In [None]:
#s3_access_utils = S3AccessUtils('/root/data')

import boto3
from urllib.parse import urlparse


aws_credentials = json.load(open('/root/jane/cv_research/jane/deploy/aws_credentials.json'))
s3_client = boto3.client('s3', aws_access_key_id=aws_credentials["aws_access_key_id"],
aws_secret_access_key=aws_credentials["aws_secret_access_key"],
            region_name="eu-west-1")

def recursive_mkdir(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    return
    
def download_from_s3(bucket, key, custom_location=None):
    if custom_location:
        recursive_mkdir(os.path.dirname(custom_location))
        s3_client.download_file(bucket, key, custom_location)
    else:
        s3_base_dir = os.path.join('/root/data', 's3')
        f = os.path.join(s3_base_dir, bucket, key)
        if not os.path.exists(f):
            recursive_mkdir(os.path.dirname(f))
            s3_client.download_file(bucket, key, f)
        return f
    
def download_from_url(url):
    parsed_url = urlparse(url, allow_fragments=False)
    if parsed_url.netloc.startswith('s3'):
        url_components = parsed_url.path.lstrip('/').split('/')
        bucket, key = url_components[0], os.path.join(*url_components[1:])
    else:
        bucket = parsed_url.netloc.split('.')[0]
        key = parsed_url.path.lstrip('/')
    image_f = download_from_s3(bucket, key)
    return image_f, bucket, key

In [None]:

import importlib
importlib.reload(data_prep)

nrows = 20
figure, axes = plt.subplots(nrows=nrows, ncols=2, figsize=(20, nrows * 6))

num_pic = -1
for idx, sf in tqdm(annotation_data.iterrows()):
    if not sf['annotation'] or idx < 90:
        continue
    num_pic += 1
    if num_pic >= nrows * 2:
        break
    left_image_f, bucket, left_image_key = download_from_url(sf["left_crop_url"])
    image_f = left_image_f 

    
    img = Image.open(image_f)

    alpha = 3 # Contrast control (1.0-3.0)
    beta = 20 # Brightness control (0-100)

    img = np.asarray(img)
    adjusted = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    axes[num_pic // 2, num_pic % 2].imshow(adjusted)
    
    iw = sf['left_crop_metadata']['width']
    ih = sf['left_crop_metadata']['height']
    crops = data_prep.generate_crops_smart(sf["annotation"], [iw, ih], [512, 512])

    for crop in crops:

        crop_left, crop_top = crop
        rect = patches.Rectangle(crop, 512, 512, linewidth=3, facecolor='none', edgecolor = 'yellow')
        axes[num_pic // 2, num_pic % 2].add_patch(rect)
    
        for lice in crops[crop]:
            lp = lice['position'] 
            x, y, w, h = lp["left"], lp["top"], lp["width"], lp["height"]
            class_index = LICE_CATEGORY.index(lice['category'])
            ec = LICE_BBOX_COLOR[class_index]
            rect = patches.Rectangle((x, y), w, h,linewidth=1,edgecolor=ec,facecolor='none') 
            axes[num_pic // 2, num_pic % 2].add_patch(rect)
figure.tight_layout()

In [None]:
num_pic

In [None]:
import importlib
importlib.reload(visualize)
importlib.reload(data_prep)
num_pic = -1
for idx, sf in tqdm(annotation_data.iterrows()):
    if not sf['annotation']:
        continue
    has_large_lice = False
    for lice in sf['annotation']:
        lp = lice['position']
        w, h = lp["width"], lp["height"]
        if lice['category'] == 'MOVING' and w > 70 and h > 25: 
            has_large_lice = True
            break
    if not has_large_lice:
        continue
    num_pic += 1
    print(num_pic)

    if num_pic >= 30:
        break
    visualizer = visualize.Visualizer(s3_access_utils, rds_access_utils)

    visualizer.load_image(sf)
    iw = sf['left_crop_metadata']['width']
    ih = sf['left_crop_metadata']['height']
    crops = data_prep.generate_crops_smart(sf["annotation"], [iw, ih], [512, 512])


    visualizer.show_crops(crops)
    visualizer.show()


In [None]:
import importlib
importlib.reload(visualize)
importlib.reload(data_prep)
from PIL import Image
import numpy as np
import cv2

LABEL_PATH = 'region_data/region_labels'
IMAGE_PATH = 'region_data/region_images'


seed(SEED)
for idx, sf in tqdm(annotation_data.iterrows()):
    if not sf['annotation'] or idx < 1195:
        continue

    left_image_f, bucket, left_image_key = s3_access_utils.download_from_url(sf["left_crop_url"])
    image_f = left_image_f 

    iw = sf['left_crop_metadata']['width']
    ih = sf['left_crop_metadata']['height']
    
    
    alpha, beta = 2, 15 # Contrast(1.0-3.0), Brightness(0-100)
    img = Image.open(left_image_f)
    img = np.asarray(img)
    img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    
    # get file name
    start, end = "left_frame_crop_", ".jpg"
    s = sf["left_crop_url"]
    file_name = "left_" + s[s.find(start)+ len(start):s.find(end)]   

    
    # save image
    data_prep.write_image(img, file_name, IMAGE_PATH)
    
    # randomly smart crops around lice
    crops = data_prep.generate_crops_smart(sf["annotation"], [iw, ih], [CROP_WIDTH, CROP_HEIGHT])

    
    
    
    
    labels = []
    
    for crop in crops:
        crop_left, crop_top = crop

        xywh = [crop_left, crop_top, CROP_WIDTH, CROP_HEIGHT]
        bbox = utils.xywh2yolobbox(xywh, [iw, ih])
            
        labels.append([0] + bbox)           
            
    data_prep.write_labels(labels, file_name, LABEL_PATH)
          
    if len(crops) > 2:    
        print("num of crops {}".format(len(crops)))
    


In [None]:
idx
