In [7]:
%matplotlib widget
%load_ext autoreload
%autoreload 2
from src.folder_handler import *
from src.cort_processor import CortProcessor
from src.tdt_support import *
from src.plotter import *
from src.decoders import *
import math
import pickle
import scipy as spicy
import numpy as np
import matplotlib.pyplot as plt
from  matplotlib.colors import LinearSegmentedColormap
from src.wiener_filter import *
from matplotlib.pyplot import cm
from scipy import signal
import scipy.stats as stats

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
'''
rates and angles come in as a list of arrays, with each array of angles being one the 7 anaotomical angles ['ankle', 'knee', 'hip', 'limbfoot', 'elbow', 'shoulder', 'forelimb']

1. peaks are identified and the angle is converted to phase
    phase is converted to a sine and cosine wave

the winer filter is trained to take reformatted neural data and turn it into a prediction of sine and cosine

2. predicted sine and cosine are used to find the angle of phase via arctan2
'''
def phase_train(rates, angles):
    full_rates, full_angles = stitch_and_format(rates, angles)
    peak_list0 = tailored_peaks(full_angles, 0)
    peak_list1 = tailored_peaks(full_angles, 1)
    peak_list2 = tailored_peaks(full_angles, 2)
    peak_list3 = tailored_peaks(full_angles, 3)
    peak_list4 = tailored_peaks(full_angles, 4)
    peak_list5 = tailored_peaks(full_angles, 5)
    peak_list6 = tailored_peaks(full_angles, 6)
    phase_list0 = to_phasex(peak_list0, full_angles[:,0])
    phase_list1 = to_phasex(peak_list1, full_angles[:,1])
    phase_list2 = to_phasex(peak_list2, full_angles[:,2])
    phase_list3 = to_phasex(peak_list3, full_angles[:,3])
    phase_list4 = to_phasex(peak_list4, full_angles[:,4])
    phase_list5 = to_phasex(peak_list5, full_angles[:,5])
    phase_list6 = to_phasex(peak_list6, full_angles[:,6])
    phase_list = np.vstack((phase_list0,phase_list1,phase_list2,phase_list3,phase_list4,phase_list5,phase_list6)).T
    sin_array, cos_array = sine_and_cosine(phase_list)
    h, _, _, _ = decode_kfolds(X=full_rates, Y=full_angles)
    h_sin, _, _, _ = decode_kfolds(X=full_rates, Y=sin_array)
    h_cos, _, _, _ = decode_kfolds(X=full_rates, Y=cos_array)
    predicted_angle = predicted_lines(full_rates, h)
    predicted_sin = predicted_lines(full_rates, h_sin)
    predicted_cos = predicted_lines(full_rates, h_cos)
    # sin_vaf = []
    # cos_vaf = []
    # for i in range(predicted_sin.shape[1]):
    #     sin_vaf_tmp = vaf(sin_array[:,i],predicted_sin[:,i])
    #     cos_vaf_tmp = vaf(cos_array[:,i],predicted_cos[:,i])
    #     sin_vaf.append(sin_vaf_tmp)
    #     cos_vaf.append(cos_vaf_tmp)
    # print(sin_vaf)
    # print(cos_vaf)
    # vaf_array = np.mean([sin_vaf, cos_vaf], axis=0)
    arctans = arctan_fn(predicted_sin, predicted_cos)
    r_array = []
    for i in range(phase_list.shape[1]):
        r, p = stats.pearsonr(phase_list[:,i], arctans[:,i])
        r_array.append(r)
    return arctans, phase_list, h_sin, h_cos, r_array, sin_array, cos_array, predicted_sin, predicted_cos

