In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from sklearn.preprocessing import zscore
from sklearn.mixture import GaussianMixture
from scipy.stats import sem

In [None]:
Behavior_files, Phot_files, Pupil_files, MetaData_files, MetaDataX_files, ANIMAL_IDs, ANIMAL_VARs = DirectoryAlloc_testedit(ROOTDIR, 200, 0)
loopIDX = range(3)
subjIDX = [ANIMAL_VARs[ANIMAL_IDs[i]] for i in loopIDX]
OFCIDX = list(range(len(Behavior_files)))

# Define time vector
t_waveform = np.arange(-41, 41)
METAMATRIX = MetaSPK_test_1

plt.figure()
for i in [0]:  
    data = loadmat(Behavior_files[i])
    waveform_data = loadmat(f"{ROOTDIR}2021-22_Attention/NP 2023-12/waveforms/longer IHB/Session_unit_waveform_{data['AnimalID']}_{data['date'][0]}-{data['date'][1]}-{data['date'][2]}-{data['date'][3]}.mat")
    goodunit_idx = np.where((METAMATRIX['Unit_type'][i] != 0) & (METAMATRIX['Unit_level'][i] > 1))
    channel_ids = METAMATRIX['Unit_map'][i][goodunit_idx, 1]

    plt.subplot(1, 1, 1)
    for j in range(waveform_data['waveFormsMeanSEM']['MEAN'].shape[0]):
        mean_waveform = waveform_data['waveFormsMeanSEM']['MEAN'][j, channel_ids[j] + 1, :]
        sem_waveform = waveform_data['waveFormsMeanSEM']['SEM'][j, channel_ids[j] + 1, :]
        plt.fill_between(t_waveform, mean_waveform - sem_waveform, mean_waveform + sem_waveform, color='gray', alpha=0.5)
        plt.plot(t_waveform, mean_waveform, 'k')
plt.show()

# Calculate and collect average waveforms from each session
lumped_average_waveform = []
for idx in [subjIDX[0][-5:], subjIDX[1][-5:], subjIDX[2][-5:]]:
    for i in idx:
        data = loadmat(data_files[i])
        goodunit_idx = np.where((METAMATRIX['Unit_type'][i] != 0) & (METAMATRIX['Unit_level'][i] > 1))
        channel_ids = METAMATRIX['Unit_map'][i][goodunit_idx, 1]

        for j in range(waveform_data['waveFormsMeanSEM']['MEAN'].shape[0]):
            waveform = waveform_data['waveFormsMeanSEM']['MEAN'][j, channel_ids[j] + 1, :]
            lumped_average_waveform.append(waveform)

# Extract waveform features
lumped_waveform_features = []
for waveform in lumped_average_waveform:
    waveform = zscore(waveform)
    features = ExtractWaveformFeatures_corr(waveform, 20)
    lumped_waveform_features.append(features)

# Normalize features by subject
for i in range(len(subjIDX)):
    norm_pool = np.vstack(lumped_waveform_features[i*5:(i+1)*5])
    for j in range(5):
        lumped_waveform_features[i*5 + j][:, 0] = (lumped_waveform_features[i*5 + j][:, 0] - np.nanmean(norm_pool[:, 0])) / np.nanstd(norm_pool[:, 0])
        lumped_waveform_features[i*5 + j][:, 1] = (lumped_waveform_features[i*5 + j][:, 1] - np.nanmean(norm_pool[:, 1])) / np.nanstd(norm_pool[:, 1])

# Scatter plot of features
mixed_features = np.vstack(lumped_waveform_features)
plt.scatter(mixed_features[:, 1], mixed_features[:, 0])
plt.xlabel('Trough to Peak Duration')
plt.ylabel('AP Width')
plt.show()

# GMM clustering and BIC evaluation
max_clusters = 9
bic_scores = []
for k in range(1, max_clusters + 1):
    gmm = GaussianMixture(n_components=k, max_iter=1000, random_state=42).fit(mixed_features)
    bic_scores.append(gmm.bic(mixed_features))

# Plot BIC
plt.figure()
plt.plot(range(1, max_clusters + 1), bic_scores, '-o')
plt.xlabel('Number of Clusters')
plt.ylabel('BIC')
plt.show()
