## Introduction

This notebook is used to generate brain coordinate files for brain maps.

Module 1 generates brain coordinate files for all elecs for single patients.

Module 2 aggregates brain coordinate files for all patients and sig elecs.

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

%cd ../

## Module 1

Module 1 is used to generate brain coordinate files for all electrodes per patient

Input files
- brain coordinate files __[txt]__
- electrode name conversion files __[csv]__
- encoding result __[folder path]__

Output files
- brain coordinate + encoding results for all elecs __[txt]__

#### Functions

In [None]:
def get_base_df(sid, cor, emb_key):

    if sid == 7170:
        sid = 717

    # Get brain coordinate file
    coordinatefilename = f"data/plotting/brainplot/{sid}/{sid}_{cor}.txt"

    data = pd.read_csv(coordinatefilename, sep=" ", header=None)
    data = data.set_index(0)
    data = data.loc[:, 1:4]
    print(f"\nFor subject {sid}:\ntxt has {len(data.index)} electrodes")

    # Get electrode name conversion file
    elecfilename = f"data/plotting/brainplot/{sid}/{sid}_elecs.csv"
    elecs = pd.read_csv(elecfilename)
    elecs = elecs.dropna()
    elecs = elecs.rename(columns={"elec2": 0})
    elecs.set_index(0, inplace=True)

    df = pd.merge(data, elecs, left_index=True, right_index=True)
    print(f"Now subject has {len(df)} electrodes")

    # Create filler columns
    for col in emb_key:
        df[col] = -1

    return df

In [None]:
def read_file(filename, path):
    # Read in one electrode encoding correlation results
    filename = os.path.join("data/encoding/",path, filename)
    if len(glob.glob(filename)) == 1:
        filename = glob.glob(filename)[0]
    elif len(glob.glob(filename)) == 0:
        return -1
    else:
        AssertionError("huh this shouldn't happen")
    elec_data = pd.read_csv(filename, header=None)
    return elec_data

In [None]:
def get_max(filename, path):
    # get max correlation for one electrode file
    elec_data = read_file(filename, path)
    if isinstance(elec_data, int):
        return -1
    return max(elec_data.loc[0])

In [None]:
def get_area(filename, path, lags, chosen_lags):
    # get area under the curve for one electrode file
    elec_data = read_file(filename, path)
    if isinstance(elec_data, int):
        return -1
    elec_data = elec_data.loc[:, chosen_lags]
    x_vals = [lags[lag] / 1000 for lag in chosen_lags]

    return np.trapz(elec_data, x=x_vals, axis=1)  # integration

In [None]:
def add_encoding(df, sid, formats, type="max", lags = [], chosen_lags=[]):
    
    for format in formats:
        # print(f"getting results for {format} embedding")
        for row, values in df.iterrows():
            col_name1 = format + "_prod"
            col_name2 = format + "_comp"
            prod_name = f"{sid}_{values['elec']}_prod.csv"
            comp_name = f"{sid}_{values['elec']}_comp.csv"
            if type == "max":
                df.loc[row, col_name1] = get_max(prod_name, formats[format])
                df.loc[row, col_name2] = get_max(comp_name, formats[format])
            elif type == "area":
                df.loc[row, col_name1] = get_area(prod_name, formats[format], lags, chosen_lags)
                df.loc[row, col_name2] = get_area(comp_name, formats[format], lags, chosen_lags)

    return df

In [None]:
def get_area_diff(df, emb_key, mode="normalized"):
    for col in emb_key:
        if "incorrect" in col or "bot" in col: # incorrect column
            pass
        else: # correct column
            # get column names
            col2 = col.replace("correct", "incorrect") # incorrect column
            col2 = col2.replace("top", "bot") # bot column
            diff_col = col.replace("correct","")
            diff_col = diff_col.replace("top","")

            # normalized area diff
            df.loc[df[col] < 0, col] = 0  # turn negative area to 0
            df.loc[df[col2] < 0, col2] = 0  # turn negative area to 0
            if mode == "normalized": # normalized area diff
                df.loc[:,diff_col] = (df[col] - df[col2]) / df[[col, col2]].max(axis=1)
            elif mode == "normalized2": # area diff normalized
                df.loc[:,diff_col] = df[col] - df[col2]
                abs_max = max(abs(df.loc[:,diff_col].max()),abs(df.loc[:,diff_col].min()))
                df.loc[:,diff_col] = df.loc[:,diff_col] / abs_max
            elif mode == "none":
                df.loc[:,diff_col] = df[col] - df[col2]
            df.drop([col, col2], axis=1, inplace=True) # drop original columns

    return df

