#### Imports and script-wide statements

In [1]:
import os
import h5py
from scipy.io import loadmat
import numpy as np
from preproc import *
from sklearn.preprocessing import MinMaxScaler
from itertools import combinations
from scipy.special import factorial
import statsmodels.api as sm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import StrMethodFormatter
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from matplotlib.widgets import Slider
from IPython.display import HTML
import textwrap
import pickle

In [2]:
%matplotlib inline
plt.rcParams['font.size'] = 12

In [3]:
# Choose which days to include
# day_list = ['20181105', '20181102', '20181101']

# Or, read in list of days from txt file
day_list = list()
with open(f'data/combined/days.txt', 'r') as file:
    for line in file:
        day_list.append(line.strip())

In [4]:
# Some common constants
num_sess = len(day_list)
num_goals = 6
tbin_size = 0.1

In [6]:
# Read in predefined list of cells
good_cell_labels = list()
with open('data/cell_list_hm.txt', 'r') as file:
    for line in file:
        line = line.strip().split('/')
        good_cell_labels.append(f'{line[5]}ch{str(int(line[8][7:]))}c{str(int(line[9][4:]))}')

#### Preprocessing of session-level navigation data

In [8]:
def group_by_pos_bins(timeseries: np.array, pos_bins: np.array) -> list:
    num_bins = 1600
    grouped = [np.empty((0, timeseries.shape[1])) for _ in range(num_bins)]
    for idx, bin in enumerate(pos_bins):
        bin = int(bin)
        if bin == 0:
            continue
        bin -= 1
        grouped[bin] = np.vstack((grouped[bin], timeseries[idx,:]))
    return grouped

def place_response_distribution(timeseries: np.array, pos_bins: np.array) -> dict:
    dist = dict()
    for idx, bin in enumerate(pos_bins):
        bin = int(bin)
        if bin == 0:
            continue
        if bin not in dist:
            dist[bin] = list()
        dist[bin].append(timeseries[idx])
    dist = {key: np.array(sorted(val)) for key, val in dist.items()}
    return dist

def downsample_pos_binning(pos_bins: np.array) -> np.array:
    def pos_coords_to_bins(coords: np.array) -> int:
        # Converts positional (x, y) coordinates to bin numbers, returns an int
        # Default maze dimensions
        num_bins = 5
        coord_min, size = -12.5, 25
        bin_width = size / num_bins
        x, y = coords
        # Convert to row/column number for each axis
        h, v = int(np.floor((x - coord_min)/bin_width)), int(np.floor((y - coord_min)/bin_width))
        # Combine to get actual bin number
        return v * num_bins + h + 1
    
    new_pos_bins = np.zeros_like(pos_bins)
    for i, bin in enumerate(pos_bins):
        x, y = pos_bins_to_coords(bin).flatten()
        new_pos_bins[i] = pos_coords_to_bins((x, y))
    return new_pos_bins

In [9]:
# Save directory for data files
prefix = "/Volumes/Hippocampus/Data/picasso-misc/"
save_dir = "data/placedist"
# Whether to overwrite preexisting files
overwrite = True