def phase_train_injured(rates, angles, intact_phase):
    full_rates, full_angles = stitch_and_format(rates, angles)
    peak_list0 = tailored_peaks(full_angles, 0)
    peak_list1 = tailored_peaks(full_angles, 1)
    peak_list2 = tailored_peaks(full_angles, 2)
    peak_list3 = tailored_peaks(full_angles, 3)
    peak_list4 = tailored_peaks(full_angles, 4)
    peak_list5 = tailored_peaks(full_angles, 5)
    peak_list6 = tailored_peaks(full_angles, 6)
    phase_list0 = to_phasex(peak_list0, full_angles[:,0])
    phase_list1 = to_phasex(peak_list1, full_angles[:,1])
    phase_list2 = to_phasex(peak_list2, full_angles[:,2])
    phase_list3 = to_phasex(peak_list3, full_angles[:,3])
    phase_list4 = to_phasex(peak_list4, full_angles[:,4])
    phase_list5 = to_phasex(peak_list5, full_angles[:,5])
    phase_list6 = to_phasex(peak_list6, full_angles[:,6])
    phase_list = np.vstack((phase_list0,phase_list1,phase_list2,phase_list3,phase_list4,phase_list5,phase_list6)).T
    ankle_roll = align_to_hind(intact_phase, 0)
    knee_roll = align_to_hind(intact_phase, 1)
    hip_roll = align_to_hind(intact_phase, 2)
    limbfoot_roll = align_to_hind(intact_phase, 3)
    ankle_mate = best_hindlimb_match(intact_phase, ankle_roll, 0)
    knee_mate = best_hindlimb_match(intact_phase, knee_roll, 1)
    hip_mate = best_hindlimb_match(intact_phase, hip_roll, 2)
    limbfoot_mate = best_hindlimb_match(intact_phase, limbfoot_roll, 3)
    phase_list[:,0] = np.roll(phase_list[:,ankle_mate],int(ankle_roll[ankle_mate]))
    phase_list[:,1] = np.roll(phase_list[:,knee_mate],int(knee_roll[knee_mate]))
    phase_list[:,2] = np.roll(phase_list[:,hip_mate],int(hip_roll[hip_mate]))
    phase_list[:,3] = np.roll(phase_list[:,limbfoot_mate],int(limbfoot_roll[limbfoot_mate]))
    print("mapping is:",ankle_mate,knee_mate,hip_mate,limbfoot_mate)
    sin_array, cos_array = sine_and_cosine(phase_list)
    h, _, _, _ = decode_kfolds(X=full_rates, Y=full_angles)
    h_sin, _, _, _ = decode_kfolds(X=full_rates, Y=sin_array)
    h_cos, _, _, _ = decode_kfolds(X=full_rates, Y=cos_array)
    predicted_angle = predicted_lines(full_rates, h)
    predicted_sin = predicted_lines(full_rates, h_sin)
    predicted_cos = predicted_lines(full_rates, h_cos)
    # sin_vaf = []
    # cos_vaf = []
    # for i in range(predicted_sin.shape[1]):
    #     sin_vaf_tmp = vaf(sin_array[:,i],predicted_sin[:,i])
    #     cos_vaf_tmp = vaf(cos_array[:,i],predicted_cos[:,i])
    #     sin_vaf.append(sin_vaf_tmp)
    #     cos_vaf.append(cos_vaf_tmp)
    # vaf_array = np.mean([sin_vaf, cos_vaf], axis=0)
    arctans = arctan_fn(predicted_sin, predicted_cos)
    r_array = []
    for i in range(phase_list.shape[1]):
        r, p = stats.pearsonr(phase_list[:,i], arctans[:,i])
        r_array.append(r)
    return arctans, phase_list, h_sin, h_cos, r_array, sin_array, cos_array, predicted_sin, predicted_cos
    
def arctan_fn_alt(predicted_sin, predicted_cos):   
    arctans = []
    for i in range(predicted_sin.shape[1]):
        arctan_hold = []
        for j in range(predicted_sin.shape[0]):
            arctan = math.atan2(predicted_sin[j,i],predicted_cos[j,i])
            angle = math.degrees(arctan)+180
            arctan_hold = np.append(arctan_hold, angle)
        arctans.append(arctan_hold)
    arctans = np.array(arctans).T    
    return arctans

def arctan_fn(predicted_sin, predicted_cos):   
    arctans = []
    for i in range(predicted_sin.shape[1]):
        arctan = np.arctan2(predicted_sin[:,i],predicted_cos[:,i])
        arctan_angles = np.degrees(arctan) + 180
        arctans.append(arctan_angles)
    arctans = np.array(arctans).T
    return arctans

def sine_and_cosine(phase_list):
    phase_list = np.radians(phase_list - 180)
    sin_array = []
    cos_array = []
    for i in range(phase_list.shape[1]):
        sin = np.sin(phase_list[:,i])
        cos = np.cos(phase_list[:,i])
        sin_array.append(sin)
        cos_array.append(cos)
    sin_array = np.array(sin_array).T
    cos_array = np.array(cos_array).T
    return sin_array, cos_array
    
def predicted_lines(actual, H):
    holding_array = []
    for ii in range(H.shape[1]):
        temp1 = test_wiener_filter(actual, H[:,ii])
        holding_array.append(temp1)
    holding_array = np.array(holding_array).T
    return holding_array
    
def to_phasex(peaks, angles):
    for i in range(0,peaks.shape[0]-1):
        for j in range(0, peaks[i+1]-peaks[i]):
            angles[peaks[i]+j] = j*360/(peaks[i+1]-peaks[i])
        angles[-1] = 0
    return angles

                
