In [13]:
import os
import h5py
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import seaborn as sns
import pdb
from scipy.io import loadmat
from scipy import optimize
from math import pi, log2
import warnings

In [14]:
mat_files = [m for m in os.listdir("data") if m.endswith('mat')]
session_list = pd.read_excel('data/CacheRetrieveSessionList.xlsx', index_col=0)
fps = 20
cmap = cm.get_cmap('viridis')

## Helper Functions

In [15]:
def estimate_center(x, y):
    method_2 = "leastsq"

    def calc_R(xc, yc):
        """ calculate the distance of each 2D points from the center (xc, yc) """
        return np.sqrt((x-xc)**2 + (y-yc)**2)

    def f_2(c):
        """ calculate the algebraic distance between the data points and the mean circle centered at c=(xc, yc) """
        Ri = calc_R(*c)
        return Ri - Ri.mean()

    center_estimate = np.mean(x), np.mean(y)
    center_2, ier = optimize.leastsq(f_2, center_estimate)
    return center_2

In [16]:
def get_xy(f, in_bound=True):
    x = np.squeeze(np.array(f['X']))
    y = np.squeeze(np.array(f['Y']))
    x_c, y_c = estimate_center(x, y)
    x -= x_c; y -= y_c
    length = np.sqrt(np.square(x) + np.square(y))
    frames = np.arange(x.size)
    if in_bound:
        oob = np.logical_or(length <= 145, length >= 215)
        x = x[np.logical_not(oob)]
        y = y[np.logical_not(oob)]
        frames = frames[np.logical_not(oob)]
    return x, y, frames

In [17]:
def get_theta(f):
    x, y, frames = get_xy(f, in_bound=True)
    theta = np.arctan2(y, x)
    theta = np.mod(theta, 2*pi)
    boundaries = np.linspace(0, 2*pi, 16, endpoint=False)
    boundaries = np.append(boundaries, [2*pi])
    theta = np.digitize(theta, boundaries)
    return theta, frames

In [18]:
def get_velocity(f):
    x, y, frames = get_xy(f, in_bound=False)
    delta_x = x[1:] - x[:-1]
    delta_y = y[1:] - y[:-1]
    frames = frames[1:]
    velocity = np.sqrt(np.square(delta_x) + np.square(delta_y)) # pixels/frame
    velocity = velocity*fps # pixels/s
    smoothing_kernel = np.ones(fps)/fps
    velocity = np.convolve(velocity, smoothing_kernel, "valid")
    frames = frames[:velocity.size]
    return velocity, frames

In [19]:
def get_wedges(f):
    x, y, frames = get_xy(f, in_bound=True)
    theta = np.mod(np.arctan2(y, x), 2*pi)
    boundaries = np.linspace(0, 2*pi, 16, endpoint=False)
    boundaries = np.append(boundaries, [2*pi])
    wedges = np.digitize(theta, boundaries)
    wedges = np.mod(16-wedges, 16) + 1
    return wedges, frames

In [20]:
def get_fr(spikes):
    smoothing_kernel = np.ones(fps+1)/(fps+1) # One sec smoothing
    fr = np.convolve(spikes, smoothing_kernel, "valid")
    fr_frames = np.arange(spikes.size)[
        smoothing_kernel.size//2:-smoothing_kernel.size//2+1
        ]
    return fr, fr_frames

In [21]:
def circular_shuffle(spikes):
    spikes = spikes.copy()
    shift = np.random.choice(np.arange(1, spikes.size))
    return np.roll(spikes, shift)

## Calculating cache index matrix

In [32]:
def get_cache_index(cache_frames, noncache_frames, neur_fr, percentile=0.8):
    noncache_frs = []
    cache_fr = np.mean(neur_fr[cache_frames])
    for nc in noncache_frames:
        noncache_fr = np.sort(neur_fr[nc])
        noncache_fr = noncache_fr[int(noncache_fr.size*percentile):]
        noncache_fr = np.mean(noncache_fr)
        noncache_frs.append(noncache_fr)
    return np.sum(noncache_frs < cache_fr)/len(noncache_frs)

