In [None]:
import os
import pandas as pd
import numpy as np
import argparse

import matplotlib.pyplot as plt
from scipy.stats import permutation_test
from scipy.stats import ttest_1samp
from scipy.stats import ttest_ind
from statsmodels.stats import multitest

from tfsplt_encoding import get_cmap_smap, aggregate_data, organize_data
from tfsplt_utils import read_folder
from tfsplt_brainmap import get_sigelecs, Colorbar, make_brainmap
from tfsplt_brainmap_cat import make_brainmap_cat

## Data Loading

In [None]:
# Load Sig Elec Files
sig_df_whisper = pd.read_csv("../data/plotting/paper-whisper/data/base_df.csv")

sig_df_gpt2 = pd.read_csv("../data/plotting/sig-elecs/20230723-tfs-sig-file/tfs-sig-file.csv").iloc[:,1:]
sig_df_gpt2.rename(columns={"patient":"sid","electrode":"elec_1","prod_significant":"gpt2-prod","comp_significant":"gpt2-comp"},inplace=True)
sig_df_gpt2 = sig_df_gpt2.loc[sig_df_gpt2.model == "glove",("sid","elec_1","gpt2-prod","gpt2-comp")]

sig_df = sig_df_whisper.merge(sig_df_gpt2,how="left",on=["sid","elec_1"])
sig_df = sig_df.fillna({"gpt2-comp":False,"gpt2-prod":False})
sig_df["electrode"] = sig_df.sid.astype(str) + "_" + sig_df.elec_1

In [None]:
class Args(argparse.Namespace):
  sid = [625, 676, 7170, 798] # subjects
  project = "tfs"
  formats = [ # encoding folder
    "../data/encoding/tfs/20230227-gpt2-preds/kw-tfs-full-%s-glove50-lag10k-25-gpt2-xl-prob/*/*%s.csv",
    "../data/encoding/tfs/20230227-gpt2-preds/kw-tfs-full-%s-glove50-lag10k-25-gpt2-xl-improb/*/*%s.csv",
  ]
  labels = [
    "prob",
    "improb"
  ]
  sig_elec_file = ["../data/plotting/sig-elecs/20230723-tfs-sig-file/tfs-sig-file-glove-%(sid)s-%(key)s.csv"]
  keys = ["comp","prod"] # comprehension and/or production
  layers = np.arange(0,25)
  lags_plot = np.arange(-10000,10001,25) # encoding lags
  lags_show = np.arange(-2000,2001,25) # lags for the effect
  lc_by = "labels"
  ls_by = "keys"

# Aggregate Data
args = Args()
args.unique_labels = list(dict.fromkeys(args.labels))
args.unique_keys = list(dict.fromkeys(args.keys))
args.lags_show = args.lags_show / 1000
args.lags_plot = args.lags_plot / 1000
args = get_cmap_smap(args)  # get color and style map
args = get_sigelecs(args)  # get significant electrodes
df = aggregate_data(args) # aggregate data
df = organize_data(args, df) # trim data if necessary


## Encoding Plots

Significant Test

In [None]:
# Sig tests
def fdr(pvals):
    _, pcor, _, _ = multitest.multipletests(
        pvals, method="fdr_bh", is_sorted=False
    )
    return pcor

def get_sig_lags(args, df):
    sig_lags = {}
    for key in args.keys:
        df_key = df[df.index.get_level_values("key") == key]
        df_prob = df_key[df_key.index.get_level_values("label") == "prob"]
        df_improb = df_key[df_key.index.get_level_values("label") == "improb"]
        # df_prob.sort_values([("electrode")], ascending=True, inplace=True)
        # df_improb.sort_values([("electrode")], ascending=True, inplace=True)

        ts = []
        rs = []
        for df_col in np.arange(0,df_prob.shape[1]):
            r = ttest_ind(df_prob.iloc[:,df_col], df_improb.iloc[:,df_col],alternative="two-sided")
            ts.append(r[0])
            rs.append(r[1])
        rs = fdr(rs)

        threshold = 0.01
        sig_lags[f"{key}_prob"] = [args.lags_show[idx] for (idx, r) in enumerate(rs) if (ts[idx] > 0 and r < threshold)]
        sig_lags[f"{key}_improb"] = [args.lags_show[idx] for (idx, r) in enumerate(rs) if (ts[idx] < 0 and r < threshold)]
    return sig_lags

Encoding Plot (Whole Brain and ROIs)

