In [None]:
import pandas as pd
import numpy as np
import altair as alt
from AFQ.viz.utils import COLOR_DICT
from AFQ.viz.altair import altair_color_dict

In [None]:
profiles = pd.read_csv("hcp_heritability_profiles.csv")
def get_hemi(cc):
    if cc == "L":
        return "Left"
    elif cc == "R":
        return "Right"
    else:
        return "Callosal"

def get_bname(s):
    if s.startswith("Left "):
        return s[5:]
    elif s.startswith("Right "):
        return s[6:]
    return s
profiles["Hemi"] = profiles["tractID"].apply(lambda x: get_hemi(x[-1]))
profiles["Bundle Name"] = profiles["tractID"].replace({
    "ATR_L": "Anterior Thalamic",
    "ATR_R": "Anterior Thalamic",
    "CST_L": "Corticospinal",
    "CST_R": "Corticospinal",
    "CGC_L": "Cingulum Cingulate",
    "CGC_R": "Cingulum Cingulate",
    "IFO_L": "Inferior Fronto-Occipital",
    "IFO_R": "Inferior Fronto-Occipital",
    "ILF_L": "Inferior Longitudinal",
    "ILF_R": "Inferior Longitudinal",
    "SLF_L": "Superior Longitudinal",
    "SLF_R": "Superior Longitudinal",
    "UNC_L": "Uncinate",
    "UNC_R": "Uncinate",
    "ARC_L": "Arcuate",
    "ARC_R": "Arcuate",
    "VOF_L": "Vertical Occipital",
    "VOF_R": "Vertical Occipital",
    "pARC_L": "Posterior Arcuate",
    "pARC_R": "Posterior Arcuate"
})
profiles

In [None]:
position_domain=(20, 80)
column_count=1
font_size=35
line_size=5
legend_line_size=5

this_cd = altair_color_dict(profiles.tractID.unique())
this_cd = {FORMAL_BUNDLE_NAMES.get(
    key, key): value for key, value in this_cd.items()}

alt.data_transformers.disable_max_rows()

profiles = profiles[np.logical_and(
    profiles.nodeID >= position_domain[0],
    profiles.nodeID < position_domain[1])]

tp_units = {
    "DKI AWF": "",
    "DKI FA": "",
    "DKI MD": " (µm²/ms)",
    "DKI MK": ""}

bundle_org = {
    "STANDARD":[
        'Arcuate',
        'Anterior Thalamic',
        'Cingulum Cingulate',
        'Corticospinal',
        'Inferior Fronto-Occipital',
        'Inferior Longitudinal',
        'Superior Longitudinal',
        'Uncinate'],
    "CALLOSAL": [
        'Orbital', 'AntFrontal', 'SupFrontal', 'Motor', 'SupParietal', 
        'PostParietal', 'Temporal', 'Occipital']}

for bundle_category, bundle_list in bundle_org.items():
    
    if bundle_category == "STANDARD":
        color_encoding = alt.Color("Hemi")
    else:
        color_encoding = alt.Color(
            "Bundle Name", 
            scale=alt.Scale(
                domain=bundle_list,
                range=[
                    "rgb(51, 33, 136)", "rgb(18, 120, 51)", "rgb(69, 172, 154)", "rgb(136, 205, 238)",
                    "rgb(223, 205, 120)", "rgb(205, 102, 120)", 
                    "rgb(136, 33, 84)", "rgb(172, 69, 154)", ]))
    row_charts = []
    for jj, b_name in enumerate(bundle_list):
        this_dataframe = profiles[profiles["Bundle Name"] == b_name]
        charts = []
        for ii, tp in enumerate(sorted(["dki_fa", "dki_md", "dki_mk", "dki_awf"])):
            tp_formal = tp.replace("_", " ").upper()
            if jj == 0:
                title_name = tp_formal + tp_units[tp_formal]
            else:
                title_name = ""
            if ii == 0:
                y_axis_title = b_name
            else:
                y_axis_title = ""
            if jj == len(profiles["Bundle Name"].unique()) - 1:
                x_axis_title = "Position (%)"
                useXlab = True
            else:
                x_axis_title = ""
                useXlab = False
            y_kwargs = dict(
                scale=alt.Scale(zero=False, domain=[-0.4, 1.2]),
                # axis=alt.Axis(title=""),
                title=y_axis_title
            )
            x_kwargs = dict(
                axis=alt.Axis(title=x_axis_title, labels=useXlab))
            prof_chart = alt.Chart(
                this_dataframe, title=title_name).mark_line(
                    size=line_size).encode(
                color=color_encoding,
                y=alt.Y(f'{tp}_h2', **y_kwargs),
                x=alt.X('nodeID', **x_kwargs))
            prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line(
                size=line_size - 2, opacity=0.5, strokeDash=[1, 1]).encode(
                color=color_encoding,
                y=alt.Y(f'{tp}_lb', **y_kwargs),
                x=alt.X('nodeID', **x_kwargs))
            prof_chart = prof_chart + alt.Chart(this_dataframe).mark_line(
                size=line_size - 2, opacity=0.5, strokeDash=[1, 1]).encode(
                color=color_encoding,
                y=alt.Y(f'{tp}_ub', **y_kwargs),
                x=alt.X('nodeID', **x_kwargs))
            prof_chart = prof_chart + alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
                stroke='red',
                strokeWidth=2,
                strokeDash=[5, 5]
            ).encode(y='y:Q')
            charts.append(prof_chart)
        row_charts.append(alt.HConcatChart(hconcat=charts))
    chart = alt.VConcatChart(vconcat=row_charts).configure_axis(
        labelFontSize=font_size,
        titleFontSize=font_size,
        labelLimit=0
    ).configure_legend(
        labelFontSize=font_size,
        titleFontSize=font_size,
        titleLimit=0,
        labelLimit=0,
        columns=column_count,
        symbolStrokeWidth=legend_line_size * 10,
        symbolSize=legend_line_size * 100,
        orient='right'
    ).configure_title(
        fontSize=font_size
    )
    chart.save(f"Heritability_{bundle_category}.png", dpi=300)

chart