In [None]:
import boto3
import AFQ.utils.streamlines as aus
from AFQ.viz.plot import BrainAxes
import os.path as op
import pandas as pd
import AFQ.viz.utils as vut 
import numpy as np
import configparser


In [None]:
subject = 550439
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
CP.sections()
aws_access_key = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
aws_secret_key = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')
client = boto3.client('s3', aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key)
clean_fname = (
    f"sub-{subject}_dwi_space-RASMM_model-CSD"
    "_desc-prob-afq-clean_tractography")
prof_fname = f"sub-{subject}_dwi_space-RASMM_model-CSD_desc-prob-afq_profiles.csv"
t1_fname = f"sub-{subject}_T1w_acpc_dc_restore.nii.gz"
apm_fname = f"sub-{subject}_dwi_model-CSD_APM.nii.gz"
if not op.exists(clean_fname + ".trx"):
    client.download_file(
        "open-neurodata",
        f"rokem/hcp1200/afq/sub-{subject}/ses-01/{clean_fname}.trx",
        clean_fname + ".trx")
if not op.exists(clean_fname + ".json"):
    client.download_file(
        "open-neurodata",
        f"rokem/hcp1200/afq/sub-{subject}/ses-01/{clean_fname}.json",
        clean_fname + ".json")
if not op.exists(prof_fname):
    client.download_file(
        "open-neurodata",
        f"rokem/hcp1200/afq/sub-{subject}/ses-01/{prof_fname}",
        prof_fname)
if not op.exists(apm_fname):
    client.download_file(
        "open-neurodata",
        f"rokem/hcp1200/afq/sub-{subject}/ses-01/{apm_fname}",
        apm_fname)
if not op.exists(t1_fname):
    client.download_file(
        "hcp-openaccess",
        f"HCP_1200/{subject}/T1w/T1w_acpc_dc_restore.nii.gz",
        t1_fname)

In [None]:
# ba = BrainAxes()
# profs = pd.read_csv(prof_fname)
# prof_tract_ids = profs.tractID.unique()
# for b_name in prof_tract_ids:
#     this_bundle_profs = profs[np.logical_and(np.logical_and(
#         profs.tractID == b_name,
#         profs.nodeID >= 20),
#         profs.nodeID < 80)]
#     if b_name == "AntFrontal":
#         b_name = "FA"
#     if b_name == "Occipital":
#         b_name = "FP"
#     b_color = vut.COLOR_DICT.get(b_name, None)
#     plot_kwargs = {"linewidth": 6, "color": b_color}
#     if b_color:
#         ba.get_axis(b_name).plot(this_bundle_profs.nodeID, this_bundle_profs.dki_fa, '-', **plot_kwargs)
# #         ba.get_axis(b_name).plot(this_bundle_profs.nodeID, this_bundle_profs.dki_md*1000, '--', **plot_kwargs)
# #         ba.get_axis(b_name).plot(this_bundle_profs.nodeID, this_bundle_profs.dki_mk, '-.', **plot_kwargs)
# #         ba.get_axis(b_name).plot(this_bundle_profs.nodeID, this_bundle_profs.dki_awf, ':', **plot_kwargs)
# _ = ba.format(disable_y=False, disable_x=False)
# ba.fig.savefig("fig1.png", dpi = 300)

# from PIL import Image
# fig1 = Image.open("fig1.png")
# head_ref = Image.open("head_ref.png")

# # scale head_ref
# scale_factor = 3
# width, height = head_ref.size
# new_width = width * 3
# new_height = height * 3
# head_ref = head_ref.resize((new_width, new_height))

# # Calculate the position to paste head_ref on fig1, so it's centered
# x = (fig1.width - head_ref.width) // 2
# y = (fig1.height - head_ref.height) // 2
# # Paste head_ref onto fig1
# fig1.paste(head_ref, (x, y), head_ref) # using head_ref as the mask to handle possible transparency

# # Save the result or display
# fig1.save("fig1.png")
# # fig1.show()  # This will display the combined image


In [None]:
from AFQ.viz.fury_backend import visualize_volume, visualize_bundles
from dipy.viz import window
from AFQ.viz.utils import BEST_BUNDLE_ORIENTATIONS, trim, get_eye
import numpy as np
import nibabel as nib
from dipy.align import resample

In [None]:
seg_sft = aus.SegmentedSFT.fromfile(clean_fname + ".trx")
figure_b_names = {
    "CallosalAxial": ([
        "AntFrontal", "Motor", "Occipital", "Orbital",
        "PostParietal", "SupFrontal", "SupParietal",
        "Temporal"], "Axial", "Top"),
    "CallosalSagittal": ([
        "AntFrontal", "Motor", "Occipital", "Orbital",
        "PostParietal", "SupFrontal", "SupParietal",
        "Temporal"], "Sagittal", "Left"),
    "Axial": (["SLF_R", "IFO_L", "ARC_R"], "Axial", "Top"),
    "Sagittal": (["ILF_L", "CST_L", "CGC_L", "UNC_L"], "Sagittal", "Left"),
}

t1_resampled = resample(
        t1_fname,
        apm_fname).get_fdata()

for panel_name, (bundle_names, view, direc) in figure_b_names.items():
    figure = visualize_volume(
        t1_resampled,
        flip_axes=(True, False, False),
        interact=False,
        inline=False,
        opacity=1.0)
    for b_name in bundle_names:
        figure = visualize_bundles(
            seg_sft,
            flip_axes=(True, False, False),
            line_width=2.0,
            n_points=100,
            bundle=b_name,
            figure=figure,
            interact=False,
            inline=False)

    eye = get_eye(view, direc)
    if direc == "Top":
        view_up = (0, 1, 0)
    else:
        view_up = (0, 0, 1)
        
    direc = np.fromiter(eye.values(), dtype=int)
    data_shape = np.asarray(
        nib.load(apm_fname).get_fdata().shape)
    print(direc)
    figure.set_camera(
        position=direc * data_shape * 2 + data_shape // 2,
        focal_point=data_shape // 2,
        view_up=view_up)
    figure.zoom(0.5)
    window.snapshot(figure, fname=f"{panel_name}.png", size=(2400, 2400))

In [None]:
from AFQ.viz.utils import PanelFigure
pf = PanelFigure(2, 2, 8, 8)
pf.add_img("CallosalAxial.png", 0, 0)
pf.add_img("CallosalSagittal.png", 0, 1)
pf.add_img("Axial.png", 1, 0)
pf.add_img("Sagittal.png", 1, 1)
pf.format_and_save_figure("fig1.png")

In [None]:
figure.camera_info()