In [None]:
import os
import glob
import numpy as np
import pandas as pd
import argparse
import matplotlib.pyplot as plt

from matplotlib.backends.backend_pdf import PdfPages
from tfsplt_utils import read_folder
from tfsplt_brainmap import Colorbar, make_brainmap
from tfsplt_brainmap_cat import make_brainmap_cat

## Data Loading

In [None]:
# Load Sig Elec Files
sig_df_whisper = pd.read_csv("../data/plotting/paper-whisper/data/base_df.csv")

sig_df_gpt2 = pd.read_csv("../data/plotting/sig-elecs/20230723-tfs-sig-file/tfs-sig-file.csv").iloc[:,1:]
sig_df_gpt2.rename(columns={"patient":"sid","electrode":"elec_1","prod_significant":"gpt2-prod","comp_significant":"gpt2-comp"},inplace=True)
sig_df_gpt2 = sig_df_gpt2.loc[sig_df_gpt2.model == "glove",("sid","elec_1","gpt2-prod","gpt2-comp")]

sig_df = sig_df_whisper.merge(sig_df_gpt2,how="left",on=["sid","elec_1"])
sig_df = sig_df.fillna({"gpt2-comp":False,"gpt2-prod":False})
sig_df["electrode"] = sig_df.sid.astype(str) + "_" + sig_df.elec_1

In [None]:
class Args(argparse.Namespace):
  sid = [625, 676, 7170, 798] # subjects
  formats = [ # encoding folder
    "../data/encoding/tfs/20230520-whisper-medium/kw-tfs-full-%s-whisper-medium.en-encoder-lag5k-25-all-%s/*/*%s.csv",
    "../data/encoding/tfs/20230520-whisper-medium/kw-tfs-full-%s-whisper-medium.en-decoder-lag5k-25-all-%s/*/*%s.csv",
    "../data/encoding/tfs/20230701-gpt2-medium-70-n/kw-tfs-full-%s-gpt2-medium-lag5k-25-all-shift-emb-%s/*/*%s.csv",
  ]
  labels = [
    "whisper-en",
    "whisper-de",
    "gpt2",
  ]
  keys = ["comp","prod"] # comprehension and/or production
  layers = np.arange(0,25)
  lags_plot = np.arange(-5000,5001,25) # encoding lags
  lags_show = np.arange(-2000,2001,25) # lags for the effect
  sigelecs = {}

# Aggregate Data
args = Args()

data = []
for load_sid in args.sid: # loop through subjects
    for label, format in zip(args.labels, args.formats): # labels/formats
        for key in args.keys: # comp/prod
            for layer in args.layers: # layers
                fname = format % (load_sid, f"{layer:02}", key)
                data = read_folder(
                    data,
                    fname,
                    args.sigelecs,
                    (load_sid, key),
                    load_sid,
                    key,
                    f"{layer:02}",
                    label,
                    True,
                )
df = pd.concat(data)

In [None]:
# Trim datum if needed
chosen_lag_idx = [
    idx
    for idx, element in enumerate(args.lags_plot)
    if element in args.lags_show
]
columns = ["electrode","key","label1","label2"]
df_trimmed = df.loc[:, columns + chosen_lag_idx]
plot_lags = args.lags_show / 1000

In [None]:
# Areas
ROIS = {
    "preCG": ["preCG"],
    "postCG": ["postCG"],
    "SM": ["preCG","postCG"],
    "TP": ["TP"],
    "STG": ["STG"],
    "IFG": ["IFG"],
    "pMTG": ["pMTG"],
    "AG": ["AG"],
}
RIOS2 = {
    "dSM": ["dSM"],
    "mSM": ["mSM"],
    "vSM": ["vSM"],
}

# Sig Dictionary
SIG_DICT = {
    "whisper-en":"whisper-en-last-0.01",
    "whisper-de":"whisper-de-best-0.01",
    "gpt2":"gpt2",
}

In [None]:
# Print number of sig electrodes
for sig, sig_tag in SIG_DICT.items():
    print(f"{sig} Comp:", sig_df[f"{sig_tag}-comp"].sum())
    print(f"{sig} Prod:", sig_df[f"{sig_tag}-prod"].sum())

## Area Average Encoding

In [None]:
cmap = plt.cm.get_cmap("viridis")
colors = [cmap(i / len(args.layers)) for i in args.layers]

for key in args.unique_keys: # comp/prod
    for sig_label, sig_tag in SIG_DICT.items(): # labels
        for area_name, area in ROIS.items(): # area
            # Get area sig elec df
            sig_df_filter = sig_df[sig_df[f"{sig_tag}-{key}"]] # key sig
            sig_df_filter = sig_df_filter[sig_df_filter.roi_1.isin(area)] # area

            # Get area sig encoding df
            df_filter = df_trimmed[df_trimmed.key == key] # key
            df_filter = df_filter[df_filter.label2 == sig_label] # label
            df_filter = df_filter.merge(sig_df_filter["electrode"], how="inner", on=["electrode"])
            print(f"{key}, {sig_label}, {area_name}, Sig {len(sig_df_filter)}, Plotting {len(df_filter)}")
            
            # Plotting
            fig, ax = plt.subplots()
            for layer, subdf in df_filter.groupby(df_filter.label1):
                subdf.drop(columns=["electrode","key","label1","label2"],inplace=True)
                ax.plot(plot_lags, subdf.mean(axis=0), color=colors[int(layer)])
                ax.set(
                    xlabel="Lags(s)",
                    ylabel="Correlation (r)",
                    title=f"{area_name}-{key} ({len(subdf)})",
                )
                ax.set_ylim(0, 0.3)
                ax.set_yticks([0, 0.15, 0.3])
                ax.set_yticklabels([0, 0.2, 0.4])
            plt.savefig(f"../enc_{area_name}_{sig_label}_{key}.png")
            plt.close()

