In [1]:
import os
import glob
import numpy as np
import pandas as pd
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

from scipy.ndimage import gaussian_filter1d

from scipy import stats
import statsmodels.formula.api as smf

sns.set_context('poster')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# import Python3_OpenOE_AC_map_functions_v1_08_30s as oem
import mz_LFP_functions as mz_LFP

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
# for publication quality plots
def set_pub_plots(pal=sns.blend_palette(["gray","crimson", 'cyan', 'magenta', 'purple'  ],5)):
    sns.set_style("white")
    sns.set_palette(pal)
    sns.set_context("poster", font_scale=1.5, rc={"lines.linewidth": 2.5, "axes.linewidth":2.5, 'figure.facecolor': 'white'}) 
    sns.set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
    # optional, makes markers bigger, too, axes.linewidth doesn't seem to work
    plt.rcParams['axes.linewidth'] = 2.5

rc_pub={'font.size': 25, 'axes.labelsize': 25, 'legend.fontsize': 25.0, 
    'axes.titlesize': 25, 'xtick.labelsize': 25, 'ytick.labelsize': 25, 
    #'axes.color_cycle':pal, # image.cmap - rewritesd the default colormap
    'axes.linewidth':2.5, 'lines.linewidth': 2.5,
    'xtick.color': 'black', 'ytick.color': 'black', 'axes.edgecolor': 'black','axes.labelcolor':'black','text.color':'black'}
# to restore the defaults, call plt.rcdefaults() 

#set_pub_bargraphs()
set_pub_plots()

In [None]:
pal=sns.blend_palette(['gray','crimson'],2)
sns.palplot(pal)
sns.set_palette(pal)

# Load some necessary variables

In [None]:
insert_depth = 3100  #change this as appropriate

sp_bw_ch = 20/2

surface_ch = np.round(insert_depth/sp_bw_ch)
V1_hip_ch = np.round((insert_depth-1100)/sp_bw_ch)
Hip_thal_ch = np.round((insert_depth-1100-1200)/sp_bw_ch)

CA1_DG_ch = np.round((insert_depth-1100-600)/sp_bw_ch)

print(surface_ch, V1_hip_ch, Hip_thal_ch, CA1_DG_ch)

In [5]:
samples_tr = 7350 #this is based on the shortest #samples in a trial
sr = 2500
n_chan = 384
rec_length = 3.0 #how long is the arduino triggered?

---

# First, load in the .npy arrays and CC_ls
These were creaded and saved using the "1_saving_LFP_arrays" jupyter notebook

In [6]:
all_pre_arr = np.load(r"D:\mz_Data\saved_dfs\Multi_brain_regions\LFPs\pre_all.npy")
all_post_arr = np.load(r"D:\mz_Data\saved_dfs\Multi_brain_regions\LFPs\post_all.npy")
all_novel_arr = np.load(r"D:\mz_Data\saved_dfs\Multi_brain_regions\LFPs\novel_all.npy")

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\Multi_brain_regions\LFPs\pre_et_ls"

open_file = open(pkl_file, "rb")
et_ls_pre = pickle.load(open_file)
open_file.close()

et_ls_pre

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\Multi_brain_regions\LFPs\post_et_ls"

open_file = open(pkl_file, "rb")
et_ls_post = pickle.load(open_file)
open_file.close()

et_ls_post

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\Multi_brain_regions\LFPs\novel_et_ls"

open_file = open(pkl_file, "rb")
et_ls_novel = pickle.load(open_file)
open_file.close()

et_ls_novel

---

# Plot the VEP ch trace

### First, we have to define a function

In [None]:
print(V1_hip_ch)
print(Hip_thal_ch)

In [15]:
def VEP_lines(data_array, et_ls):

    all_mean = data_array.mean(axis=0)

    print('Group array: {0}'.format(data_array.shape))
    print('Group mean: {0}'.format(all_mean.shape))

    V1 = all_mean[250:275, :]
    min_v1 = np.where(V1 == np.amin(V1))
    min_ch_v1 = min_v1[0][0] + 249
    print(min_ch_v1)
    
    HPC = all_mean[80:200, :]
    min_HPC = np.where(HPC == np.amin(HPC))
    min_ch_HPC = min_HPC[0][0] + 79
    print(min_ch_HPC)
    
    return all_mean, min_ch_v1, min_ch_HPC

### Second, apply the function to the three situations

In [None]:
# Pre
pre_mean, pre_ch_v1, pre_ch_HPC = VEP_lines(all_pre_arr, et_ls_pre)

In [None]:
# Post
post_mean, post_ch_v1, post_ch_HPC = VEP_lines(all_post_arr, et_ls_post)

In [None]:
# Novel
novel_mean, novel_ch_v1, novel_ch_HPC = VEP_lines(all_novel_arr, et_ls_novel)

### Finally, you can plot the graphs

In [None]:
plt.figure(figsize=(10, 5))

pre_V1_trace = pre_mean[pre_ch_v1,:]
time_arr2_A = np.linspace(0, pre_V1_trace.shape[0]/sr, pre_V1_trace.shape[0])

pre_HPC_trace = pre_mean[pre_ch_HPC,:]
time_arr2_B = np.linspace(0, pre_HPC_trace.shape[0]/sr, pre_HPC_trace.shape[0])

plt.plot(time_arr2_A, pre_V1_trace, label='v1', color='grey')
plt.plot(time_arr2_B, pre_HPC_trace, label='hpc', color='crimson')
plt.legend(loc="lower right")

