# Ghost Attack

In [7]:
import matplotlib
# matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy import signal, stats
import mat73
import re
from neurodsp.timefrequency import compute_wavelet_transform
import os
import mne
import IPython
import seaborn as sns
import scipy
import joblib
import pickle

import statsmodels
from statsmodels import stats
from statsmodels.stats import multitest

# Import required code for visualizing example models
from fooof import FOOOF
from fooof.sim.gen import gen_power_spectrum
from fooof.sim.utils import set_random_seed
from fooof.plts.spectra import plot_spectra
from fooof.plts.annotate import plot_annotated_model
from neurodsp.utils import create_times
from neurodsp.plts.time_series import plot_time_series
from neurodsp.spectral import compute_spectrum, rotate_powerlaw
from neurodsp.plts.spectral import plot_power_spectra

## import custom functions
import sys
sys.path.append('/home/brooke/pacman/preprocessing/scripts')
from preproc_functions import *
import average_tfr_functions



In [3]:
# folders
raw_dir = '/home/brooke/pacman/raw_data'
preproc_dir = '/home/brooke/pacman/preprocessing'
tfr_dir = '/home/brooke/knight_server/remote/bstavel/pacman/preprocessing'

# subjects
sub_list = ['BJH021', 'BJH025', 'BJH016', 'SLCH002', 'BJH026', 'BJH027', 'BJH029', 'BJH039', 'BJH041', 'LL10', 'LL12', 'LL13', 'LL14', 'LL17', 'LL19']

# conditions
conditions = ['TrialType <= 16']


## Functions

In [4]:
def calculate_subregion_ghost_attack_average(sub_list, string_filters, roi, sub_roi):
    """
    Calculates the average TFRs for the GHOST ATTACK condition across subjects,
    handling potential differences in sampling rates and saving progress for
    efficiency.

    Args:
        sub_list (list): A list of subject IDs to process.
        string_filters (list): A list of strings to filter TFR cases by.
        roi (str): The name of the region of interest.

    Returns:
        list: A list of lists containing average TFRs for each string filter.

    Steps:
        1. Iterates through subjects:
            - Checks for TFR file existence.
            - Loads and preprocesses TFR data (log and zscore).
            - Filters TFR cases based on string_filters.
            - Calculates mean TFRs for each case and appends to a list.
            - Handles exceptions and reports any errors.
        2. Saves intermediate progress to a pickle file.
        3. Invert list structure for easier processing.
        4. Calculates average TFRs for each string filter:
            - Identifies subjects with high or low sampling rates.
            - Calculates separate means for high and low rate TFRs.
            - Combines and averages TFRs from different sampling rates if applicable.
        5. Returns a list of average TFRs for each string filter.
    """    
    tfrs = []
    used_subs = []
    for subject in sub_list:

        # try:
        if os.path.exists(f"{tfr_dir}/{subject}/ieeg/ghost_attack/{roi}-tfr.h5"):
            # load data
            used_subs.append(subject)
            
            # load data
            tmp_TFR = mne.time_frequency.read_tfrs(f"{tfr_dir}/{subject}/ieeg/ghost_attack/{roi}-tfr.h5")

            # zscore and log
            tmp_TFR = log_and_zscore_TFR(tmp_TFR[0], baseline = (-1,3), logflag=True)

            ## MONICA TODO
            # need to calculate mfg_tfr and replace tmp_TFR with mfg_tfr

            tfr_cases = []
            for case in string_filters:            
                # filter
                tfr_case = mfg_tfr[case]
                # append
                tfr_cases.append(tfr_case.data.mean(axis = 0).mean(axis = 0))

            # get mean and append
            tfrs.append(tfr_cases)

        # except Exception as e:
        #     print(f"Failed to load {subject}")
        #     print(e)
        #     used_subs.remove(subject)
        #     continue

        print(f"currently used subs: {used_subs}")

    # save progress cuz it is so long to load these dang things       
    with open(f'../ieeg/ghost_attack_average_{roi}.pkl', 'wb') as f:
        pickle.dump(tfrs, f)                
        
    # invert list so the outer list is the string filter
    tfrs_cases = [[tfrs[j][i] for j in range(len(tfrs))] for i in range(len(tfrs[0]))]

    all_subs_averages = []
    for tfr_case in tfrs_cases:

        if any("LL" in subject for subject in used_subs):

            # get indicies of high/low samp rate subs
            first_ll_sub = [subject for subject in used_subs if "LL" in subject][0]
            ll_begin = used_subs.index(first_ll_sub)

            # high sampling rate
            washu_tfrs = np.asarray(tfr_case[0:ll_begin])
            washu_tfrs_mean = washu_tfrs.mean(axis = 0)

            # Low sampling rate
            ll_tfrs = np.asarray(tfr_case[ll_begin:])
            ll_tfrs_mean = ll_tfrs.mean(axis = 0)

            # combine
            all_subs_tfrs = np.stack((washu_tfrs_mean[:, ::2], ll_tfrs_mean[:, 0:2001]))
        
            # mean
            all_subs_average = all_subs_tfrs.mean(axis = 0)
            all_subs_averages.append(all_subs_average)
            
        else:
            
            # high sampling rate
            washu_tfrs = np.asarray(tfr_case)
            washu_tfrs_mean = washu_tfrs.mean(axis = 0)

            # mean
            all_subs_average = washu_tfrs_mean
            all_subs_averages.append(all_subs_average)    

    return all_subs_averages