In [10]:
for day_dir in day_list:
    if os.path.exists(f'{save_dir}/{day_dir}_data.pkl') and not overwrite:
        continue

    # Get list of cells under the day directory
    os.system(f"sh ~/Documents/neural_decoding/Hippocampus_Decoding/get_cells.sh {day_dir}")
    cell_list = list()
    with open("cell_list.txt", "r") as file:
        for line in file.readlines():
            cell_list.append(line.strip())
    os.system("rm cell_list.txt")

    # Load data and extract spike times from all spiketrain.mat objects
    spike_times = list()
    cell_labels = list()
    for cell_dir in cell_list:
        try:
            spk_file = loadmat(prefix + day_dir + "/session01/" + cell_dir + "/spiketrain.mat")
        except NotImplementedError:
            spk_file = h5py.File(prefix + day_dir + "/session01/" + cell_dir + "/spiketrain.mat")
        except FileNotFoundError:
            continue
        spk = np.array(spk_file.get('timestamps')).flatten() # spike timestamps is loaded in as a column vector
        spk /= 1000 # convert spike timestamps from msec to sec
        spike_times.append(spk)
        if isinstance(spk_file, h5py.File):
            spk_file.close()
        
        cell_name = cell_dir.split('/')
        array, channel, cell = cell_name[0][6:], cell_name[1][7:], cell_name[2][5:]
        if channel[0] == '0':
            channel = channel[1:]
        cell_labels.append(f'a{array}/ch{channel}/c{cell}')

    # Load data from vmpv.mat object
    pv_file = h5py.File(prefix + day_dir + "/session01/1vmpv.mat")
    pv = pv_file.get('pv').get('data')

    # Extract session time and position bins, then convert spike timestamps to spiketrains and then to spike rates
    place_intervals = get_place_intervals(pv)
    pv_file.close()
    spikerates = spike_rates_per_observation(place_intervals, spike_times)

    # Set mindur and threshvel filter
    min_dur = 0.05 # Minimum duration for an observation set to 50ms
    thresh_vel = 1 # Set minimum velocity threshold to 1 unit/s

    dur_intervals = place_intervals[:,1] - place_intervals[:,0]
    max_dur = np.sqrt(2) * 0.625 / thresh_vel
    valid_obs = np.where((dur_intervals > min_dur) & (dur_intervals <= max_dur))

    # Apply mindur and threshvel filter
    place_intervals = place_intervals[valid_obs]
    spikerates = spikerates[valid_obs]
    dur_intervals = dur_intervals[valid_obs]

    # Downsample position binning from (40 x 40) to (5 x 5)
    place_intervals[:,2] = downsample_pos_binning(place_intervals[:,2])

    # Generate spiking distribution per position bin for each individual cell
    place_responses_per_cell = list()
    for cell in range(spikerates.shape[1]):
        responses_per_bin = place_response_distribution(spikerates[:,cell], place_intervals[:,2])
        place_responses_per_cell.append(responses_per_bin)
    durations_per_place = place_response_distribution(dur_intervals, place_intervals[:,2])

    ### Save processed data to pkl file ###
    for i, label in enumerate(cell_labels):
        label = label.split('/')
        label[1] = label[1][2:]
        if label[1][0] == '0':
            label[1] = label[1][1:]
        cell_labels[i] = f'{day_dir}ch{label[1]}{label[2]}'

    data = {'place_responses_per_cell': place_responses_per_cell, 'durations_per_place': durations_per_place, 'cell_labels': cell_labels}
    with open(f'{save_dir}/{day_dir}_data.pkl', 'wb') as file:
        pickle.dump(data, file)

#### Pseudopopulation place responses

In [11]:
all_place_responses = list()
all_place_durations = list()
all_cell_labels = list()

