In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
import joblib
from scipy.special import softmax
import sasc.config
import numpy as np
from collections import defaultdict
from copy import deepcopy
import pandas as pd
import story_helper
from sasc.modules.fmri_module import convert_module_num_to_voxel_num

In [None]:
story_mapping = {
    "uts02_pilot_gpt4_mar28___ver=v4_noun___seed=1": "GenStory1_resps.npy",
    "uts02_pilot_gpt4_mar28___ver=v4_noun___seed=3": "GenStory2_resps.npy",
    "uts02_pilot_gpt4_mar28___ver=v4_noun___seed=4": "GenStory3_resps.npy",
    "uts02_pilot_gpt4_mar28___ver=v5_noun___seed=1": "GenStory4_resps.npy",
    "uts02_pilot_gpt4_mar28___ver=v5_noun___seed=2": "GenStory5_resps.npy",
    "uts02_pilot_gpt4_mar28___ver=v5_noun___seed=4": "GenStory6_resps.npy",
}

STORIES_DIR = join(sasc.config.RESULTS_DIR, "pilot_v1")
story_names = story_mapping.keys()  # os.listdir(STORIES_DIR)

story_data = defaultdict(list)
for story_name in story_names:
    story_data["timing"].append(
        pd.read_csv(join(STORIES_DIR, story_name, "timings_processed.csv"))
    )
    story_data["story_name_original"].append(story_name)
    story_data["story_name_new"].append(story_mapping[story_name])
    story_data["story_text"].append(
        open(join(STORIES_DIR, story_name, "story.txt"), "r").read()
    )
    story_data["prompts"].append(
        open(join(STORIES_DIR, story_name, "prompts.txt"), "r").read().split("\n\n")
    )

    # rows
    rows = pd.read_csv(join(STORIES_DIR, story_name, "rows.csv"))
    rows["voxel_num"] = rows.apply(
        lambda row: convert_module_num_to_voxel_num(row["module_num"], row["subject"]),
        axis=1,
    )
    rows = rows[
        [
            "expl",
            "module_num",
            "top_explanation_init_strs",
            "subject",
            "fmri_test_corr",
            "top_score_synthetic",
            "roi_anat",
            "roi_func",
            "voxel_num",
        ]
    ]
    story_data["rows"].append(rows)
joblib.dump(story_data, "../results/pilot_story_data.pkl")

In [None]:
# load data and corresponding resps
pilot_data_dir = '/home/chansingh/mntv1/deep-fMRI/story_data/20230504'
resp_np_files = sorted(os.listdir(pilot_data_dir))
resps_dict = {
    k: np.load(join(pilot_data_dir, k))
    for k in tqdm(resp_np_files)
}

# Look at heatmaps

In [None]:
mats = []

for story_num in [0, 1, 2, 3, 4, 5]:
    rows = story_data["rows"][story_num]  

    # get resp_chunks
    resp_story = resps_dict[story_data["story_name_new"][story_num]].T  # (voxels, time)
    timing = story_data["timing"][story_num]
    paragraphs = story_data["story_text"][story_num].split("\n\n")
    assert len(paragraphs) == len(rows)
    resp_chunks = story_helper.get_resp_chunks(timing, paragraphs, resp_story)

    # calculate mat
    mat = np.zeros((len(rows), len(paragraphs)))
    for i in range(len(paragraphs)):
        mat[:, i] = resp_chunks[i][rows["voxel_num"].values].mean(axis=1)
    mat[:, 0] = np.nan # ignore the first column
    
    # sort by voxel_num
    args = np.argsort(rows["voxel_num"].values)
    mat = mat[args, :][:, args]
    mats.append(deepcopy(mat))

    # plt.imshow(mat)
    # plt.colorbar(label="Mean response")
    # plt.xlabel("Corresponding paragraph\n(Ideally, diagonal should be brighter)")
    # plt.ylabel("Voxel")
    # plt.title(f"{story_data['story_name_new'][story_num][3:-10]}")
    # plt.show()
expls = rows.sort_values(by="voxel_num")["expl"].values

mats = np.array(mats)  # (6, 17, 17)
m = np.nanmean(mats, axis=0)

## Make average plot

In [None]:
n = m.shape[0]
diag_means = np.diag(m)
diag_mean = np.nanmean(diag_means)

# get mean of each row excluding the diagonal
off_diag_means = m.mean(axis=1) - (diag_means / n)
off_diag_mean = off_diag_means.mean()

In [None]:
x = np.arange(n) - n / 2

plt.bar(1, diag_mean, width=0.5, label='Diagonal', alpha=0.1, color='C0')
plt.errorbar(1, diag_mean, yerr=diag_means.std() / np.sqrt(len(diag_means)), fmt='.', label='Diagonal', ms=0, color='black', elinewidth=3, capsize=5, lw=1)
plt.plot(1 + x/50, diag_means, '.', color='C0', alpha=0.5)

