In [10]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt5

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, CheckButtons, RadioButtons
from matplotlib.animation import FuncAnimation
from skimage import io
import pickle
import glob

In [14]:
states_to_classes = {
    'all_growing': 0,
    'partly_growing': 1,
    'stopped_growing': 2,
    'stopped_fading': 3,
    'stopped_vanishing': 4,
    'channel_empty': 5
}
class_to_states = {
    0: 'all_growing',
    1: 'partly_growing',
    2: 'stopped_growing',
    3: 'stopped_fading',
    4: 'stopped_vanishing',
    5: 'channel_empty'
}

### Data creator using matplotlib widgets

Mark the state transitions


In [90]:
class StateTransition(object):
    
    def __init__(self, phaseDir, fileformat='.tiff',
                 labels_file='transitions.pickle', frame_rate=1):
        
        self.phaseDir = phaseDir
        self.fileformat = fileformat
        self.labelsFile = labels_file
        self.frameRate = frame_rate
        self.indices = [int(filename.split('.')[0].split('/')[-1]) for filename in 
                       glob.glob(self.phaseDir + "*" + self.fileformat)]
        self.indices.sort()
        
        self.fig, self.ax = plt.subplots(1, 1, num=str(self.phaseDir), figsize=(12, 10))
        plt.subplots_adjust(left=0.35, bottom=0.25)
        self.axcolor = 'lightgoldenrodyellow'
        
        self.currentFrame = 0
        self.imagesPlot = self.ax.imshow(self.__getitem__(self.currentFrame), cmap='gray')
        self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
        
          
        self.previousax =  plt.axes([0.5, 0.025, 0.1, 0.03])
        self.previousButton = Button(self.previousax, 'Previous', color=self.axcolor, hovercolor='0.975')
        self.previousButton.on_clicked(self.previousDatapoint)
        
        self.nextax =  plt.axes([0.65, 0.025, 0.1, 0.03])
        self.nextButton = Button(self.nextax, 'Next', color=self.axcolor, hovercolor='0.975')
        self.nextButton.on_clicked(self.nextDatapoint)
        
        self.saveax = plt.axes([0.8, 0.025, 0.1, 0.03])
        self.saveButton = Button(self.saveax, 'Save', color=self.axcolor, hovercolor='0.875')
        self.saveButton.on_clicked(self.save)
        
        
        self.rax = plt.axes([0.05, 0.5, 0.20, 0.25], facecolor=self.axcolor)
        self.radioToState = {
            '0 : All Growing': 0,
            '1 : Partly Growing': 1,
            '2 : Stopped Growing': 2,
            '3 : Stopped Fading' : 3,
            '4 : Stopped Vanishing': 4,
            '5 : Channel Empty' : 5
        }
        self.radio = RadioButtons(self.rax, ('0 : All Growing', '1 : Partly Growing',
                                            '2 : Stopped Growing', '3 : Stopped Fading', 
                                            '4 : Stopped Vanishing', '5 : Channel Empty'))
        #self.radio.on_clicked(self.changeState)
        
        
        # the keys will be frame1_frame2 number and value is the state
        self.states = {
            
        }
        
        
        plt.pause(0.01)

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):

        if idx + self.frameRate < self.__len__() and idx >= 0:
            phaseFilename1 = self.phaseDir + str(self.indices[idx]) + self.fileformat
            phaseFilename2 = self.phaseDir + str(self.indices[idx+self.frameRate]) + self.fileformat
            img1 = io.imread(phaseFilename1)
            img2 = io.imread(phaseFilename2)
            
            full_img = np.concatenate((img1, img2), axis = 1)
            return full_img
        else:
            return None
    
    
    def nextDatapoint(self, buttonPress):
        key = str(self.currentFrame) + "_" + str(self.currentFrame + self.frameRate)
        #print(self.radio.value_selected)
        self.states[key] = self.radioToState[self.radio.value_selected]
        
        print(self.states)
        
        # get next Image if you are in range
        self.currentFrame += 1
        nextImage = self.__getitem__(self.currentFrame)
        #print(nextImage.shape)
        if nextImage is not None:
            self.imagesPlot.set_data(nextImage)
            self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
            self.fig.canvas.draw()
        else:
            self.currentFrame -= 1
            
    def previousDatapoint(self, buttonPress):
        key = str(self.currentFrame) + "_" + str(self.currentFrame + self.frameRate)
        #print(self.radio.value_selected)
        self.states[key] = self.radioToState[self.radio.value_selected]
        
        print(self.states)
        
        # get next Image if you are in range
        self.currentFrame -= 1
        previousImage = self.__getitem__(self.currentFrame)
        #print(nextImage.shape)
        if previousImage is not None:
            self.imagesPlot.set_data(previousImage)
            self.ax.set_title(str(self.currentFrame) + " _ " + str(self.currentFrame + self.frameRate))
            self.fig.canvas.draw()
        else:
            self.currentFrame += 1
    
    def save(self, save):
        print(self.states)
        saveFilename = self.phaseDir + self.labelsFile
        with open(saveFilename, 'wb') as f:
            pickle.dump(self.states, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"File save: {saveFilename}")

