# Figure 2 - Similarities

## Setup

In [84]:
import pandas as pd
from convergence.plotting import plot_faverage_parcelation, get_hcp_labels, add_area_labels
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as ticker

# Matplotlib arial font
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
#plt.rcParams['svg.fonttype'] = 'none'

## Functions

In [71]:
def save_brain_views(
    df,
    name,
    hemispheres=["lh", "rh"],
    views=["lateral", "medial", "rostral", "caudal", "dorsal", "ventral"],
    hcp=None,
    area_ids=None,
    size=(2*800, 2*600),
    **kwargs,
):
    for hemi in hemispheres:
        brain = plot_faverage_parcelation(df, hemisphere=hemi, size=size, **kwargs)
        if hcp is not None:
            add_area_labels(brain, hcp=hcp, area_ids=area_ids, hemispheres=[hemi])

        for view in views:
            if view == 'lateral':
                if hemi == 'lh':
                    brain.show_view("lateral", azimuth=20, elevation=-100)
                else:
                    brain.show_view("lateral", azimuth=-20, elevation=100)
            elif view == 'medial':
                if hemi == 'lh':
                    brain.show_view("medial",  azimuth=-20, elevation=100)
                else:
                    brain.show_view("medial",  azimuth=20, elevation=-100)
            else:
                brain.show_view(view)
            brain.save_image(f"{name}_{hemi}_{view}.png", mode="rgba")
            #brain.save_image(f"{name}_{hemi}_{view}.tiff", mode="rgba")
        brain.close()

# Brain surfaces

## Brain - Cross-Participant similarities

In [86]:
filename = "data/cross_subject_pairwise_similarities_1_separated.parquet"
df = pd.read_parquet(filename)
hcp = pd.read_csv("data/hcp2.csv")


df_g = df.query("roi_x == roi_y and subject_i != subject_j").groupby(["roi_x", "subject_i"]).aggregate({"score": "mean"}).reset_index()
df_g = df_g.groupby("roi_x").aggregate({"score": "mean"}).reset_index().rename(columns={"roi_x": "roi"})
df_g = df_g.merge(hcp[["mne_name", "roi"]], on="roi")
df_g.score = df_g.score.clip(0, 1)
vlims = (0, 0.18)

# folder = Path("cross_subject_brain") 
# folder.mkdir(exist_ok=True)

# Save without labels
# save_brain_views(
#     df_g,
#     str(folder / "mean_cka_brain_cross_subject"),
#     hemispheres=["lh", "rh"],
#     views=["lateral", "medial", "rostral", "caudal", "dorsal", "ventral"],
    
#     normalize=vlims,
# )

# # Save with labels
#areas = df_g[["roi", "score"]].merge(hcp, on="roi").query("score>0.05").area_id.unique()
# areas = [0,1,2,3,4,5,7]
# save_brain_views(
#     df_g,
#     str(folder / "mean_cka_brain_cross_subject_labels"),
#     normalize=vlims,
#     hcp=hcp,
#     area_ids=areas,
# )

In [88]:
df["shift"].unique()

array([1])

In [81]:
hcp[['area_id', 'area']].drop_duplicates().sort_values('area_id')

Unnamed: 0,area_id,area
0,0,Primary Visual
3,1,Early Visual (V2-4)
6,2,Ventral Visual
2,3,Dorsal Visual
1,4,MT+ Visual Areas
117,5,Medial Temporal
130,6,Lateral Temporal
24,7,TPO
7,8,Somatomotor
35,9,Mid Cingulate


## Self participants

In [None]:
filename = "data/cross_subject_pairwise_similarities_1_separated.parquet"
df = pd.read_parquet(filename)
hcp = pd.read_csv("data/hcp2.csv")


df_g = df.query("roi_x == roi_y and subject_i == subject_j").groupby(["roi_x", "subject_i"]).aggregate({"score": "median"}).reset_index()
df_g = df_g.groupby("roi_x").aggregate({"score": "mean"}).reset_index().rename(columns={"roi_x": "roi"})
df_g = df_g.merge(hcp[["mne_name", "roi"]], on="roi")

vlims = (0, 0.25)

folder = Path("self_subject_brain") 
folder.mkdir(exist_ok=True)

# Save without labels
save_brain_views(
    df_g,
    str(folder / "mean_cka_brain_self_subject"),
    hemispheres=["lh", "rh"],
    views=["lateral", "medial", "rostral", "caudal", "dorsal", "ventral"],
    normalize=vlims,
)