## Area Max Average Across Layers

In [None]:
for key in args.unique_keys: # comp/prod
    for area_name, area in ROIS.items(): # area
        fig, ax = plt.subplots()
        for sig_label, sig_tag in SIG_DICT.items(): # labels
            # Get area sig elec df
            sig_df_filter = sig_df[sig_df[f"{sig_tag}-{key}"]] # key sig
            sig_df_filter = sig_df_filter[sig_df_filter.roi_1.isin(area)] # area

            # Get area sig encoding df
            df_filter = df_trimmed[df_trimmed.key == key] # key
            df_filter = df_filter[df_filter.label2 == sig_label] # label
            df_filter = df_filter.merge(sig_df_filter["electrode"], how="inner", on=["electrode"])
            print(f"{key}, {sig_label}, {area_name}, Sig {len(sig_df_filter)}, Plotting {len(df_filter)}")

            # max correlation
            max_cor = df_filter.groupby(df_filter.label1).mean().max(axis=1).tolist()
            if sig_label == "whisper-en":
                facecol = colors
                edgecol = "none"
            elif sig_label == "whisper-de":
                facecol = "none"
                edgecol = "grey"
            elif sig_label == "gpt2":
                facecol = "grey"
                edgecol = "none"
            ax.scatter(
                args.layers,
                max_cor,
                s=100,
                facecolors=facecol,
                edgecolors=edgecol,
                marker="o",
            )
        ax.set(
            xlabel="Layers",
            ylabel="Max Correlation (r)",
            title=f"{area_name}-{key}",
        )
        ax.set_ylim(0, 0.3)
        plt.savefig(f"../max-{area_name}-{key}.png")
        plt.close()


## Max Correlation Brainmaps Across Layers

In [None]:
SIG_DICT = {
    "whisper-en":"whisper-en-last-0.01",
    # "whisper-de":"whisper-de-best-0.01",
    # "gpt2":"gpt2",
}

LAYERS = [
    0,
    1,
    8,
    12,
    16,
    24,
]

class Args(argparse.Namespace):
  main_dir = "../data/plotting/brainplot/" # loads coordinate and brain surface files
  project = "tfs"
  unique_keys = ["prod","comp"]
  brain_type = "ave" # average brain
  hemisphere = "left" # only plot left hemisphere
  outfile = "../%s_max_%s_%s.png"

args = Args()
args.color_split = [Colorbar(title="max-cor",colorscale="viridis",bar_min=0,bar_max=0.4)]

for key in args.unique_keys: # comp/prod
    for sig_label, sig_tag in SIG_DICT.items(): # labels
        # Get area sig elec df
        sig_df_filter = sig_df[sig_df[f"{sig_tag}-{key}"]] # key sig

        # Get area sig encoding df
        df_filter = df_trimmed[df_trimmed.key == key] # key
        df_filter = df_filter[df_filter.label2 == sig_label] # label
        df_filter = df_filter.merge(sig_df_filter["electrode"], how="inner", on=["electrode"])
        print(f"{key}, {sig_label}, Sig {len(sig_df_filter)}, Plotting {len(df_filter)}")

        # Plot Brainmap
        for layer in LAYERS:
            df_plot = df_filter.loc[df_filter.label1 == f"{layer:02}",:]
            df_plot = df_plot.drop(columns=["key","label1","label2"])
            df_plot["effect"] = df_plot.iloc[:,1:].max(axis=1)
            fig = make_brainmap(args, df_plot, args.outfile % (sig_label,key,f"{layer:02}")) # plot png

## Correlation Brainmaps across lags

In [None]:
SIG_DICT = {
    "whisper-en":"whisper-en-last-0.01",
    # "whisper-de":"whisper-de-best-0.01",
    # "gpt2":"gpt2",
}

LAGS = [
    -400,
    -200,
    0,
    200,
    400,
]

LAYERS = [
    # 0,
    # 12,
    24,
]

class Args(argparse.Namespace):
  main_dir = "../data/plotting/brainplot/" # loads coordinate and brain surface files
  project = "tfs"
  unique_keys = ["prod","comp"]
  lags_plot = np.arange(-5000,5001,25) # encoding lags
  brain_type = "ave" # average brain
  hemisphere = "left" # only plot left hemisphere
  outfile = "../%s_%s_%s_%s.png"

args = Args()
args.color_split = [Colorbar(title="max-cor",colorscale="viridis",bar_min=0,bar_max=0.4)]

for key in args.unique_keys: # comp/prod
    for sig_label, sig_tag in SIG_DICT.items(): # labels
        # Get area sig elec df
        sig_df_filter = sig_df[sig_df[f"{sig_tag}-{key}"]] # key sig

        # Get area sig encoding df
        df_filter = df_trimmed[df_trimmed.key == key] # key
        df_filter = df_filter[df_filter.label2 == sig_label] # label
        df_filter = df_filter.merge(sig_df_filter["electrode"], how="inner", on=["electrode"])
        print(f"{key}, {sig_label}, Sig {len(sig_df_filter)}, Plotting {len(df_filter)}")

        # Plot Brainmap
        for layer in LAYERS:
            for lag in LAGS:
                df_plot = df_filter.loc[df_filter.label1 == f"{layer:02}",:]
                df_plot = df_plot.drop(columns=["key","label1","label2"])
                lag_col = np.where(args.lags_plot == lag)[0][0]
                df_plot["effect"] = df_plot[lag_col]
                fig = make_brainmap(args, df_plot, args.outfile % (sig_label,key,f"{layer:02}",lag)) # plot png