def tailored_peaks(angles, index):
    peak_dict = {
            0 : {
                'signal': -(angles[:,index]),
                'prominence': 5,
                'distance': 5,
                'width' : 2,
                'height' : -1.2*np.mean(angles[:,index])
            },
            1 : {
                'signal': angles[:,index],
                'prominence': 10,
                'distance': None,
                'width' : None,
                'height' : np.mean(angles[:,index])
            },
            2 : {
                'signal': angles[:,index],
                'prominence': 5,
                'distance': None,
                'width' : 2,
                'height' : np.mean(angles[:,index])
            },
            3 : {
                'signal': angles[:,index],
                'prominence': 10,
                'distance': None,
                'width' : None,
                'height' : np.mean(angles[:,index])
            },
            4 : {
                'signal': angles[:,index],
                'prominence': 6.5,
                'distance': 5,
                'width' : None,
                'height' : 1.1*np.mean(angles[:,index])
            },
            5 : {
                'signal': -(angles[:,index]),
                'prominence': 5,
                'distance': None,
                'width' : None,
                'height' : -1.1*np.mean(angles[:,index])
            },
            6 : {
                'signal': angles[:,index],
                'prominence': 9,
                'distance': 5,
                'width' : None,
                'height' : None
            }
        }
    peak_dict = {
            0 : {
                'signal': -(angles[:,index]),
                'prominence': 5,
                'distance': 5,
                'width' : 2,
                'height' : -1.2*np.mean(angles[:,index])
            },
            1 : {
                'signal': angles[:,index],
                'prominence': 10,
                'distance': None,
                'width' : None,
                'height' : np.mean(angles[:,index])
            },
            2 : {
                'signal': angles[:,index],
                'prominence': 5,
                'distance': None,
                'width' : 2,
                'height' : np.mean(angles[:,index])
            },
            3 : {
                'signal': angles[:,index],
                'prominence': 10,
                'distance': None,
                'width' : None,
                'height' : np.mean(angles[:,index])
            },
            4 : {
                'signal': angles[:,index],
                'prominence': 6.5,
                'distance': 5,
                'width' : None,
                'height' : 1.1*np.mean(angles[:,index])
            },
            5 : {
                'signal': -(angles[:,index]),
                'prominence': 5,
                'distance': None,
                'width' : None,
                'height' : -1.1*np.mean(angles[:,index])
            },
            6 : {
                'signal': angles[:,index],
                'prominence': 9,
                'distance': 5,
                'width' : None,
                'height' : None
            }
        }
    peaks, _ = spicy.signal.find_peaks(peak_dict[index]['signal'], prominence=peak_dict[index]['prominence'], distance =peak_dict[index]['distance'], width =peak_dict[index]['width'], height =peak_dict[index]['height'])    
    peaks = np.concatenate([[0],peaks,[np.shape(angles[:,index])[0]-1]])
    return peaks

# def wave_align(waves):#, plot_req):
#     temp_shift = []
#     ts = np.linspace(0, (waves.shape[0]*50)/1000,waves.shape[0])
#     dx = np.mean(np.diff(ts))
#     for target in [waves[:,0], waves[:,1], waves[:,2]]:
#         shift = (np.argmax(signal.correlate(waves[:,3], target)) - len(target)) * dx
#         temp_shift = np.append(temp_shift, shift)
#     temp_shift = np.append(temp_shift, 0)
#     for target in [waves[:,4], waves[:,5]]:
#         shift = (np.argmax(signal.correlate(waves[:,6], target)) - len(target)) * dx
#         temp_shift = np.append(temp_shift, shift)
#     temp_shift = np.append(temp_shift, 0)
#     # temp_shift = np.array(temp_shift)
#     return temp_shift

def align_to_hind(phases, target):
    temp_shift = []
    ts = np.linspace(0, (phases.shape[0]*50)/1000,phases.shape[0])
    dx = np.mean(np.diff(ts))
    for i in range(phases.shape[1]):
        if i != target:
            shift = (np.argmax(signal.correlate(phases[:,i], phases[:,target])) - phases[:,target].shape[0]) * dx
            shift = int(-shift*1000/50)
            temp_shift = np.append(temp_shift, shift)
        else:
            temp_shift = np.append(temp_shift, 0)
    return temp_shift

# def alignment_check(tsf, arctans, full_phase):
#     dx = np.mean(np.diff(tsf))
#     temp_shift = []
#     target = [arctans[:,0], arctans[:,1], arctans[:,2], arctans[:,3], arctans[:,4], arctans[:,5], arctans[:,6]]
#     for i in range(len(target)):
#         shift = (np.argmax(signal.correlate(full_phase[:,i], target[i])) - len(target[i])) * dx
#         temp_shift.append(shift)
#     return temp_shift