In [5]:
def get_roi_elec_lists(epochs, roi):

    # prep lists
    roi_list = []
    roi_names = []
    roi_indices = []

    # exclude bad ROI from list
    pairs_long_name = [ch.split('-') for ch in epochs.info['ch_names']]
    bidx = len(epochs.info['bads']) +1
    pairs_name = pairs_long_name[bidx:len(pairs_long_name)]

    # sort ROI into lists
    for ix in range(0, len(pairs_name)):
        if pairs_name[ix][0] in ROIs[roi] or pairs_name[ix][1] in ROIs[roi]:
            roi_list.append(epochs.info['ch_names'][ix + bidx])
            roi_names.append(pairs_name[ix])
            roi_indices.append(ix)

    return roi_list, roi_names, roi_indices

# Create Average TFRs

## dlPFC

### Medial Frontal Gyrus

In [12]:
all_subs_average_dlpfcs = calculate_subregion_ghost_attack_average(sub_list, conditions, 'dlpfc', 'sfg', 'mfg')
all_subs_average_dlpfc_conflict = all_subs_average_dlpfcs[0]


TypeError: calculate_subregion_ghost_attack_average() takes 4 positional arguments but 5 were given

In [None]:
plot_allsub_averages(all_subs_average_dlpfc_conflict, "Average dlPFC TFR During Ghost Attack", 'average_mfg_ghost_attack_all_subs.png', -1, 3)

### Superior Frontal Gyrus

In [None]:
all_subs_average_dlpfcs = calculate_subregion_ghost_attack_average(sub_list, conditions, 'dlpfc')
all_subs_average_dlpfc_conflict = all_subs_average_dlpfcs[0]


In [None]:
plot_allsub_averages(all_subs_average_dlpfc_conflict, "Average dlPFC TFR During Ghost Attack", 'average_sfg_ghost_attack_all_subs.png', -1, 3)

### Examples

In [9]:
subject = 'LL12'
roi = 'dlpfc'
tmp_TFR = mne.time_frequency.read_tfrs(f"{tfr_dir}/{subject}/ieeg/ghost_attack/{roi}-tfr.h5")

Reading /home/brooke/knight_server/remote/bstavel/pacman/preprocessing/LL12/ieeg/ghost_attack/dlpfc-tfr.h5 ...
Adding metadata with 5 columns


In [10]:
# you always have to run this step after loading a tfr from a file
tmp_TFR = tmp_TFR[0]

