In [1]:
import os
import numpy as np
from scipy.io import loadmat
import h5py
from preproc import *
import matplotlib.pyplot as plt
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

In [2]:
# Specify which days' dataset to use
prefix = "/Volumes/Hippocampus/Data/picasso-misc/"
sessions = ["20181102", "20181101"]

In [3]:
def import_data(day_dir: str):
    # Get list of cells under the day directory
    os.system(f"sh ~/Documents/neural_decoding/Hippocampus_Decoding/get_cells.sh {day_dir}")
    cell_list = list()
    with open("cell_list.txt", "r") as file:
        for line in file.readlines():
            cell_list.append(line.strip())
    os.system("rm cell_list.txt")

    # Load data from rplparallel.mat object, extract trial markers, time stamps and session start timestamp
    rp = h5py.File(prefix + day_dir + "/session01/rplparallel.mat")
    rp = rp.get('rp').get('data')
    trial_markers = np.array(rp.get('markers'))
    trial_timestamps = np.array(rp.get('timeStamps'))
    session_start_time = np.round(np.array(rp.get('session_start_sec'))[0,0], 3)

    # Load data and extract spike times from all spiketrain.mat objects
    spike_times = list()
    cell_labels = list()
    for cell_dir in cell_list:
        spk = loadmat(prefix + day_dir + "/session01/" + cell_dir + "/spiketrain.mat")
        spk = spk.get('timestamps').flatten() # spike timestamps is loaded in as a column vector
        spk = spk / 1000 # convert spike timestamps from msec to sec
        spike_times.append(spk)
        
        cell_name = cell_dir.split('/')
        array, channel, cell = cell_name[0][6:], cell_name[1][7:], cell_name[2][5:]
        if channel[0] == '0':
            channel = channel[1:]
        cell_labels.append(f'{day_dir}/ch{channel}/c{cell}')

    # Load data from vmpv.mat object, extract session end timestamp
    pv = h5py.File(prefix + day_dir + "/session01/1vmpv.mat")
    pv = pv.get('pv').get('data')
    session_end_time = np.round(np.array(pv.get('rplmaxtime'))[0,0], 3)

    return trial_markers, trial_timestamps, session_start_time, session_end_time, spike_times, cell_labels


def session_preproc(session_data):
    # Unpack session data
    trial_markers, trial_timestamps, session_start_time, session_end_time, spike_times, cell_labels = session_data

    # Get poster numbers from trial markers, cue phase time intervals
    trial_markers = trial_markers[0,:] % 10
    trial_markers = trial_markers.astype(int)
    cue_intervals = trial_timestamps[0:2,:].T

    # Get durations of each navigation phase
    nav_intervals = trial_timestamps[1:,:].T
    nav_durations = nav_intervals[:,1] - nav_intervals[:,0]
    nav_durations = nav_durations.astype(np.int8)

    # Generate time intervals for each trial
    trial_intervals = np.empty_like(cue_intervals)
    trial_intervals[:,0] = cue_intervals[:,1]
    trial_intervals[:-1,1] = cue_intervals[1:,0]
    trial_intervals[-1,1] = session_end_time

    # Further differentiate trial markers into trial trajectories (start poster, end poster)
    trial_trajectories = np.zeros((trial_markers.shape[0], 2))
    trial_trajectories[:,1] = trial_markers
    trial_trajectories[1:,0] = trial_markers[:-1]

    # Filter out trials that are too long (> 25 seconds) or have repeated goal from previous trial
    good_trials = np.ones(trial_markers.shape, dtype=np.int8)
    max_dur = 25  # maximum duration of trials (in seconds) to filter out
    prev_goal = 0
    for num, dur in enumerate(nav_durations):
        curr_goal = trial_markers[num]
        if dur > max_dur or curr_goal == prev_goal:
            good_trials[num] = 0
        prev_goal = curr_goal
    good_trials[0] = 0  # Discard the first trial also
    trial_filt = np.where(good_trials == 1)

    trial_markers = trial_markers[trial_filt]
    trial_trajectories = trial_trajectories[trial_filt,:][0]  # not sure why it adds an extra axis
    cue_intervals = cue_intervals[trial_filt,:][0]  # not sure why it adds an extra axis
    nav_intervals = nav_intervals[trial_filt,:][0]  # not sure why it adds an extra axis
    trial_intervals = trial_intervals[trial_filt]

    return trial_markers, trial_trajectories, cue_intervals, nav_intervals, trial_intervals, spike_times, cell_labels


