In [None]:
import os
import glob
import numpy as np
import pandas as pd
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

from scipy.ndimage import gaussian_filter1d

from scipy import stats
import statsmodels.formula.api as smf

# # For TF analysis
import scipy.fftpack
from scipy.fftpack import fft, ifft
from scipy import signal

sns.set_context('poster')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# import Python3_OpenOE_AC_map_functions_v1_08_30s as oem
import mz_LFP_functions as mz_LFP


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# for publication quality plots
def set_pub_plots(pal=sns.blend_palette(["gray","crimson", 'cyan', 'magenta', 'purple'  ],5)):
    sns.set_style("white")
    sns.set_palette(pal)
    sns.set_context("poster", font_scale=1.5, rc={"lines.linewidth": 2.5, "axes.linewidth":2.5, 'figure.facecolor': 'white'}) 
    sns.set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
    # optional, makes markers bigger, too, axes.linewidth doesn't seem to work
    plt.rcParams['axes.linewidth'] = 2.5

rc_pub={'font.size': 25, 'axes.labelsize': 25, 'legend.fontsize': 25.0, 
    'axes.titlesize': 25, 'xtick.labelsize': 25, 'ytick.labelsize': 25, 
    #'axes.color_cycle':pal, # image.cmap - rewritesd the default colormap
    'axes.linewidth':2.5, 'lines.linewidth': 2.5,
    'xtick.color': 'black', 'ytick.color': 'black', 'axes.edgecolor': 'black','axes.labelcolor':'black','text.color':'black'}
# to restore the defaults, call plt.rcdefaults() 

#set_pub_bargraphs()
set_pub_plots()

In [None]:
pal=sns.blend_palette(['black','royalblue'],2)
sns.palplot(pal)
sns.set_palette(pal)

# Load some necessary variables

In [None]:
insert_depth = 1000  #change this as appropriate

sp_bw_ch = 20/2
samples_tr = 7350 #this is based on the shortest #samples in a trial
sr = 2500
n_chan = 384
rec_length = 3.0 #how long is the arduino triggered?

---

# First, load in the .npy arrays and CC_ls
These were creaded and saved using the "1_saving_LFP_arrays" jupyter notebook

