In [2]:
import pandas as pd
import altair as alt
from AFQ.viz.utils import FORMAL_BUNDLE_NAMES, COLOR_DICT
from AFQ.viz.altair import altair_color_dict

In [3]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [5]:
figure_groups = {
    "part1": ['CogCrystalComp_Unadj', 'CogFluidComp_Unadj', 'CogTotalComp_Unadj', 'ReadEng_Unadj', 'VSPLOT_TC'],
    "part2":['Age_in_Yrs', 'DDisc_AUC_200', 'Endurance_Unadj', 'IWRD_TOT', 'SCPT_SEN']
}

pheno_formal_labels = {
    'Age_in_Yrs': 'Age',
    'CogCrystalComp_Unadj': 'Crystalized Intelligence',
    'CogFluidComp_Unadj': 'Fluid Intelligence',
    'CogTotalComp_Unadj': 'Global Intelligence',
    'DDisc_AUC_200': 'Impulsivity',
    'Endurance_Unadj': 'Endurance',
    'IWRD_TOT': 'Verbal Memory',
    'ReadEng_Unadj': 'Reading Ability',
    'SCPT_SEN': 'Attention',
    'VSPLOT_TC': 'Spatial Orientation'}

font_size = 40

for figure_group_name, figure_group_ptypes in figure_groups.items():
    layers = []
    for ii, phenotype in enumerate(figure_group_ptypes):
        combined_dataframe = pd.read_csv(f'summary_df_{phenotype}.csv')
    
        combined_dataframe["tractID"] = combined_dataframe.rowname.apply(lambda x: x.rsplit('_', 1)[0])
        combined_dataframe["Position"] = combined_dataframe.rowname.apply(lambda x: x.rsplit('_', 1)[1])
        
        
        tract_ordering = []
        for tractID in COLOR_DICT.keys():
            if tractID in combined_dataframe.tractID.unique():
                tract_ordering.append(FORMAL_BUNDLE_NAMES.get(tractID, tractID))
        
        
        this_cd = altair_color_dict(combined_dataframe["tractID"].unique())
        this_cd = {FORMAL_BUNDLE_NAMES.get(key, key): value for key, value in this_cd.items()}
        
        combined_dataframe["Model"] = combined_dataframe["variable"].str.upper()
        combined_dataframe["Bundle"] = combined_dataframe.tractID.replace(FORMAL_BUNDLE_NAMES)
    
        y_encoding = alt.Y(
            'value:Q',
            scale=alt.Scale(),
            axis=alt.Axis(title=pheno_formal_labels[phenotype]))
        color_encoding = alt.Color(
            "Bundle", 
            scale=alt.Scale(
                domain=list(this_cd.keys()),
                range=list(this_cd.values())))
        if ii == 4:
            x_encoding_axis = alt.Axis()
        else:
            x_encoding_axis = None
        x_encoding = alt.X('Position:Q', axis=x_encoding_axis)
        
        this_chart = alt.Chart(combined_dataframe).mark_line().encode(
            y=y_encoding,
            color=color_encoding,
            x=x_encoding)
        this_chart = this_chart + alt.Chart(combined_dataframe).mark_area(opacity=0.2).encode(
            color=color_encoding,
            x=x_encoding,
            y='lower:Q',
            y2='upper:Q'
        )

        if ii == 0:
            column_header_encoding = alt.Header(
                    labelFontSize=font_size,
                    titleFontSize=font_size)
        else:
            column_header_encoding = alt.Header(
                    labelExpr="''",
                    title=None)

        this_chart = this_chart.facet(
            column=alt.Column(
                "Model",
                header=column_header_encoding))
        layers.append(this_chart)
    
    this_chart = alt.VConcatChart(vconcat=layers).configure_axis(
            labelFontSize=font_size - 10,
            titleFontSize=font_size,
            labelLimit=0
    ).configure_title(
        fontSize=font_size
    ).configure_legend(
        labelFontSize=font_size - 10,
        titleFontSize=font_size,
        titleLimit=0,
        labelLimit=0,
        orient="right",
        columns=1
        
    )
    
    this_chart.save(f'ModelWeightComparison_{figure_group_name}.png', ppi=300)
this_chart