for day in day_list:
    with open(f'{save_dir}/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        num_sess_cells = len(data['cell_labels'])
        all_place_responses.extend(data['place_responses_per_cell'])
        for _ in range(num_sess_cells):
            all_place_durations.append(data['durations_per_place'])
        all_cell_labels.extend(data['cell_labels'])

num_all_cells = len(all_cell_labels)

In [12]:
# Read in predefined list of cells
good_cell_labels = list()
with open('data/cell_list_pseudopop.txt', 'r') as file:
    for line in file:
        good_cell_labels.append(line.rstrip())

# Filter out cells not in pseudopopulation list
cell_filter = np.array([idx for idx, cell in enumerate(all_cell_labels) if cell in set(good_cell_labels)])
all_place_responses = [all_place_responses[i] for i in cell_filter]
all_place_durations = [all_place_durations[i] for i in cell_filter]

# Update num_all_cells and all_cell_labels to reflect number of cells in pseudopopulation
num_all_cells = cell_filter.shape[0]
all_cell_labels = [all_cell_labels[i] for i in cell_filter]

In [13]:
# Convert each place response distribution into (mean, var, duration) summary parameters
for cell, place_responses_per_cell in enumerate(all_place_responses):
    num_bins = 25
    response_params_per_cell = np.zeros((num_bins+1, 3))
    place_durations_per_cell = all_place_durations[cell]
    for bin, dist in place_responses_per_cell.items():
        response_params_per_cell[bin,0] = np.mean(dist)
        response_params_per_cell[bin,1] = np.std(dist, ddof=1)
        response_params_per_cell[bin,2] = np.sum(place_durations_per_cell[bin])
    response_params_per_cell[:,2] = response_params_per_cell[:,2] / np.sum(response_params_per_cell[:,2])
    all_place_responses[cell] = response_params_per_cell
all_place_responses = np.array(all_place_responses)

# Clean up large memory variables
del all_place_durations

In [14]:
all_place_responses.shape

(233, 26, 3)

#### Place response distributions plots per cell

In [64]:
to_save = True
figsave_dir = 'figures/placedist'

In [65]:
for day in day_list:
    with open(f'{save_dir}/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        place_responses_per_cell = data['place_responses_per_cell']
        cell_labels = data['cell_labels']

    for cell, label in enumerate(cell_labels):
        if label not in good_cell_labels:
            continue
        
        fig = plt.figure(figsize=(20, 12))
        plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, wspace=0.2, hspace=0.35)
        plt.suptitle(f'Firing rate distribution per place bin for cell {label}', y=0.95, fontsize=14)

        for bin, dist in place_responses_per_cell[cell].items():
            # Get distribution median and mean
            dist_mean, dist_median = np.mean(dist), np.median(dist)
            # Filter out 0 Hz observations
            dist = dist[dist > 0]
            # Plot subplots in same position as actual place bins in the environment
            fig_pos = (4 - (bin - 1) // 5) * 5 + ((bin - 1) % 5) + 1

            ax = fig.add_subplot(5, 5, fig_pos)
            ax.set_title(f'Bin #{bin}', fontsize=12)
            ax.hist(dist, bins=50)
            ax.axvline(x=dist_mean, color='C1', linewidth=1, linestyle='--')
            ax.axvline(x=dist_median, color='green', linewidth=1, linestyle='--')
            ax.set_xlim(0, ax.get_xlim()[1])
            if fig_pos == 23:
                ax.set_xlabel('Firing rate (Hz)', fontsize=14)
            if fig_pos == 11:
                ax.set_ylabel('Count', fontsize=14)
        
        if to_save:
            if not os.path.exists(figsave_dir):
                os.makedirs(figsave_dir)
            plt.savefig(f'{figsave_dir}/placedist_{label}.png', bbox_inches='tight')
        plt.close()

#### Bayesian place decoder

In [None]:
class BayesDecoder:
    def __init__(self, dist):
        self.dist = dist
        self.num_cells = dist.shape[0]
        self.num_bins = dist.shape[1]

    def gaussian_pdf(x, mu, sig):
        return (1 / np.sqrt(2 * np.pi * sig**2)) * np.exp(-(x - mu)**2 / (2 * sig**2))
    
    def __likelihood(self, x, cell, bin):
        mu, sig = self.dist[cell,bin,0], self.dist[cell,bin,1]
        return BayesDecoder.gaussian_pdf(x, mu, sig)
    
    def __predict_cell(self, x, cell):
        posterior = np.zeros(self.num_bins)
        for bin in range(1, self.num_bins):
            prior = self.dist[cell,bin,2]
            posterior[bin] = prior * self.__likelihood(x, cell, bin)
        pred = np.argmax(posterior[1:])
        return pred, posterior[pred] 

    def predict(self, x):
        prediction, confidence = np.zeros(self.num_cells), np.zeros(self.num_cells)
        for cell in range(self.num_cells):
            pred, conf = self.__predict_cell(x[cell], cell)
            prediction[cell] = pred
            confidence[cell] = conf
        prediction, counts = np.unique(prediction, return_counts=True)
        return prediction[np.argmax(counts)]

In [None]:
place_decoder = BayesDecoder(all_place_responses)