## Brain - Participant - Model similarities

In [None]:
filename = "data/subject_model_similarities_cka.parquet"
df = pd.read_parquet(filename)
hcp2 = pd.read_csv("data/hcp2.csv")

df_m = (
    df.query("not excluded")
    .groupby(["subject", "model_name", "roi", "session", "modality"], observed=True)
    .aggregate({"score": "max"})
    .reset_index()
)
df_m = (
    df_m.groupby(["roi", "modality", "subject", "model_name"], observed=True)
    .aggregate({"score": "median"})
    .reset_index()
)
df_m = (
    df_m.groupby(["roi", "modality", "subject"], observed=True)
    .aggregate({"score": "median"})
    .reset_index()
)
df_m = df_m.groupby(["roi", "modality"], observed=True).aggregate({"score": "mean"}).reset_index()
df_m = df_m.merge(hcp2[["mne_name", "roi"]], on="roi")


folder = Path("subject_model_vision")
folder.mkdir(exist_ok=True)
save_brain_views(
    df=df_m.query("modality == 'vision'"),
    name=str(folder / "cka_brain_vision"),
    normalize=(0, 0.18),
    cmap="hot",
)

folder = Path("subject_model_language")
folder.mkdir(exist_ok=True)
save_brain_views(
    df=df_m.query("modality == 'language'"),
    name=str(folder / "cka_brain_language"),
    normalize=(0, 0.18),
    cmap="hot",
)

## Brain - Cross participants - Paired

In [None]:
df = pd.read_parquet("data/cross_subject_partitions.parquet")
hcp = pd.read_csv("data/hcp2.csv")
df_paired = df.query("metric=='unbiased_cka' and partition=='all' and repetition_shift==0")
df_paired2 = df_paired.copy()
df_paired2["subject_i"] = df_paired["subject_j"]
df_paired2["subject_j"] = df_paired["subject_i"]
df_paired = pd.concat([df_paired, df_paired2])
df_paired = df_paired.groupby(["subject_i", "roi"]).aggregate({"score": "median"}).reset_index()
df_paired = df_paired.groupby("roi").aggregate({"score": "mean"}).reset_index()
df_paired = df_paired.merge(hcp[["mne_name", "roi"]], on="roi")


folder = Path("cross_subject_partitions_paired")
folder.mkdir(exist_ok=True)

save_brain_views(
    df=df_paired,
    name=str(folder / "cka_brain_paired"),
    normalize=(0, 0.18),
    cmap="hot",
)


areas = df_g[["roi", "score"]].merge(hcp, on="roi").query("score>0.05").area_id.unique()


save_brain_views(
    df=df_paired,
    name=str(folder / "cka_brain_paired_labels"),
    normalize=(0, 0.18),
    cmap="hot",
    hcp=hcp,
    area_ids=areas,
)

# Boxplot - ROIs

### Prepare data

In [None]:
# Participants data
df = pd.read_parquet("data/cross_subject_partitions.parquet")
hcp = pd.read_csv("data/hcp2.csv")
hcp = hcp[["name", "roi", "area", "area_id", "area_color", "roi_order"]]
df_paired = df.query("metric=='unbiased_cka' and partition=='all' and repetition_shift==0")
df_paired2 = df_paired.copy()
df_paired2["subject_i"] = df_paired["subject_j"]
df_paired2["subject_j"] = df_paired["subject_i"]
df_paired = pd.concat([df_paired, df_paired2])
df_paired = df_paired.groupby(["subject_i", "roi"]).aggregate({"score": "median"}).reset_index()
df_paired.loc[df_paired.roi>180, "roi"] -= 180
df_paired = df_paired.groupby(["roi", "subject_i"]).aggregate({"score": "mean"}).reset_index()
df_paired = df_paired.merge(hcp, on="roi")


# Models data
filename = "data/subject_model_similarities_cka.parquet"
df = pd.read_parquet(filename)
hcp2 = pd.read_csv("data/hcp2.csv")

df_m = (
    df.query("not excluded")
    .groupby(["subject", "model_name", "roi", "session", "modality"], observed=True)
    .aggregate({"score": "max"})
    .reset_index()
)
df_m = (
    df_m.groupby(["roi", "modality", "subject", "model_name"], observed=True)
    .aggregate({"score": "median"})
    .reset_index()
)
df_m = (
    df_m.groupby(["roi", "modality", "subject"], observed=True)
    .aggregate({"score": "median"})
    .reset_index()
)
df_m.loc[df_m.roi>180, "roi"] -= 180
df_m = df_m.groupby(["roi", "modality", "subject"], observed=True).aggregate({"score": "mean"}).reset_index()
df_m = df_m.merge(hcp, on="roi")


