In [1]:
import pickle
import pandas as pd
import numpy as np
import itertools
import matplotlib.pyplot as plt
import glob as glob
import os
import scipy
import sys; sys.path.append('/home/marcush/projects/neural_control/analysis_scripts/turnkey')

from dca.cov_util import form_lag_matrix, calc_cross_cov_mats_from_data
from config import PATH_DICT; sys.path.append(PATH_DICT['repo'])
from region_select import *
from loaders import *
from utils import calc_loadings
from collections import defaultdict
from sklearn.cross_decomposition import CCA
from sklearn.model_selection import KFold


# New
import numpy as np
import matplotlib.pyplot as plt
import time
import sys
import pickle
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from sklearn.model_selection import KFold
from dca.cov_util import form_lag_matrix, calc_cross_cov_mats_from_data
import glob
import pdb
from statsmodels.tsa import stattools
from mpi4py import MPI

### Load consolidated decoding dataframe

In [2]:
decoding_glom_path = '/clusterfs/NSDS_data/FCCA/postprocessed/tsao_decode_df.pkl'
with open(decoding_glom_path, 'rb') as f:
    dat = pickle.load(f) 

df_decode = pd.DataFrame(dat)

In [16]:
fold_splits = np.unique(df_decode['fold_idx'].values)
sessions = np.unique(df_decode['data_file'])
comm = MPI.COMM_WORLD
sessions = np.array_split(sessions, comm.size)[comm.rank]
regions = np.unique(df_decode['loader_args'].apply(lambda x: x.get('region')))

In [18]:
all_spikes = {}
for session in sessions:
    all_spikes[session] = {}
    for region in regions: 
    
        df_ = apply_df_filters(df_decode, **{'loader_args':{'region': region}})
        dat = load_tsao(**dict(df_['full_arg_tuple'][0]))            
        all_spikes[session][region] = dat['spike_rates']

Begin Loading Data...
Done Loading Data
Begin getting spike times...
Done getting spike times
Begin filtering spike times into spike rates...
FILTERING SPIKE RATES!
Done filtering spike times into spike rates
Begin Loading Data...
Done Loading Data
Begin getting spike times...
Done getting spike times
Begin filtering spike times into spike rates...
FILTERING SPIKE RATES!
Done filtering spike times into spike rates
Begin Loading Data...
Done Loading Data
Begin getting spike times...
Done getting spike times
Begin filtering spike times into spike rates...
FILTERING SPIKE RATES!
Done filtering spike times into spike rates
Begin Loading Data...
Done Loading Data
Begin getting spike times...
Done getting spike times
Begin filtering spike times into spike rates...
FILTERING SPIKE RATES!
Done filtering spike times into spike rates


# Get data and perform CCA

In [21]:
RELOAD = True
max_cca_dim_check = 50

tmp_path = PATH_DICT['tmp']
save_path_cca_corrs = f"{tmp_path}/CCA_Analysis_Tsao.pkl" 

if RELOAD:

    lags = np.array([0])
    windows = np.array([1])
    reg0 = regions[0]
    reg1 = regions[1]
    results = []

    for session in sessions:

        X = all_spikes[session]['ML']
        Y = all_spikes[session]['AM']
        
        X = X.reshape(-1, X.shape[-1])
        Y = Y.reshape(-1, Y.shape[-1])

        for k, lag in enumerate(lags):
            for w, window in enumerate(windows):
                for fold_idx, (train_idxs, test_idxs) in enumerate(KFold(n_splits=len(fold_splits)).split(X)):
                
                    x = X[train_idxs]
                    y = Y[train_idxs]

                    # Apply window and lag relative to each other
                    if lag != 0:
                        x = x[:-lag, :]
                        y = x[lag:, :]

                    if window > 1:
                        x = form_lag_matrix(x, window)
                        y = form_lag_matrix(y, window)

                    ccamodel = CCA(n_components=min(max_cca_dim_check, min(x.shape[-1], y.shape[-1])))                
                    ccamodel.fit(x, y)
                    X_c, Y_c = ccamodel.transform(x, y)
                    canonical_correlations = [scipy.stats.pearsonr(X_c[:, i], Y_c[:, i])[0] for i in range(max_cca_dim_check)]

                    r = {
                    'dfile': session,
                    'lag': lag,
                    'win': window,
                    'fold_idx': fold_idx,
                    'ccamodel': ccamodel,
                    'canonical_correlations': canonical_correlations
                    
                    }
                    results.append(r)
                    print(f"Done with fold {fold_idx+1}")

    df_results = pd.DataFrame(results)
    with open(save_path_cca_corrs, 'wb') as f:
        pickle.dump(df_results, f)