plt.title('Pre-training')
plt.xlabel('Time (s)')
plt.ylabel('uV')
plt.ylim([-400,300])
plt.axvspan(0.5, 0.7, alpha=0.2, facecolor='b')

# Change the end of this next line with a new file name!!
# plt.savefig(r"D:\mz_Data\DATA_Figs\HDAC\LFP\pre_VEP_trace.pdf", transparent=True) # fix this before running

plt.show()

In [None]:
plt.figure(figsize=(10, 5))

post_V1_trace = post_mean[post_ch_v1,:]
time_arr2_A = np.linspace(0, post_V1_trace.shape[0]/sr, post_V1_trace.shape[0])

post_HPC_trace = post_mean[post_ch_HPC,:]
time_arr2_B = np.linspace(0, post_HPC_trace.shape[0]/sr, post_HPC_trace.shape[0])

plt.plot(time_arr2_A, post_V1_trace, label='v1', color='grey')
plt.plot(time_arr2_B, post_HPC_trace, label='hpc', color='crimson')
plt.legend(loc="lower right")

plt.title('Post-training')
plt.xlabel('Time (s)')
plt.ylabel('uV')
plt.ylim([-400,300])
plt.axvspan(0.5, 0.7, alpha=0.2, facecolor='b')

# Change the end of this next line with a new file name!!
# plt.savefig(r"D:\mz_Data\DATA_Figs\HDAC\LFP\pre_VEP_trace.pdf", transparent=True) # fix this before running

plt.show()

In [None]:
plt.figure(figsize=(10, 5))

novel_V1_trace = novel_mean[novel_ch_v1,:]
time_arr2_A = np.linspace(0, novel_V1_trace.shape[0]/sr, novel_V1_trace.shape[0])

novel_HPC_trace = novel_mean[novel_ch_HPC,:]
time_arr2_B = np.linspace(0, novel_HPC_trace.shape[0]/sr, novel_HPC_trace.shape[0])

plt.plot(time_arr2_A, novel_V1_trace, label='v1', color='grey')
plt.plot(time_arr2_B, novel_HPC_trace, label='hpc', color='crimson')
plt.legend(loc="lower right")

plt.title('Novel')
plt.xlabel('Time (s)')
plt.ylabel('uV')
plt.ylim([-400,300])
plt.axvspan(0.5, 0.7, alpha=0.2, facecolor='b')

# Change the end of this next line with a new file name!!
# plt.savefig(r"D:\mz_Data\DATA_Figs\HDAC\LFP\pre_VEP_trace.pdf", transparent=True) # fix this before running

plt.show()

In [None]:
print(pre_V1_trace.shape)
print(pre_HPC_trace.shape)

print(post_V1_trace.shape)
print(post_HPC_trace.shape)

print(novel_V1_trace.shape)
print(novel_HPC_trace.shape)

# Quantify PLV for the V1-HPC relation

In [41]:
import numpy as np
import scipy.signal as sig

def hilphase(y1,y2):
    sig1_hill=sig.hilbert(y1)
    sig2_hill=sig.hilbert(y2)
    pdt=(np.inner(sig1_hill,np.conj(sig2_hill))/
         (np.sqrt(np.inner(sig1_hill,np.conj(sig1_hill))
                  *np.inner(sig2_hill,np.conj(sig2_hill)))))
    phase = np.angle(pdt)
    return phase

def PLV_row(fixed_arr, ls_other_arrs):
    spot1 = hilphase(fixed_arr, ls_other_arrs[0])
    spot2 = hilphase(fixed_arr, ls_other_arrs[1])
#     spot3 = hilphase(fixed_arr, ls_other_arrs[2])
    out_arr = np.append(spot1, spot2)
    return out_arr

In [42]:
v1_pre = PLV_row(pre_V1_trace, [pre_V1_trace, pre_HPC_trace])
hip_pre = PLV_row(pre_HPC_trace, [pre_V1_trace, pre_HPC_trace])
pre_PLV = np.vstack((v1_pre,hip_pre)) #probably should take the abs() of this?

v1_post = PLV_row(post_V1_trace, [post_V1_trace, post_HPC_trace])
hip_post = PLV_row(post_HPC_trace, [post_V1_trace, post_HPC_trace])
post_PLV = np.vstack((v1_post,hip_post)) #probably should take the abs() of this?

v1_novel = PLV_row(novel_V1_trace, [novel_V1_trace, novel_HPC_trace])
hip_novel = PLV_row(novel_HPC_trace, [novel_V1_trace, novel_HPC_trace])
novel_PLV = np.vstack((v1_novel,hip_novel)) #probably should take the abs() of this?

In [None]:
# plt.rcdefaults()
plt.figure(figsize=(12, 4))

plt.subplot(1,3,1)
sns.heatmap(pre_PLV, 
            center=0, vmin=-5, vmax=5, annot=True,
            xticklabels=['V1','HPC'], 
            yticklabels=['V1','HPC'], cbar=False)
plt.title('Pre')

plt.subplot(1,3,2)
sns.heatmap(post_PLV, 
            center=0, vmin=-5, vmax=5, annot=True,
            xticklabels=['V1','HPC'], 
            yticklabels=['',''], cbar=False)
plt.title('Post')

ax3 = plt.subplot(1,3,3)
sns.heatmap(novel_PLV, 
            center=0, vmin=-5, vmax=5, annot=True,
            xticklabels=['V1','HPC'], 
            yticklabels=['',''], cbar=False)
plt.title('Novel')

plt.show()