In [None]:
import pandas as pd
import altair as alt
import numpy as np
import os.path as op
import boto3
from tqdm import tqdm

In [None]:
if not op.exists("combined.csv"):
    results = pd.DataFrame()
    hcp_sub_list = np.loadtxt("./hcp_subs.txt", dtype=int)
    s3 = boto3.client('s3')
    
    for subject_id in tqdm(hcp_sub_list):
        clean_trk_fname = (
            f"sub-{subject_id}/ses-01/sub-{subject_id}_dwi_space-RASMM_model-CSD"
            "_desc-prob-afq-clean_tractography")
        unclean_trk_fname = (
            f"sub-{subject_id}/ses-01/sub-{subject_id}_dwi_space-RASMM_model-CSD"
            "_desc-prob-AFQ-clean_tractography")
        if not op.exists(f"csvs/{subject_id}.csv"):
            print(subject_id)
        else:
            this_csv = pd.read_csv(f"csvs/{subject_id}.csv")
            this_csv["SubjectID"] = subject_id
            try:
                response = s3.head_object(Bucket='open-neurodata', Key=f"rokem/hcp1200/afq/{clean_trk_fname}.trk")
            except:
                response = s3.head_object(Bucket='open-neurodata', Key=f"rokem/hcp1200/afq/{unclean_trk_fname}.trk")
            this_csv["TRK size"] = response['ContentLength']
            try:
                response = s3.head_object(Bucket='open-neurodata', Key=f"rokem/hcp1200/afq/{clean_trk_fname}.trx")
            except:
                response = s3.head_object(Bucket='open-neurodata', Key=f"rokem/hcp1200/afq/{unclean_trk_fname}.trx")
            this_csv["TRX size"] = response['ContentLength']
            results = pd.concat([results, this_csv])
    results.to_csv("combined.csv", index=False)

In [None]:
# You can use this for long form
#     for ftype in ["TRK", "TRX"]:
#         sub_results = pd.DataFrame()
#         sub_results["Bundle Name"] = this_csv["Bundle Name"]
#         sub_results["Avg. Absolute Err. (um)"] = this_csv["Avg. Absolute Err. (um)"]
#         sub_results["Profile error (%)"] = this_csv["Profile error (%)"]
    
#         sub_results["Peak Memory usage (MiB)"] = this_csv[f"Peak Memory usage {ftype} (MiB)"]
#         sub_results["Time (s)"] = this_csv[f"Time {ftype} (MiB)"]
#         sub_results["mean FA"] = this_csv[f"{ftype} mean FA"]
        
#         sub_results["SubjectID"] = subject_id
#         sub_results["File Type"] = ftype
    
#         results = pd.concat([results, sub_results])

In [None]:
alt.data_transformers.disable_max_rows()

In [None]:
results = pd.read_csv("combined.csv")
results["TRK size (MB)"] = results["TRK size"] / 1000000.
results["TRX size (MB)"] = results["TRX size"] / 1000000.

chart_info = {
    "Time": (
        'Time TRK (s)',
        'Time TRX (s)',
        0.0, 2.0),
    "Memory": (
        'Peak Memory usage TRK (MiB)',
        'Peak Memory usage TRX (MiB)',
        0.9, 1.1),
    "FA": (
        'TRK mean FA',
        'TRX mean FA',
        0.999, 1.001),
}

from AFQ.viz.altair import altair_color_dict
from AFQ.viz.utils import FORMAL_BUNDLE_NAMES
this_cd = altair_color_dict(results["Bundle Name"].unique())
results["Bundle Name"] = results["Bundle Name"].replace(FORMAL_BUNDLE_NAMES)
this_cd = {FORMAL_BUNDLE_NAMES.get(key, key): value for key, value in this_cd.items()}

In [None]:
from altair_saver import save
from PIL import Image
font_size=30

color_order = list(this_cd.keys())
c_order_kwargs = dict(
    scale=alt.Scale(domain=color_order))
