In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import MultipleLocator # for minor ticks
from matplotlib import gridspec
from matplotlib import patches

import numpy as np

from scipy import signal # for convolution
import scipy.stats as scstats
import scipy.special as sps

import pickle
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)

import sys
sys.path.append("..")

from neuroprob.utils import stats, tools, neural_utils, field_passes
import neuroprob.models as mdl

from tqdm.notebook import tqdm

dev = tools.PyTorch()

plt.style.use(['paper.mplstyle'])

### Load data

In [None]:
### Real data ###
sample_bin, track_samples, x_t, y_t, s_t, dir_t, hd_t, eeg_t, theta_t, \
    hilbert_amp, hilbert_theta, pause_ind, pause_size, \
    sep_t_spike, clu_id, t_spike, spike_samples, units, \
    shank_id, local_clu, FR_waveshape, SpkWidthC, \
    refract_viol, sess_avg_rate, isolation_dist, LV, ISI, \
    left_x, right_x, bottom_y, top_y = pickle.load(open('./saves/datasets/hc5_13.p', 'rb'))

In [None]:
max_speed = s_t.max()
wrap_theta_t = tools.WrapPi(theta_t, True)
arena_width = right_x - left_x
arena_height = top_y - bottom_y

grid_size = (int(arena_width/2), int(arena_height/2))
grid_shape = (left_x, right_x, bottom_y, top_y)

### Visualize and compute trajectories

In [None]:
# animate trajectory
unit = 21
neuron = unit_used[unit]
traj = 9

start = neurpass_start[unit][traj]
tsize = neurpass_len[unit][traj]
end = start+tsize
passes = neurpass_len[unit].shape[0]

images = [] # 125 steps is 0.1 s, so 50 fps implies steps of 25
for tt in range(start, end, 250):
    fig, ax, im = tools.draw_2d(np.transpose(smth_used_rate[unit]), origin='lower', aspect='equal', 
                            cmap='plasma', vmax=smth_used_rate[unit].max())
    tools.decorate_ax(ax)

    r = 1.5
    
    ax.set_anchor('N')
    ax.arrow(x_t[tt]/delta_bin_x, y_t[tt]/delta_bin_y, r*np.cos(hd_t[tt]), r*np.sin(hd_t[tt]), color='w', 
             width=0.2, head_width=0.8, head_length=1.0)
    ax.scatter(x_t[tt]/delta_bin_x, y_t[tt]/delta_bin_y, marker='o', color='lime')

    #fig.tight_layout()
    #fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
    images.append(tools.render_image(fig))
    plt.close(fig)

tools.generate_gif(images, './out.gif', fps=25)

In [None]:
dir_bins = 4
dir_bin = np.linspace(0, 2*np.pi+1e-3, dir_bins+1)
tg_dir = np.digitize(dir_t, dir_bin)-1


# theta and covariate correlations
spiketimes = []
traj_len = []
for traj in range(len(neurpass_centre[unit])):
    spks = neurpass_spikes[unit][traj][tg_dir[neurpass_spikes[unit][traj]] == 0]
    spiketimes.append(spks)
    
    start = neurpass_start[unit][traj]
    end = start+neurpass_len[unit][traj]
    traj_len.append(np.cumsum(s_t[start:end]*sample_bin)[spks-start])
    
spiketimes = np.concatenate(spiketimes)
traj_len = np.concatenate(traj_len)
posx = x_t[spiketimes]
posy = y_t[spiketimes]

phases = wrap_theta_t[spiketimes]

In [None]:
fig = plt.figure(figsize=(10, 3))
widths = [2, 3]
heights = [1]
spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=widths,
                         height_ratios=heights, wspace=0.3)


# theta plot with spikes and trajectory plot
ax = fig.add_subplot(spec[0, 0])
ax.text(-0.21, 1.16, 'A', transform=ax.transAxes, size=15)
unit = 8 # for processed containers
traj_c = np.where(neurpass_centre[unit] == True)[0]
traj = traj_c[0]

grid_shape = (left_x, right_x, bottom_y, top_y)
_, ax = vs.visualize_field_(np.transpose(smth_used_rate[unit]), grid_shape, aspect='equal', 
                        cmap='viridis', vmax=smth_used_rate[unit].max(), figax=(fig, ax))

start = neurpass_start[unit][traj]
end = start+neurpass_len[unit][traj]

#ax.text(bins_x*0.9, 4.0, 'start', color='lime', fontdict={'weight': 'bold', 'size': 12})
#ax.text(bins_x*0.9, 1.0, 'end', color='aqua', fontdict={'weight': 'bold', 'size': 12})
ax.scatter(x_t[start], y_t[start], s=14, color='lime', label='start')
ax.scatter(x_t[end-1], y_t[end-1], s=14, color='r', label='end')
ax.plot(x_t[start:end], y_t[start:end], color='w')
lg = ax.legend(frameon=True, framealpha=0.7)
#plt.setp(lg.get_texts(), color='w')

