In [None]:
from research.weight_estimation.akpd_utils.akpd import AKPD
from research.utils.data_access_utils import S3AccessUtils, RDSAccessUtils
import json
import os
import  pandas as pd
import numpy as np
import matplotlib.patches as patches
from tqdm import tqdm
from utils import utils, data_prep, sector
import cv2
from PIL import Image
import matplotlib.pyplot as plt
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_colwidth', 500)

In [None]:
import importlib
importlib.reload(sector)

In [None]:
s3_access_utils = S3AccessUtils('/root/data',
                               json.load(open(os.environ['AWS_CREDENTIALS'])))

In [None]:
LICE_BBOX_COLOR = ['b', 'r'] # bbox edge color
LICE_CATEGORY = ['ADULT_FEMALE', 'MOVING']

# load annotation data

annotation_data_akpd  = pd.read_pickle("annotation_data_akpd_2020_05_27.pkl")


In [None]:
pen_ids = ('56', '60', '37', '66', '85', '86')

In [None]:
annotation_data_akpd.shape

In [None]:
annotation_data_akpd.head(1)

# get augmented annotation_data_akpd and lice_data


In [None]:
aux_points =  [sector.DORSAL_BACK, 
              sector.VENTRAL_BACK, 
              sector.DORSAL_MID, 
              sector.VENTRAL_MID, 
              sector.DORSAL_FRONT, 
              sector.VENTRAL_FRONT,
              sector.HEAD]

In [None]:
for ap in aux_points:
    annotation_data_akpd[ap] = 0

In [None]:
 annotation_data_akpd["date"] = annotation_data_akpd.captured_at.apply(lambda x: x.strftime("%m-%d-%Y"))

In [None]:
annotation_data_akpd.head(5)

In [None]:
lice_data = pd.DataFrame()


for idx, sf in tqdm(annotation_data_akpd.iterrows()):
#     if idx > 10:
#         break
    
    if sf['annotation']:
        for lice in sf['annotation']:
            if lice['category'] not in ['ADULT_FEMALE', 'MOVING']:
                continue            
            #lp = lice['position'] #lice position
            #x, y, w, h = lp["left"], lp["top"], lp["width"], lp["height"]
            
            x, y, w, h = lice["xCrop"], lice["yCrop"], lice["width"], lice["height"]
            fish_sector = sector.get_sector(np.array([x + w / 2, y + h /2]), sf["kps"])
   
            annotation_data_akpd.at[idx, fish_sector] += 1

            lice_data = lice_data.append({
                            "group_id": sf.group_id,
                            "category": lice['category'],
                            "fish_image_url": sf['url_key'],
                            "location": lice['liceLocation'],
                            "left": x,
                            "top": y,
                            "width": w, 
                            "height": h,
                            "fish_sector": fish_sector,
                            "fish_image_width": sf['crop_metadata']['width'],
                            "fish_image_height": sf['crop_metadata']['height'],
                            "left_kps": sf["kps"],
                            "captured_at": sf.captured_at
                           }, ignore_index=True)

In [None]:
lice_data.shape

In [None]:
num_pic = 0
max_num_pic = 15
for idx, lice in tqdm(lice_data.iloc[10:30].iterrows()):

    num_pic += 1
    if num_pic > max_num_pic:
        break
    image_f, bucket, image_key = s3_access_utils.download_from_url(lice.fish_image_url)
       
    fig, ax = plt.subplots(figsize=(10, 10))


    img = Image.open(image_f)

    alpha = 3 # Contrast control (1.0-3.0)
    beta = 20 # Brightness control (0-100)

    img = np.asarray(img)
    adjusted = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    ax.imshow(adjusted)
               

    class_index = LICE_CATEGORY.index(lice.category)
    ec = LICE_BBOX_COLOR[class_index]

        
    rect = patches.Rectangle((lice.left, lice.top), lice.width, lice.height,linewidth=1,edgecolor=ec,facecolor='none')    
    ax.add_patch(rect) 
    ax.annotate(lice.fish_sector, (lice.left, lice.top), color = ec)

    for kp in lice.left_kps:
        x, y = kp['xCrop'], kp['yCrop']
        bp = kp['keypointType']
        ax.scatter(x, y, c='orange', marker='o')
        ax.annotate(bp, (x, y), color='orange')
    plt.show()


In [None]:
lice_data.groupby('fish_sector').location.value_counts().unstack(0).plot(kind='barh', 
                                                               legend=True,
                                                               title = "Bar Chart of Closest Key Point by Location")


In [None]:
lice_data.groupby('fish_sector').location.value_counts().unstack(0).plot(kind='barh', 
                                                               legend=True,
                                                               title = "Bar Chart of Closest Key Point by Location")


In [None]:

lice_data.groupby('fish_sector').category.value_counts().unstack(0).plot(kind='barh', 
                                                               legend=True,
                                                               title = "Bar Chart of Closest Key Point by category")


In [None]:
lice_data.groupby("category").fish_sector.value_counts().unstack(0).plot(kind='barh', 
                                                               legend=True,
                                                               color=LICE_BBOX_COLOR,
                                                               title = "Bar Chart by sector")


In [None]:
fig, axes = plt.subplots(3, 2, figsize = (12, 15))

