In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd, eigh, pinv
from typing import List, Dict, Callable, Optional
import sys
from my_dpca_module import myDPCA


### Include path for my customized dPCA class (/utils/myDPCA)

In [None]:
sys.path.append(f'{ROOTDIR}2021-22_Attention/NP 2023-12/myDPCA')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from sklearn.decomposition import PCA
from myDPCA import dpca, dpca_plot, dpca_explainedVariance, dpca_perMarginalization, dpca_optimizeLambda, dpca_getNoiseCovariance

data_dir = f'{ROOTDIR}2021-22_Attention/NP 2023-12/data for dpca/saline'
data_files = sorted([f"{data_dir}/{file}" for file in os.listdir(data_dir)])

combinedParams = [{'stimulus', 'stimulus/time'}, {'decision'}]
margNames = ['Decision', 'Condition-independent']
margColours = np.array([[187, 20, 25], [150, 150, 150]]) / 256

# Load first data file to inspect structure and initialize arrays
session_data = loadmat(data_files[0])
session_data_punish = loadmat(data_files[71])



### Testing data

In [None]:
Session_PriorPunish_Train = session_data['Session_PriorPunish_Train']
Session_PriorSuccess_Train = session_data['Session_PriorSuccess_Train']

time = np.arange(Session_PriorPunish_Train[0].shape[1]) / 20 - 5
timeEvents = [time[-1]]
trialNum = np.zeros((len(Session_PriorPunish_Train), 2))
trialNum[:, 0] = [train.shape[0] for train in Session_PriorPunish_Train]
trialNum[:, 1] = [train.shape[0] for train in Session_PriorSuccess_Train]
firingRates = np.full((len(Session_PriorPunish_Train), 2, Session_PriorPunish_Train[0].shape[1], max(trialNum[:, 0].max(), trialNum[:, 1].max())), np.nan)
for j in range(len(Session_PriorPunish_Train)):
    firingRates[j, 0, :, :Session_PriorPunish_Train[j].shape[0]] = Session_PriorPunish_Train[j].T
    firingRates[j, 1, :, :Session_PriorSuccess_Train[j].shape[0]] = Session_PriorSuccess_Train[j].T
firingRatesAverage = np.nanmean(firingRates, axis=3)

X = firingRatesAverage.reshape(-1, firingRatesAverage.shape[-1])
X = X - X.mean(axis=0)
pca = PCA(n_components=20)
W = pca.fit_transform(X)

dpca_plot(firingRatesAverage, W, W, plot_function=dpca_plot_default)

# Explained variance calculation
explVar = dpca_explainedVariance(firingRatesAverage, W, W, combinedParams=combinedParams)

dpca_plot(firingRatesAverage, W, W, plot_function=dpca_plot_default, explainedVar=explVar,
          time=time, timeEvents=timeEvents, marginalizationNames=margNames, marginalizationColours=margColours)

# dPCA without regularization (ignoring noise covariance)
W, V, whichMarg = dpca(firingRatesAverage, n_components=20, combinedParams=combinedParams, lambda_=1e-4)
explVar = dpca_explainedVariance(firingRatesAverage, W, V, combinedParams=combinedParams)

dpca_plot(firingRatesAverage, W, V, plot_function=dpca_plot_default, explainedVar=explVar,
          marginalizationNames=margNames, marginalizationColours=margColours, whichMarg=whichMarg,
          time=time, timeEvents=timeEvents, timeMarginalization=3, legendSubplot=16)

# dPCA with regularization
optimalLambda = dpca_optimizeLambda(firingRatesAverage, firingRates, trialNum, combinedParams=combinedParams,
                                    simultaneous=ifSimultaneousRecording, numRep=10, filename='tmp_optimalLambdas.mat')
Cnoise = dpca_getNoiseCovariance(firingRatesAverage, firingRates, trialNum, simultaneous=ifSimultaneousRecording)

W, V, whichMarg = dpca(firingRatesAverage, n_components=20, combinedParams=combinedParams, lambda_=optimalLambda, Cnoise=Cnoise)
explVar = dpca_explainedVariance(firingRatesAverage, W, V, combinedParams=combinedParams)

dpca_plot(firingRatesAverage, W, V, plot_function=dpca_plot_default, explainedVar=explVar,
          marginalizationNames=margNames, marginalizationColours=margColours, whichMarg=whichMarg,
          time=time, timeEvents=timeEvents, timeMarginalization=3, legendSubplot=16)

# Optional - Decoding (classification accuracy)
decodingClasses = {(np.arange(S), np.arange(S)), np.repeat([1, 2], S), [], (np.arange(S), np.arange(S) + S)}

accuracy = dpca_classificationAccuracy(firingRatesAverage, firingRates, trialNum, lambda_=optimalLambda,
                                       combinedParams=combinedParams, decodingClasses=decodingClasses,
                                       simultaneous=ifSimultaneousRecording, numRep=5, filename='tmp_classification_accuracy.mat')

# Plot classification accuracy
dpca_classificationPlot(accuracy, decodingClasses=decodingClasses)