In [None]:
import pandas as pd
import altair as alt
from AFQ.viz.utils import FORMAL_BUNDLE_NAMES, COLOR_DICT

In [None]:
combined_dataframe = pd.DataFrame()
for phenotype in ['endurance']:
    this_csv = pd.read_csv(f'model_reliability_{phenotype}.csv')
    this_csv["Phenotype"] = phenotype
    combined_dataframe = pd.concat((combined_dataframe, this_csv))

tract_ordering = []
for tractID in COLOR_DICT.keys():
    if tractID in combined_dataframe.tractID.unique():
        tract_ordering.append(FORMAL_BUNDLE_NAMES.get(tractID, tractID))

combined_dataframe["Position"] = combined_dataframe["nodeID"]
combined_dataframe["Model"] = combined_dataframe["variable"].str.upper()
combined_dataframe["Bundle"] = combined_dataframe.tractID.replace(FORMAL_BUNDLE_NAMES)

print(combined_dataframe)

In [None]:
y_encoding = alt.Y(
    'value:Q',
    scale=alt.Scale(domain=[-1.2, 1.2]),
    axis=alt.Axis(title=f'Model Weight, {phenotype}'))
color_encoding = alt.Color("Bundle", scale=alt.Scale(domain=tract_ordering))
x_encoding = 'Position:Q'

this_chart = alt.Chart(combined_dataframe).mark_line().encode(
    y=y_encoding,
    color=color_encoding,
    x=x_encoding)
this_chart = this_chart + alt.Chart(combined_dataframe).mark_area(opacity=0.2).encode(
    color=color_encoding,
    x=x_encoding,
    y='lower:Q',
    y2='upper:Q'
)


this_chart = this_chart.facet(
    row=alt.Row("Phenotype"),
    column=alt.Column("Model")
)

this_chart.save(f'ModelWeightComparison.png', ppi=300)
this_chart

In [None]:
# for phenotype in ['endurance']:
#     plot_obj = []
    
#     summary_df = pd.read_csv(f'model_reliability_{phenotype}.csv')
#     print(summary_df)
    
#     tract_ordering = []
#     for tractID in COLOR_DICT.keys():
#         if tractID in summary_df.tractID.unique():
#             tract_ordering.append(FORMAL_BUNDLE_NAMES.get(tractID, tractID))
    
#     summary_df.tractID = summary_df.tractID.replace(FORMAL_BUNDLE_NAMES)
#     summary_df["Position"] = summary_df["nodeID"]
    
#     for tract in tract_ordering:
#         # Filter the dataframe for the current tract
#         df_tract = summary_df[summary_df['tractID'] == tract]

#         y_encoding = alt.Y(
#             'value:Q',
#             scale=alt.Scale(domain=[-1.2, 1.2]),
#             axis=alt.Axis(title=f'Model Weight, {phenotype}'))
        
#         # Line plot for 'sgl'
#         line_sgl = alt.Chart(df_tract[df_tract['variable'] == 'sgl']).mark_line().encode(
#             x='Position:Q',
#             y=y_encoding
#         )
    
#         # Ribbon plot for 'sgl'
#         ribbon_sgl = alt.Chart(df_tract[df_tract['variable'] == 'sgl']).mark_area(
#             color='red',
#             opacity=0.5
#         ).encode(
#             x='Position:Q',
#             y='lower:Q',
#             y2='upper:Q'
#         )
    
#         # Line plot for 'lasso'
#         line_lasso = alt.Chart(df_tract[df_tract['variable'] == 'lasso']).mark_line().encode(
#             x='Position:Q',
#             y=y_encoding
#         )
    
#         # Ribbon plot for 'lasso'
#         ribbon_lasso = alt.Chart(df_tract[df_tract['variable'] == 'lasso']).mark_area(
#             opacity=0.5
#         ).encode(
#             x='Position:Q',
#             y='lower:Q',
#             y2='upper:Q'
#         )
    
#         # Combine plots
#         plot = (line_sgl + ribbon_sgl + line_lasso + ribbon_lasso).properties(
#             title=tract
#         )
    
#         plot_obj.append(plot)
    
#     # plot_obj[0]
#     # Combine all plots into a single chart
#     combined_plot = alt.vconcat(*[alt.hconcat(*plot_obj[i:i+4]) for i in range(0, len(plot_obj), 4)])
#     combined_plot = combined_plot.configure_axis(
#         labelFontSize=20,
#         titleFontSize=20,
#         labelLimit=0
#     ).configure_title(
#         fontSize=20
#     ).resolve_scale(y='shared')
    
#     # Save the plot
#     combined_plot.save(f'ConnectomicsComparison_{phenotype}.png', ppi=300)
# combined_plot