for idx in range(len(pen_ids)):

    pen_id = pen_ids[idx]
    dat = lice_data.loc[lice_data['group_id'] == pen_id]
    
    dat.groupby("category").fish_sector.value_counts().unstack(0).plot(kind='barh', 
                                                                       legend=True, 
                                                                       ax = axes[idx//2, idx%2],
                                                                       color=LICE_BBOX_COLOR,
                                                                       title = "Bar Chart of lice count by sector, pen_id {}".format(pen_id))
    
    
    
    
    
plt.tight_layout()
    
    
    
    

In [None]:
plt.scatter(annotation_data_akpd, x="captured_at", y="left", color="r")

In [None]:

from datetime import datetime



In [None]:

annotation_data_akpd.apply(lambda row: len(row.annotation) == 
                           np.sum([annotation_data_akpd.DORSAL_BACK, annotation_data_akpd.VENTRAL_BACK], axis = 0)

In [None]:
# 
num_image = 0
for idx, sf in tqdm(annotation_data_akpd.loc[annotation_data_akpd['group_id'] == "56"].iterrows()):
#for idx, sf in tqdm(annotation_data_akpd.iterrows()):

    if not sf.annotation or not sector.face_left(sf.kps):
        continue
    num_image += 1
    if num_image > 15: break
    left_image_f, bucket, left_image_key = s3_access_utils.download_from_url(sf["url_key"])
    image_f = left_image_f 

        
    fig, ax = plt.subplots(figsize=(10, 10))
    
    
    img = Image.open(image_f)

    alpha = 3 # Contrast control (1.0-3.0)
    beta = 20 # Brightness control (0-100)

    img = np.asarray(img)
    adjusted = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    ax.imshow(adjusted)


    for kp in sf.kps:
        k1, k2 = kp['xCrop'], kp['yCrop']
        bp = kp['keypointType']
        ax.scatter(k1, k2, c='orange', marker='o')
        ax.annotate(bp, (k1, k2), color='orange')
    kps = sf.kps
    eye = sector.get_kp_location(kps, "EYE")
    tn = sector.get_kp_location(kps, "TAIL_NOTCH")
    ad_fin = sector.get_kp_location(kps, "ADIPOSE_FIN")
    an_fin = sector.get_kp_location(kps, "ANAL_FIN")
    ds_fin = sector.get_kp_location(kps, "DORSAL_FIN")
    pv_fin = sector.get_kp_location(kps, "PELVIC_FIN")
    pt_fin = sector.get_kp_location(kps, "PECTORAL_FIN")

    aux_kps = sector.get_auxiliary_kps(kps)
    
    ad_an_mid = aux_kps["ad_an_mid"]
    ds_pv_mid = aux_kps["ds_pv_mid"]
    h1 = aux_kps["h1"]
    h0 = aux_kps["h0"]
    h_mid = aux_kps["h_mid"]
    
    pv_back = aux_kps["pv_back"]
    ds_back = aux_kps["ds_back"]
    
    plt.plot([ad_an_mid[0], tn[0]],[ad_an_mid[1], tn[1]],'k-')
    plt.plot([ad_an_mid[0], ds_pv_mid[0]],[ad_an_mid[1], ds_pv_mid[1]],'k-')
    plt.plot([ds_pv_mid[0], h_mid[0]],[ds_pv_mid[1], h_mid[1]],'k-')
    
    plt.plot([ad_fin[0], an_fin[0]],[ad_fin[1], an_fin[1]],'k-')    
    #plt.plot([ds_fin[0], pv_fin[0]],[ds_fin[1], pv_fin[1]],'k-')
    plt.plot([pv_back[0], ds_back[0]], [pv_back[1], ds_back[1]],'k-')
    plt.plot([h0[0], h1[0]], [h0[1], h1[1]],'k-')

#     for lice in sf['annotation']:
        
#         x, y, w, h = lice["xCrop"], lice["yCrop"], lice["width"], lice["height"]
#         class_index = LICE_CATEGORY.index(lice['category'])
#         ec = LICE_BBOX_COLOR[class_index]
#         rect = patches.Rectangle((x, y), w, h,linewidth=1,edgecolor=ec,facecolor='none') 
#         ax.add_patch(rect)
        
#         sector_label = sector.get_sector(np.array([x, y]), sf.kps)
#         ax.annotate(sector_label, (x, y), color = ec)
    plt.show()

In [None]:
annotation_data_akpd.head(1)

In [None]:
annotation_data_akpd.groupby(["group_id",'date']).count().loc[:, ['url_key', 'annotation'] ].head(10)

In [None]:
for pen_id in pen_ids:
    annotation_data_akpd[annotation_data_akpd.group_id == pen_id].groupby(['date']).count()[['url_key', 'annotation']] .plot(figsize = (8, 5),
                                                                                                               legend = True,
                                                                                                title = "number of accepted images, pen_id = {}".format(pen_id))

In [None]:
fig, axes = plt.subplots(6, 2, figsize = (12, 36))

for idx in range(len(pen_ids)):

    pen_id = pen_ids[idx]
    dat = annotation_data_akpd[annotation_data_akpd.group_id == pen_id].groupby(['date'])
    
    
    dat.mean().plot(ax = axes[idx, 0], 
                    title = "trend of avg number lice by sector, pen_id = {}".format(pen_id))
    axes[idx, 0].set(ylabel='avg lice count')
                                                                                                         
                                                                                                             
    dat.count()[['url_key', 'annotation']].plot(ax = axes[idx, 1],
                                                legend = True,
                                                color = ['black', 'gray'],
                                                title = "fish count, pen_id = {}".format(pen_id))
    axes[idx, 1].set(ylabel='fish count')
    axes[idx, 1].legend(['fish (QA accept)', 'fish with lice'])
plt.tight_layout()

                                                                                                              
                                                                                                              
                                                                                                              