In [41]:
def get_cache_index_shuffled(cache_frames, noncache_frames, neur_fr, percentile=0.8):
    # EpisodeIndex(ni,ei) = sum(FR_othervisits<FR_episodei)/length(FR_othervisits)
    relevant_frames = [cache_frames]
    for nc in noncache_frames:
        relevant_frames.append(nc)
    relevant_frames = np.concatenate(relevant_frames)
    np.random.shuffle(relevant_frames)
    shuff_cache_frames = relevant_frames[:cache_frames.size]
    idx = cache_frames.size
    shuff_noncache_frames = []
    for nc in noncache_frames:
        shuff_nc = relevant_frames[idx:idx+nc.size]
        idx += nc.size
        shuff_noncache_frames.append(shuff_nc)
    return get_cache_index(
        shuff_cache_frames, shuff_noncache_frames, neur_fr, percentile
        )

In [42]:
def get_cache_frames(
    window, wedges, wedge_frames,
    cache_site_idx, cache_sites,
    cache_frames_poke, cache_frames_enter, cache_frames_exit
    ):
    
    cache_site = cache_sites[cache_site_idx]
    event_idxs = np.argwhere(cache_sites == cache_site).squeeze()
    cache_poke = cache_frames_poke[cache_site_idx]
    cache_enter = cache_frames_enter[cache_site_idx]
    cache_exit = cache_frames_exit[cache_site_idx]
    cache_frames = np.arange(cache_poke-window, cache_poke+window+1)
    all_cache_exit = cache_frames_exit[cache_sites==cache_site]
    all_cache_enter = cache_frames_enter[cache_sites==cache_site]
    visit_frames = [
        np.arange(c, all_cache_exit[i]) for i, c in enumerate(all_cache_enter)
        ]
    visit_frames = np.concatenate(visit_frames)
    wedge_frames = wedge_frames[wedges == cache_site]
    wedge_frames = wedge_frames[np.logical_not(np.isin(wedge_frames, visit_frames))]
    wedge_frames = np.split(
        wedge_frames, np.where(np.diff(wedge_frames) != 1)[0]+1
        )
    return cache_frames, wedge_frames


In [45]:
def find_specific_cache(find_func, window):
    results = {}
    for mat_file in mat_files:
        f = h5py.File("data/" + mat_file, 'r')
        _, wedge_frames = get_wedges(f)
        wedges = np.array(f['whichWedge']).squeeze()
        wedges = wedges[np.isin(np.arange(wedges.size), wedge_frames)]
        cache_sites = np.array(f['CacheSites']).squeeze()
        cache_frames_poke = np.array(f['CacheFrames']).squeeze().astype(int) - 1
        cache_frames_enter = np.array(f['CacheFramesEnter']).squeeze().astype(int) - 1
        cache_frames_exit = np.array(f['CacheFramesExit']).squeeze().astype(int) - 1
        was_retrieval = np.array(f['ThisWasRetrieval']).squeeze().astype(bool)
        spikes = np.array(f['S'])
        results[mat_file] = np.zeros((spikes.shape[1], cache_sites.size))
        for cache_site_idx in range(len(cache_sites)):
            cache_frames, noncache_frames = find_func(
                window, wedges, wedge_frames,
                cache_site_idx, cache_sites,
                cache_frames_poke, cache_frames_enter, cache_frames_exit
                )
            for neur in np.arange(spikes.shape[1]):
                neur_spikes = spikes[:, neur]
                neur_fr, fr_frames = get_fr(neur_spikes)
                cf = np.argwhere(np.isin(fr_frames, cache_frames))
                ncf = [np.argwhere(np.isin(fr_frames, ncf)) for ncf in noncache_frames]
                cache_info = get_cache_index(cf, ncf, neur_fr)
                shuffled_info = []
                shuffled_peak_fr = []
                for _ in range(110):
                    shuffled_info.append(
                        get_cache_index_shuffled(cf, ncf, neur_fr)
                        )
                shuffled_info = np.array(shuffled_info)
                high_cache_info = np.sum(shuffled_info < cache_info) > 0.99*shuffled_info.size
                if high_cache_info:
                    results[mat_file][neur, cache_site_idx] = cache_info
    return results