df_language = df_m.query("modality == 'language'")
df_vision = df_m.query("modality == 'vision'")
df_cross = df_paired.copy()

### All rois

In [None]:
order = df_cross.sort_values(["area_id", "roi_order"]).name.unique()
hue_order = df_cross.sort_values("area_id").area.unique()
palette = list(df_cross.sort_values("area_id").area_color.unique())
counts = df_cross.drop_duplicates("name").area.value_counts().to_dict()


# Restore default matplotlib rcParams
# plt.rcParams.update(plt.rcParamsDefault)

def plot_all_rois(df, ax, order, hue_order, palette, title=None, legend=True):

    sns.boxplot(
        data=df,
        x="name",
        hue="area",
        y="score",
        ax=ax,
        order=order,
        hue_order=hue_order,
        palette=palette,
        legend=legend,
    )
    if legend:
        ax.legend(loc="upper right", ncol=2, fontsize=9)

    ax.set_ylabel("Similarity (CKA)")
    ax.set_xlabel("")
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    ax.set_xticks(range(len(order)))
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center", fontsize=7)
    # ax.grid(axis='y', linestyle='--', alpha=0.5)
    for v in range(6):
        color = "maroon" if v == 0 else "gray"
        ax.axhline(v * 0.05, color=color, linestyle="--", lw=0.5, zorder=-1)

    s = 0
    for area in hue_order:
        s = s + counts[area]
        ax.axvline(s - 0.5, color="gray", ls="--", alpha=0.5, lw=0.3, zorder=-10)

    eps = 1
    ax.set_xlim(-eps, len(order) + eps)
    ax.set_ylim(-0.01, 0.25)
    sns.despine(ax=ax)
    if title is not None:
        # Move into the plot
        ax.set_title(title, fontsize=14)


fig, ax = plt.subplots(1, 1, figsize=(20, 4), dpi=300)
plot_all_rois(df_cross, ax, order, hue_order, palette, title="Cross-subject", legend=True)
fig.savefig("cross_subject_rois_all.svg", bbox_inches="tight", transparent=True)

fig, ax = plt.subplots(1, 1, figsize=(20, 4), dpi=300)
plot_all_rois(df_vision, ax, order, hue_order, palette, title="Brain - Vision Models", legend=False)
fig.savefig("vision_rois_all.svg", bbox_inches="tight", transparent=True)

fig, ax = plt.subplots(1, 1, figsize=(20, 4), dpi=300)
plot_all_rois(
    df_language, ax, order, hue_order, palette, title="Brain - Language Models", legend=False
)
fig.savefig("language_rois_all.svg", bbox_inches="tight", transparent=True)

plt.show()

## Selected ROIs

In [None]:

def plot_all_rois(df, ax, order, hue_order, palette, title=None, legend=True, plot_zero=False, tickfontsize=10):

    sns.boxplot(
        data=df,
        x="name",
        hue="area",
        y="score",
        ax=ax,
        order=order,
        hue_order=hue_order,
        palette=palette,
        legend=legend,
    )
    if legend:
        ax.legend(loc="upper center", ncol=2, fontsize=9)

    ax.set_ylabel("Similarity (CKA)")
    ax.set_xlabel("")
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    ax.set_xticks(range(len(order)))
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center", fontsize=tickfontsize)
    # ax.grid(axis='y', linestyle='--', alpha=0.5)
    for v in range(6):
        color = "maroon" if v == 0 else "gray"
        ax.axhline(v * 0.05, color=color, linestyle="--", lw=0.5, zorder=-1)

    counts = df.drop_duplicates("name").area.value_counts().to_dict()
    s = 0
    for area in hue_order[:-1]:
        s = s + counts[area]
        
        ax.axvline(s - 0.5, color="gray", ls="--", alpha=0.5, lw=0.3, zorder=-10)

    eps = 0.8
    ax.set_xlim(-eps, len(order))

    if plot_zero:
        ax.set_ylim(-0.01, 0.25)
    else:
        ax.set_ylim(0, 0.25)
    sns.despine(ax=ax)
    if title is not None:
        # Move into the plot
        ax.set_title(title, fontsize=14)

