In [None]:
import glob
import pickle
import shutil
import pandas as pd
from database.db_setup import *
import preprocessing.data_preprocessing.binning as binning
import preprocessing.data_preprocessing.create_vectors_from_time_points as create

import utils.helper_func as hf

In [9]:
top_dir = '/home/anastasia/epiphyte/anastasia/output'
patient_id = 53
PATH_TO_PLOTS = f'/home/anastasia/epiphyte/anastasia/output/10-descriptive_figures_{patient_id}'
df_patient_info = pd.read_csv(f'{top_dir}/{patient_id}_channel_info.csv')

session = 1
alpha = 0.01

In [None]:
"""
Functions for plotting the screening results. 
Adapted from Johannes Niediek's original standalone code to interface withe database 
"""

import warnings
import scipy.signal
import numpy as np
import matplotlib.pyplot as mpl
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec
from scipy.ndimage import gaussian_filter1d

# local imports
from database.db_setup import *

from visualization.scr_utils import *
import analysis.stats_code.compute_pvalues as cp

warnings.filterwarnings("ignore")

# constants 

EMPTY = np.array([])
OUTFOLDER = 'pvalues'

PVAL_CUTOFF = .1
MIN_ACTIVE_TRIALS = 3

FIGSIZE = (6, 4)
T_PRE = 500
T_POST = 1000
HIST_BIN_WIDTH = 100
HIST_BINS = np.arange(-T_PRE, T_POST + 1, HIST_BIN_WIDTH)

POSITIONS = ('pre', 'post')
COLORS = ('blue', 'green')

DPI = 200

norm_type = 'zscore'
filtering_types = ['theta', 'slow_gamma', 'fast_gamma']


def plot_one_raster(plot, plot_hist, sp_times, event_data, pvalues, filtering_type, amplitude_envelope_pre, amplitude_envelope_post, std_pre, std_post, i):
    
    sep = .2
    hist_max = 0
    time_data = []
    color_data = []
    hist_data = defaultdict(list)
    
    for name, color in zip(POSITIONS, COLORS):
        events = event_data.loc[event_data.position == name, 'time'].values
        for event in events:
            idx_sp = (sp_times >= event - T_PRE) & (sp_times <= event + T_POST)
            if idx_sp.any():
                temp = sp_times[idx_sp] - event

                time_data.append(temp)
                t_hist, _ = np.histogram(temp, HIST_BINS)

                hist_data[name].append(t_hist * 1000 / HIST_BIN_WIDTH)
            else:
                time_data.append([-2*T_PRE])   # hack of eventplot

            color_data.append(color)

    # immediately plot the histogram
    for name, color in zip(POSITIONS, COLORS):
        if len(hist_data[name]):
            hist = np.vstack(hist_data[name]).mean(0)
            hist_max = np.max((hist.max(), hist_max))
            plot_hist.bar(HIST_BINS[:-1], hist, width=HIST_BIN_WIDTH, facecolor=color,
                lw=0, alpha=.5)
    plot.eventplot(time_data, color=color_data, lw=1)
    plot.tick_params(axis='both', labelsize=7)
    plot.margins(x=1)
    
    plot_hist.set_xticks((0, 1000))
    plot_hist.margins(x=0)
    plot_hist.tick_params(axis='both', labelsize=7)
    #start, end = plot_hist.get_xlim()
    #plot_hist.xaxis.set_ticks(np.arange(start, end, 300))
    #plot_hist.locator_params(axis="x", nbins=5)
    mpl.xticks(list(mpl.xticks()[0]) + [-500, 250, 500, 750])
    
    #add hilbert transform
    times = np.linspace(-0.5, 1, 1500)
    times = times*1000
    plot_hist2 = plot_hist.twinx()
    plot_hist2.plot(times, amplitude_envelope_pre, color='blue', linewidth=0.5)
    plot_hist2.plot(times, amplitude_envelope_post, color='green', linewidth=0.5)
    plot_hist2.set_ylabel(f'power {filtering_type}', fontsize=5)
    plot_hist2.tick_params(axis='both', labelsize=7)
    plot_hist2.margins(x=0)
    
    for pos in ('bottom', 'top'):
        plot.spines[pos].set_visible(False)
        plot_hist.spines[pos].set_visible(False)
        plot_hist2.spines[pos].set_visible(False)
    if i != 3:
        plot_hist2.get_xaxis().set_ticks([])
        plot_hist2.get_xaxis().set_visible(False)
        
    #plot.axis('off')
    #plot_hist2.axis('off')

    for pos in (0, 1000):
        for pl in (plot, plot_hist):
            pl.axvline(pos, ls='--', color='k', lw=1, alpha=.8)

    plot.set_xlim((-T_PRE, T_POST))
    plot.set_ylim((-.5, len(time_data) + .5))
    plot.set_yticks([])
    plot.set_xticks((0, 1000))
    plot.set_xticklabels([])

    return