In [None]:
all_pre_arr = np.load(r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\pre_all.npy")
all_post_arr = np.load(r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\post_all.npy")
all_novel_arr = np.load(r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\novel_all.npy")

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\pre_et_ls"

open_file = open(pkl_file, "rb")
et_ls_pre = pickle.load(open_file)
open_file.close()

len(et_ls_pre)

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\post_et_ls"

open_file = open(pkl_file, "rb")
et_ls_post = pickle.load(open_file)
open_file.close()

len(et_ls_post)

In [None]:
pkl_file = r"D:\mz_Data\saved_dfs\HPC_nmda\lfp_npy\novel_et_ls"

open_file = open(pkl_file, "rb")
et_ls_novel = pickle.load(open_file)
open_file.close()

len(et_ls_novel)

---

# Time Frequency plots
### First, make some functions

In [None]:
def regroup(all_array, et_ls, show=1):
    nmda = []
    sham = []
    for i in range(all_array.shape[0]):
        nmda_ls = ['et1710', 'et1700', 'et1570', 'et1520', 'et171', 'et170', 'et157', 'et152']
        if et_ls[i] in nmda_ls:
            nmda.append(all_array[i])
        else:
            sham.append(all_array[i])

    tf_sham_arr = np.array(sham)
    tf_nmda_arr = np.array(nmda)
    
    if show == 1:
        print(tf_sham_arr.shape)
        print(tf_nmda_arr.shape)
    
    return tf_sham_arr, tf_nmda_arr

In [None]:
def make_tf_data(group_arr):
    chs_ls = []
    for ii in range(group_arr.shape[0]):
        V1_region = group_arr[ii][0:100,:]
        min_ch = np.where(V1_region == np.amin(V1_region))
        min_ch2 = min_ch[0][0] + 0
        chs_ls.append(group_arr[ii][min_ch2,:])
        
    tf_plot = np.array(chs_ls)
    tf_plot_mean = np.mean(tf_plot, axis=0)
    tf_plot_mean = np.reshape(tf_plot_mean,(1,len(tf_plot_mean)))
    
    return tf_plot, tf_plot_mean

### Second, reestablish the A and B groups
These are 3d arrays with the following dimensions:
mice x ch x samples

In [None]:
sham_pre, nmda_pre = regroup(all_pre_arr, et_ls_pre, show=0)
sham_post, nmda_post = regroup(all_post_arr, et_ls_post, show=0)
sham_novel, nmda_novel = regroup(all_novel_arr, et_ls_novel, show=0)


### Third, apply the functions to find the strongest ch averaged across all mice
This differs from a previous cell commented out because it iterates through each mouse and find the strongest channel, which is appended to a list. The mean of this list is then used to plot the TF.

The important part is that the strongest channel from each mouse is used and not the overall strongest channel after averaging.
- Previous: average all recordings, then find strongest response
- Current: find strongest response for each mouse, then average

In [None]:
pre_tf_sham, mean_pre_sham = make_tf_data(sham_pre)
pre_tf_nmda, mean_pre_nmda = make_tf_data(nmda_pre)

post_tf_sham, mean_post_sham = make_tf_data(sham_post)
post_tf_nmda, mean_post_nmda = make_tf_data(nmda_post)

novel_tf_sham, mean_novel_sham = make_tf_data(sham_novel)
novel_tf_nmda, mean_novel_nmda = make_tf_data(nmda_novel)

# this is just printing an example shape to make sure it worked correctly
print('Example dimension check! \nShould go from (n,7350) to (1,7350)')
print(pre_tf_sham.shape)
print(mean_pre_sham.shape)

### Fourth, plot the individual TF heatmaps for each group
The cell below requires a _`user input`_ for the scenario you want to look at!

In [None]:
rew_selection = input('Scenario (pre, post, novel): ')

if rew_selection == 'pre':
    groupA_plot = mean_pre_sham
    groupB_plot = mean_pre_nmda
    plt_titleA = 'sham - pre'
    plt_titleB = 'nmda - pre'
    fnA = "pre_sham_heat.pdf"
    fnB = "pre_nmda_heat.pdf"
elif rew_selection == 'post':
    groupA_plot = mean_post_sham
    groupB_plot = mean_post_nmda
    plt_titleA = 'sham - post'
    plt_titleB = 'nmda - post'
    fnA = "post_sham_heat.pdf"
    fnB = "post_nmda_heat.pdf"
elif rew_selection == 'novel':
    groupA_plot = mean_novel_sham
    groupB_plot = mean_novel_nmda
    plt_titleA = 'sham - novel'
    plt_titleB = 'nmda - novel'
    fnA = "novel_sham_heat.pdf"
    fnB = "novel_nmda_heat.pdf"
else:
    raise Exception('Input is not one of the 3 options')

    
f_path_start = r"C:\Users\AChub_Lab\Desktop\tmp_nmda\tf" #change this file destination!!

In [None]:
f, ax1 = plt.subplots()
tf_A, time_A, frex_A, tf3d_A = mz_LFP.tf_cmw(ax=ax1, df_res=groupA_plot)
f,ax1.set_title(plt_titleA)
sns.despine()

# Change the end of this next line with a new file name!!
# f_path = f_path_start + '\\' + fnA
# plt.savefig(f_path, transparent=True)

plt.show()

In [None]:
f, ax2 = plt.subplots()
tf_B, time_B, frex_B, tf3d_B = mz_LFP.tf_cmw(ax=ax2,df_res=groupB_plot)
ax2.set_title(plt_titleB)
sns.despine()

# Change the end of this next line with a new file name!!
# f_path = f_path_start + '\\' + fnB
# plt.savefig(f_path, transparent=True)

plt.show()

### Fifth, rerun the TF code to extract the freq. band values
I have to rerun it on each mouse to get the confidence intervals

In [None]:

time_window = [0.7,2.0]


In [None]:
pre_sham_df = mz_LFP.TF_band_values(pre_tf_sham, time_window)
pre_nmda_df = mz_LFP.TF_band_values(pre_tf_nmda, time_window)

post_sham_df = mz_LFP.TF_band_values(post_tf_sham, time_window)
post_nmda_df = mz_LFP.TF_band_values(post_tf_nmda, time_window)

novel_sham_df = mz_LFP.TF_band_values(novel_tf_sham, time_window)
novel_nmda_df = mz_LFP.TF_band_values(novel_tf_nmda, time_window)

### Sixth, combine the two group dfs, maintaining an ID for each group

In [None]:
pre_sham_df['group'] = 'sham'
pre_sham_df['stim_id'] = 'pre'
pre_nmda_df['group'] = 'nmda'
pre_nmda_df['stim_id'] = 'pre'

post_sham_df['group'] = 'sham'
post_sham_df['stim_id'] = 'post'
post_nmda_df['group'] = 'nmda'
post_nmda_df['stim_id'] = 'post'

novel_sham_df['group'] = 'sham'
novel_sham_df['stim_id'] = 'novel'
novel_nmda_df['group'] = 'nmda'
novel_nmda_df['stim_id'] = 'novel'

overall_tf = pd.concat([pre_sham_df, pre_nmda_df, post_sham_df, post_nmda_df, novel_sham_df, novel_nmda_df])
overall_tf.head()

In [None]:
print(overall_tf.group.unique())
print(overall_tf.stim_id.unique())

### Seventh, plot the frequency band values
separated out from each other by the different groups

In [None]:
TF_plot1 = overall_tf[overall_tf['stim_id'] == 'pre']
plt_title1 = 'Pre-training'
TF_plot2 = overall_tf[overall_tf['stim_id'] == 'post']
plt_title2 = 'Post-training'
TF_plot3 = overall_tf[overall_tf['stim_id'] == 'novel']
plt_title3 = 'Novel'

plt_yticks = [-4,0,4,8,12,16,20]

In [None]:
# Plotting the barplot of the T-F plot

g = sns.catplot(x='variable', y='value', data=overall_tf, kind = 'bar', col='stim_id',
                hue='group', hue_order=['sham','nmda'],
                legend=False,
                height = 6, aspect=1.2,
                order=['4-8Hz', '8-12Hz', '12-30Hz', '30-40Hz','50-70Hz','30-70Hz'], 
                ci=68)

g.set_xticklabels(['4-8Hz', '8-12Hz', '12-30Hz', '30-40Hz','50-70Hz','30-70Hz'],
                           rotation=40, fontsize=20)

plt.xlabel('')
plt.yticks(plt_yticks)
plt.ylabel('Power (dB)')
plt.legend(loc="upper right")
sns.despine()

# Change the end of this next line with a new file name!!
# plt.savefig(r"C:\Users\AChub_Lab\Desktop\tmp_nmda\tf\tf_bar.pdf", transparent=True)

plt.show()

### Eighth, find the stats for the above plot

In [None]:
def TF_stats(df):
    tf_groups = np.array(['4-8Hz','8-12Hz','12-30Hz','30-40Hz','50-70Hz','30-70Hz'])
    stat_result = []
    for ii in tf_groups:
        foo_A = df[(df['variable'] == ii) & (df['group'] == 'sham')].value.values
        foo_B = df[(df['variable'] == ii) & (df['group'] == 'nmda')].value.values
        U, p = stats.mannwhitneyu(foo_A, foo_B)

        stat_result.append(ii)
        stat_result.append([U,p])

    return stat_result

In [None]:
pre_TF_stats = TF_stats(TF_plot1)
post_TF_stats = TF_stats(TF_plot2)
novel_TF_stats = TF_stats(TF_plot3)

In [None]:
print(plt_title1)
print(pre_TF_stats)
print(plt_title2)
print(post_TF_stats)
print(plt_title3)
print(novel_TF_stats)

---

---

---