In [None]:
import numpy as np
import flylib as flb
from matplotlib import pyplot as plt
import numpy as np
import scipy
from flylib import util

In [None]:
def get_transitions(state_mtrx_dict,muscle_key,pre_trig_idx = 10,post_trig_idx = 100):
    off_on_list = []
    on_off_list = []
    for flynum,tdict in state_mtrx_dict.items():
        for row in tdict[muscle_key]:
            off_on_idx = np.argwhere(np.diff(row[0]) == 1)
            on_off_idx = np.argwhere(np.diff(row[0]) == -1)
            
            for idx in off_on_idx:
                if (idx+post_trig_idx < len(row[0])) & (idx-pre_trig_idx > 0):
                    off_on_list.append((flynum,row[1][idx-pre_trig_idx:idx+post_trig_idx]))
            for idx in on_off_idx:
                if (idx+post_trig_idx < len(row[0])) & (idx-pre_trig_idx > 0):
                    on_off_list.append((flynum,row[1][idx-pre_trig_idx:idx+post_trig_idx]))
    return {'off_on':off_on_list,'on_off':on_off_list}            

In [None]:
def plot_trig_panel(ax_group,
                    trig_key,
                    direction = 'off_on',
                    ts = 0.02,
                    pretrig = 10,
                    posttrig = 50,
                    state_mtrx_dict = None,
                    sorted_keys = None,
                    flydict = None,
                    colors = None):
    from matplotlib import pyplot as plt
    time = np.arange(pretrig+posttrig)*ts
    idx_list = get_transitions(state_mtrx_dict,trig_key,pretrig,posttrig)
    signal_mtrxs = {}
    for key in sorted_keys:
        signal_mtrxs[key] = np.vstack([flydict[fnum].spikestates[key][idx] for fnum,idx in idx_list[direction]])
    signal_mtrxs['left','amp'] =  np.hstack([np.array(flydict[fnum].left_amp)[idx] for fnum,idx in idx_list[direction]]).T
    signal_mtrxs['right','amp'] =  np.hstack([np.array(flydict[fnum].right_amp)[idx] for fnum,idx in idx_list[direction]]).T
    filter_cond = np.sum(signal_mtrxs[trig_key][:,:11],axis = 1) == {'off_on':0,'on_off':1*11}[direction]
    rast_mtrx = signal_mtrxs[trig_key][filter_cond,:]
    rast_mtrx = rast_mtrx[np.random.randint(0,rast_mtrx.shape[0],size = 100),:]
    ax_group['raster'].imshow(rast_mtrx,
               aspect = 'auto',interpolation = 'nearest',extent = [0,time[-1],0,100],cmap = plt.cm.binary)
    ax_group['raster'].set_ybound(0,100)
    ax_group['kine'].plot(time,np.rad2deg(np.nanmean(signal_mtrxs['left','amp'][filter_cond,:],axis = 0)),color = colors['l'])
    ax_group['kine'].plot(time,np.rad2deg(np.nanmean(signal_mtrxs['right','amp'][filter_cond,:],axis = 0)),color = colors['r'])
    for key,ax in ax_group['left'].items():
        ax.plot(time,np.nanmean(signal_mtrxs['left',key][filter_cond,:],axis = 0),color = colors['l'])
    for key,ax in ax_group['right'].items():
        ax.plot(time,np.nanmean(signal_mtrxs['right',key][filter_cond,:],axis = 0),color = colors['r'])

In [None]:
def make_state_matrix(flylist,
                     sorted_keys,
                     block_key = 'cl_blocks, g_x=-1, g_y=0 b_x=0, b_y=0'):
    state_mtrxs = []
    left = []
    right = []
    lmr = []
    stim_key = ('common','idx',block_key)
    for fly in flylist:
        state_mtrx = np.vstack([fly.spikestates[key] for key in sorted_keys])
        #key = ('common', 'idx', 'cl_blocks, g_x=-1, g_y=0 b_x=-8, b_y=0')
        #key = ('common', 'idx', 'cl_blocks, g_x=-1, g_y=0 b_x=8, b_y=0')
        idx_list = fly.block_data[stim_key]
        state_mtrxs.extend([np.array(state_mtrx[:,idx[100:]]) for idx in idx_list])
        left.extend([np.array(fly.left_amp)[idx[100:]] for idx in idx_list])
        right.extend([np.array(fly.right_amp)[idx[100:]] for idx in idx_list])
        lmr.extend([np.array(fly.left_amp)[idx[100:]]-np.array(fly.right_amp)[idx[100:]] 
                    for idx in idx_list])
    state_mtrx = np.hstack(state_mtrxs)
    state_mtrx = state_mtrx.astype(int)
    return state_mtrx,np.vstack(left),np.vstack(right)

def get_transiton_prob(state_mtrx):
    tprob = {}
    state_list = [tuple(row) for row in state_mtrx.T]
    state_set = set(state_list)
    state_counts = {}
    for state in state_set:
        state_counts[state] = np.sum(np.sum(state==state_mtrx.T,axis = 1)==8)
    tprob = {}
    for col1,col2 in zip(state_mtrx.T[:-1],state_mtrx.T[1:]):
        if (tuple(col1),tuple(col2)) in tprob.keys():
            tprob[tuple(col1),tuple(col2)] += 1
        else:
            tprob[tuple(col1),tuple(col2)] = 1
    return tprob,state_counts

def make_transition_matrix(tprob,
                           state_counts,
                           min_tran_num = 1,
                           min_state_num = 10):
    filtered = {}
    for key,tnum in tprob.items():
        if (tnum > min_tran_num) & \
              (state_counts[key[0]] > min_state_num) & \
              (state_counts[key[1]]> min_state_num):
            filtered[key] = tnum

    inkeys = [x[0] for x in filtered.keys()]
    outkeys = [x[1] for x in filtered.keys()]

    filterd_set = list(set(inkeys + outkeys))
    transition_mtrx = np.zeros((len(filterd_set),len(filterd_set)))

    for i,state1 in enumerate(filterd_set):
        for j,state2 in enumerate(filterd_set):
            try:
                transition_mtrx[i,j] = filtered[state1,state2]
            except KeyError:
                pass

    transition_mtrx = transition_mtrx/np.sum(transition_mtrx,axis = 1)[:,None]
    transition_mtrx[np.isnan(transition_mtrx)] = 0
    sidx = np.argsort(np.diag(transition_mtrx))[::-1]
    transition_mtrx = transition_mtrx[sidx].T[sidx]
    state_table = np.array(filterd_set)[sidx,:]
    return transition_mtrx,state_table

def next_state(current_state,state_table,tmtrx):
    """simulate a markov step using transition matrx"""
    from numpy import random
    state_idx = np.squeeze(np.argwhere(np.all(state_table == current_state,axis = 1)))
    #print state_idx
    prob_vector = tmtrx[:,state_idx]
    #print prob_vector
    idx = random.choice(np.arange(len(state_table)),p = prob_vector)
    return state_table[idx]