In [1]:
%pylab inline
%gui qt

Populating the interactive namespace from numpy and matplotlib


In [2]:
from spiketag.utils import EventEmitter
from spiketag.base import CLU

In [3]:
class status_manager(EventEmitter):
    def __init__(self):
        super(status_manager, self).__init__()
        self.reporters = {}
    
    def append(self, state_reporter):
        self.reporters[state_reporter._id] = state_reporter
        
        @state_reporter.connect
        def on_report(state):
#             print(state+' from group '+ str(state_reporter._id))
              self.emit('update', state=self.state_list, nclu=self.nclu_list)
            
    def __getitem__(self, i):
        return self.reporters[i]
    
    def __setitem__(self, i, state_str):
        self.reporters[i].state = state_str
    
    def __repr__(self):
        _str = ''
        for _, _reporter in self.reporters.items():
            _str = _str + str(_reporter._id) + ":" + str(_reporter.state) + '\n'
        return _str
    
    def reset(self):
        for reports in self.reporters.values():
            reports._state = 'IDLE'
        self.emit('update', state=self.state_list, nclu=self.nclu_list)
    
    @property
    def state_list(self):
        return [troy.s.index(troy.state) for troy in self.reporters.values()]
    
    @property
    def ngroup(self):
        return len(self.reporters.keys())
    
    @property
    def nclu_list(self):
        return [troy.nclu for troy in self.reporters.values()]

In [4]:
from vispy import app, scene, visuals
from vispy.util import keys
import numpy as np


class cluster_view(scene.SceneCanvas):
    def __init__(self):
        scene.SceneCanvas.__init__(self, keys=None, title='clusters overview')
        self.unfreeze()
        self.view = self.central_widget.add_view()
        self.view.camera = 'panzoom'    
        # every group (grp_marker) has a clustering result each cluster will has a (clu_marker)
        # every group has its status: finish, unfinished
        # every clu has its status: spike counts, high quality, low quality, information bits
        self.grp_marker  = scene.visuals.Markers(parent=self.view.scene)
        self.nclu_text = scene.visuals.Text(parent=self.view.scene)
        self.event = EventEmitter() 


    def set_data(self, status_manager, size=25):
        '''
        group_No is a scala number #grp
        nclu_list is a list with length = group_No
        '''
        
        self.sman = status_manager
        
        self.group_No = self.sman.ngroup
        self.nclu_list = np.array(self.sman.nclu_list)
        self.sorting_status = np.array(self.sman.state_list)
        self.nspks_list = None
        self._size = size

        self.xmin = -0.02
        self.xmax =  0.04
        grp_x_pos = np.zeros((self.group_No,))
        grp_y_pos = np.arange(self.group_No)
        self.grp_pos = np.vstack((grp_x_pos, grp_y_pos)).T
        self.nclu_text_pos = np.vstack((grp_x_pos+0.02, grp_y_pos)).T

#         if selected_group_id is None and group_No>1:
#             selected_group_id = np.min(np.where(self.sorting_status==1)[0])
#             self.current_group = selected_group_id
#         elif group_No == 1:
#             self.current_group = 0
#             self._previous_group = 0
#             self._next_group = 0
#         else:
        
        self.current_group, self.selected_group_id, self._previous_group, self._next_group = [0]*4

        self.color = self.generate_color(self.sorting_status, self.nspks_list, self.selected_group_id) 

        self.grp_marker.set_data(self.grp_pos, symbol='square', face_color=self.color, size=size)
        self.nclu_text.text = [str(i) for i in self.nclu_list]
        self.nclu_text.pos  = self.nclu_text_pos
        self.nclu_text.color = 'g'
        self.nclu_text.font_size = size*0.50

        self.view.camera.set_range(x=[self.xmin, self.xmax])
        # self.view.camera.interactive = False


    def generate_color(self, sorting_status, nspks_list, selected_group_id):
        self.color = np.ones((self.group_No, 4)) * 0.5
        self.color[sorting_status==0] = np.array([1,1,1, .2]) # IDLE
        self.color[sorting_status==1] = np.array([0,1,1, .3]) # BUSY
        self.color[sorting_status==2] = np.array([1,0,1, .3]) # READY
        self.color[sorting_status==3] = np.array([1,1,0, .3]) # READY
        # self.color[selected_group_id] = np.array([1,1,1,  1]) # selected group id (current_group)
        self.color[selected_group_id, -1] = 1
        if nspks_list is not None:
            self.transparency = np.array(nspks_list)/np.array(nspks_list).max()
            self.color[:, -1] = self.transparency
        return self.color 


    def on_key_press(self, e):
        if e.text == 'r':
            self.view.camera.set_range(x=[self.xmin, self.xmax])
        if e.text == 'k':
            self.moveto(self.next_group)
        if e.text == 'j':
            self.moveto(self.previous_group)
        if e.text == 'd':
            self.set_cluster_done(self.current_group)
            self.moveto(self.next_group)
        if e.text == 'o':
            self.select(self.current_group)

    @property
    def cpu_ready_list(self):
        return np.where(self.sorting_status==1)[0]

    def set_cluster_ready(self, grp_id):
        self.sorting_status[grp_id] = 1
        self.refresh()

    def set_cluster_done(self, grp_id):
        self.sorting_status[grp_id] = 2
        self.refresh()

    def refresh(self):
        self.set_data(self.sman, self._size)


    @property
    def previous_group(self):
        if self.current_group>0:
            self._previous_group = self.current_group - 1
            return self._previous_group
        else:
            return self._previous_group 


    @property
    def next_group(self):
        if self.current_group<self.group_No-1:
            self._next_group = self.current_group + 1
            return self._next_group
        else:
            return self._next_group 


    def moveto(self, group_id):
        self.current_group = group_id
        self.set_data(self.sman, self._size) 


    def select(self, group_id):
        # if self.sorting_status[group_id] != 0:
        self.event.emit('select', group_id=self.current_group)
        # else:
            # print('unable to select busy cpu {}'.format(self.current_group)) 


    def run(self):
        self.show()
        self.app.run()

In [16]:
cluview = cluster_view()

In [10]:
clu_manager = status_manager()

In [11]:
for i in range(10):
    clu = CLU(np.arange(i+1))
    clu._id = i
    s.append(clu)

In [12]:
s.state_list

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [13]:
s.nclu_list

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [17]:
@s.connect
def on_update(state, nclu):
    print(state)
    print(nclu)
    cluview.set_data(s)

In [18]:
# disconnect if `update` function has something wrong
# s._callbacks.pop('update')

In [19]:
cluview.show()

In [23]:
s[8] = 'READY'
s[6] = 'DONE'

[0, 0, 0, 0, 0, 0, 0, 0, 2, 0]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[0, 0, 0, 0, 0, 0, 0, 0, 2, 0]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[0, 0, 0, 0, 0, 0, 3, 0, 2, 0]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[0, 0, 0, 0, 0, 0, 3, 0, 2, 0]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]


In [24]:
s.reset()

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