In [91]:
phaseDir = '/home/pk/Documents/trainingData/deadalive1/0/'
markStates = StateTransition(phaseDir, frame_rate=1)

{'0_1': 0}
{'0_1': 1}
{'0_1': 1}
{'0_1': 1}
{'0_1': 1}
{'0_1': 2}
{'0_1': 5}
{'0_1': 5}
{'0_1': 5, '1_2': 5}
{'0_1': 5, '1_2': 5, '2_3': 4}
{'0_1': 5, '1_2': 5, '2_3': 4, '3_4': 4}
{'0_1': 5, '1_2': 5, '2_3': 3, '3_4': 4}
{'0_1': 5, '1_2': 5, '2_3': 3, '3_4': 5}
{'0_1': 5, '1_2': 5, '2_3': 3, '3_4': 5, '4_5': 5}
{'0_1': 5, '1_2': 5, '2_3': 3, '3_4': 5, '4_5': 5}
{'0_1': 5, '1_2': 5, '2_3': 5, '3_4': 5, '4_5': 5}
{'0_1': 5, '1_2': 5, '2_3': 5, '3_4': 5, '4_5': 5}
{'0_1': 0, '1_2': 5, '2_3': 5, '3_4': 5, '4_5': 5}
{'0_1': 0, '1_2': 0, '2_3': 5, '3_4': 5, '4_5': 5}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 5, '4_5': 5}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 5}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 0}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 0, '6_7': 0}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 4, '6_7': 0}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 4, '6_7': 4}
{'0_

{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 4, '6_7': 4, '7_8': 4, '8_9': 4, '9_10': 4, '10_11': 4, '11_12': 4, '12_13': 4, '13_14': 4, '14_15': 4, '15_16': 4, '16_17': 4, '17_18': 4, '18_19': 4, '19_20': 4, '20_21': 4, '21_22': 4, '22_23': 4, '23_24': 4, '24_25': 4, '25_26': 4, '26_27': 4, '27_28': 4, '28_29': 4, '29_30': 4, '30_31': 4, '31_32': 4, '32_33': 4, '33_34': 4, '34_35': 4, '35_36': 4, '36_37': 4, '37_38': 4}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 4, '6_7': 4, '7_8': 4, '8_9': 4, '9_10': 4, '10_11': 4, '11_12': 4, '12_13': 4, '13_14': 4, '14_15': 4, '15_16': 4, '16_17': 4, '17_18': 4, '18_19': 4, '19_20': 4, '20_21': 4, '21_22': 4, '22_23': 4, '23_24': 4, '24_25': 4, '25_26': 4, '26_27': 4, '27_28': 4, '28_29': 4, '29_30': 4, '30_31': 4, '31_32': 4, '32_33': 4, '33_34': 4, '34_35': 4, '35_36': 4, '36_37': 4, '37_38': 4, '38_39': 4}
{'0_1': 0, '1_2': 0, '2_3': 0, '3_4': 0, '4_5': 0, '5_6': 4, '6_7': 4, '7_8': 4, '8_9': 4, '9_10': 4, '10_11': 4