# Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
import os, json
import xarray as xr
from datetime import datetime
from sklearn.neighbors import KernelDensity

# Basic functions

In [2]:
def read_timestamp_series(s):
    output_time = []
    fmt = '%H:%M:%S.%f'
    if s.dtype != np.float64:
        for current_time in s:
            str_time = str(current_time).strip()
            try:
                t = datetime.strptime(str_time, fmt)
            except ValueError as v:
                ulr = len(v.args[0].partition('unconverted data remains: ')[2])
                if ulr:
                    str_time = str_time[:-ulr]
            try:
                output_time.append((datetime.strptime(str_time, '%H:%M:%S.%f') - datetime.strptime('00:00:00.000000', '%H:%M:%S.%f')).total_seconds())
            except ValueError:
                output_time.append(np.nan)
        output_time = np.array(output_time)
    else:
        output_time = s.values
    return output_time

def read_timestamp_file(path):
    # read data and set up format
    s = pd.read_csv(path, encoding='utf-8', engine='c', header=None).squeeze()
    if s[0] == 0:
        s = s[1:]
    camT = read_timestamp_series(s)
    return camT

def calc_PSTH(spikeT, eventT, bandwidth=10, resample_size=1, edgedrop=15, win=1000):
    """
    calcualtes for a single cell at a time

    bandwidth (msec)
    resample_size (msec)
    edgedrop (msec to drop at the start and end of the window so eliminate artifacts of filtering)
    win = 1000msec before and after
    """

    bandwidth = bandwidth / 1000
    resample_size = resample_size / 1000
    win = win / 1000
    edgedrop = edgedrop / 1000
    edgedrop_ind = int(edgedrop / resample_size)

    bins = np.arange(-win-edgedrop, win+edgedrop+resample_size, resample_size)

    # Timestamps of spikes (`sps`) relative to `eventT`
    sps = []
    for i, t in enumerate(eventT):
        sp = spikeT-t
        # Only keep spikes in this window
        sp = sp[(sp <= (win+edgedrop)) & (sp >= (-win-edgedrop))] 
        sps.extend(sp)
    sps = np.array(sps)

    kernel = KernelDensity(kernel='gaussian', bandwidth=bandwidth).fit(sps[:, np.newaxis])
    density = kernel.score_samples(bins[:, np.newaxis])

    # Multiply by the # spikes to get spike count per point. Divide
    # by # events for rate/event.
    psth = np.exp(density) * (np.size(sps ) / np.size(eventT))

    # Drop padding at start & end to eliminate edge effects.
    psth = psth[edgedrop_ind:-edgedrop_ind]

    return psth

In [3]:
def list_subdirs(rootdir, name_only=False):
    """ List subdirectories in a root directory.

    without keep_parent, the subdirectory itself is named
    with keep_parent, the subdirectory will be returned *including* its parent path
    """
    paths = []; names = []
    for item in os.scandir(rootdir):
        if os.path.isdir(item):
            if item.name[0]!='.':
                paths.append(item.path)
                names.append(item.name)

    if not name_only:
        return paths
    elif name_only:
        return names
    
import os, fnmatch

def find(pattern, path):
    """ Glob for subdirectories.

    Parameters
    --------
    pattern : str
        str with * for missing sections of characters
    path : str
        path to search, including subdirectories
    
    Returns
    --------
    result : list
        list of files matching pattern.
    """
    result = [] # initialize the list as empty
    for root, _, files in os.walk(path): # walk though the path directory, and files
        for name in files:  # walk to the file in the directory
            if fnmatch.fnmatch(name,pattern):  # if the file matches the filetype append to list
                result.append(os.path.join(root,name))
    return result # return full list of file of a given type

# Paths

In [None]:
list_subdirs('/home/niell_lab/Mounts/Goeppert/nlab-nas/freely_moving_ephys/ephys_recordings/060922/J611RN/')

In [4]:
recpath = '/home/niell_lab/Mounts/Goeppert/nlab-nas/freely_moving_ephys/ephys_recordings/060922/J611RN/hf4_sparsenoiseflashRAND/'

In [None]:
os.listdir(recpath)

# Ephys during playback

In [5]:
ephys_path = os.path.join(recpath,'060922_J611RN_control_Rig2_hf4_sparsenoiseflashRAND_ephys_merge.json')
ephys_data = pd.read_json(ephys_path)

In [6]:
# sort by ch
ephys_data = ephys_data.sort_values(by='ch', axis=0, ascending=True)
ephys_data = ephys_data.reset_index()
ephys_data = ephys_data.drop('index', axis=1)

In [7]:
# good cells
ephys = ephys_data.loc[ephys_data['group']=='good']

In [8]:
# timing correction
offset = 0.1
drift = -0.000114

spike_times = {}
for i, ind in enumerate(ephys.index.values):
    sps = np.array(ephys.loc[ind,'spikeT'].copy()).astype(float)
    new_sps = sps - (offset + sps * drift)
    # new_sps = new_sps - t0
    spike_times[i] = new_sps
    

In [9]:
ephysT0 = ephys_data['t0'].iloc[0].copy()

# Flip times