In [35]:
# print the tfr object so you know what you're working with
tmp_TFR

<EpochsTFR | time : [-1.000000, 3.000000], freq : [1.000000, 150.000000], epochs : 27, channels : 12, ~405.2 MB>

In [36]:
# print the channels that are included in the tfr
tmp_TFR.info['ch_names']

['LOF7-LOF8',
 'LOF8-LOF9',
 'LOF9-LOF10',
 'LOF10-LOF11',
 'LOF11-LOF12',
 'LAC5-LAC6',
 'LAC6-LAC7',
 'LAC7-LAC8',
 'LMC5-LMC6',
 'LMC6-LMC7',
 'LMC7-LMC8',
 'ROF9-ROF10']

In [44]:
# get the channel names of the subregion you are interested in, in this case 'mfg

%run ../../preprocessing/scripts/roi.py
ROIs = ROIs[subject]

mfg_list, mfg_names, mfg_indices = get_roi_elec_lists(tmp_TFR, 'mfg')
sfg_list, sfg_names, sfg_indices = get_roi_elec_lists(tmp_TFR, 'sfg')

In [42]:
# print the channels in the mfg
mfg_list

[]

In [45]:
# create a new tfr with only the mfg channels
mfg_tfr = tmp_TFR.copy().pick_channels(mfg_list)

In [46]:
mfg_tfr

<EpochsTFR | time : [-1.000000, 3.000000], freq : [1.000000, 150.000000], epochs : 27, channels : 6, ~202.6 MB>

## TODO

## Hippocampus

In [None]:
all_subs_average_hcs = calculate_ghost_attack_average(sub_list, conditions, 'hc')
all_subs_average_hc_conflict = all_subs_average_hcs[0]


In [None]:
plot_allsub_averages(all_subs_average_hc_conflict, "Average Hippocampal TFR During Ghost Attack", 'average_hc_ghost_attack_all_subs.png', -1, 3)

## OFC

In [None]:
all_subs_average_ofcs = calculate_ghost_attack_average(sub_list, conditions, 'ofc')
all_subs_average_ofc_conflict = all_subs_average_ofcs[0]


In [None]:
plot_allsub_averages(all_subs_average_ofc_conflict, "Average OFC TFR During Ghost Attack", 'average_ofc_ghost_attack_all_subs.png', -1, 3)

## Anterior Cingulate

In [None]:
all_subs_average_cings = calculate_ghost_attack_average(sub_list, conditions, 'cing')
all_subs_average_cing_conflict = all_subs_average_cings[0]


In [None]:
plot_allsub_averages(all_subs_average_cing_conflict, "Average Ant. Cingulate TFR During Ghost Attack", 'average_cing_ghost_attack_all_subs.png', -1, 3)

## Amygdala

In [None]:
all_subs_average_amygs = calculate_ghost_attack_average(sub_list, conditions, 'amyg')
all_subs_average_amyg_conflict = all_subs_average_amygs[0]


In [None]:
plot_allsub_averages(all_subs_average_amyg_conflict, "Average Amygdala TFR During Ghost Attack", 'average_amyg_ghost_attack_all_subs.png', -1, 3)

## Insula

In [None]:
all_subs_average_insulas = calculate_ghost_attack_average(sub_list, conditions, 'insula')
all_subs_average_insula_conflict = all_subs_average_insulas[0]


In [None]:
plot_allsub_averages(all_subs_average_insula_conflict, "Average Insula TFR During Ghost Attack", 'average_insula_ghost_attack_all_subs.png', -1, 3)

## EC

In [None]:
all_subs_average_ecs = calculate_ghost_attack_average(sub_list, conditions, 'ec')
all_subs_average_ec_conflict = all_subs_average_ecs[0]


In [None]:
plot_allsub_averages(all_subs_average_ec_conflict, "Average Entorhinal Cortex TFR During Ghost Attack", 'average_ec_ghost_attack_all_subs.png', -1, 3)