In [None]:
results = find_specific_cache(get_cache_frames, window=30)

In [63]:
num_neurs = 0
for key in results.keys():
    print(key)
    cache_idx_matrix = results[key]
    sig_neurs = np.mean(cache_idx_matrix, axis=1)
    print(np.argwhere(sig_neurs > 0).squeeze())
    print()
    num_neurs += np.sum(sig_neurs>0)
print(num_neurs)

ExtractedWithXY_Cleaned184713_09102019.mat
[ 0  1  4  5  6  7  8  9 13 14 15 16 17 18 21 22 23 25 29 32 33 34 37 38
 39 41]

ExtractedWithXY_Cleaned184430_09102019.mat
[ 5  6  7  8  9 10 11 13 14 17 20 23 24 25 27 29 30 31 32 33 46 48]

ExtractedWithXY_Cleaned184526_09102019.mat
[ 3  4  5  7  8 11 13 15 17 19 20 21 22 23 28 29 31 34 35 37 38 39 42 43
 45]

ExtractedWithXY_Cleaned184946_09102019.mat
[ 3  5  6  8 10 14 15 17 18 19 20 21 22 27 28 29 30 31 32 33 34 35 36 37
 38 40 41 44 46 47 48 49 50 51 52 53 54 56 61]

ExtractedWithXY_Cleaned185033_09102019.mat
[ 0  2  5  7 10 11 12 16 17 18 19 20 21 22 24 26 27 28 29 31 32 35 36 38
 39 40 41 42 43 44 46 47 48 49 50]

ExtractedWithXY_Cleaned184331_09102019.mat
[ 6  7  9 11 14 18 23 24 26 27 29 31 32 33 34 36 38 39 40 43]

ExtractedWithXY_Cleaned144233_09112019.mat
[ 2  7  8  9 11 12 13 14 18 19 20 21 22 23 24 27 29 30 31 32 33 34 36 38
 41 44 45 46 49 50 51 53]

199


## Calculating cache index matrix with visit frames

In [65]:
def get_visit_frames(
    window, wedges, wedge_frames,
    cache_site, cache_sites,
    cache_frames_poke, cache_frames_enter, cache_frames_exit
    ):
    
    event_idxs = np.argwhere(cache_sites == cache_site).squeeze()
    cache_frames_enter = cache_frames_enter[event_idxs]
    cache_frames_exit = cache_frames_exit[event_idxs]
    cache_frames_poke = cache_frames_poke[event_idxs]
    wedge_frames = wedge_frames[wedges == cache_site]
    if event_idxs.size == 1:
        cache_frames_enter = [cache_frames_enter]
        cache_frames_exit = [cache_frames_exit]
        cache_frames_poke = [cache_frames_poke]
    visit_frames = [
        np.arange(enter, cache_frames_exit[i] + 1) for i, enter in enumerate(cache_frames_enter)
        ]
    visit_frames = np.concatenate(visit_frames)
    window_frames = []
    for i, enter in enumerate(cache_frames_enter):
        poke = cache_frames_poke[i]
        exit = cache_frames_exit[i]
        prepoke_time = poke - enter
        postpoke_time = exit - poke
        total_time = exit - enter
        if prepoke_time >= window and postpoke_time >= window:
            _window_frames = np.arange(poke-window, poke+window+1)
        elif prepoke_time < 30 and total_time > (window*2 + 1):
            _window_frames = np.arange(enter, enter+window*2+1)
        elif postpoke_time < 30 and total_time > (window*2 + 1):
            _window_frames = np.arange(exit-window*2, exit+1)
        else:
            _window_frames = np.arange(enter, exit+1)
        window_frames.append(_window_frames)
    if len(window_frames) == 0:
        return np.array([]), np.array([])
    window_frames = np.concatenate(window_frames)
    nonvisit = np.logical_not(np.isin(wedge_frames, visit_frames))
    nonvisit_frames = wedge_frames[nonvisit]
    return window_frames, nonvisit_frames


