## Create heatmap from ask and using image warping

1. load the masks created by the segmenter
2. flipped the L/R ones
3. warped the S out of them
4. warp the lice locations
5. Do some nice plots

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import json
import os
from datetime import datetime

import cv2
import matplotlib.pyplot as plt
import numpy as np
from dipy.viz import regtools
from keras.models import load_model
from PIL import Image, ImageDraw
from sqlalchemy import create_engine
from sqlalchemy import MetaData
from sqlalchemy import Table, select, func, and_, insert, delete, update, or_
from tqdm import tqdm

from unet import jaccard_coef_loss, jaccard_coef_int
from warping_utils import translate_moving, register, bbox_mask


In [None]:
sql_credentials = json.load(open("/root/thomas/sqlcredentials.json"))

sql_engine = create_engine(
    "postgresql://{}:{}@{}:{}/{}".format(sql_credentials["user"], sql_credentials["password"],
                                         sql_credentials["host"], sql_credentials["port"],
                                         sql_credentials["database"]))

1. Load the masks and get the directions

In [None]:
crops = json.load(open('/root/thomas/single_fish.json'))

In [None]:
# get the s3 keys
keys = []
for c in crops:
    name = c['External ID']
    farm, pen, ts = name.split('_')[1:4]
    # print(name)
    date = datetime.utcfromtimestamp(int(ts)/1000.0).date()
    key = 'dev/{}/{}/{}/{}'.format(farm, pen, date, name)
    keys.append(key)

In [None]:
import os 
import shutil

In [None]:
# for folder in tqdm(glob.glob('/root/data/lice-data/crops/blom-kjeppevikholmen/2/*')):
#     penfolder = os.path.join(folder, '2')
#     for file in glob.glob(penfolder + '/*'):
# #         print(file)
#         new_path = '/'.join(file.split('/')[0:8] + file.split('/')[9:])
# #         print(new_path)
#         shutil.copy(file, new_path)
#     shutil.rmtree(penfolder)

In [None]:
# query the directions
metadata = MetaData()
fd = Table('lati_fish_detections', metadata, autoload=True, autoload_with=sql_engine)
fdla = Table('lati_fish_detections_lice_annotations_reconciled', metadata, autoload=True, autoload_with=sql_engine)

query = select([fd.c.image_key, 
                fdla.c.direction, 
                fdla.c.lice_bbox_list, 
                fd.c.image_width_px, 
                fd.c.image_height_px]) \
    .select_from(fdla.join(fd, fdla.c.lati_fish_detections_id == fd.c.id)) \
    .where(fd.c.image_key.in_(keys))

In [None]:
# get the fish direction + lice info
connection = sql_engine.connect()
q = connection.execute(query)
infodic = {}
for result in q:
    key = result[0]
    infodic[os.path.basename(key)] = {'local_path': key.replace('dev', '/root/data/lice-data/crops'),
                                      'direction': result[1], 
                                      'lice': result[2],
                                      'width': result[3],
                                      'height': result[4]}

2. main loop

In [None]:
input_shape = (768, 256)

In [None]:
def create_mask(annotation):
    name = annotation['External ID']
    width = infodic[name]['width']
    height = infodic[name]['height']
    
    label = annotation['Label']['Salmon'][0]['geometry']
    polygon = [(k['x'], k['y']) for k in label]
    
    img = Image.new('L', (width, height), 0)
    ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
    mask = np.array(img)
    return mask

In [None]:
# example = infodic['left_blom-kjeppevikholmen_2_1543838255793_203_0_4054_1132.jpg']
# image = cv2.imread(example['local_path'])
# lices = example['lice']
# heatmap = np.zeros_like(image)
# f, ax = plt.subplots(1, figsize=(20, 10))
# ax.imshow(image)
# for l in lices:
#     print(l)
#     position = l['position']
#     x1 = position['top']
#     y1 = position['left']
#     width = position['width']
#     height=position['height']
#     heatmap[x1:x1+height, y1:y1+width, :] = 1
#     rect = Rectangle((y1, x1) ,width, height,
#                      linewidth=2, edgecolor='r', facecolor='none')

#     ax.add_patch(rect)
# plt.show()

In [None]:
lice_maps = []
warped_masks = []

for (i, crop) in tqdm(enumerate(crops)):
    print(i)
    name = crop['External ID']
    direction = infodic[name]['direction']
    if direction is None:
        continue
    try:
        mask = create_mask(crop)
    except: 
        print('Mask does not exist')
        continue
    # load image
    # image = cv2.imread(infodic[name]['local_path'])
    if direction == 'RIGHT':
        # image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)
        
    # resize mask
    lice_map = bbox_mask(mask, infodic[name]['lice'])
    mask = cv2.resize(mask, input_shape)
    lice_map = cv2.resize(lice_map, input_shape)

#     plt.imshow(mask)
#     plt.show()
    if i == 0:
        static_mask = mask
        lice_maps.append(lice_map)
    else:
        moving_mask = mask
        # no need for translation
        # moving_translated, translation_vector = translate_moving(static_mask, moving_mask)
        # registration
        mapping = register(static_mask, moving_mask)
        # forward
        warped_moving = mapping.transform(moving_mask, 'linear')
        warped_lice_map = mapping.transform(lice_map, 'linear')
        lice_maps.append(warped_lice_map)
        warped_masks.append(warped_moving)

In [None]:
heatmap = np.stack(lice_maps, -1)
mean_heatmap = np.mean(heatmap, -1)

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(mean_heatmap)
plt.imshow(static_mask, alpha=0.5)
plt.axis('off')
plt.show()