n = 19
rois = (
    df_cross.groupby("roi")
    .aggregate({"score": "mean"})
    .reset_index()
    .sort_values("score", ascending=False)
    .head(n)
    .roi.to_list()
)

# Restore default matplotlib rcParams
plt.rcParams.update(plt.rcParamsDefault)

df_cross_selected = df_cross.query("roi in @rois")
df_vision_selected = df_vision.query("roi in @rois")
df_language_selected = df_language.query("roi in @rois")

order_selected = df_cross_selected.sort_values(["area_id", "roi_order"]).name.unique()
hue_order_selected = df_cross_selected.sort_values("area_id").area.unique()
palette_selected = list(df_cross_selected.sort_values("area_id").area_color.unique())
counts_selected = df_cross_selected.drop_duplicates("name").area.value_counts().to_dict()


fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 4), dpi=300, sharey=True)
fig.subplots_adjust(wspace=0.07)

plot_all_rois(df_cross_selected, ax1, order_selected, hue_order_selected, palette_selected, title="Cross-subject", legend=False)
plot_all_rois(df_vision_selected, ax2, order_selected, hue_order_selected, palette_selected, title="Vision Models", legend=True)
plot_all_rois(df_language_selected, ax3, order_selected, hue_order_selected, palette_selected, title="Language Models", legend=False)
fig.savefig("selected_rois_all.svg", bbox_inches="tight", transparent=True)
plt.show()


### Scatter

In [None]:
agg = "mean"
df_cross_agg = (
    df_cross.groupby(["roi"])
    .aggregate({"score": agg})
    .reset_index()
    .rename(columns={"score": "cross_subject"})
)
df_vision_agg = (
    df_vision.groupby(["roi"])
    .aggregate({"score": agg})
    .reset_index()
    .rename(columns={"score": "vision"})
)
df_language_agg = (
    df_language.groupby(["roi"])
    .aggregate({"score": agg})
    .reset_index()
    .rename(columns={"score": "language"})
)

df_agg = df_cross_agg.merge(df_vision_agg, on="roi").merge(df_language_agg, on="roi")
df_agg = df_agg.merge(hcp, on="roi")


def plot_scatter(
    df_agg, x="cross_subject", y="vision", ax=None, legend=False, xlabel=None, ylabel=None, ylim=(-0.005, 0.155), ymult=0.05, xlim=(-0.005, 0.205), xmult=0.05, kind="scatter", add_text=True, **kwargs
):

    hue_order = df_agg.sort_values("area_id").area.unique()
    palette = list(df_agg.sort_values("area_id").area_color.unique())
    if kind == "scatter":
        sns.scatterplot(
            data=df_agg,
            x=x,
            y=y,
            hue="area",
            palette=palette,
            ax=ax,
            hue_order=hue_order,
            legend=False,
            **kwargs,
        )

    sns.despine(ax=ax)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

    ax.grid(axis="both", linestyle="--", alpha=0.5)
    ax.set_xlim(xlim)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(xmult))
    ax.xaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    ax.set_ylim(ylim)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(ymult))
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(1, 0))

    # Add text labels with roi names
    if add_text:
        for i, row in df_agg.iterrows():
            if (row[x] > 0.04 and row[y] > 0.04) or (row[x] > 0.06 or row[y] > 0.06):
                ax.text(row[x], row[y], row["name"], fontsize=8, ha="center", va="center")


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
plot_scatter(
    df_agg,
    ax=ax1,
    x="cross_subject",
    y="vision",
    xlabel="Cross-subject similarity (CKA)",
    ylabel="Vision models similarity (CKA)",
)

plot_scatter(
    df_agg,
    ax=ax2,
    x="cross_subject",
    y="language",
    xlabel="Cross-subject similarity (CKA)",
    ylabel="Language models similarity (CKA)",
)


fig.savefig("scatter_vision_language.svg", bbox_inches="tight", transparent=True)
plt.show()

In [None]:

def plot_joint(df_agg, x="cross_subject", y="vision"):
    g = sns.jointplot(
        data=df_agg,
        x=x,
        y=y,
        
        palette=palette,
        
        kind="scatter",
        legend=False,
        marginal_kws=dict(bins=25, color="gray"),
    )
    g.ax_joint.grid(axis="both", linestyle="--", alpha=0.5, zorder=-10)
    g.ax_joint.set_xlim(-0.005, 0.205)
    g.ax_joint.set_ylim(-0.005, 0.155)
    g.ax_joint.xaxis.set_major_locator(ticker.MultipleLocator(0.05))
    g.ax_joint.xaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    g.ax_joint.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
    g.ax_joint.yaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    g.ax_joint.set_xlabel("Cross-subject similarity (CKA)")
    g.ax_joint.set_ylabel(f"{y.title()} models similarity (CKA)")

    # Set same lims for marginal plots
    g.ax_marg_x.set_ylim(0, 150)
    g.ax_marg_y.set_xlim(0, 150)


    return g


g = plot_joint(df_agg, y="vision", x="cross_subject")
g.savefig("joint_vision.svg", bbox_inches="tight", transparent=True)

g = plot_joint(df_agg, y="language", x="cross_subject")
g.savefig("joint_language.svg", bbox_inches="tight", transparent=True)

plt.show()



## Bilateral

In [None]:
# Participants data
df = pd.read_parquet("data/cross_subject_partitions.parquet")
hcp = pd.read_csv("data/hcp2.csv")
hcp = hcp[["name", "roi", "area", "area_id", "area_color", "roi_order"]]
df_paired = df.query("metric=='unbiased_cka' and partition=='all' and repetition_shift==0")
df_paired2 = df_paired.copy()
df_paired2["subject_i"] = df_paired["subject_j"]
df_paired2["subject_j"] = df_paired["subject_i"]
df_paired = pd.concat([df_paired, df_paired2])
df_paired = df_paired.groupby(["subject_i", "roi"]).aggregate({"score": "median"}).reset_index()
df_paired = df_paired.groupby(["roi", "subject_i"]).aggregate({"score": "mean"}).reset_index()
df_paired = df_paired.merge(hcp, on="roi")


# Models data
filename = "data/subject_model_similarities_cka.parquet"
df = pd.read_parquet(filename)
hcp2 = pd.read_csv("data/hcp2.csv")

df_m = (
    df.query("not excluded")
    .groupby(["subject", "model_name", "roi", "session", "modality"], observed=True)
    .aggregate({"score": "max"})
    .reset_index()
)
df_m = (
    df_m.groupby(["roi", "modality", "subject", "model_name"], observed=True)
    .aggregate({"score": "median"})
    .reset_index()
)
df_m = (
    df_m.groupby(["roi", "modality", "subject"], observed=True)
    .aggregate({"score": "median"})
    .reset_index()
)


df_language = df_m.query("modality == 'language'").rename(columns={"score": "language"})
df_vision = df_m.query("modality == 'vision'").rename(columns={"score": "vision"})
df_cross = df_paired.copy().rename(columns={"score": "cross_subject"})
df_merge = df_cross.merge(df_vision, on=["roi"])
df_merge = df_merge.merge(df_language, on=["roi"])
df_merge = df_merge.drop(columns=["modality_y", "modality_x", "subject_x", "subject_y"])
df_merge.loc[df_merge.roi <= 180, "name"] = "L " + df_merge.loc[df_merge.roi <= 180, "name"]
df_merge.loc[df_merge.roi > 180, "name"] = "R " + df_merge.loc[df_merge.roi > 180, "name"]
df_merge["hemisphere"] = (df_merge.roi > 180).replace({True: "RH", False: "LH"})
df_merge = df_merge.rename(columns={"subject_i": "subject"})
columns = [
    "roi",
    "area",
    "name",
    "subject",
    "cross_subject",
    "vision",
    "language",
    "hemisphere",
    "area_id",
    "area_color",
    "roi_order",
]
df_merge = df_merge[columns]

In [None]:
def plot_all_rois(df, ax, order, hue_order, palette, y="score", title=None, legend=True, vmax=0.25):

    sns.boxplot(
        data=df,
        x="name",
        hue="area",
        y=y,
        ax=ax,
        order=order,
        hue_order=hue_order,
        palette=palette,
        legend=legend,
    )
    if legend:
        ax.legend(loc="upper right", ncol=2, fontsize=9)

    ax.set_ylabel("Similarity (CKA)")
    ax.set_xlabel("")
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    ax.set_xticks(range(len(order)))
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha="center", fontsize=6)
    # ax.grid(axis='y', linestyle='--', alpha=0.5)
    if vmax > 0.25:
        r = 7
    else:
        r = 6
    for v in range(r):
        color = "maroon" if v == 0 else "gray"
        ax.axhline(v * 0.05, color=color, linestyle="--", lw=0.5, zorder=-1)

    s = 0
    for area in hue_order:
        s = s + counts[area]
        ax.axvline(s - 0.5, color="gray", ls="--", alpha=0.5, lw=0.3, zorder=-10)

    eps = 1
    ax.set_xlim(-eps, len(order) + eps)
    ax.set_ylim(-0.01, vmax)
    sns.despine(ax=ax)
    if title is not None:
        # Move into the plot
        ax.set_title(title, fontsize=14)