def plot_one_stimulus(fig, title, stim_frame, spikes, unit_pvals, stim_data, ch, ch_site, unit, unit_type, save_folder, hilbert_data, spgram_data):
    """
    stim_data is a filename -> (a, b, c, d) dictionary, where
    a: stim_num
    b: stim_name
    c: paradigm
    d: image
    """

    all_hists = []
    times = spikes
    save_unit = False

    num_stimuli = len(np.unique(stim_frame["filename"]))
    
    if num_stimuli == 42:
        img_order = IMAGE_ORDER_ALL
    elif num_stimuli == 35:
        img_order = IMAGE_ORDER_NO_TEXT
    else: 
        raise Exception("Irregular number of stimuli for patient.")
    freq = np.load(f'{top_dir}/{spgram_data}/spectrograms/freq.npy')
    t = np.load(f'{top_dir}/{spgram_data}/spectrograms/times.npy')
    #we will work only with frequencies below 100
    idx = np.where(freq == hf.find_nearest(freq, 100))[0][0]
    freq = freq[0:idx+1]
    
    for i_row, img in enumerate(img_order):
        for i_col, img_fname in enumerate(img):
            fig = mpl.figure(figsize=FIGSIZE)
            grid = GridSpec(4, 2, top=.9, bottom=.05, wspace=0.2, width_ratios=[1,0.85])

            info_plot = fig.add_axes([0, .9, .8, .1])
            info_plot.axis('off')
            stim_num, stim_name, paradigm, image = stim_data[img_fname]
            
            plot = fig.add_subplot(grid[0, 0])
            plot.axis('off')

            # show the image
            if paradigm == 'scr' :
                plot.imshow(image)
                plot.text(.5, 1, stim_name, transform=plot.transAxes,
                    va='bottom', ha='center', size=6) 
            else:
                plot.text(.5, .5, stim_name, transform=plot.transAxes,
                    va='center', ha='center')
            #plot.legend(COLORS, POSITIONS)
            #plot.text(1.05, 2, 'pre', transform=plot.transAxes, va='bottom', ha='right', size=6)
            for name, color in zip(POSITIONS, COLORS):
                if name == 'pre':
                    ypos = 0.65
                    ha = 'left'
                elif name == 'post':
                    ypos = 0.4
                    ha = 'right'
                plot.text(1.05, ypos, name, color=color, transform=plot.transAxes) #va='bottom', ha='right', size=6)
            
            if stim_name == '???':
                spgram_pre = np.load(f'{top_dir}/{spgram_data}/spectrograms/normalized/{norm_type}/CSC{ch}_{ch_site}_{stim_num}_unknown_pre_{norm_type}.npy')
                spgram_post = np.load(f'{top_dir}/{spgram_data}/spectrograms/normalized/{norm_type}/CSC{ch}_{ch_site}_{stim_num}_unknown_post_{norm_type}.npy')
            else:
                spgram_pre = np.load(f'{top_dir}/{spgram_data}/spectrograms/normalized/{norm_type}/CSC{ch}_{ch_site}_{stim_num}_{stim_name}_pre_{norm_type}.npy')            
                spgram_post = np.load(f'{top_dir}/{spgram_data}/spectrograms/normalized/{norm_type}/CSC{ch}_{ch_site}_{stim_num}_{stim_name}_post_{norm_type}.npy')
            
            plot = fig.add_subplot(grid[-3, 0])
            spgram_pre = np.mean(spgram_pre, axis=0)
            im = plot.pcolormesh(t, freq, spgram_pre[:idx+1,:], shading='auto', cmap='viridis')
            plot.axes.get_xaxis().set_visible(False)
            plot.tick_params(axis='y', labelsize=7)
            plot.set_ylabel('Hz, pre', fontsize=7)
            cbar = mpl.colorbar(im)
            cbar.ax.tick_params(labelsize=3) 
            
            plot = fig.add_subplot(grid[-2, 0])
            spgram_post = np.mean(spgram_post, axis=0)
            im = plot.pcolormesh(t, freq, spgram_post[:idx+1,:], shading='auto', cmap='viridis')
            plot.axes.get_xaxis().set_visible(False)
            plot.tick_params(axis='y', labelsize=7)
            plot.set_ylabel('Hz, post', fontsize=7)
            cbar = mpl.colorbar(im)
            cbar.ax.tick_params(labelsize=3) 
            
            plot = fig.add_subplot(grid[-1, 0])
            spgram_diff = spgram_post - spgram_pre
            im = plot.pcolormesh(t, freq, spgram_diff[:idx+1,:], shading='auto', cmap='viridis')
            plot.tick_params(axis='both', labelsize=7)
            plot.set_ylabel('Hz, diff', fontsize=7)
            cbar = mpl.colorbar(im)
            cbar.ax.tick_params(labelsize=3)
            
            plot = fig.add_subplot(grid[0, -1])
            
            i=1    
            for filtering_type in filtering_types:
                if stim_name == '???':
                    ampl_envelope_pre = np.load(f'{top_dir}/{hilbert_data}/{filtering_type}/power/{norm_type}/CSC{ch}_{ch_site}/CSC{ch}_{ch_site}_{stim_num}_unknown_amplitude_envelope_{norm_type}_pre.npy')
                    ampl_envelope_post = np.load(f'{top_dir}/{hilbert_data}/{filtering_type}/power/{norm_type}/CSC{ch}_{ch_site}/CSC{ch}_{ch_site}_{stim_num}_unknown_amplitude_envelope_{norm_type}_post.npy')
                else:
                    ampl_envelope_pre = np.load(f'{top_dir}/{hilbert_data}/{filtering_type}/power/{norm_type}/CSC{ch}_{ch_site}/CSC{ch}_{ch_site}_{stim_num}_{stim_name}_amplitude_envelope_{norm_type}_pre.npy')
                    ampl_envelope_post = np.load(f'{top_dir}/{hilbert_data}/{filtering_type}/power/{norm_type}/CSC{ch}_{ch_site}/CSC{ch}_{ch_site}_{stim_num}_{stim_name}_amplitude_envelope_{norm_type}_post.npy')
                std_pre = np.std(ampl_envelope_pre,axis=0)
                std_post = np.std(ampl_envelope_post,axis=0)            

                amplitude_envelope_pre = np.mean(ampl_envelope_pre, axis=0)
                amplitude_envelope_post = np.mean(ampl_envelope_post, axis=0)

                #DOWNSAMPLE
                amplitude_envelope_pre = scipy.signal.resample(amplitude_envelope_pre, 1500)
                amplitude_envelope_post = scipy.signal.resample(amplitude_envelope_post, 1500)
                
                #amplitude_envelope_pre = gaussian_filter1d(amplitude_envelope_pre, 10)
                #amplitude_envelope_post = gaussian_filter1d(amplitude_envelope_post, 10)
                
                gauss_window = scipy.signal.windows.gaussian(20, std=10)
                amplitude_envelope_pre_smooth = np.convolve(amplitude_envelope_pre, gauss_window, mode='same')
                amplitude_envelope_post_smooth = np.convolve(amplitude_envelope_post, gauss_window, mode='same')

                plot_hist = fig.add_subplot(grid[i, -1])
                plot_hist.set_ylabel('Hz', fontsize=7)

                # generate the raster plot for one stimulus
                # plot_one_raster(plot, plot_hist, sp_times, event_data, pvalues)
                plot_one_raster(plot, plot_hist, times, stim_frame.loc[stim_frame.stim_num == stim_num, ['position', 'time']],
                        unit_pvals.loc[unit_pvals.stim_num == stim_num, :], filtering_type, amplitude_envelope_pre_smooth, amplitude_envelope_post_smooth, 
                        std_pre, std_post, i)
                i=i+1

            info_plot.text(.5, .5, title, va='center', ha='center', size=12)
                
            fname = '{:03d}mv1_CSC{:02d}_{}_{}_unit{:03d}_stim{}.jpeg'.format(patient_id, ch, ch_site, unit_type, unit, stim_num)
            
            #mpl.tight_layout()
            mpl.subplots_adjust(wspace=0)
            fig.savefig(os.path.join(save_folder, fname), dpi=DPI, transparent=False)
            mpl.close(fig)
    
    return

