# Dataset inspection

We want to be able to look at our labeling to make sure it is correct. 

Nothing like training a network with erronous labels to get bad tracking!


In [1]:
from unetTracker.trackingProject import TrackingProject
from unetTracker.multiClassUNetDataset import MultiClassUNetDataset

In [2]:
project = TrackingProject(name="faceTrack",root_folder = "/home/kevin/Documents/trackingProjects/")
dataset = MultiClassUNetDataset(image_dir=project.image_dir, mask_dir=project.mask_dir, coordinates_dir=project.coordinates_dir)

Project directory: /home/kevin/Documents/trackingProjects/faceTrack
Loading /home/kevin/Documents/trackingProjects/faceTrack/config.yalm
{'augmentation_HorizontalFlipProb': 0.0, 'augmentation_RandomBrightnessContrastProb': 0.2, 'augmentation_RandomSizedCropProb': 1.0, 'augmentation_RotateProb': 0.3, 'image_size': [480, 640], 'labeling_ImageEnlargeFactor': 2.0, 'name': 'faceTrack', 'object_colors': [(0.0, 0.0, 255.0), (255.0, 0.0, 0.0), (255.0, 255.0, 0.0), (128.0, 0.0, 128.0)], 'objects': ['nose', 'chin', 'rEye', 'lEye'], 'target_radius': 10}


In [17]:
import ipywidgets as widgets
import threading
from ipywidgets import Label, HTML, HBox, Image, VBox, Box, HBox
from ipyevents import Event 
from IPython.display import display
from unetTracker.camera import bgr8_to_jpeg
import cv2
import numpy as np
import glob
import os
import torch
import ntpath
    
################################
####### work with images #######
################################
class ReviewDatasetGUI():
    """
    Class to label frames from a camera feed.
    """
    def __init__(self,project,dataset):

        self.project = project
        self.dataset = dataset
        
        self.imgWidget = Image(format='jpeg',height=project.image_size[0], width=project.image_size[1])
        self.htmlWidget = HTML('Event info')
        self.frameNameWidget = HTML('Frame name')
        
        
        self.previousButton = widgets.Button(description='Previous frame',
                            disabled=False,
                            button_style='', # 'success', 'info', 'warning', 'danger' or ''
                            tooltip='Click me',
                            icon='check') # (FontAwesome names without the `fa-` prefix)
        
        self.nextButton = widgets.Button(description='Next frame',
                            disabled=False,
                            button_style='', # 'success', 'info', 'warning', 'danger' or ''
                            tooltip='Click me',
                            icon='check') # (FontAwesome names without the `fa-` prefix)
        
        
        self.deleteButton = widgets.Button(description='Delete frame',
                            disabled=False,
                            button_style='', # 'success', 'info', 'warning', 'danger' or ''
                            tooltip='Click me',
                            icon='check') # (FontAwesome names without the `fa-` prefix)
        
        
        self.previousEvent = Event(source=self.previousButton, watched_events=['click'])
        self.previousEvent.on_dom_event(self.previous_handle_event)
        
        self.nextEvent = Event(source=self.nextButton, watched_events=['click'])
        self.nextEvent.on_dom_event(self.next_handle_event)
        
        
        self.deleteEvent = Event(source=self.deleteButton, watched_events=['click'])
        self.deleteEvent.on_dom_event(self.delete_handle_event)
        
        
        
        self.imageIndex = 0
        frame = self.get_labelled_image(self.imageIndex)
        self.imgWidget.value = bgr8_to_jpeg(frame)
        
        fn = self.dataset.get_image_file_name(self.imageIndex)
        lines = ntpath.basename(fn) 
        content = "  ".join(lines)
        self.frameNameWidget.value = content   
        
        
        
        lines = "{} / {}".format(self.imageIndex,len(dataset))
        content = "  ".join(lines)
        self.htmlWidget.value = content   
            
        display(VBox([HBox([self.htmlWidget, self.previousButton, self.nextButton, self.deleteButton]),
                      self.frameNameWidget,
                      self.imgWidget]))
    
    def delete_handle_event(self,event):
        self.dataset.delete_entry(self.imageIndex)
        
        frame = self.get_labelled_image(self.imageIndex)
        fn = self.dataset.get_image_file_name(self.imageIndex)
        self.imgWidget.value = bgr8_to_jpeg(frame)
        
        fn = self.dataset.get_image_file_name(self.imageIndex)
        lines = ntpath.basename(fn) 
        content = "  ".join(lines)
        self.frameNameWidget.value = content   
        

    def previous_handle_event(self,event):
        self.imageIndex-=1
        if self.imageIndex < 0:
            self.imageIndex =  len(self.dataset)-1
       
        fn = self.dataset.get_image_file_name(self.imageIndex)
        lines = ntpath.basename(fn) 
        content = "  ".join(lines)
        self.frameNameWidget.value = content   
    
        lines = "{} / {}".format(self.imageIndex,len(self.dataset))
        content = "  ".join(lines)
        self.htmlWidget.value = content   
        
        frame = self.get_labelled_image(self.imageIndex)
        self.imgWidget.value = bgr8_to_jpeg(frame)    
       
        
    def next_handle_event(self,event):
        self.imageIndex+=1
        if self.imageIndex >= len(self.dataset):
            self.imageIndex = 0
        
        fn = self.dataset.get_image_file_name(self.imageIndex)
        lines = ntpath.basename(fn) 
        content = "  ".join(lines)
        self.frameNameWidget.value = content   
       
        lines = "{} / {}".format(self.imageIndex,len(self.dataset))
        content = "  ".join(lines)
        self.htmlWidget.value = content   
    
    
        frame = self.get_labelled_image(self.imageIndex)
        
       
        
        self.imgWidget.value = bgr8_to_jpeg(frame)
        
    
    def get_labelled_image(self,index):
        frame,mask,coord = self.dataset[self.imageIndex]
        frame = frame.permute(1,2,0).numpy()*255
        
        for i,myObject in enumerate(self.project.object_list):
            x=int(coord[i][0])
            y=int(coord[i][1])
            if(x!=0 and y!=0):
                cv2.circle(frame,(x,y), self.project.target_radius, self.project.object_colors[i], -1)
        return frame
        
        

In [18]:
gui = ReviewDatasetGUI(project,dataset)

VBox(children=(HBox(children=(HTML(value='0     /     4  8  8'), Button(description='Previous frame', icon='châ€¦

In [19]:
len(dataset)

461