# Restore default matplotlib rcParams
plt.rcParams.update(plt.rcParamsDefault)



titles = {
    "cross_subject": "Cross-subject",
    "vision": "Vision Models",
    "language": "Language Models",
}
for y in ["cross_subject", "vision", "language"]:
    for hemisphere in ["LH", "RH"]:
        df_hemisphere = df_merge.query("hemisphere==@hemisphere")
        order = df_hemisphere.sort_values(["area_id", "roi_order"]).name.unique()
        hue_order = df_hemisphere.sort_values("area_id").area.unique()
        palette = list(df_hemisphere.sort_values("area_id").area_color.unique())
        counts = df_hemisphere.drop_duplicates("name").area.value_counts().to_dict()    

        if y == "cross_subject" and hemisphere == "RH":
            vmax = 0.3
            v = 4 * (0.3 / 0.25)
        else:
            vmax = 0.25
            v = 4
        fig, ax = plt.subplots(1, 1, figsize=(20, v), dpi=300)
        plot_all_rois(df_hemisphere, ax, order, hue_order, palette, y=y, title=titles[y] + f" ({hemisphere})", legend=True, vmax=vmax)
        fig.savefig(f"{y}_{hemisphere}_rois_all.svg", bbox_inches="tight", transparent=True)
        plt.show()


In [None]:



def plot_joint(df_agg, x="cross_subject", y="vision", **kwargs):
    g = sns.jointplot(
        data=df_agg,
        x=x,
        y=y,
        kind="scatter",
        legend=False,
        **kwargs,
    )
    g.ax_joint.grid(axis="both", linestyle="--", alpha=0.5, zorder=-10)
    g.ax_joint.set_xlim(-0.005, 0.205)
    g.ax_joint.set_ylim(-0.005, 0.155)
    g.ax_joint.xaxis.set_major_locator(ticker.MultipleLocator(0.05))
    g.ax_joint.xaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    g.ax_joint.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
    g.ax_joint.yaxis.set_major_formatter(ticker.PercentFormatter(1, 0))
    g.ax_joint.set_xlabel("Cross-subject similarity (CKA)")
    g.ax_joint.set_ylabel(f"{y.title()} models similarity (CKA)")

    # Set same lims for marginal plots
    #g.ax_marg_x.set_ylim(0, 150)
    #g.ax_marg_y.set_xlim(0, 150)

    for i, row in df_agg.iterrows():
        if (row[x] > 0.04 and row[y] > 0.04) or (row[x] > 0.06 or row[y] > 0.06):
            g.ax_joint.text(row[x], row[y], row["name"], fontsize=8, ha="center", va="center")
    


    return g
    

df_merge_g = df_merge.groupby(["roi", "name", "hemisphere","area", "area_id", "area_color", "roi_order"]).aggregate(
    {"cross_subject": "mean", "vision": "mean", "language": "mean"}
).reset_index()

import numpy as np

palette = list(df_merge_g.sort_values("area_id").area_color.unique())
hue_order = list(df_merge_g.sort_values("area_id").area.unique())

for y in ["vision", "language"]:
    for hemisphere in ["RH", "LH"]:
        g = plot_joint(df_merge_g.query("hemisphere==@hemisphere"), x="cross_subject", y=y, hue="area", hue_order=hue_order, palette=palette)
        g.savefig(f"joint_{y}_{hemisphere}.svg", bbox_inches="tight", transparent=True)
        g = plot_joint(df_merge_g.query("hemisphere==@hemisphere"), x="cross_subject", y=y, marginal_kws=dict(bins=np.linspace(0, 0.2, 30), color="gray"))
        g.ax_marg_x.set_ylim(0, 150)
        g.ax_marg_y.set_xlim(0, 150)
        g.savefig(f"joint_{y}_{hemisphere}_marginal.svg", bbox_inches="tight", transparent=True)
plt.show()