else:

    with open(save_path_cca_corrs, 'rb') as f:
        df_results = pickle.load(f)
    print("Loading previous CCA fit to split data.")



Done with fold 1




KeyboardInterrupt: 

### Find saturating CCA dimensionality

In [None]:
# p <= q
def CC_AIC(cc_coefs, N, p, q):
    # Sort in descending order
    cc_coefs = np.sort(cc_coefs)[::-1]

    # Calculate the vector Ak
    Ak = np.array([-N * np.sum(np.log(1 - np.power(cc_coefs[k + 1:], 2))) -2 * (p - k) * (q - k) for k in range(cc_coefs.size - 1)])
    return Ak

def CC_BIC(cc_coefs, N, p, q):
    # Sort in descending order
    cc_coefs = np.sort(cc_coefs)[::-1]

    # Calculate the vector Ak
    Ak = np.array([-N * np.sum(np.log(1 - np.power(cc_coefs[k + 1:], 2))) -np.log(N) * (p - k) * (q - k) for k in range(cc_coefs.size - 1)])
    return Ak

In [87]:
for session in sessions:
    
    df_results_sess = apply_df_filters(df_results, **{'session':session})

    canonical_correlations = np.array(df_results_sess['canonical_correlations'].to_list())
    
    ### FIGURE 1
    fig, ax = plt.subplots(figsize=(4, 4))
    medianprops = dict(linewidth=0)
    bplot = ax.boxplot(np.reshape(canonical_correlations, (-1, max_cca_dim_check)), patch_artist=True, medianprops=medianprops, notch=True)
    nTicks = 10
    ax.set_xticks(np.arange(1, max_cca_dim_check, nTicks))
    ax.set_xticklabels(np.arange(1, max_cca_dim_check, nTicks))
    ax.set_xlim([0, max_cca_dim_check])
    ax.set_ylabel('Canonical Correlation Coefficient')
    ax.set_xlabel('Dimension')
    ax.set_title(session)
    
    
    #### FIGURE 2
    nSplits = len(fold_splits)
    cc_dim = np.zeros((nSplits, 2))

    for split in fold_splits: 
        train_idxs, test_idxs = list(KFold(n_splits=nSplits).split(X))[split]

        x = X[train_idxs]
        y = Y[train_idxs]

        p = min(x.shape[1], y.shape[1])
        q = max(x.shape[1], y.shape[1])

        Ak = CC_AIC(canonical_correlations[split, :], x.shape[0], p, q)
        cc_dim[split, 0] = np.argmin(Ak)
        Ak = CC_BIC(canonical_correlations[split, :], x.shape[0], p, q)    
        cc_dim[split, 1] = np.argmin(Ak)
                
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.boxplot(cc_dim + 1)
    ax.set_xticklabels(['AIC', 'BIC'])
    ax.set_ylabel('Dimension')
    ax.set_title(session)
    
    
    print(f"Optimal AIC Dim for session {session}: {np.median(cc_dim[:,0])}")
    print(f"Optimal BIC Dim for session {session}: {np.median(cc_dim[:,1])}")
    print(f"Optimal CCA Dim for session {session}: {np.mean([np.median(cc_dim[:,0]), np.median(cc_dim[:,1])])}")


### Compute and Save CCA Matrices

In [23]:
for session in sessions:
        
    X = all_spikes[session]['ML']
    Y = all_spikes[session]['AM']
    
    X = X.reshape(-1, X.shape[-1])
    Y = Y.reshape(-1, Y.shape[-1])

    manual_CCA_dim = 21 
    ccamodel = CCA(n_components=manual_CCA_dim)
    ccamodel.fit(X, Y)
        
    cca_save_path = f'/clusterfs/NSDS_data/FCCA/postprocessed/CCA_structs/CCA_{session}_{manual_CCA_dim}_dims.pkl'

    with open(cca_save_path, 'wb') as file:
        pickle.dump(ccamodel, file)