def run_one_channel(fig, save_folder, frame, channel, ch_site, patient_id, session_nr, stim_data, hilbert_data, spgram_data):
    """
    load the pvalues and units for one channel, one sorting, and iterate over the units
    """
    
    # get all units from a single channel
    channel_units = get_unit_ids_in_channel(patient_id, session_nr, channel)
    
    for unit in channel_units:
        spikes = get_spiking_activity(patient_id, session_nr, unit)
        stim_index, eventtimes = get_scr_eventtimes(patient_id, session_nr)
        region = get_brain_region(patient_id, unit)
        unit_type = get_unit_type(patient_id, session_nr, unit)
    
        title = "{:03d}mv1 Unit #{}, Channel: {} ({})".format(patient_id, unit, channel, unit_type)

        unit_pvals = get_scr_stats_as_df(patient_id, session_nr, unit)
        
        fig.clf()
        
        plot_one_stimulus(fig, title, frame, spikes, unit_pvals, stim_data, channel, ch_site, unit, unit_type, save_folder, hilbert_data, spgram_data)

            
def run_session(df_patient_info, patient_id):
    """
    load stimulus frame and channel list for one session
    """  
    hilbert_data=f'05-Hilbert_transform_{patient_id}'
    spgram_data = f'04-spectrogram_wavelet_{patient_id}'
    session_nr = get_session_info(patient_id)
    
    assert isinstance(session_nr, int), "More than one session for patient {}. Code currently not set up for automatically running multiple sessions from a single patient.".format(patient_id)
    
    position, stim_id, filename, stim_name, is_500_days, paradigm, time = get_screening_data(patient_id, session_nr)
    frame = cp.make_dataframe(position, stim_id, filename, stim_name, is_500_days, paradigm, time)

    stim_data = {}
    stim_nums = frame.stim_num.unique()

    for stim_num in stim_nums:
        # setting up to get the stimulus image
        meta = frame.loc[frame.stim_num == stim_num, :].iloc[0]
        stim_fname = meta["filename"]
        stim_data[stim_fname] = (stim_num,
                                meta["stim_name"],
                                meta["paradigm"],
                                mpl.imread(os.path.join(PATH_TO_IMAGES, stim_fname))) ## read image file into an array

    all_channels = get_cscs_for_patient(patient_id, session_nr)
    all_channels = all_channels[23:]
        
    for channel in all_channels:
        print("Running channel {}...".format(channel))
        ch_site = df_patient_info.loc[channel-1,'recording_site']
        PATH_TO_PLOTS = f'/home/anastasia/epiphyte/anastasia/output/10-descriptive_figures_{patient_id}'
        channel_folder = f'CSC{channel}_{ch_site}'
        save_folder = os.path.join(PATH_TO_PLOTS, channel_folder)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)
        run_one_channel(fig, save_folder, frame, channel, ch_site, patient_id, session_nr, stim_data, hilbert_data, spgram_data)


