In [None]:
import numpy as np
import scipy
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.fftpack import fft, ifft

import pingouin as pg

from glob import glob
import pickle


In [None]:
sns.set_context('poster')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# for publication quality plots
def set_pub_plots(pal=sns.blend_palette(["gray","crimson", 'cyan', 'magenta', 'purple'  ],5)):
    sns.set_style("white")
    sns.set_palette(pal)
    sns.set_context("poster", font_scale=1.5, rc={"lines.linewidth": 2.5, "axes.linewidth":2.5, 'figure.facecolor': 'white'}) 
    sns.set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
    plt.rcParams['axes.linewidth'] = 2.5

rc_pub={'font.size': 25, 'axes.labelsize': 25, 'legend.fontsize': 25.0, 
    'axes.titlesize': 25, 'xtick.labelsize': 25, 'ytick.labelsize': 25, 
    'axes.linewidth':2.5, 'lines.linewidth': 2.5,
    'xtick.color': 'black', 'ytick.color': 'black', 'axes.edgecolor': 'black','axes.labelcolor':'black','text.color':'black'}
# to restore the defaults, call plt.rcdefaults() 

set_pub_plots()

In [None]:
pal=sns.blend_palette(['black','royalblue'],2)
sns.palplot(pal)
sns.set_palette(pal)

---

In [None]:
def calc_plv(x, y, num_samples=7350, 
             Sampling_Rate=2500,
             base_idx = [0, 400],
             min_freq = 2,
             max_freq = 90,
             num_frex = 40,
             range_cycles = [3, 10]
            ):
    
    #frequencies vector
    frex = np.logspace(np.log10(min_freq),np.log10(max_freq),num_frex)
    time = np.linspace(0, num_samples, int(num_samples) )

    #wavelet parameters
    s = np.divide(np.logspace(np.log10(range_cycles[0]), np.log10(range_cycles[-1]), num_frex), 2*np.pi*frex)
    wavtime = np.linspace(-1, 1, 2*int(Sampling_Rate)+1)
    half_wave = (len(wavtime)-1)/2

    #FFT parameters
    nWave = len(wavtime)

    num_trials= x[:20].shape[0]
    nData = num_trials * num_samples
    nConv = [nWave+nData-1, nWave+nData-1 ,  nWave+num_samples-1 ]

    dataX = {}
    dataY = {}
#             #FFT of total data
    dataX[0] = fft( x[:20].flatten(), nConv[0])
    dataY[0] = fft( y[:20].flatten(), nConv[0])

    tf = np.zeros((len(frex), num_samples) )
    phd = np.zeros((len(frex), num_samples) )
    
    #main loop
    for fi in range(len(frex)):
        # create wavelet and get its FFT
        # the wavelet doesn't change on each trial...
        wavelet  = np.exp(2*1j*np.pi*frex[fi]*wavtime) * np.exp(-wavtime**2/(2*s[fi]**2))    

        # need separate FFT 
        waveletX = fft(wavelet,nConv[0])
        waveletX = waveletX / max(waveletX)

        # notice that the fft_EEG cell changes on each iteration
        a_sig = ifft(waveletX*dataX[0],nConv[0])
        b_sig = ifft(waveletX*dataY[0],nConv[0])

        a_sig = a_sig[int(half_wave): int(len(a_sig)-half_wave)]
        b_sig = b_sig[int(half_wave): int(len(b_sig)-half_wave)]
        aphase = (np.angle(a_sig)+2*np.pi)%(2*np.pi)
        bphase = (np.angle(b_sig)+2*np.pi)%(2*np.pi)
        phased = aphase - bphase
#         phd[fi,:]=(phased.reshape(-1,2000))
        phd[fi,:]=pg.circ_mean(((phased.reshape(-1,num_samples)+2*np.pi)%(2*np.pi)))
        tf[fi,:]=np.abs(np.exp(1j*phased).reshape(-1,num_samples).sum(axis=0))/(phased.reshape(-1,num_samples).shape[0])
        
    return tf, phd

---

# First, load in the .npy arrays and CC_ls
These were creaded and saved using the "1_saving_LFP_arrays" jupyter notebook