spiketimes = neurpass_spikes[unit][traj]
spikephases = theta_t[spiketimes]
inds = np.where(wrap_theta_t[start:end] - wrap_theta_t[start-1:end-1] < 0)[0] + start # ends of theta cycles
time = np.arange(start, end)*sample_bin

ax = fig.add_subplot(spec[0, 1])
ax.text(-0.05, 1.05, 'B', transform=ax.transAxes, size=15)

ax.plot(time, eeg_t[start:end], color='b')
xposition = inds*sample_bin
for xc in xposition:
    ax.axvline(x=xc, color='k', linestyle='-', alpha=0.5)
ax.scatter(spiketimes*sample_bin, eeg_t[start:end].max()*1.5*np.ones(len(spiketimes)), marker='|', s=400, color='r')
ax.set_xlabel("$t$ (s)")
ax.set_xlim([time[2000], time[8000]])
ax.set_frame_on(False)
ax.get_yaxis().set_visible(False)

# add a comparison rod for the voltage scale in y direction
ax.annotate("",
    xy=(1.03, 0.2), xycoords='axes fraction',
    xytext=(1.03, 0.5), textcoords='axes fraction',
    arrowprops=dict(arrowstyle="-", lw=2, 
              connectionstyle="arc3, rad=0"),
    )
ax.text(1.04, 0.33, '5 mV', transform=ax.transAxes, fontsize=14)

plt.show()

In [None]:

ax = fig.add_subplot(spec[0, 1])

corr_p = np.copy(phases)
lower = 1.0
corr_p[corr_p < lower] += 2*np.pi
r, r_p = scstats.pearsonr(posx, corr_p) # Pearson r correlation test
print(r)
print(r_p)
plt.figure()
ax.scatter([posx, posx], [phases, phases+2*np.pi], s=1)
ax.axvline(x=place_centre[unit][0]*arena_width/bins_x, color='k', linestyle='-', alpha=0.5)
ax.xlabel('$x$ (mm)')
ax.ylabel(r'$\theta$')
ax.set_title('north')


corr_p = np.copy(phases)
lower = 1.0
corr_p[corr_p < lower] += 2*np.pi
r, r_p = scstats.pearsonr(posy, corr_p) # Pearson r correlation test
print(r)
print(r_p)
plt.figure()
plt.scatter([posy, posy], [phases, phases+2*np.pi], s=1)
plt.axvline(x=place_centre[unit][1]*arena_height/bins_y, color='k', linestyle='-', alpha=0.5)
plt.xlabel('$y$ (mm)')
plt.ylabel(r'$\theta$')
ax.set_title('west')


ax.set_title('south')


ax.set_title('east')

plt.show()

### Theta related quantities

In [None]:
x, y, omega_p, _, _, _, _, modamp_p = neural_utils.theta_fit(sample_bin, ISI[:10], dev=dev)
unit = 0
plt.plot(x, y[unit])
plt.show()

corr_eeg, fit_eeg, omega_eeg, tau_eeg, power_decay = \
    neural_utils.EEG_fit(sample_bin, eeg_t, lag_range=1250, time_window=1000000, lr=1e-3, dev=dev)

plt.plot(corr_eeg)
plt.plot(fit_eeg)
plt.show()

print(2*np.pi/omega_eeg) #0.1368731546016358

In [None]:
# Organization of cell assemblies in the hippocampus
corr, (freq, fourierTransform), theta_period, theta_index = neural_utils.theta_CCG(sample_bin, sep_t_spike[:], track_samples, dev=dev)

plt.plot(corr[:, 2, 1])
plt.show()
print(theta_index)

plt.plot(freq, np.abs(fourierTransform[:, 6]))
plt.show()

In [None]:
# phase precession
lag = 3000
fac = 50
g = int(np.ceil(theta_t.max()/2./np.pi))
bin_window = g - 101
cbin_theta = np.linspace(0, g*2*np.pi+1e-3, fac*g)
corr, (freq, fourierTransform), precess_index = neural_utils.precess_CCG(sep_t_spike[:10], theta_t, cbin_theta, 
                                                                         lag, fac, bin_window, start_points=[0], dev=dev)

n = 0
plt.plot(np.arange(lag)[:60]/fac, corr[:60, n])
plt.grid('on')
plt.show()

plt.plot(freq[:100], np.abs(fourierTransform[:100, n]))
print(freq)
print(precess_index)