In [None]:
world_path = find('*world.nc', recpath)[0]
world_data = xr.open_dataset(world_path)
worldVid = world_data.WORLD_video.values.astype(np.uint8).astype(float)
worldT_ = world_data.timestamps.values

In [None]:
worldT = worldT_.copy() - ephysT0

In [None]:
dStim = np.sum(np.abs(np.diff(worldVid, axis=0)), axis=(1,2))
dStim[0] = np.nan

In [None]:
dStim_thresh = 1e5

In [None]:
plt.plot(worldT[:-1], dStim)
plt.xlim([0, 30])
plt.hlines(dStim_thresh, 0, 30, 'k')

In [None]:
flips = np.argwhere((dStim[1:] > dStim_thresh) * (dStim[:-1] < dStim_thresh)).flatten()

In [None]:
eventT = worldT[flips+1]
eventT = eventT + (1/120)

In [None]:
plt.hist(np.diff(eventT), bins=10)

# Responses during playback

In [None]:
all_psth = np.zeros([len(ephys.index.values), 2001])

for i, spT in tqdm(spike_times.items()):
    
    sps = np.array(spT)
    psth = calc_PSTH(sps, eventT)
    
    all_psth[i,:] = psth

# plots

# rasters for example cells

In [None]:
# plot PSTHs
bins = np.arange(-1000,1001,1)
n_movs = 2000
xrange = 0.5

# fig, axs = plt.subplots(10,8, dpi=300, figsize=(15,15))
fig, axs = plt.subplots(1,5, dpi=300, figsize=(8,6))

for col, ind in enumerate(ephys.index.values[:5]):
    
    # row = int(np.floor(i/8))
    # col = int(i%8)
    
    sps = np.array(ephys.loc[ind,'spikeT'].copy()).astype(float)
    for n, t in enumerate(eventT[:n_movs]):
        sp = sps-t
        sp = sp[np.abs(sp)<=xrange]
        axs[col].plot(sp, np.ones(sp.size)*n, '|', color='k', markersize=1)
    
    # axs[row,col].plot(bins, psth)
    axs[col].vlines(0, 0, n_movs, color='r', linestyle='dashed')
    axs[col].set_ylim([0, n_movs])
    axs[col].set_title(ind)
    axs[col].set_xticklabels([])
    if col!=0:
        axs[col].set_yticklabels([])
    axs[col].set_xlim([-xrange, xrange])

fig.tight_layout()

In [None]:
# plot PSTHs
bins = np.arange(-1000,1001,1)

fig, axs = plt.subplots(10,9, dpi=300, figsize=(15,15))

for i, ind in enumerate(ephys.index.values):
    psth = all_psth[i,:]
    
    row = int(np.floor(i/9))
    col = int(i%9)
    
    axs[row,col].plot(bins, psth)
    axs[row,col].vlines(0, 0, np.max(psth)*1.1, color='k', linestyle='dashed')
    axs[row,col].set_ylim([0, np.max(psth)*1.1])
    axs[row,col].set_title(ind)
    axs[row,col].set_xticklabels([])
    axs[row,col].set_xlim([-500,500])

fig.tight_layout()
fig.savefig('/home/niell_lab/Desktop/060922_constant_sparse_noise_PSTHs.pdf')

In [None]:
np.save('/home/niell_lab/Desktop/constant_sn.npy', all_psth)

In [None]:
tempT = np.array(saccFr['gazeL']+saccFr['gazeR'])
tempT = tempT[(tempT<(60*5))]
tempT = np.arange(0, 5.016, 0.016)[tempT]
tempT

In [None]:
# plot raster
tlen = 5

fig, ax = plt.subplots(1,1, figsize=(10,5), dpi=300)
sps = np.array(ephys.loc[11, 'spikeT'].copy())
for i, t in enumerate(stim_restart_times):
    sps_i = sps[(sps>t) * (sps<(t+tlen))]
    sps_i = sps_i - t
    ax.plot(sps_i, np.ones(len(sps_i))*i, '|', color='k')
ax.vlines(tempT, 0, 19, 'r')
ax.set_xlim([0,5])

# Plot of constant and random SN

In [None]:
rand_psth = np.load('/home/niell_lab/Desktop/random_sn.npy')

In [None]:
# plot PSTHs
bins = np.arange(-1000,1001,1)

fig, axs = plt.subplots(10,9, dpi=300, figsize=(15,15))

for i, ind in enumerate(ephys.index.values):
    psth = all_psth[i,:]
    
    row = int(np.floor(i/9))
    col = int(i%9)
    
    set_max = np.max([psth, rand_psth[i,:]])
    
    axs[row,col].plot(bins, psth, 'k', label='constant')
    axs[row,col].plot(bins, rand_psth[i,:], 'r', label='variable')
    axs[row,col].vlines(0, 0, set_max*1.1, color='k', linestyle='dashed')
    axs[row,col].set_ylim([0, set_max*1.1])
    axs[row,col].set_title(ind)
    axs[row,col].set_xlim([-500,500])
    
    if i == 0:
        axs[row,col].legend()
    # else:
        # axs[row,col].set_xticklabels([])

fig.tight_layout()
fig.savefig('/home/niell_lab/Desktop/060922_both_sparse_noise_PSTHs.pdf')