In [8]:
fig = mpl.figure(figsize=FIGSIZE)
run_session(df_patient_info, patient_id=patient_id)

Running channel 62...
Running channel 73...
Running channel 74...
Running channel 78...
Running channel 81...
Running channel 82...
Running channel 83...
Running channel 84...
Running channel 85...
Running channel 86...
Running channel 87...
Running channel 88...


<Figure size 432x288 with 0 Axes>

In [10]:
electrode_unit = pd.DataFrame((ElectrodeUnit & f"patient_id='{patient_id}'" & f"session_nr='{session}'"))

pvals = (ScreeningStats & f"patient_id='{patient_id}'" & f"session_nr='{session}'").fetch("pval_scr")

# pull the corresponding stimulus ids as well 
stim_ids = (ScreeningStats & f"patient_id='{patient_id}'" & f"session_nr='{session}'").fetch("stim_id")

# and the unit ids
unit_ids = (ScreeningStats & f"patient_id='{patient_id}'" & f"session_nr='{session}'").fetch("unit_id")

interesting_stim_ids = stim_ids[pvals <= alpha]
interesting_unit_ids = unit_ids[pvals <= alpha]

unit_stim_pairs = list(zip(interesting_unit_ids, interesting_stim_ids))

#unit_stim_pairs
electrode_unit.to_csv(f'{PATH_TO_PLOTS}/electrode_unit.csv')
with open(f'{PATH_TO_PLOTS}/unit_stim_pairs.txt', "wb") as fp:   #Pickling
    pickle.dump(unit_stim_pairs, fp)

In [11]:
unit_stims = []
for unit_stim in unit_stim_pairs:
    unit = unit_stim[0]
    stim_num = unit_stim[1]
    unit_stims.append(unit_stim[1])
    ch = electrode_unit.loc[unit,'csc']
    ch_site = electrode_unit.loc[unit,'brain_region']
    ch_folder = f'CSC{ch}_{ch_site}'
    save_folder = os.path.join(PATH_TO_PLOTS, 'interesting_units')
    if not os.path.isdir(save_folder):
        os.makedirs(save_folder)
    
    for name in glob.glob(f'{PATH_TO_PLOTS}/{ch_folder}/*{unit}_stim{stim_num}.jpeg'):
        shutil.copy(name, save_folder)