def stitch_and_format(firing_rates_list, resampled_angles_list):
        """
        takes list of rates, list of angles, then converts them into lags of 10
        using format rate in wiener_filter.py, and then stitches them into one
        big array

        both lists must have same # of elements, and each array inside list
        must have the same size as the corresponding array in the other list.
        """
        assert isinstance(firing_rates_list, list), 'rates must be list'
        assert isinstance(resampled_angles_list, list), 'angles must be list'
        formatted_rates = []
        formatted_angles = []

        for i in range(len(firing_rates_list)):
            f_rate, f_angle = format_data(firing_rates_list[i],
                    resampled_angles_list[i])
            formatted_rates.append(f_rate)
            formatted_angles.append(f_angle)


        if len(formatted_rates)==1: #check if just single array in list
            rates = np.array(formatted_rates)
        else: #if multiple, stitch into single array
            rates = np.vstack(formatted_rates)

        if len(formatted_angles)==1: #check if single array
            kin = np.array(formatted_angles)
        elif formatted_angles[0].ndim > 1: #check if multiple angles
            kin = np.vstack(formatted_angles)
        else:
            kin = np.hstack(formatted_angles)
        return np.squeeze(rates), np.squeeze(kin)
    
def impulse_response(H_mat, AOI):
    column_response = []
    for i in range(0,32):
            product_list = []
            for j in range(0,10):
                dummyarray = np.zeros((10,32))
                dummyarray[-1,i] = 1
                dummyarray = np.roll(dummyarray,-j,axis = 0)
                dummyarray = dummyarray.flatten()
                dummyarray = np.insert(dummyarray,0,1)
                dummyarray = dummyarray.reshape(321,1)
                dummyarray = dummyarray.T
                product = np.dot(dummyarray, H_mat)
                product_list.append(product[0])
            product_list = np.array(product_list)
            column_response.append(product_list)
    fig1, ax1 = plt.subplots()
    x = np.arange(0,10,1)        
    for i in range(len(column_response)):
        ax1.plot(x, column_response[i][:,AOI])
    ax1.set_xticks(x)
        
def vaf(x,xhat):
    """
    Calculating vaf value
    x: actual values, a numpy array
    xhat: predicted values, a numpy array
    """
    x = x - x.mean(axis=0)
    xhat = xhat - xhat.mean(axis=0)
    return (1-(np.sum(np.square(x - xhat))/np.sum(np.square(x))))

def best_hindlimb_match(phase_list, roll, AOI):
    rank_list = []
    for i in range(4,7):
        mate = np.roll(phase_list[:,i],int(roll[i]))
        target_phase = phase_list[:,AOI]
        r, p = stats.pearsonr(mate, target_phase)
        # phase_synchrony = 1-np.sin(np.abs(target_phase-mate)/2)
        # rank = np.mean(phase_synchrony)
        # rank_list.append(rank)
        rank_list.append(r)
    best_index = max(range(len(rank_list)), key=rank_list.__getitem__)+4
    return best_index

In [79]:
# with open('/mnt/c/oobootoo/rat-fes/data/pickles/3results-gregintact_729_session.pkl', 'rb') as inp:
#     session729 = pickle.load(inp)
# rates729 = session729.data['rates']
# angles729 = session729.data['angles']
# coords729 = session729.data['coords']
# arctans729, phase_list729, H_sin729, H_cos729, VAF729, sin_array, cos_array, predicted_sin, predicted_cos = phase_train(rates729, angles729)
# tsf729 = np.linspace(0, (phase_list729.shape[0]*50)/1000,phase_list729.shape[0])
# fig512, ax = plt.subplots(figsize=(12,8), sharex = True)
# ax.set_title('7/29 knee_phase VAF=' + "{0:.3f}".format(VAF729[1]))
# ax.plot(tsf729, arctans729[:,1], c='r', alpha=0.5, label = "predicted")
# ax.plot(tsf729, phase_list729[:,1], c='k', alpha=0.5, label = "actual")
# ax.legend(loc="lower right")
# fig512.tight_layout

In [80]:
# arctans729injured, phase_list729injured, H_sin729injured, H_cos729injured, VAF729injured, sin_arrayinjured, cos_arrayinjured, predicted_sininjured, predicted_cosinjured = phase_train_injured(rates729, angles729, phase_list729)
# tsf729injured = np.linspace(0, (phase_list729injured.shape[0]*50)/1000,phase_list729injured.shape[0])
# fig599, ax = plt.subplots(figsize=(12,8), sharex = True)
# ax.set_title('knee phase on hypothetical training data')
# ax.plot(tsf729injured, arctans729injured[:,1], c='r', alpha=0.5, label = "predicted phase")
# ax.plot(tsf729injured, phase_list729injured[:,1], c='k', linestyle='--',alpha=0.5, label = "hypothetical phase (training data)")
# ax.plot(tsf729, phase_list729[:,1], c='k', alpha=0.5, label = "actual phase")
# ax.legend(loc="lower right")
# fig512.tight_layout