plt.bar(2, off_diag_mean, width=0.5, label='Off-diagonal', alpha=0.1, color='C1')
plt.errorbar(2, off_diag_mean, yerr=off_diag_means.std() / np.sqrt(len(off_diag_means)), fmt='.', label='Diagonal', ms=0, color='black', elinewidth=3, capsize=5)
plt.plot(2 + x/50, off_diag_means, '.', color='C1')

plt.xticks([1, 2], ['Driving paragraph', 'Baseline paragraphs'])
plt.ylabel('Mean voxel response ($\sigma_f$)')
plt.grid(axis='y')

# annotate the point with the highest mean
kwargs = dict(
    arrowprops=dict(arrowstyle='->', color='#333'), fontsize='x-small', color='#333'
)
idx = np.argmax(diag_means)
plt.annotate(f"{expls[idx]}", (1 + x[idx]/50, diag_means[idx]), xytext=(1.1, diag_means[idx] + 0.1), **kwargs)

# annotate the point with the second highest mean
idx = np.argsort(diag_means)[-2]
plt.annotate(f"{expls[idx]}", (1 + x[idx]/50, diag_means[idx]), xytext=(1.1, diag_means[idx] + 0.1), **kwargs)

# annotate the point with the lowest mean
idx = np.argmin(diag_means)
plt.annotate(f"{expls[idx]}", (1 + x[idx]/50, diag_means[idx]), xytext=(1.1, diag_means[idx]), **kwargs)
plt.tight_layout()
print('mean', diag_mean - off_diag_mean)
plt.savefig('../results/pilot_means.pdf')

## Relationship between different voxels

In [None]:
# cg = sns.clustermap(pd.DataFrame(m, columns=expls, index=expls), method='complete', cmap='viridis', figsize=(10, 10))
# plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
# plt.xlabel('Driving paragraph')
plt.figure(figsize=(6, 6))
# m = softmax(m, axis=0)


for r in range(m.shape[0]):
    for c in range(m.shape[1]):
        # outline the diagonal
        if r == c:
            plt.plot([r - 0.5, r + 0.5], [c - 0.5, c - 0.5], color='gray', lw=1)
            plt.plot([r - 0.5, r + 0.5], [c + 0.5, c + 0.5], color='gray', lw=1)
            plt.plot([r - 0.5, r - 0.5], [c - 0.5, c + 0.5], color='gray', lw=1)
            plt.plot([r + 0.5, r + 0.5], [c - 0.5, c + 0.5], color='gray', lw=1)
        


plt.imshow(m)
plt.xlabel("Driving paragraph\n(Ideally, diagonal should be brighter)", fontsize='x-small')
plt.ylabel("Voxel", fontsize='x-small')
plt.yticks(labels=expls, ticks=np.arange(len(expls)), fontsize='x-small')
plt.xticks(labels=expls, ticks=np.arange(len(expls)), rotation=90, fontsize='x-small')
plt.show()

# plot correlations across all resps
# resps_voxels = np.concatenate(
#     [resps_dict[story_data["story_name_new"][story_num]].T for story_num in [2, 3, 4]],
#     axis=1,
# )[rw["voxel_num"].values]
# corr = pd.DataFrame(resps_voxels.T, columns=expls).corr().round(2)
# sns.clustermap(corr)

### Story-level differences

In [None]:
d = defaultdict(list)
for i in range(len(mats)):
    m = mats[i]
    d['driving'].append(np.nanmean(np.diag(m)))
    d['baseline'].append(np.nanmean(m[~np.eye(m.shape[0], dtype=bool)]))
    d['story'].append(story_data['story_name_new'][i][3:-10])

df = pd.DataFrame.from_dict(d)

# make barplot comparing driving and baseline
df = df.melt(id_vars='story', value_vars=['driving', 'baseline'], var_name='condition', value_name='mean')
df = df.sort_values(by='story')
sns.barplot(data=df, x='story', y='mean', hue='condition')
plt.ylabel('Mean voxel response ($\sigma_f$)')
plt.show()

### Voxel-level differences

In [None]:
rw = rw.sort_values(by="voxel_num")
rw['mean_resp_diff'] = diag_means # - off_diag_means

# ax = sns.pairplot(rw, vars=['mean_resp_diff', 'top_score_synthetic', 'fmri_test_corr'], hue='expl')
# sns.move_legend(ax, "upper left", bbox_to_anchor=(1.01, 1))

plt.figure(figsize=(7, 4))
plt.subplot(1, 2, 1)
plt.plot(rw['top_score_synthetic'], rw['mean_resp_diff'], 'o')
plt.ylabel('Mean voxel response')
plt.xlabel('Synthetic score')

plt.subplot(1, 2, 2)
plt.plot(rw['fmri_test_corr'], rw['mean_resp_diff'], 'o')
plt.xlabel('Predicted test correlation')

plt.tight_layout()
plt.show()