In [None]:
'''
Gets candidate points
'''

import cv2
from PIL import Image, ImageDraw, ImageFont
import os
import numpy as np

import matplotlib.pyplot as plt

import lib.utils as utils
import lib.features as features
import lib.detections as detections

font_file = '/root/bryton/aquabyte_sealice/Helvetica-Regular.ttf'
font = ImageFont.truetype(font_file, 20)

In [None]:
mask_thresh = 100
bbox_halfwidth = 24
area_min = 100
area_max = 200
ecc_min = 0.7
ecc_max = 0.99

base_directory = '/root/bryton/aquabyte_sealice'

annotations_file = '%s/annotations.csv' % (base_directory, )
#svm_model_filepath = '%s/models/sealice_detection_ORB_SVM_model.yml' % (base_directory, )
svm_model_filepath = '%s/models/sealice_detection_ORB_SVM_model_20180506-162646.yml' % (base_directory, )

svm_pipeline_output_directory = '/root/bryton/aquabyte_sealice/svm_pipeline_output'

try: 
    os.makedirs(svm_pipeline_output_directory)
except OSError:
    if not os.path.isdir(svm_pipeline_output_directory):
        raise

# load the saved SVM model
svm_model = cv2.ml.SVM_load(svm_model_filepath)

In [None]:
# create a ORB instance
orb_descriptor = cv2.ORB_create()

annotations = utils.get_lice_annotations_from_file(annotations_file)[0:20]

f, ax = plt.subplots(10, 1, figsize=(50, 100))

processed_index = -1

# for each frame detect keypoints
for annotation_index, annotation in enumerate(annotations):
    if annotation_index % 10 == 0:
        print 'Processing annotation %i of %i' % (annotation_index, len(annotations))

    image_filename, x1, y1, x2, y2, label = annotation

    split_name = image_filename.split('/')
        
    image = Image.open(image_filename)
    frame = np.array(image)

    draw = ImageDraw.Draw(image)
        
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    ret, thresh_mask = cv2.threshold(gray_frame, mask_thresh, 255, cv2.THRESH_BINARY)

    candidate_kps = features.extract_sealice_candidate_kps(thresh_mask, bbox_halfwidth, area_min, area_max, ecc_min, ecc_max)

    #print candidate_kps

    kps, orb_descriptors = orb_descriptor.compute(frame, candidate_kps)
    orb_descriptors = np.asfarray(orb_descriptors, dtype = 'float32')

    if len(orb_descriptors) > 1:
        processed_index = processed_index + 1

        prediction = svm_model.predict(orb_descriptors, True, 1)
        #print prediction
        predicted_labels = prediction[1]
        
        raw_lice_detections = []

        for i, predicted_label in enumerate(predicted_labels):
            x1 = np.int(kps[i].pt[0] - 28)
            y1 = np.int(kps[i].pt[1] - 28)
            x2 = np.int(kps[i].pt[0] + 28)
            y2 = np.int(kps[i].pt[1] + 28)
            
            top_left_point_elevated = (x1, y1 - 25) 
            top_left_point = (x1, y1) 
            bottom_right_point = (x2, y2)
            
            confidence = 1.0 / (1.0 + np.exp(- predicted_label[0]));

            if confidence > 0.6:
                confidence_text = '%0.2f%%' % (confidence * 100, )
                cv2.putText(frame, confidence_text, top_left_point, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255,255,255), 2, cv2.LINE_AA)
                cv2.rectangle(frame, top_left_point, bottom_right_point, (0, 255, 0), 3)
                
                draw.text(top_left_point_elevated, confidence_text, (255,255,0), font = font)
                draw.rectangle((top_left_point, bottom_right_point), outline = 'green')
                
                output_file = '%s/%s.jpg' % (svm_pipeline_output_directory, split_name[6].split('.')[0])
    
                image.save(output_file)
        
                raw_lice_detections.append({ 'x1': x1, 'y1': y2, 'x2': x2, 'y2': y2, 'confidence': confidence })
        
        output = detections.create_fish_detection(None, raw_lice_detections)
    
        output_json = {
            'fish_detection': output['fish_detection'],
            'lice_detections': output['lice_detections']
        }
        
        # Alok / Thomas - this is what you can use
        #print output_json

    if processed_index < 10:
        ax[processed_index].imshow(frame)

print 'Wait for the images...'