In [None]:
def save_file(df, sid, emb_keys, dir, cor, project):
    df.loc[:,0] = df.index

    for col in emb_keys:
        sid_file = os.path.join(dir, f"{sid}_{cor}_{col}.txt")
        # sids_file = os.path.join(dir, f"{project}_{cor}_{col}.txt")
        
        df_output = df.loc[:, [0, 1, 2, 3, 4, "elec", col]]
        df_output.dropna(inplace=True)
        with open(sid_file, "w") as outfile:
            df_output.to_string(outfile, index=False, header=False)
        # with open(sids_file, "a") as outfile:
        #     df_output.to_string(outfile, index=False, header=False)
            
    return

#### Arguments

In [None]:
###### Core Arguments ######
PRJ_ID = "tfs"

SIDS = [625] # for testing / 1 patient
SIDS = [625, 676, 7170, 798]

KEYS = ["prod", "comp"]

COR_TYPE = "ind" # unique brain coordinate + brain map per patient
COR_TYPE = "ave" # average brain coordinates (for several patients)

##### Encoding Results Folder #####
FORMATS = []
for sid in SIDS:
    FORMATS.append(
        {
    # "glove-all" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-glove50-lag10k-25-all-aligned/*/",
    # "rand-all" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-glove50-lag10k-25-all-aligned-rand/*/",
    # "arb-all" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-glove50-lag10k-25-all-aligned-arb/*/",
    # "gptn-1-all" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-all-aligned/*/",
    # "gptn-all" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-all-shift-emb-aligned/*/",
    # "gptn-1-all-l30" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-all-l30-aligned/*/",
    # "gptn-all-l30" : f"tfs/20230228-all-embs/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-all-l30-shift-emb-aligned/*/",
    # "glove-correct5" : f"tfs/20230227-gpt2-preds/kw-tfs-full-{sid}-glove50-lag10k-25-gpt2-xl-correct5/*/",
    # "glove-incorrect5" : f"tfs/20230227-gpt2-preds/kw-tfs-full-{sid}-glove50-lag10k-25-gpt2-xl-incorrect5/*/",
    # "glove-correct1" : f"tfs/20230227-gpt2-preds/kw-tfs-full-{sid}-glove50-lag10k-25-gpt2-xl-correct1/*/",
    # "glove-incorrect1" : f"tfs/20230227-gpt2-preds/kw-tfs-full-{sid}-glove50-lag10k-25-gpt2-xl-incorrect1/*/",
    # "glove-top0.3" : f"tfs/20230227-gpt2-preds/kw-tfs-full-{sid}-glove50-lag10k-25-gpt2-xl-prob/*/",
    # "glove-bot0.3" : f"tfs/20230227-gpt2-preds/kw-tfs-full-{sid}-glove50-lag10k-25-gpt2-xl-improb/*/",
    # "gptn-1-correct5" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-correct5/*/",
    # "gptn-1-incorrect5" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-incorrect5/*/",
    # "gptn-1-correct1" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-correct1/*/",
    # "gptn-1-incorrect1" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-incorrect1/*/",
    # "gptn-1-top0.3" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-prob/*/",
    # "gptn-1-bot0.3" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-improb/*/",
    # "gptn-1-correct5-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-l30-correct5/*/",
    # "gptn-1-incorrect5-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-l30-incorrect5/*/",
    # "gptn-1-correct1-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-l30-correct1/*/",
    # "gptn-1-incorrect1-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-l30-incorrect1/*/",
    # "gptn-1-top0.3-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-l30-prob/*/",
    # "gptn-1-bot0.3-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-l30-improb/*/",
    # "gptn-correct5" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-correct5/*/",
    # "gptn-incorrect5" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-incorrect5/*/",
    # "gptn-correct1" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-correct1/*/",
    # "gptn-incorrect1" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-incorrect1/*/",
    # "gptn-top0.3" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-prob/*/",
    # "gptn-bot0.3" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-improb/*/",
    # "gptn-correct5-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-l30-correct5/*/",
    # "gptn-incorrect5-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-l30-incorrect5/*/",
    # "gptn-correct1-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-l30-correct1/*/",
    # "gptn-incorrect1-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-l30-incorrect1/*/",
    # "gptn-top0.3-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-l30-prob/*/",
    # "gptn-bot0.3-l30" : f"tfs/20230226-gpt2-preds/kw-tfs-full-{sid}-gpt2-xl-lag10k-25-shift-emb-l30-improb/*/",
    # "whisper-en-last" : f"tfs/20230210-whisper-encoder-onset/kw-tfs-full-en-onset-{sid}-whisper-tiny.en-l4-wn1-5/*/",
    # "whisper-de-last" : f"tfs/20230212-whisper-decoder/kw-tfs-full-de-{sid}-whisper-tiny.en-l4/*/",
    # "whisper-de-best" : f"tfs/20230212-whisper-decoder/kw-tfs-full-de-{sid}-whisper-tiny.en-l3/*/",
    # "whisper-last" : f"tfs/20230216-whisper-full/kw-tfs-full-{sid}-whisper-tiny.en-l4/*/",
    # "whisper-best" : f"tfs/20230216-whisper-full/kw-tfs-full-{sid}-whisper-tiny.en-l3/*/",
    # "whisper-en-4-correct5" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-encoder-lag10k-25-correct5/*/",
    # "whisper-en-4-incorrect5" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-encoder-lag10k-25-incorrect5/*/",
    # "whisper-en-4-top0.3" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-encoder-lag10k-25-prob/*/",
    # "whisper-en-4-bot0.3" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-encoder-lag10k-25-improb/*/",
    # "whisper-de-3-correct5" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-decoder-lag10k-25-correct5/*/",
    # "whisper-de-3-incorrect5" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-decoder-lag10k-25-incorrect5/*/",
    # "whisper-de-3-top0.3" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-decoder-lag10k-25-prob/*/",
    # "whisper-de-3-bot0.3" : f"tfs/20230402-whisper-preds/kw-tfs-full-{sid}-whisper-tiny.en-decoder-lag10k-25-improb/*/",
    "prod-comp-flip" : f"tfs/20231016-whisper-pc-flip-best-lag/kw-tfs-full-{sid}-whisper-tiny.en-encoder-var-win-lag10k-25-all-pc-flip-best-lag-4/*/"
        }
    )