for chart_name, chart_details in chart_info.items():
    for legend_desc in [alt.Legend(columns=3), None]:
        if legend_desc is not None and chart_name != "time":
            continue
        results["TRX/TRK"] = results[chart_details[1]]/results[chart_details[0]]
        bounds_ls = [chart_details[2], chart_details[3]]
        chart = alt.Chart(results, title=chart_name).mark_boxplot(size=10, extent=1.5, outliers=dict(size=30)).encode(
            y=alt.Y(
                "Bundle Name",
                **c_order_kwargs,
                axis=alt.Axis(labels=False),
                title=""),
            x=alt.X("TRX/TRK:Q", title="TRX/TRK Ratio", scale=alt.Scale(domain=bounds_ls)),
            color=alt.Color(
                'Bundle Name',
                scale=alt.Scale(
                    domain=color_order,
                    range=list(this_cd.values())),
                legend=legend_desc))
        xy_line_data = pd.DataFrame({"TRX/TRK": (1, 1), "Bundle Name": (color_order[0], color_order[-1])})
        chart = alt.Chart(xy_line_data).mark_line(color="red").encode(
            x=alt.X("TRX/TRK", title="TRX/TRK Ratio", scale=alt.Scale(domain=bounds_ls)),
            y=alt.Y("Bundle Name", **c_order_kwargs, title="")) + chart
        chart = chart.configure_axis(
            labelFontSize=font_size,
            titleFontSize=font_size,
            labelLimit=0
        ).configure_legend(
            labelFontSize=font_size/2,
            titleFontSize=font_size/2,
            titleLimit=0,
            labelLimit=0,
        ).configure_title(
            fontSize=font_size)
        if legend_desc is not None:
            chart.save(f"legend.png", ppi=300)
            leg_img = Image.open(f"legend.png")
            leg_img = leg_img.crop((0.37*leg_img.size[0], 0, leg_img.size[0], leg_img.size[1]))
            display(leg_img)
            leg_img.save(f"legend.png")
        else:
            chart.display(dpi=300)
            chart.save(f"{chart_name}.png", ppi=300)

In [None]:
print(results)

In [None]:
sz_results = results[results["Bundle Name"]=="Left Corticospinal"]
print(sz_results)
sz_results["TRX/TRK size"] = sz_results['TRX size']/sz_results['TRK size']
SZ_chart = alt.Chart(sz_results).mark_circle(size=60, color="green").encode(
    x=alt.X('TRK size (MB)', scale=alt.Scale(domain=[20, 260])),
    y=alt.Y('TRX/TRK size', scale=alt.Scale(domain=[0.2, 1.8])),
    detail='SubjectID',
)
xy_line_data = pd.DataFrame({'TRK size (MB)': [20, 260], 'TRX/TRK size': [1, 1]})
SZ_chart = alt.Chart(xy_line_data).mark_line(color="red").encode(
    x=alt.X('TRK size (MB)', scale=alt.Scale(domain=[20, 260])),
    y=alt.Y('TRX/TRK size', scale=alt.Scale(domain=[0.2, 1.8]))) + SZ_chart
xy2_line_data = pd.DataFrame({'TRK size (MB)': [20, 260], 'TRX/TRK size': [0.5, 0.5]})
SZ_chart = alt.Chart(xy2_line_data).mark_line(color="blue").encode(
    x=alt.X('TRK size (MB)', scale=alt.Scale(domain=[20, 260])),
    y=alt.Y('TRX/TRK size', scale=alt.Scale(domain=[0.2, 1.8]))) + SZ_chart

SZ_chart = SZ_chart.configure_axis(
    labelFontSize=font_size,
    titleFontSize=font_size,
    labelLimit=0
).configure_legend(
    labelFontSize=font_size,
    titleFontSize=font_size,
    titleLimit=0,
    labelLimit=0,
).configure_title(
    fontSize=font_size)

SZ_chart.properties(
    # width=800,
    # height=400
).save("SZ.png", ppi=300)

In [None]:
print(np.mean(results["Avg. Absolute Err. (um)"]))
print(np.mean(results["Profile error (%)"]))

In [None]:
from AFQ.viz.utils import PanelFigure
pf = PanelFigure(3, 2, 6, 9, panel_label_kwargs={"color": "black"})
pf.add_img("Time.png", 0, 0, subplot_label_pos=(0.0, 1.0))
pf.add_img("Memory.png", 0, 1, subplot_label_pos=(0.0, 1.0))
pf.add_img("FA.png", 1, 0, subplot_label_pos=(0.0, 1.0))
pf.add_img("SZ.png", 1, 1, subplot_label_pos=(0.2, 1.0))
pf.add_img("legend.png", slice(0, 2, None), 2, add_panel_label=False)
pf.format_and_save_figure("fig2.png", trim_final=True)