def spiketrain_preproc(session_data, timebin_window: int):
    # Unpack session data
    trial_markers, trial_trajectories, cue_intervals, nav_intervals, trial_intervals, spike_times, cell_labels = session_data
    # Get number of cells in dataset
    num_cells = len(cell_labels)

    # Bin entire session into 250 ms time bins, aligned to the start of each cue phase for each trial
    session_intervals = list()
    delta = 0.25  # Size of time bin (in seconds)
    for idx, intvl in enumerate(trial_intervals):
        trial_start, trial_end = intvl
        for time in np.arange(trial_start, trial_end - delta, delta):
            session_intervals.append(np.array([time, time + delta]))
    session_intervals = np.array(session_intervals)

    # Divide cue phases into 250 ms time bins
    num_prds = int(1/delta)
    new_cue_intervals = np.empty((cue_intervals.shape[0], cue_intervals.shape[1], num_prds))
    for num, intvl in enumerate(cue_intervals):
        st_time, ed_time = intvl
        for prd in range(num_prds):
            new_cue_intervals[num,0,prd] = st_time + delta * prd
            new_cue_intervals[num,1,prd] = st_time + delta * (prd + 1)
    full_cue_intervals = cue_intervals
    cue_intervals = new_cue_intervals

    # Choose which 250ms timebin to use for cue intervals to be fitted to the model
    cue_intervals = cue_intervals[:,:,timebin_window]
    timebin_labels = ['0-250ms', '250-500ms', '500-750ms', '750ms-1s']

    # Slot spikes into cue phase intervals for each trial and session time intervals
    spikerates_cue = spike_rates_per_observation(cue_intervals, spike_times)
    spikerates_session = spike_rates_per_observation(session_intervals, spike_times)

    # Bin spike rates within each cell for entire sesion, and get firing rate thresholds used for binning
    binned_spikes_session = np.empty_like(spikerates_session)
    binning_stats = list()
    for col in range(spikerates_session.shape[1]):
        binned_spikes_session[:,col] = bin_firing_rates_4(spikerates_session[:,col])
        binning_stats.append(get_binning_stats_4(spikerates_session[:,col]))
    
    # Bin spike rates within each cell for cue phases
    binned_spikes_cue = np.empty_like(spikerates_cue)
    for col in range(spikerates_cue.shape[1]):
        binned_spikes_cue[:,col] = bin_firing_rates_4(spikerates_cue[:,col], stats=binning_stats[col])
    
    return trial_markers, trial_trajectories, spikerates_cue, binned_spikes_cue, cell_labels


def groupby_trial_trajectories(timeseries: np.array, trial_trajectories: np.array) -> dict:
    # Some important constants
    num_cells = timeseries.shape[1]
    num_goals = 6
    # Group responses according to trial trajectories
    trial_responses = dict()
    for idx, trial in enumerate(timeseries):
        traj = tuple(trial_trajectories[idx])
        if traj not in trial_responses:
            trial_responses[traj] = trial
        else:
            trial_responses[traj] = np.vstack((trial_responses[traj], trial))
    return trial_responses


def random_pairings(list_lengths: list) -> np.array:
        min_len = min(list_lengths)
        indices = np.hstack([np.random.permutation(range(l))[:min_len].reshape((-1,1)) for l in list_lengths])
        return indices


def merge_sessions(sessions_data: list):
    # Contents of each session's data:
    # (0) trial_markers, (1) trial_trajectories, (2) spikerates_cue, (3) binned_spikes_cue, (4) cell_labels
    # raw/binned_responses_grouped is a list of dictionaries, each dictionary corresponds to the grouped responses for one session
    num_sessions = len(sessions_data)
    cell_labels = [cell for sess in sessions_data for cell in sess[4]]
    trajectories = [sess[1] for sess in sessions_data]
    raw_responses = [sess[2] for sess in sessions_data]
    binned_responses = [sess[3] for sess in sessions_data]
    raw_responses_grouped = list()
    binned_responses_grouped = list()
    for sess, traj in enumerate(trajectories):
        raw_resp = raw_responses[sess]
        binned_resp = binned_responses[sess]
        raw_responses_grouped.append(groupby_trial_trajectories(raw_resp, traj))
        binned_responses_grouped.append(groupby_trial_trajectories(binned_resp, traj))

    # Get number of responses per trial trajectory for each session 
    num_responses_per_traj = [dict(map(lambda item: (item[0], item[1].shape[0]), sess.items())) for sess in raw_responses_grouped]
    trial_types = list(num_responses_per_traj[0].keys())
    merged_raw_responses = dict()
    merged_binned_responses = dict()
    for traj in trial_types:
        # Get number of trials from each session for the given trial trajectory
        num_trials = [sess[traj] for sess in num_responses_per_traj]
        pairings = random_pairings(num_trials)  # Generate pairings between sessions for the given trial trajectory
        # Concatenate responses across sessions according to generated pairings
        raw_resp = [sess[traj] for sess in raw_responses_grouped]
        binned_resp = [sess[traj] for sess in binned_responses_grouped]
        merged_raw = np.hstack([raw_resp[sess][pairings[:,sess],:] for sess in range(num_sessions)])
        merged_binned = np.hstack([binned_resp[sess][pairings[:,sess],:] for sess in range(num_sessions)])
        merged_raw_responses[traj] = merged_raw
        merged_binned_responses[traj] = merged_binned

    return merged_raw_responses, merged_binned_responses, cell_labels

In [4]:
# Import and preprocess data
sessions_data = list()
for sess in sessions:
    sess_data = import_data(sess)
    sess_data = session_preproc(sess_data)
    sess_data = spiketrain_preproc(sess_data, 1)  # using timebin window of 1s
    sessions_data.append(sess_data)

In [5]:
# Merge data across sessions
merged_sessions_data = merge_sessions(sessions_data)