# Output directory name
OUTPUT_DIR = "results/cor-tfs-area-diff-after"
OUTPUT_DIR = "results/cor-tfs-area-before-diff"
OUTPUT_DIR = "results/cor-tfs-area-before-norm-diff"
OUTPUT_DIR = "results/cor-tfs-area-before-norm2-diff"
OUTPUT_DIR = "results/cor-tfs-max"

# AREA lags (used for add_area)
LAGS = np.arange(-10000,10025,25)
AREA_START = -500
AREA_END = -100

# AREA_START = 100
# AREA_END = 500

#### Runs

In [None]:
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)


##### Max correlation #####
if "cor-tfs-max" in OUTPUT_DIR:
    for sid, format in zip(SIDS, FORMATS):
        emb_key = [emb + "_" + key for emb in format.keys() for key in KEYS]
        df = get_base_df(sid, COR_TYPE, emb_key) # get all electrodes
        df = add_encoding(df, sid, format, "max") # add on the columns from encoding results
        save_file(df, sid, emb_key, OUTPUT_DIR, COR_TYPE, PRJ_ID) # save txt files


##### Difference in area under the curve #####
elif "cor-tfs-area" in OUTPUT_DIR:
    chosen_lag_idx = [
        idx for idx, element in enumerate(LAGS) if (element >= AREA_START) & (element <= AREA_END)
    ] # calculate the correct lag idx

    for sid, format in zip(SIDS, FORMATS):
        emb_key = [emb + "_" + key for emb in format.keys() for key in KEYS]
        df = get_base_df(sid, COR_TYPE, emb_key) # get all electrodes
        df = add_encoding(df, sid, format, "area", LAGS, chosen_lag_idx) # add on columns from encoding results
        df = get_area_diff(df, emb_key, "normalized2") # get area difference

        # save txt files
        new_emb_key = [col.replace("incorrect","").replace("bot", "") for col in emb_key if "incorrect" in col or "bot" in col]
        save_file(df, sid, new_emb_key, OUTPUT_DIR, COR_TYPE, PRJ_ID) # save txt files

## Module 2

Input files
- brain coordinate + encoding results (output of Module 1) __[txt]__
- electrode name conversion files __[csv]__
- significant electrode list __[csv]__