In [None]:
# Plot average encoding plots
def average_encoding(args, df, tag):
    sig_lags = get_sig_lags(args, df)
    fig, axes = plt.subplots(1, 2, figsize=(18,5))
    for ax, (plot, subdf) in zip(axes, df.groupby("key", axis=0)):
        for line, subsubdf in subdf.groupby("label", axis=0):
            vals = subsubdf.mean(axis=0)
            err = subsubdf.sem(axis=0)
            map_key = (line, plot)
            ax.fill_between(
                args.lags_show,
                vals - err,
                vals + err,
                alpha=0.2,
                color=args.cmap[map_key],
            )
            ax.plot(
                args.lags_show,
                vals,
                label=f"{line} ({len(subsubdf)})",
                color=args.cmap[map_key],
                ls=args.smap[map_key],
            )
            ax.scatter(
                sig_lags[f"{plot}_{line}"],
                np.full(len(sig_lags[f"{plot}_{line}"]), 0.001),
                color=args.cmap[map_key]
            )
        # ax.set_xticks(args.lag_ticks)
        # ax.set_xticklabels(args.lag_tick_labels)
        ax.axhline(0, ls="dashed", alpha=0.3, c="k")
        ax.axvline(0, ls="dashed", alpha=0.3, c="k")
        ax.set_title(f"{plot}s global average")
        ax.legend(loc="upper right", frameon=False)
        ax.set(xlabel="Lag (s)", ylabel="Correlation (r)")
    plt.savefig(f"../prob-improb-{tag}.png")
    plt.close()

# Plot average encoding plots for whole brain
average_encoding(args, df, "all")

# Plot average encoding plots for ROIs
ROIS = {
    # "preCG": ["preCG"],
    # "postCG": ["postCG"],
    # "SM": ["preCG","postCG"],
    # "TP": ["TP"],
    "STG": ["STG"],
    "IFG": ["IFG"],
    # "pMTG": ["pMTG"],
    # "AG": ["AG"],
}

for area_name, area in ROIS.items(): # area
    roi_df = df[df.index.isin(sig_df.loc[sig_df.roi_1.isin(area), "electrode"].tolist(), level=1)]
    average_encoding(args, roi_df, area_name)

## Brainmap Plots

Brainmap subject plots

In [None]:
# Subjects
class Args(argparse.Namespace):
  main_dir = "../data/plotting/brainplot/" # loads coordinate and brain surface files
  project = "tfs"
  sid = [625, 676, 7170, 798] # subjects
  keys = ["comp","prod"] # comprehension and/or production
  brain_type = "ave" # average brain
  hemisphere = "left" # only plot left hemisphere
  outfile = "../tfs_sids_%s.png"

args = Args()

prop_cycle = plt.rcParams["axes.prop_cycle"] # get the encoding default colors
color_list = prop_cycle.by_key()["color"]

# Set Up Color Split
args.colors = color_list

for key in args.keys:
    sig_df = sig_df[sig_df[f"gpt2-{key}"]]
    sig_plot = pd.DataFrame({"electrode":sig_df.electrode,"effect":sig_df.sid})
    fig = make_brainmap_cat(args, sig_plot, args.outfile % key)

Getting effect (area difference)

In [None]:
def get_part_df(label):  # get partial df
    idx = pd.IndexSlice
    part_df = chosen_df.loc[idx[label, :, :, :], :].copy()
    part_df.index = part_df.index.droplevel("label")
    part_df_idx = part_df.index.get_level_values("electrode").tolist()
    return part_df, part_df_idx

chosen_lags = np.arange(-400,-99,25)
lags_show = np.arange(-10000,10001,25)
chosen_lags = [idx for (idx, lag) in enumerate(lags_show) if lag in chosen_lags]
chosen_df = df.loc[:,chosen_lags]
x_vals = [lags_show[lag] / 1000 for lag in chosen_lags]

# Get Effect
chosen_df["area"] = np.trapz(chosen_df, x=x_vals, axis=1) # get area
df1, _ = get_part_df("prob") # get first encoding
df2, _ = get_part_df("improb") # get second encoding
df1["area2"] = df2["area"]
# df1.loc[:, "effect"] = df1["area"] - df1["area2"] # diff
df1.loc[:, "effect"] = (df1["area"] - df1["area2"]) / df1[["area","area2"]].max(axis=1) # norm diff
chosen_df = df1
chosen_df.reset_index(inplace=True)

Brainmap plots for area difference

In [None]:
class Args(argparse.Namespace):
  main_dir = "../data/plotting/brainplot/" # loads coordinate and brain surface files
  project = "tfs"
  sid = [625, 676, 7170, 798] # subjects
  keys = ["comp","prod"] # comprehension and/or production
  brain_type = "ave" # average brain
  hemisphere = "left" # only plot left hemisphere
  outfile = "../tfs_prob-improb_%s.png"

args = Args()

# Customize Your Color Split Here
pos_bar = Colorbar(title="Δ corr pos",colorscale=[[0, "rgb(255,248,240)"], [1, "rgb(255,0,0)"]],bar_max=1)
neg_bar = Colorbar(title="Δ corr neg",colorscale=[[0, "rgb(0,0,255)"], [1, "rgb(240,248,255)"]],bar_min=-1)
args.color_split = [neg_bar,0,pos_bar]

for key in args.keys: # comp/prod
    df_plot = chosen_df.loc[chosen_df.key == key, ("electrode", "effect")]
    fig = make_brainmap(args, df_plot, args.outfile % key) # plot png