In [4]:
import os
import nibabel as nib
import numpy as np
import time
from copy import deepcopy
from utils import get_fixel_data, load_image, time_to_string
from fixel import fixel_comparison
from fod import fod_comparison

def print_progress(prefix='', step=0, n_steps=1, t_init=None):
    if t_init is not None:
        t_out = time.time() - t_init
        t_eta = (t_out / (step + 1)) * (n_steps - (step + 1))
        time_s = '<{:} - ETA: {:}>'.format(time_to_string(t_out), time_to_string(t_eta))
    else:
        time_s = ''
    percent = 25 * (step + 1) // n_steps
    progress_s = ''.join(['█'] * percent)
    remainder_s = ''.join([' '] * (25 - percent))
    print(' '.join([' '] * 300), end='\r')
    print(
        '\033[K{:} [{:}{:}] {:06d}/{:06d} - {:05.2f}% {:}'.format(
            prefix, progress_s, remainder_s,
            step, n_steps, 100 * (step + 1) / n_steps,
            time_s
        ),
        end='\r'
    )


gt = '64dir'
main_path = '/media/transcend/Data/Fudan/'
subjects = sorted(os.listdir(main_path))
methods = sorted(
    f for f in os.listdir(os.path.join(main_path, subjects[0]))
    if os.path.isdir(os.path.join(main_path, subjects[0], f))
)
data_dict = {}
for sub in subjects:
    sub_path = os.path.join(main_path, sub)
    sub_dict = {
        'brain': os.path.join(sub_path, '{:}_brainmask.nii.gz'.format(sub)),
        'wm': os.path.join(sub_path, '{:}_wm_mask.mif.gz'.format(sub)),
        'gt': {},
        'methods': {},
    }
    for method in methods:
        method_dict = {
            'fod': os.path.join(
                sub_path, method, '{:}_wmfod_norm.mif.gz'.format(sub)
            ),
            'fixel': os.path.join(
                sub_path, method, 'fixels'
            ),
            'connectome': os.path.join(
                sub_path, method, '{:}_connectome_DK_32dir.csv'.format(sub)
            ),
        }
        if method == gt:
            sub_dict['gt'] = method_dict
        else:
            sub_dict['methods'][method] = method_dict
    data_dict[sub] = sub_dict
    
    

for i, (sub, sub_data) in enumerate(data_dict.items()):
    print(' '.join([' '] * 300), end='\r')
    print('Subject {:} [{:02d}/{:02d}]'.format(
        sub, i + 1, len(data_dict.keys())
    ), end='\r')
    roi = load_image(sub_data['brain']).astype(bool)
    fixel_paths = [m['fixel'] for m in sub_data['methods'].values()]
    fod_paths = [m['fod'] for m in sub_data['methods'].values()]
    gt_fixel_path =  sub_data['gt']['fixel']
    gt_index, gt_afd, gt_peak, gt_dir = get_fixel_data(gt_fixel_path)
    m_fixels = [get_fixel_data(m) for m in fixel_paths]
    m_index, m_afd, m_peak, m_dir = zip(*m_fixels)
    
    gt_fod = load_image(sub_data['gt']['fod'])
    m_fod = [load_image(m) for m in fod_paths]
    
    roi = load_image(sub_data['wm']).astype(bool)
    
    valid_gt = gt_index[roi]
    valid_m = [m[roi] for m in m_index]
    angular_errors, (afd_e, extra_afd_e, miss_afd_e), (peak_e, extra_peak_e, miss_peak_e), = fixel_comparison(
        valid_gt, gt_peak, gt_afd, gt_dir,
        valid_m, m_peak, m_afd, m_dir
    )
    
    mse_list, mae_list, psnr_list = fod_comparison(
        gt_fod, m_fod, roi
    )
    
    print(
        sub, sub_data['methods']
    )
    print('Fixels')
    print(
        'AFD', [np.mean(np.concatenate(e)) for e in afd_e]
    )
    print(
        [np.mean(np.concatenate(e)) for e in extra_afd_e]
    )
    print(
        'Peaks', [np.mean(np.concatenate(e)) for e in peak_e]
    )
    print(
        [np.mean(np.concatenate(e)) for e in extra_peak_e]
    )
    print(
        'Angular errors', angular_errors
    )
    print('FODs')
    print(
        'MSE', [np.mean(np.concatenate(e)) for e in psnr_list]
    )
    print(
        'PSNR', [np.mean(np.concatenate(e)) for e in mse_list]
    )
    print(
        'MAE', [np.mean(np.concatenate(e)) for e in mae_list]
    )

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       Subject 061801001 [01/10]

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()