In [None]:
all_pre_arr = np.load(r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\pre_all_trials.npy")
all_post_arr = np.load(r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\post_all_trials.npy")
all_novel_arr = np.load(r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\novel_all_trials.npy")

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\pre_et_ls"

open_file = open(pkl_file, "rb")
et_ls_pre = pickle.load(open_file)
open_file.close()

len(et_ls_pre)

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\post_et_ls"

open_file = open(pkl_file, "rb")
et_ls_post = pickle.load(open_file)
open_file.close()

len(et_ls_post)

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\novel_et_ls"

open_file = open(pkl_file, "rb")
et_ls_novel = pickle.load(open_file)
open_file.close()

len(et_ls_novel)

---

In [None]:
def split_mice_groups(data_array, et_ls):
    nmda = []
    sham = []
    for i in range(data_array.shape[0]):
        nmda_ls = ['et1710', 'et1700', 'et1570', 'et1520', 'et171', 'et170', 'et157', 'et152']
        if et_ls[i] in nmda_ls:
            nmda.append(data_array[i])
        else:
            sham.append(data_array[i])

    nmda_arr = np.array(nmda)
    sham_arr = np.array(sham)
    return nmda_arr, sham_arr

In [None]:
def VEP_lines(nmda_arr, sham_arr):
    mean_nmda = nmda_arr.mean(axis=0)#.mean(axis=0)
    mean_sham = sham_arr.mean(axis=0)#.mean(axis=0)

    print('Group nmda array: {0}'.format(nmda_arr.shape))
    print('Group nmda mean: {0}'.format(mean_nmda.shape))
    print('Group sham array: {0}'.format(sham_arr.shape))
    print('Group sham mean: {0}'.format(mean_sham.shape))

#     V1_nmda = mean_nmda[:, 0:100, :]
#     min_nmda = np.where(V1_nmda == np.amin(V1_nmda))
#     min_ch_nmda = min_nmda[0][0]
#     V1_sham = mean_sham[:, 0:100, :]
#     min_sham = np.where(V1_sham == np.amin(V1_sham))
#     min_ch_sham = min_sham[0][0]
    
    min_ch_nmda, min_ch_sham = 65,65
    
    return mean_nmda, mean_sham, min_ch_nmda, min_ch_sham

---

In [None]:
my_input = input('pre/post/novel: ')


if my_input == 'pre':
    my_array = all_pre_arr
    et_ls = et_ls_pre
elif my_input == 'post':
    my_array = all_post_arr
    et_ls = et_ls_post
elif my_input == 'novel':
    my_array = all_novel_arr
    et_ls = et_ls_novel


In [None]:
nmda_arr, sham_arr = split_mice_groups(my_array, et_ls)
mean_nmda, mean_sham, min_ch_nmda, min_ch_sham = VEP_lines(nmda_arr, sham_arr)

plvdic = []
nmda_channel = mean_nmda[:, min_ch_nmda, :]
sham_channel = mean_sham[:, min_ch_sham, :]
tf,phd = calc_plv(nmda_channel, sham_channel)
plvdic.append((tf, phd))

tmpdf=pd.DataFrame(plvdic, columns=['plv','phsdiff'])
print(tf.shape, phd.shape)
tmpdf.head()

In [None]:
base_idx = [0, 400]
min_freq = 2
max_freq = 90 #40,50
num_frex = 40
range_cycles = [3, 10]

# data info
Sampling_Rate = 2500.
num_samples = 7350

#frequencies vector
frex = np.logspace(np.log10(min_freq),np.log10(max_freq),num_frex)
time = np.linspace(0, num_samples, int(num_samples) )

# Plot mean VEPs and PLV together

In [None]:
mean_nmda_trs = mean_nmda.mean(axis=0)
mean_sham_trs = mean_sham.mean(axis=0)

nmda_mean_ch = 65
sham_mean_ch = 65

sr=2500

In [None]:
f,ax=plt.subplots(2,1,figsize=(12,4*2))

#VEP line plots
mean_ch_tracenmda = mean_nmda_trs[nmda_mean_ch,:]
time_arr2_nmda = np.linspace(0, mean_ch_tracenmda.shape[0]/sr, mean_ch_tracenmda.shape[0])
mean_ch_tracesham = mean_sham_trs[sham_mean_ch,:]
time_arr2_sham = np.linspace(0, mean_ch_tracesham.shape[0]/sr, mean_ch_tracesham.shape[0])
ax[0].plot(time_arr2_sham, mean_ch_tracesham, label='Sham', color='black')
ax[0].plot(time_arr2_nmda, mean_ch_tracenmda, label='NMDA', color='royalblue')
ax[0].legend(loc="lower right")
ax[0].set_title(my_input)
ax[0].set_xlabel('Time (s)')
ax[0].set_ylim([-360,210])
ax[0].set_yticks([-300,-150,0,150])
ax[0].set_xlim([0,3.0])
ax[0].set_ylabel('uV')
ax[0].axvspan(0.5, 0.7, alpha=0.2, facecolor='grey')


# PLV plot
all_tmp=np.mean(np.stack(tmpdf.plv.values), axis=0)
tf_plot = ax[1].contourf(time, frex, all_tmp, cmap='jet', extend='both', levels=np.linspace(0.,1.0,60))
ax[1].set_yscale('log')
ax[1].set_xticks(np.arange(0,7500,1250))
ax[1].set_xticklabels(np.arange(0,3,0.5))
ax[1].set_yticks(np.logspace(np.log10(min_freq),np.log10(max_freq),6))
ax[1].set_yticklabels(np.round(np.logspace(np.log10(min_freq),np.log10(max_freq),6)))
ax[1].set_title(f'Sham/NMDA PLV: {my_input}')
ax[1].axvspan(2500*0.5, 2500*0.7, ec ='black', lw=3, fill=False)
ax[1].set_xlabel('Time (s)')
ax[1].set_ylabel('Frequency (Hz)')
cb_tf = f.add_axes([.92, 0.1, 0.02, 0.35])
cb_tf = f.colorbar(tf_plot,cax=cb_tf,ticks=np.arange(0,1.1,0.2))
cb_tf.set_label('PLV')

sns.despine()
# Change the end of this next line with a new file name!!
# plt.savefig(r"C:\Users\AChub_Lab\Desktop\zimmer94\PLV_trlbytrl_plots\nmda_post_PLV.pdf", transparent=True)
plt.show()