In [4]:
import sys
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
import matplotlib
from matplotlib import pyplot as plt
from dca.dca import DynamicalComponentsAnalysis
from dca_research.kca import KalmanComponentsAnalysis
from dca_research.lqg import LQGComponentsAnalysis
from sklearn.linear_model import LinearRegression
from scipy.signal import find_peaks


In [5]:
sys.path.append('/home/akumar/nse/neural_control')

In [6]:
from loaders import load_peanut
from loaders import segment_peanut
from loaders import location_bin_peanut

In [7]:
def calc_loadings(U, d=1):
    # Sum over components
    U = np.sum(np.power(np.abs(U), 2), axis=-1)
    # Reshape and then sum over neurons
    U = np.reshape(U, (d, -1))
    loadings = np.sum(U, axis=0)
    loadings /= np.max(loadings)
    return loadings
    
def getLoadingsonTransitions(X):
    """fit DCA/PCA/KCA to each of the trialized transitions"""
    DCAmodel = DynamicalComponentsAnalysis(d=2, T=3)
    PCAmodel = PCA(n_components=2)
    KCAmodel = KalmanComponentsAnalysis(d=2, T=3)
    FCAmodel = LQGComponentsAnalysis(d=2, T=3)
    DCAmodel.fit(X)
    KCAmodel.fit(X)
    FCAmodel.fit(X)
    extended = X[0]
    for transit in X[1:]:
        extended = np.vstack((extended,transit))
    PCAmodel.fit(extended)
    
    PCA_loading = calc_loadings(PCAmodel.components_.T) # np.log(calc_loadings(PCAmodel.components_.T))
    DCA_loading = calc_loadings(DCAmodel.coef_) #np.log(calc_loadings(DCAmodel.coef_))
    KCA_loading = calc_loadings(KCAmodel.coef_)
    FCA_loading = calc_loadings(FCAmodel.coef_)

    return PCA_loading, DCA_loading, KCA_loading, FCA_loading

In [8]:
supervised_df = pd.read_pickle("/home/akumar/nse/neural_control/data/peanut_segmented_supervised.dat")

In [16]:
from tqdm import tqdm

In [19]:
epochs = [2, 4, 6, 8, 10, 12, 14, 16]

results_list = []

for i, epoch in tqdm(enumerate(epochs)):

    dat =  load_peanut('/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj', epoch, spike_threshold=200 , bin_width=1, boxcox=None,
                       speed_threshold=4)

    transitions1, transitions2 = segment_peanut(dat, '/mnt/Secondary/data/peanut/linearization_dict_peanut_day14.obj', epoch) 

    spike_rates = dat['spike_rates']  
    #Fit DCA/PCA on both kinds of transitions
    spike_rates_list_transition1 = [spike_rates[transit] for transit in transitions1]

    PCA_loading_1, DCA_loading_1, KCA_loading_1, FCA_loadings1 = getLoadingsonTransitions(spike_rates_list_transition1)
    spike_rates_list_transition2 = [spike_rates[transit] for transit in transitions2]
    PCA_loading_2, DCA_loading_2, KCA_loading_2, FCA_loadings2 = getLoadingsonTransitions(spike_rates_list_transition2)

    SS_loading_1 =  supervised_df.loc[(supervised_df['epoch'] == epoch) & \
                                      (supervised_df['fold_idx'] == 1) &\
                                      (supervised_df['transition_type'] == 1) ]["loadings"].iloc[0]
    SS_loading_2 =  supervised_df.loc[(supervised_df['epoch'] == epoch) & \
                                      (supervised_df['fold_idx'] == 1) &\
                                      (supervised_df['transition_type'] == 2) ]["loadings"].iloc[0]

    # PCA_loading_3, DCA_loading_3, KCA_loading_3, FCA_loadings3 = getLoadingsonTransitions(spike_rates)
    PCA_loading_3 = np.nan
    DCA_loading_3 = np.nan
    KCA_loading_3 = np.nan
    FCA_loadings3 = np.nan
    #print("Epoch {0}".format(epoch))
    #print(PCA_loading_1.shape, DCA_loading_1.shape, SS_loading_1.shape, KCA_loading_1.shape)
    transitions, bins_ = location_bin_peanut('/mnt/Secondary/data/peanut/data_dict_peanut_day14.obj',
                                             '/mnt/Secondary/data/peanut/linearization_dict_peanut_day14.obj',
                                             epoch=epoch, spike_threshold=200)

    num_peaks = []    
    
    for transition, bins, dcaloading, pcaloading, kcaloading, ssloading, tran_idx \
        in zip(transitions, bins_, [DCA_loading_1, DCA_loading_2],[PCA_loading_1, PCA_loading_2],\
                                   [KCA_loading_1, KCA_loading_2],[SS_loading_1, SS_loading_2],[1,2]):
        num_peaks_transition = np.zeros(transition.shape[1])    
        for neuron_idx in range(transition.shape[1]):
            peak_indices = find_peaks(transition[:,neuron_idx])[0]
            reg = LinearRegression().fit(bins[1:, np.newaxis], transition[:,neuron_idx])
            predicted_line = reg.intercept_ + np.multiply(bins[1:],  np.squeeze(reg.coef_))
            above_fit_line_peak_idxs = peak_indices[transition[:,neuron_idx][peak_indices] > predicted_line[peak_indices]]
            num_peaks_transition[neuron_idx] = len(above_fit_line_peak_idxs)
        num_peaks.append(num_peaks_transition)

    result = {'epoch':epoch, 'num_peaks':num_peaks, 'PCA_loadings':[PCA_loading_1, PCA_loading_2, PCA_loading_3], 
              'DCA_loadings':[DCA_loading_1, DCA_loading_2, DCA_loading_3], 'KCA_loadings':[KCA_loading_1, KCA_loading_2, KCA_loading_3], 
              'SS_loadings':[SS_loading_1, SS_loading_2], 'FCA_loadings':[FCA_loadings1, FCA_loadings2, FCA_loadings3]} 
    
    results_list.append(result)
        #saved_name = "DistributionOfLoadingsByPFNum/"+ "Epoch" + str(epoch) + "Transition" + str(tran_idx) + ".png"
        #plt.savefig(saved_name)
        #plt.clf()

  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
  cross_cov_mats = torch.tensor(cross_cov_mats)
8it [01:25, 10.67s/it]


In [None]:
fig, ax = plt.subplots(2, 4, figsize=(16, 8))
epochs = [2, 4, 6, 8, 10, 12, 14, 16]
for i, epoch in enumerate(epochs):
        a = ax[np.unravel_index(i, (2, 4))]
        a.boxplot([kcaloading[num_peaks == 1],kcaloading[num_peaks != 1],\
                        pcaloading[num_peaks == 1],pcaloading[num_peaks != 1],\
                        ssloading[num_peaks == 1],ssloading[num_peaks != 1],\
                        dcaloading[num_peaks == 1],dcaloading[num_peaks != 1]],\
                        positions=range(1, 16,2))
        a.set_xticks([1, 3, 5, 7, 9, 11, 13, 15])
        a.set_xticklabels(['KCA 1PF', 'KCA >1PF', 'PCA 1PF', 'PCA >1PF', 'SS 1PF', 'SS >1PF', 'DCA 1PF', 'DCA >1PF'])
        a.set_ylabel("Loadings")
        a.set_title("Distribution of loadings")