Output files
- brain coordinate + encoding results for sig elecs __[txt]__

#### Functions

In [63]:
def aggregate_results(input_dir, sids, cor_type, emb_name, key, sig_name):

    df_all = pd.DataFrame()
    for sid in sids:

        # load coordinate file
        cor_filename = os.path.join(input_dir,f"{sid}_{cor_type}_{emb_name}_{key}.txt")
        df = pd.read_fwf(cor_filename,header=None)

        # load significance file
        if sig_name:
            sig_filename = os.path.join("data/plotting/",f"tfs-sig-file-{sid}-{sig_name}-{key}.csv")
            sig_df = pd.read_csv(sig_filename)
            df = pd.merge(df, sig_df, how='inner', left_on=5, right_on="electrode")
        
        # aggregate
        df_all = pd.concat([df_all,df])

    # save aggregate file
    df_all = df_all[df_all[6] >= 0.08]
    df_output = df_all.loc[:, [0, 1, 2, 3, 4, 6]]
    sig_str = "_sig"
    if sig_name is None:
        sig_str = ""
    aggre_filename = os.path.join(input_dir,f"tfs_{cor_type}_{emb_name}_{key}{sig_str}.txt")
    print(aggre_filename)
    with open(aggre_filename, "w") as outfile:
        df_output.to_string(outfile, index=False, header=False)

    return

#### Arguments

In [61]:
INPUT_DIR = OUTPUT_DIR

# whether to use significance list
SIG_ELECS = True # only sig elecs
SIG_ELECS = False


if "cor-tfs-max" in INPUT_DIR: # significance dict for max cor
    SIG_DICT = {
        # "glove-all" : "glove",
        # "rand-all" : "glove",
        # "arb-all" : "glove",
        # "gptn-1-all" : "gpt",
        # "gptn-all" : "gpt",
        # "gptn-1-all-l30" : "gpt",
        # "gptn-all-l30" : "gpt",
        # "glove-correct1": "glove",
        # "glove-incorrect1": "glove",
        # "glove-pred1": "glove",
        # "glove-correct5": "glove",
        # "glove-incorrect5": "glove",
        # "glove-pred5": "glove",
        # "glove-top0.3" : "glove",
        # "glove-bot0.3" : "glove",
        # "glove-pred0.3" : "glove",
        # "gptn-1-correct5": "gpt",
        # "gptn-1-incorrect5": "gpt",
        # "gptn-1-correct1": "gpt",
        # "gptn-1-incorrect1": "gpt",
        # "gptn-1-top0.3": "gpt",
        # "gptn-1-bot0.3": "gpt",
        # "gptn-correct5": "gpt",
        # "gptn-incorrect5": "gpt"
        # "gptn-correct1": "gpt",
        # "gptn-incorrect1": "gpt",
        # "gptn-top0.3": "gpt",
        # "gptn-bot0.3": "gpt",
        "prod-comp-flip": "gpt",
    }
elif "cor-tfs-area" in INPUT_DIR: # significance dict for area
    SIG_DICT = {
        # "glove-5": "glove",
        # "glove-1": "glove",
        # "glove-0.3": "glove",
        # "gptn-1-5": "gpt",
        # "gptn-1-1": "gpt",
        # "gptn-1-0.3": "gpt",
        # "gptn-5": "gpt",
        # "gptn-1": "gpt",
        # "gptn-0.3": "gpt",
        # "gptn-1-5-l30": "gpt",
        # "gptn-1-1-l30": "gpt",
        # "gptn-1-0.3-l30": "gpt",
        # "gptn-5-l30": "gpt",
        # "gptn-1-l30": "gpt",
        # "gptn-0.3-l30": "gpt",
        # "whisper-de-3-0.3": "whisper-de-best-0.01",
        # "whisper-de-3-5": "whisper-de-best-0.01",
        # "whisper-en-4-0.3": "whisper-en-last-0.01",
        # "whisper-en-4-5": "whisper-en-last-0.01",
    }

#### Runs

In [64]:
for emb in SIG_DICT.keys():
    for key in KEYS:
        if SIG_ELECS:
            sig_name = SIG_DICT[emb]
        else:
            sig_name = None
        df = aggregate_results(INPUT_DIR, SIDS, COR_TYPE, emb, key, sig_name)


results/cor-tfs-max/tfs_ave_prod-comp-flip_prod.txt
results/cor-tfs-max/tfs_ave_prod-comp-flip_comp.txt
