In [1]:
import seaborn as sns
import scipy
from src.dataset import SESSION_SUBJECT_RECORDING, SUBJECT_PROBE_DATASET
import pandas as pd
import numpy as np

sns.set_theme('paper')

def load_ds_data(session_id):
    ds_data_path = f'data/{session_id}/{session_id}_DS_TYPE12.mat'
    file = scipy.io.loadmat(ds_data_path)
    ds1sup = file['kType1sup'].flatten().astype('bool')
    ds2sup = file['kType2sup'].flatten().astype('bool')
    sample = file['samplesDS'].flatten()
    ds1_sample = sample[ds1sup] - 1 # matlab to python index
    ds2_sample = sample[ds2sup] - 1
    return ds1_sample, ds2_sample   # time in LFP samples

def ca3_cell_ids(session_id):
    subject_id, recording_id = SESSION_SUBJECT_RECORDING[session_id]
    ca3_dataset = SUBJECT_PROBE_DATASET[(subject_id, 'AP2')]
    cell_info_path = f'ref/features_ca1_ca3dg_good.csv'
    cell_info = pd.read_csv(cell_info_path)
    cell_info = cell_info.loc[cell_info['dataset'] == ca3_dataset].drop(columns='dataset')
    cell_info = cell_info.loc[cell_info['recording'] == recording_id].drop(columns='recording')
    cell_info = cell_info.loc[cell_info['location'] == 'CA3'].drop(columns='location')
    return np.array(cell_info['id'])

def load_spike_data(session_id, cell_ids):
    spike_cell_path = f'data/{session_id}/{session_id}_AP2_spike_clusters.npy'
    spike_time_path = f'data/{session_id}/{session_id}_AP2_spike_times.npy'
    spike_cell_ids = np.load(spike_cell_path).squeeze()
    spike_times = np.load(spike_time_path).squeeze()
    cell_id_mask = np.isin(spike_cell_ids, cell_ids)
    spike_cell_ids = spike_cell_ids[cell_id_mask]
    spike_times = spike_times[cell_id_mask]
    spike_times = spike_times // 12
    return spike_cell_ids, spike_times

def load_ds_combos(session_id):
    ds1_sample, ds2_sample = load_ds_data(session_id)
    cell_ids = ca3_cell_ids(session_id)
    ds1_time, ds2_time = load_ds_data(session_id)
    ds_time = sorted([(t, 1) for t in ds1_time] + [(t, 2) for t in ds2_time])
    ds_time, ds_type = np.array(ds_time).T
    combo_isi = np.diff(ds_time)
    ds_time, prev_ds_time = ds_time[1:], ds_time[:-1]
    ds_type, prev_ds_type = ds_type[1:], ds_type[:-1]
    combo_type = np.zeros_like(combo_isi, dtype='int')
    combo_masks = [
        (prev_ds_type == 1) & (ds_type == 1),
        (prev_ds_type == 1) & (ds_type == 2),
        (prev_ds_type == 2) & (ds_type == 1),
        (prev_ds_type == 2) & (ds_type == 2),
    ]
    for n, mask in enumerate(combo_masks):
        combo_type[mask] = n
    return ds_time, combo_isi, combo_type

def make_psth(session_id):
    cell_ids = ca3_cell_ids(session_id)
    ds_time, combo_isi, combo_type = load_ds_combos(session_id)
    spike_ids, spike_times = load_spike_data(session_id, cell_ids)
    n_samples = 2500 # must divide by 2
    n_events = len(combo_type)
    n_cells = len(cell_ids)
    ds_psth = np.zeros((n_cells, n_events, n_samples), dtype='int')
    spike_counts = []
    for i in range(n_cells):
        cell_spike_times = spike_times[spike_ids == cell_ids[i]]
        cell_spike_times = cell_spike_times[cell_spike_times < ds_time[-1]]
        spike_mask = np.zeros(ds_time[-1], dtype='int')
        spike_mask[cell_spike_times] = 1
        for j in range(n_events):
            t = ds_time[j]
            t_start = t - n_samples//2
            t_stop = t + n_samples//2
            if t_start < 0 or len(spike_mask) < t_stop:
                ds_psth[i, j, :] += 0
                continue
            ds_psth[i, j, :] += spike_mask[t_start:t_stop]
        spike_counts.append(len(cell_spike_times))
    return cell_ids, spike_counts, ds_time, ds_psth, combo_isi, combo_type

In [2]:
session_list = list(SESSION_SUBJECT_RECORDING.keys())
result_dict = {}
for session_id in session_list:
    cell_ids, spike_count, ds_time, ds_psth, combo_isi, combo_type = make_psth(session_id)
    result_dict[session_id] = {
        'cell_id': cell_ids,
        'spike_count': spike_count,
        'ds_time': ds_time,
        'ds_combo': combo_type,
        'ds_isi': combo_isi,
        'ds_psth': ds_psth,
    }

In [3]:
pd.to_pickle(result_dict, 'ds_aligned_spikes.pkl')