In [72]:
results = find_specific_cache(get_visit_frames, window=40)

> <ipython-input-71-1567a5d62fd7>(6)get_cache_index()
-> return np.sum(noncache_fr < mean_cache_fr)/noncache_fr.size
(Pdb) l
  1  	def get_cache_index(cache_frames, noncache_frames, neur_fr, percentile=0.95):
  2  	    mean_cache_fr = np.mean(neur_fr[cache_frames])
  3  	    noncache_fr = np.sort(neur_fr[noncache_frames])
  4  	    noncache_fr = noncache_fr[int(noncache_fr.size*percentile):]
  5  	    import pdb; pdb.set_trace()
  6  ->	    return np.sum(noncache_fr < mean_cache_fr)/noncache_fr.size
[EOF]
(Pdb) noncache_fr.size
139
(Pdb) cache_frames.shape
(124,)
(Pdb) c
> <ipython-input-71-1567a5d62fd7>(6)get_cache_index()
-> return np.sum(noncache_fr < mean_cache_fr)/noncache_fr.size
(Pdb) noncache_fr.siz
*** AttributeError: 'numpy.ndarray' object has no attribute 'siz'
(Pdb) noncache_fr.size
139
(Pdb) exit()


BdbQuit: 

In [78]:
num_neurs = 0
for key in results.keys():
    print(key)
    cache_idx_matrix = results[key]
    sig_neurs = np.mean(cache_idx_matrix, axis=1)
    threshold = 0.05
    print(np.argwhere(sig_neurs > threshold).squeeze())
    print()
    num_neurs += np.sum(sig_neurs > threshold)
print(num_neurs)

ExtractedWithXY_Cleaned184713_09102019.mat
[ 1  4  5  7  8 14 15 18 21 22 23 25 26 27 29 37 38]

ExtractedWithXY_Cleaned184430_09102019.mat
[ 8  9 10 13 23 24 25 27 30 33 36 39 46]

ExtractedWithXY_Cleaned184526_09102019.mat
[ 3  4  5  8 15 17 20 21 23 29 34 35 38 39 42 43]

ExtractedWithXY_Cleaned184946_09102019.mat
[17 18 27 28 29 31 36 37 38 46 48 49 50 53]

ExtractedWithXY_Cleaned185033_09102019.mat
[ 0  7 10 11 17 18 19 20 21 22 24 27 28 29 31 32 35 39 40 42 44 46 47 49
 50]

ExtractedWithXY_Cleaned184331_09102019.mat
[14 18 32 34 36 37 38 40 43]

ExtractedWithXY_Cleaned144233_09112019.mat
[ 9 19 22 27 30 32 33 34 36 44 45 46 53]

107


In [74]:
num_neurs = 0
for key in results.keys():
    print(key)
    cache_idx_matrix = results[key]
    sig_neurs = np.mean(cache_idx_matrix, axis=1)
    print(sig_neurs)
    print()
    num_neurs += np.sum(sig_neurs>0)
print(num_neurs)

ExtractedWithXY_Cleaned184713_09102019.mat
[0.         0.05079329 0.         0.         0.06526559 0.05651578
 0.0398951  0.05690377 0.06913333 0.         0.         0.
 0.         0.         0.07521125 0.06       0.         0.04825662
 0.10237099 0.         0.         0.09576011 0.13013154 0.06666667
 0.         0.0960751  0.05495119 0.06506667 0.         0.05651578
 0.         0.         0.         0.         0.04937238 0.
 0.         0.06666667 0.09772579 0.         0.         0.0430622 ]

ExtractedWithXY_Cleaned184430_09102019.mat
[0.         0.         0.04646714 0.         0.         0.
 0.         0.0381932  0.13994673 0.0608229  0.05635063 0.
 0.03918723 0.10737286 0.03039645 0.         0.         0.
 0.         0.         0.00780379 0.         0.         0.11222594
 0.24456872 0.08977544 0.         0.07692308 0.         0.04065934
 0.07692308 0.03980322 0.         0.05837768 0.         0.
 0.07692308 0.         0.         0.06753131 0.         0.
 0.         0.         0.     