In [None]:
import os
import glob
from location_test_gui_functions import TestLocationsGUI
import cv2

In [None]:
import cv2
import numpy as np
import pandas as pd
import random
import matplotlib.cm as cm



def rescale_image(image, max_size):
    """ Rescale image so that longest dimension equals max size.
    
    Args:
        image: 3d numpy array
        max_size: number longest output size of image
    """
    
    im_size = np.array(image.shape)
    longest_side = np.max(im_size)
    scale = max_size / longest_side
    new_im_shape = (im_size * scale).astype(int)
    new_image = cv2.resize(image, 
                          (new_im_shape[1], new_im_shape[0]),
                          interpolation=cv2.INTER_NEAREST
                          )
    
    return new_image, scale


class TestLocationsGUI():
    
    #constructor
    def __init__(self, image_files, csv_files, max_size):
        """ 
        Args:
            image_files: list of bat images 
            csv_files: list of csvs with frame's bat info
            max_size: largest dimension in pixels of output display
        """
        assert len(image_files) == len(csv_files)
        self.total_images = len(image_files)
        
        self.test_num = 0
        self.focal_tip = 1
        
        self.image_files = image_files
        self.csv_files = csv_files
        
        self.max_size = max_size
        self.show_points = True
        
        self.load_new_test()
        
    def load_new_test(self):
        raw_image = cv2.imread(self.image_files[self.test_num]) 
        self.raw_image, self.image_rescale = rescale_image(raw_image, self.max_size)
        
        self.info = pd.read_csv(self.csv_files[self.test_num])
        
    def refresh_windows(self):
        self.image = np.copy(self.raw_image)
        
    def save_validation(self):
        """ Overwrite test's wing positions file with current values."""
        
        self.info.to_csv(self.csv_files[self.test_num], index=False)
        
    def change_frame(self, amount):
        """Change test frame forward or backward by amount.
        0 is minimum frame ind and number of frames is max (no periodic boundaries)
        
        Args:
            amount: number of frames to move positive or negative"""
        
        self.save_validation()
        
        self.test_num += amount
        self.test_num = np.max([0, self.test_num])
        self.test_num = np.min([self.total_images-1, self.test_num])
        self.load_new_test()
        
    def change_focal_wingtip(self):
        """Change focal wingtip to the other."""
        
        self.save_validation()
        
        if self.focal_tip == 1:
            self.focal_tip = 2
        elif self.focal_tip == 2:
            self.focal_tip = 1       
        
    def draw_wingtip_positions(self):
        """ Draw the location of all wingtip points."""
        
        def _draw_wingtip(wingtip, color, radius, thickness, image_rescale):
            if not np.any(np.isnan(wingtip)):
                if not np.any(wingtip==-1):
                    cv2.circle(self.image, 
                               (int(wingtip[0]*image_rescale), 
                                int(wingtip[1]*image_rescale)), 
                               radius=radius, 
                               color=color, thickness=thickness)
            
        point_thickness = -1
        circle_thickness = 1
        point_radius = int(.006*self.max_size)
        circle_radius = int(.018*self.max_size)
        color = [255, 255, 255]
        if self.info.loc[0, 'hard']:
            color = [0, 0, 255]
        
        if self.show_points:
            wingtip1 = np.array(
                [self.info.loc[0, 'wingtip1_x'], self.info.loc[0, 'wingtip1_y']]
            )
            _draw_wingtip(wingtip1, color, point_radius, point_thickness, self.image_rescale)
            if self.focal_tip == 1:
                _draw_wingtip(wingtip1, color, circle_radius, circle_thickness, self.image_rescale)
            wingtip2 = np.array(
                [self.info.loc[0, 'wingtip2_x'], self.info.loc[0, 'wingtip2_y']]
            )
            _draw_wingtip(wingtip2, color, point_radius, point_thickness, self.image_rescale)
            if self.focal_tip == 2:
                _draw_wingtip(wingtip2, color, circle_radius, circle_thickness, self.image_rescale)
                
    def add_frame_info(self):
        font = cv2.FONT_HERSHEY_SIMPLEX
        bottomLeftCornerOfText = (20, 40)
        fontScale = 1
        fontColor = (255,255,255)
        lineType = 2

        cv2.putText(self.image,
                    f'{self.test_num} / {self.total_images}', 
                    bottomLeftCornerOfText, 
                    font, 
                    fontScale,
                    fontColor,
                    lineType
                   )
                    
        
    def change_wingtip_locations(self, x, y):
        """Record x and y position of click in bat image assosiated with wing focal ind.
        Args:
            x, y: from mouse click
        """
        corrected_x = x / self.image_rescale
        corrected_y = y / self.image_rescale

        if self.focal_tip == 1 or self.focal_tip == 2:
            self.info.loc[0, f'wingtip{self.focal_tip}_x'] = corrected_x
            self.info.loc[0, f'wingtip{self.focal_tip}_y'] = corrected_y

        self.change_focal_wingtip()
        
        
        
    def show_windows(self):
        cv2.imshow('bat-image', self.image)
        
    def not_present(self):
        # Can't see wings to mark
        if self.focal_tip == 1 or self.focal_tip == 2:
            self.info.loc[0, f'wingtip{self.focal_tip}_x'] = -2
            self.info.loc[0, f'wingtip{self.focal_tip}_y'] = -2
        self.change_frame(1)
        
    def mark_hard(self):
        # Can't see wings to mark
        if self.info.loc[0, 'hard'] != True:
            self.info.loc[0, 'hard'] = True
        else:
            self.info.loc[0, 'hard'] = False
        
    def react_to_keypress(self, key):
        """ Process key press
        Args:
            key: return from cv2 cv2.waitkey
        """
        
        if key == ord('l'):
            self.change_frame(1)
        elif key == ord('j'):
            self.change_frame(-1)
        elif key == ord('i'):
            self.change_focal_wingtip()
        elif key == ord('k'):
            self.mark_hard()
        elif key == ord('p'):
            self.toggle_show_points()
        elif key == ord('n'):
            self.not_present()
            
    def clicked(self, event, x, y, flags, param):
        
        if event == cv2.EVENT_LBUTTONDOWN:
            self.change_wingtip_locations(x, y)
            
    def toggle_show_points(self):
        self.show_points = not self.show_points

In [None]:
annotater_name = 'koger'
root_folder = '.../bats-data/wing-validation'

In [None]:
images_folder = os.path.join(root_folder, 'validation-images')
csvs_folder = os.path.join(root_folder, 'validation-csvs')

In [None]:
image_files = sorted(glob.glob(os.path.join(images_folder, '*.png')))
csv_files = sorted(glob.glob(os.path.join(csvs_folder,'*.csv')))
assert len(image_files) == len(csv_files), f"{len(image_files)} {len(csv_files)} every image should have coresponding .csv file"

In [None]:
max_size = 600

gui = TestLocationsGUI(image_files, csv_files, max_size)

In [None]:
cv2.namedWindow('bat-image')
cv2.setMouseCallback('bat-image', gui.clicked)

while True:
    gui.refresh_windows()
    gui.draw_wingtip_positions()
    gui.add_frame_info()
    gui.show_windows()
    key = cv2.waitKey(500) & 255
    gui.react_to_keypress(key)
    if key == ord('q'):
        gui.save_validation()
        print('quitting')
        break
cv2.destroyAllWindows()
cv2.waitKey(1)

In [None]:
x = pd.read_csv(csv_files[0])

In [None]:
x.loc[0, 'hard']