diff --git a/mriqc/__about__.py b/mriqc/__about__.py index 7c0bd90af..cf7947082 100644 --- a/mriqc/__about__.py +++ b/mriqc/__about__.py @@ -3,18 +3,13 @@ """MRIQC.""" from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -__copyright__ = ('Copyright 2009, Center for Reproducible Neuroscience, ' - 'Stanford University') -__credits__ = 'Oscar Esteban' -__download__ = ('https://github.com/poldracklab/mriqc/archive/' - '{}.tar.gz'.format(__version__)) +__version__ = get_versions()["version"] +del get_versions -__all__ = [ - '__version__', - '__copyright__', - '__credits__', - '__download__' -] +__copyright__ = ( + "Copyright 2020, Center for Reproducible Neuroscience, Stanford University" +) +__credits__ = "Oscar Esteban" +__download__ = f"https://github.com/poldracklab/mriqc/archive/{__version__}.tar.gz" +__all__ = ["__version__", "__copyright__", "__credits__", "__download__"] diff --git a/mriqc/__init__.py b/mriqc/__init__.py index 8428e7e38..9b7d84f1b 100644 --- a/mriqc/__init__.py +++ b/mriqc/__init__.py @@ -14,7 +14,7 @@ __all__ = [ - '__copyright__', - '__credits__', - '__version__', + "__copyright__", + "__credits__", + "__version__", ] diff --git a/mriqc/_warnings.py b/mriqc/_warnings.py index 62b5c4840..4357aa31c 100644 --- a/mriqc/_warnings.py +++ b/mriqc/_warnings.py @@ -12,9 +12,7 @@ def _warn(message, category=None, stacklevel=1, source=None): category = type(category).__name__ category = category.replace("type", "WARNING") - logging.getLogger("py.warnings").warning( - f"{category or 'WARNING'}: {message}" - ) + logging.getLogger("py.warnings").warning(f"{category or 'WARNING'}: {message}") def _showwarning(message, category, filename, lineno, file=None, line=None): diff --git a/mriqc/bin/abide2bids.py b/mriqc/bin/abide2bids.py index 27b9c728d..58cfd1238 100644 --- a/mriqc/bin/abide2bids.py +++ b/mriqc/bin/abide2bids.py @@ -25,30 +25,34 @@ def main(): """Entry point""" - parser = ArgumentParser(description='ABIDE2BIDS downloader', - formatter_class=RawTextHelpFormatter) - g_input = parser.add_argument_group('Inputs') - g_input.add_argument('-i', '--input-abide-catalog', action='store', - required=True) - g_input.add_argument('-n', '--dataset-name', action='store', - default='ABIDE Dataset') - g_input.add_argument('-u', '--nitrc-user', action='store', - default=os.getenv('NITRC_USER')) - g_input.add_argument('-p', '--nitrc-password', action='store', - default=os.getenv('NITRC_PASSWORD')) - - g_outputs = parser.add_argument_group('Outputs') - g_outputs.add_argument('-o', '--output-dir', action='store', - default='ABIDE-BIDS') + parser = ArgumentParser( + description="ABIDE2BIDS downloader", formatter_class=RawTextHelpFormatter + ) + g_input = parser.add_argument_group("Inputs") + g_input.add_argument("-i", "--input-abide-catalog", action="store", required=True) + g_input.add_argument( + "-n", "--dataset-name", action="store", default="ABIDE Dataset" + ) + g_input.add_argument( + "-u", "--nitrc-user", action="store", default=os.getenv("NITRC_USER") + ) + g_input.add_argument( + "-p", "--nitrc-password", action="store", default=os.getenv("NITRC_PASSWORD") + ) + + g_outputs = parser.add_argument_group("Outputs") + g_outputs.add_argument("-o", "--output-dir", action="store", default="ABIDE-BIDS") opts = parser.parse_args() if opts.nitrc_user is None or opts.nitrc_password is None: - raise RuntimeError('NITRC user and password are required') + raise RuntimeError("NITRC user and password are required") - dataset_desc = {'BIDSVersion': '1.0.0rc3', - 'License': 'CC Attribution-NonCommercial-ShareAlike 3.0 Unported', - 'Name': opts.dataset_name} + dataset_desc = { + "BIDSVersion": "1.0.0rc3", + "License": "CC Attribution-NonCommercial-ShareAlike 3.0 Unported", + "Name": opts.dataset_name, + } out_dir = op.abspath(opts.output_dir) try: @@ -57,19 +61,18 @@ def main(): if exc.errno != errno.EEXIST: raise exc - with open(op.join(out_dir, 'dataset_description.json'), 'w') as dfile: + with open(op.join(out_dir, "dataset_description.json"), "w") as dfile: json.dump(dataset_desc, dfile) catalog = et.parse(opts.input_abide_catalog).getroot() - urls = [el.get('URI') for el in catalog.iter() if el.get('URI') is not None] + urls = [el.get("URI") for el in catalog.iter() if el.get("URI") is not None] pool = Pool() - args_list = [(url, opts.nitrc_user, opts.nitrc_password, out_dir) - for url in urls] + args_list = [(url, opts.nitrc_user, opts.nitrc_password, out_dir) for url in urls] res = pool.map(fetch, args_list) - tsv_data = np.array([('subject_id', 'site_name')] + res) - np.savetxt(op.join(out_dir, 'participants.tsv'), tsv_data, fmt='%s', delimiter='\t') + tsv_data = np.array([("subject_id", "site_name")] + res) + np.savetxt(op.join(out_dir, "participants.tsv"), tsv_data, fmt="%s", delimiter="\t") def fetch(args): @@ -86,55 +89,57 @@ def fetch(args): else: out_dir = op.abspath(out_dir) - pkg_id = [u[9:] for u in url.split('/') if u.startswith('NITRC_IR_')][0] - sub_file = op.join(tmpdir, '%s.zip' % pkg_id) + pkg_id = [u[9:] for u in url.split("/") if u.startswith("NITRC_IR_")][0] + sub_file = op.join(tmpdir, "%s.zip" % pkg_id) - cmd = ['curl', '-s', '-u', '%s:%s' % (user, password), '-o', sub_file, url] + cmd = ["curl", "-s", "-u", "%s:%s" % (user, password), "-o", sub_file, url] sp.check_call(cmd) - sp.check_call(['unzip', '-qq', '-d', tmpdir, '-u', sub_file]) + sp.check_call(["unzip", "-qq", "-d", tmpdir, "-u", sub_file]) - abide_root = op.join(tmpdir, 'ABIDE') + abide_root = op.join(tmpdir, "ABIDE") files = [] for root, path, fname in os.walk(abide_root): - if fname and (fname[0].endswith('nii') or fname[0].endswith('nii.gz')): + if fname and (fname[0].endswith("nii") or fname[0].endswith("nii.gz")): if path: root = op.join(root, path[0]) files.append(op.join(root, fname[0])) - site_name, sub_str = files[0][len(abide_root) + 1:].split('/')[0].split('_') - subject_id = 'sub-' + sub_str + site_name, sub_str = files[0][len(abide_root) + 1:].split("/")[0].split("_") + subject_id = "sub-" + sub_str for i in files: - ext = '.nii.gz' - if i.endswith('.nii'): - ext = '.nii' - if 'mprage' in i: - bids_dir = op.join(out_dir, subject_id, 'anat') + ext = ".nii.gz" + if i.endswith(".nii"): + ext = ".nii" + if "mprage" in i: + bids_dir = op.join(out_dir, subject_id, "anat") try: os.makedirs(bids_dir) except OSError as exc: if exc.errno != errno.EEXIST: raise exc - shutil.copy(i, op.join(bids_dir, subject_id + '_T1w' + ext)) + shutil.copy(i, op.join(bids_dir, subject_id + "_T1w" + ext)) - if 'rest' in i: - bids_dir = op.join(out_dir, subject_id, 'func') + if "rest" in i: + bids_dir = op.join(out_dir, subject_id, "func") try: os.makedirs(bids_dir) except OSError as exc: if exc.errno != errno.EEXIST: raise exc - shutil.copy(i, op.join(bids_dir, subject_id + '_rest_bold' + ext)) + shutil.copy(i, op.join(bids_dir, subject_id + "_rest_bold" + ext)) shutil.rmtree(tmpdir, ignore_errors=True, onerror=_myerror) - print('Successfully processed subject %s from site %s' % (subject_id[4:], site_name)) + print( + "Successfully processed subject %s from site %s" % (subject_id[4:], site_name) + ) return subject_id[4:], site_name def _myerror(msg): - print('WARNING: Error deleting temporal files: %s' % msg) + print("WARNING: Error deleting temporal files: %s" % msg) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/dfcheck.py b/mriqc/bin/dfcheck.py index cfad3f403..7014e9650 100644 --- a/mriqc/bin/dfcheck.py +++ b/mriqc/bin/dfcheck.py @@ -20,15 +20,34 @@ def main(): """Entry point""" from ..classifier.data import read_iqms - parser = ArgumentParser(description='compare two pandas dataframes', - formatter_class=RawTextHelpFormatter) - g_input = parser.add_argument_group('Inputs') - g_input.add_argument('-i', '--input-csv', action='store', type=Path, - required=True, help='input data frame') - g_input.add_argument('-r', '--reference-csv', action='store', type=Path, - required=True, help='reference dataframe') - g_input.add_argument('--tolerance', type=float, default=1.e-5, - help='relative tolerance for comparison') + + parser = ArgumentParser( + description="compare two pandas dataframes", + formatter_class=RawTextHelpFormatter, + ) + g_input = parser.add_argument_group("Inputs") + g_input.add_argument( + "-i", + "--input-csv", + action="store", + type=Path, + required=True, + help="input data frame", + ) + g_input.add_argument( + "-r", + "--reference-csv", + action="store", + type=Path, + required=True, + help="reference dataframe", + ) + g_input.add_argument( + "--tolerance", + type=float, + default=1.0e-5, + help="relative tolerance for comparison", + ) opts = parser.parse_args() @@ -39,17 +58,19 @@ def main(): tst_df.set_index(tst_bids) if sorted(ref_bids) != sorted(tst_bids): - sys.exit('Dataset has different BIDS bits w.r.t. reference') + sys.exit("Dataset has different BIDS bits w.r.t. reference") if sorted(ref_names) != sorted(tst_names): - sys.exit('Output CSV file changed number of columns') + sys.exit("Output CSV file changed number of columns") ref_df = ref_df.sort_values(by=ref_bids) tst_df = tst_df.sort_values(by=tst_bids) if len(ref_df) != len(tst_df): - print('Input datases have different lengths (input %d, reference %d).' % ( - len(ref_df), len(tst_df))) + print( + "Input datases have different lengths (input %d, reference %d)." + % (len(ref_df), len(tst_df)) + ) tst_rows = tst_df[tst_bids] ref_rows = ref_df[ref_bids] @@ -58,7 +79,9 @@ def main(): tst_keep = np.sum(tst_rows.isin(ref_rows).values.ravel().tolist()) print(tst_keep) - diff = ~np.isclose(ref_df[ref_names].values, tst_df[tst_names].values, rtol=opts.tolerance) + diff = ~np.isclose( + ref_df[ref_names].values, tst_df[tst_names].values, rtol=opts.tolerance + ) if np.any(diff): # ne_stacked = pd.DataFrame(data=diff, columns=ref_names).stack() # ne_stacked = np.isclose(ref_df[ref_names], tst_df[ref_names]).stack() @@ -69,26 +92,32 @@ def main(): changed_to = tst_df[ref_names].values[difference_locations] cols = [ref_names[v] for v in difference_locations[1]] bids_df = ref_df.loc[difference_locations[0], ref_bids].reset_index() - chng_df = pd.DataFrame({'iqm': cols, 'from': changed_from, 'to': changed_to}) + chng_df = pd.DataFrame({"iqm": cols, "from": changed_from, "to": changed_to}) table = pd.concat([bids_df, chng_df], axis=1) - print(table[ref_bids + ['iqm', 'from', 'to']].to_string(index=False)) + print(table[ref_bids + ["iqm", "from", "to"]].to_string(index=False)) corr = pd.DataFrame() - corr['iqms'] = ref_names - corr['cc'] = [float(np.corrcoef(ref_df[[var]].values.ravel(), - tst_df[[var]].values.ravel(), - rowvar=False)[0, 1]) - for var in ref_names] + corr["iqms"] = ref_names + corr["cc"] = [ + float( + np.corrcoef( + ref_df[[var]].values.ravel(), + tst_df[[var]].values.ravel(), + rowvar=False, + )[0, 1] + ) + for var in ref_names + ] if np.any(corr.cc < 0.95): - print('IQMs with Pearson correlation < 0.95:') + print("IQMs with Pearson correlation < 0.95:") print(corr[corr.cc < 0.95]) - sys.exit('Output CSV file changed one or more values') + sys.exit("Output CSV file changed one or more values") else: - print('All IQMs show a Pearson correlation >= 0.95') + print("All IQMs show a Pearson correlation >= 0.95") sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/fs2gif.py b/mriqc/bin/fs2gif.py index 964fa480f..a89ff37fb 100644 --- a/mriqc/bin/fs2gif.py +++ b/mriqc/bin/fs2gif.py @@ -25,18 +25,20 @@ def main(): """Entry point""" - parser = ArgumentParser(description='Batch export freesurfer results to animated gifs', - formatter_class=RawTextHelpFormatter) - g_input = parser.add_argument_group('Inputs') - g_input.add_argument('-s', '--subject-id', action='store') - g_input.add_argument('-t', '--temp-dir', action='store') - g_input.add_argument('--keep-temp', action='store_true', default=False) - g_input.add_argument('--zoom', action='store_true', default=False) - g_input.add_argument('--hist-eq', action='store_true', default=False) - g_input.add_argument('--use-xvfb', action='store_true', default=False) - - g_outputs = parser.add_argument_group('Outputs') - g_outputs.add_argument('-o', '--output-dir', action='store', default='fs2gif') + parser = ArgumentParser( + description="Batch export freesurfer results to animated gifs", + formatter_class=RawTextHelpFormatter, + ) + g_input = parser.add_argument_group("Inputs") + g_input.add_argument("-s", "--subject-id", action="store") + g_input.add_argument("-t", "--temp-dir", action="store") + g_input.add_argument("--keep-temp", action="store_true", default=False) + g_input.add_argument("--zoom", action="store_true", default=False) + g_input.add_argument("--hist-eq", action="store_true", default=False) + g_input.add_argument("--use-xvfb", action="store_true", default=False) + + g_outputs = parser.add_argument_group("Outputs") + g_outputs.add_argument("-o", "--output-dir", action="store", default="fs2gif") opts = parser.parse_args() @@ -57,15 +59,18 @@ def main(): if exc.errno != EEXIST: raise exc - subjects_dir = os.getenv('SUBJECTS_DIR', op.abspath('subjects')) + subjects_dir = os.getenv("SUBJECTS_DIR", op.abspath("subjects")) subject_list = [opts.subject_id] if opts.subject_id is None: - subject_list = [op.basename(name) for name in os.listdir(subjects_dir) - if op.isdir(os.path.join(subjects_dir, name))] + subject_list = [ + op.basename(name) + for name in os.listdir(subjects_dir) + if op.isdir(os.path.join(subjects_dir, name)) + ] environ = os.environ.copy() - environ['SUBJECTS_DIR'] = subjects_dir + environ["SUBJECTS_DIR"] = subjects_dir if opts.use_xvfb: - environ['doublebufferflag'] = 1 + environ["doublebufferflag"] = 1 # tcl_file = pkgr.resource_filename('mriqc', 'data/fsexport.tcl') tcl_contents = """ @@ -85,10 +90,11 @@ def main(): if exc.errno != EEXIST: raise exc - niifile = op.join(tmp_sub, '%s.nii.gz') % subid - ref_file = op.join(sub_path, 'mri', 'T1.mgz') - sp.call(['mri_convert', op.join(sub_path, 'mri', 'norm.mgz'), niifile], - cwd=tmp_sub) + niifile = op.join(tmp_sub, "%s.nii.gz") % subid + ref_file = op.join(sub_path, "mri", "T1.mgz") + sp.call( + ["mri_convert", op.join(sub_path, "mri", "norm.mgz"), niifile], cwd=tmp_sub + ) data = nb.load(niifile).get_data() data[data > 0] = 1 @@ -99,115 +105,173 @@ def main(): center = np.average([bbox_min, bbox_max], axis=0) if opts.hist_eq: - modnii = op.join(tmp_sub, '%s.nii.gz' % subid) - ref_file = op.join(tmp_sub, '%s.mgz' % subid) + modnii = op.join(tmp_sub, "%s.nii.gz" % subid) + ref_file = op.join(tmp_sub, "%s.mgz" % subid) img = nb.load(niifile) data = exposure.equalize_adapthist(img.get_data(), clip_limit=0.03) nb.Nifti1Image(data, img.affine, img.header).to_filename(modnii) - sp.call(['mri_convert', modnii, ref_file], cwd=tmp_sub) + sp.call(["mri_convert", modnii, ref_file], cwd=tmp_sub) if not opts.zoom: # Export tiffs for left hemisphere - tcl_file = op.join(tmp_sub, '%s.tcl' % subid) - with open(tcl_file, 'w') as tclfp: + tcl_file = op.join(tmp_sub, "%s.tcl" % subid) + with open(tcl_file, "w") as tclfp: tclfp.write(tcl_contents) - tclfp.write('for { set slice %d } { $slice < %d } { incr slice } {' % ( - bbox_min[2], bbox_max[2])) - tclfp.write(' SetSlice $slice\n') - tclfp.write(' RedrawScreen\n') - tclfp.write(' SaveTIFF [format "%s/%s-' % (tmp_sub, subid) + '%03d.tif" $i]\n') - tclfp.write(' incr i\n') - tclfp.write('}\n') - tclfp.write('QuitMedit\n') - cmd = ['tkmedit', subid, 'T1.mgz', 'lh.pial', '-aux-surface', 'rh.pial', '-tcl', - tcl_file] + tclfp.write( + "for { set slice %d } { $slice < %d } { incr slice } {" + % (bbox_min[2], bbox_max[2]) + ) + tclfp.write(" SetSlice $slice\n") + tclfp.write(" RedrawScreen\n") + tclfp.write( + ' SaveTIFF [format "%s/%s-' % (tmp_sub, subid) + + '%03d.tif" $i]\n' + ) + tclfp.write(" incr i\n") + tclfp.write("}\n") + tclfp.write("QuitMedit\n") + cmd = [ + "tkmedit", + subid, + "T1.mgz", + "lh.pial", + "-aux-surface", + "rh.pial", + "-tcl", + tcl_file, + ] if opts.use_xvfb: cmd = _xvfb_run() + cmd - print('Running tkmedit: %s' % ' '.join(cmd)) + print("Running tkmedit: %s" % " ".join(cmd)) sp.call(cmd, env=environ) # Convert to animated gif - print('Stacking coronal slices') - sp.call(['convert', '-delay', '10', '-loop', '0', '%s/%s-*.tif' % (tmp_sub, subid), - '%s/%s.gif' % (out_dir, subid)]) + print("Stacking coronal slices") + sp.call( + [ + "convert", + "-delay", + "10", + "-loop", + "0", + "%s/%s-*.tif" % (tmp_sub, subid), + "%s/%s.gif" % (out_dir, subid), + ] + ) else: # Export tiffs for left hemisphere - tcl_file = op.join(tmp_sub, 'lh-%s.tcl' % subid) - with open(tcl_file, 'w') as tclfp: + tcl_file = op.join(tmp_sub, "lh-%s.tcl" % subid) + with open(tcl_file, "w") as tclfp: tclfp.write(tcl_contents) - tclfp.write('SetZoomLevel 2') - tclfp.write('for { set slice %d } { $slice < %d } { incr slice } {' % ( - bbox_min[2], bbox_max[2])) - tclfp.write(' SetZoomCenter %d %d $slice\n' % (center[0] + 30, center[1] - 10)) - tclfp.write(' SetSlice $slice\n') - tclfp.write(' RedrawScreen\n') - tclfp.write(' SaveTIFF [format "{}/{}-lh-%03d.tif" $i]\n'.format( - tmp_sub, subid)) - tclfp.write(' incr i\n') - tclfp.write('}\n') - tclfp.write('QuitMedit\n') - cmd = ['tkmedit', subid, 'norm.mgz', 'lh.white', '-tcl', tcl_file] + tclfp.write("SetZoomLevel 2") + tclfp.write( + "for { set slice %d } { $slice < %d } { incr slice } {" + % (bbox_min[2], bbox_max[2]) + ) + tclfp.write( + " SetZoomCenter %d %d $slice\n" + % (center[0] + 30, center[1] - 10) + ) + tclfp.write(" SetSlice $slice\n") + tclfp.write(" RedrawScreen\n") + tclfp.write( + ' SaveTIFF [format "{}/{}-lh-%03d.tif" $i]\n'.format( + tmp_sub, subid + ) + ) + tclfp.write(" incr i\n") + tclfp.write("}\n") + tclfp.write("QuitMedit\n") + cmd = ["tkmedit", subid, "norm.mgz", "lh.white", "-tcl", tcl_file] if opts.use_xvfb: cmd = _xvfb_run() + cmd - print('Running tkmedit: %s' % ' '.join(cmd)) + print("Running tkmedit: %s" % " ".join(cmd)) sp.call(cmd, env=environ) # Convert to animated gif - print('Stacking coronal slices') + print("Stacking coronal slices") # Export tiffs for right hemisphere - tcl_file = op.join(tmp_sub, 'rh-%s.tcl' % subid) - with open(tcl_file, 'w') as tclfp: + tcl_file = op.join(tmp_sub, "rh-%s.tcl" % subid) + with open(tcl_file, "w") as tclfp: tclfp.write(tcl_contents) - tclfp.write('SetZoomLevel 2') - tclfp.write('for { set slice %d } { $slice < %d } { incr slice } {' % ( - bbox_min[2], bbox_max[2])) - tclfp.write(' SetZoomCenter %d %d $slice\n' % (center[0] - 30, center[1] - 10)) - tclfp.write(' SetSlice $slice\n') - tclfp.write(' RedrawScreen\n') - tclfp.write(' SaveTIFF [format "{}/{}-rh-%03d.tif" $slice]\n'.format( - tmp_sub, subid)) - tclfp.write(' incr i\n') - tclfp.write('}\n') - tclfp.write('QuitMedit\n') - cmd = ['tkmedit', subid, 'norm.mgz', 'rh.white', '-tcl', tcl_file] + tclfp.write("SetZoomLevel 2") + tclfp.write( + "for { set slice %d } { $slice < %d } { incr slice } {" + % (bbox_min[2], bbox_max[2]) + ) + tclfp.write( + " SetZoomCenter %d %d $slice\n" + % (center[0] - 30, center[1] - 10) + ) + tclfp.write(" SetSlice $slice\n") + tclfp.write(" RedrawScreen\n") + tclfp.write( + ' SaveTIFF [format "{}/{}-rh-%03d.tif" $slice]\n'.format( + tmp_sub, subid + ) + ) + tclfp.write(" incr i\n") + tclfp.write("}\n") + tclfp.write("QuitMedit\n") + cmd = ["tkmedit", subid, "norm.mgz", "rh.white", "-tcl", tcl_file] if opts.use_xvfb: cmd = _xvfb_run() + cmd - print('Running tkmedit: %s' % ' '.join(cmd)) + print("Running tkmedit: %s" % " ".join(cmd)) sp.call(cmd, env=environ) # Convert to animated gif - print('Stacking coronal slices') - sp.call(['convert', '-delay', '10', '-loop', '0', '%s/%s-lh-*.tif' % (tmp_sub, subid), - '%s/%s-lh.gif' % (out_dir, subid)]) - sp.call(['convert', '-delay', '10', '-loop', '0', '%s/%s-rh-*.tif' % (tmp_sub, subid), - '%s/%s-rh.gif' % (out_dir, subid)]) + print("Stacking coronal slices") + sp.call( + [ + "convert", + "-delay", + "10", + "-loop", + "0", + "%s/%s-lh-*.tif" % (tmp_sub, subid), + "%s/%s-lh.gif" % (out_dir, subid), + ] + ) + sp.call( + [ + "convert", + "-delay", + "10", + "-loop", + "0", + "%s/%s-rh-*.tif" % (tmp_sub, subid), + "%s/%s-rh.gif" % (out_dir, subid), + ] + ) if not opts.keep_temp: rmtree(tmp_sub, ignore_errors=True, onerror=_myerror) -def _xvfb_run(wait=5, server_args='-screen 0, 1600x1200x24', logs=None): +def _xvfb_run(wait=5, server_args="-screen 0, 1600x1200x24", logs=None): """ Wrap command with xvfb-run. Copied from: https://github.com/VUIIS/seam/blob/1dabd9ca5b1fc7d66ef7d41c34ea8d42d668a484/seam/util.py """ if logs is None: - logs = op.join(mkdtemp(), 'fs2gif_xvfb') + logs = op.join(mkdtemp(), "fs2gif_xvfb") - return ['xvfb-run', - '-a', # automatically get a free server number - '-f {}.out'.format(logs), - '-e {}.err'.format(logs), - '--wait={:d}'.format(wait), - '--server-args="{}"'.format(server_args)] + return [ + "xvfb-run", + "-a", # automatically get a free server number + "-f {}.out".format(logs), + "-e {}.err".format(logs), + "--wait={:d}".format(wait), + '--server-args="{}"'.format(server_args), + ] def _myerror(msg): - print('WARNING: Error deleting temporal files: %s' % msg) + print("WARNING: Error deleting temporal files: %s" % msg) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/labeler.py b/mriqc/bin/labeler.py index ceeca39f3..96e3f36e9 100644 --- a/mriqc/bin/labeler.py +++ b/mriqc/bin/labeler.py @@ -24,13 +24,13 @@ def num_rows(data): def main(): """read the input file""" - print('Reading file sinfo.csv') - csvfile = open('sinfo.csv', 'rb') + print("Reading file sinfo.csv") + csvfile = open("sinfo.csv", "rb") csvreader = csv.reader(csvfile) file = list(csvreader) # display statistics - finished = [0., 0., 0.] + finished = [0.0, 0.0, 0.0] hold = np.zeros((3, len(file) - 1)) hold[:] = np.nan total = 601 @@ -53,32 +53,32 @@ def main(): curEnt = num_rows(file[row]) if curEnt <= 1: # if less than 1, run the row - print('Check participant #' + file[row][0]) - fname = os.getcwd() + '/abide/' + file[row][0] + print("Check participant #" + file[row][0]) + fname = os.getcwd() + "/abide/" + file[row][0] if os.path.isfile(fname): - webbrowser.open('file://' + fname) + webbrowser.open("file://" + fname) quality = input("Quality? [-1/0/1/e/c] ") - if quality == 'e': + if quality == "e": break - if quality == 'c': - print('Current comment: ' + file[row][4]) + if quality == "c": + print("Current comment: " + file[row][4]) comment = input("Comment: ") if len(comment) > 0: file[row][4] = comment quality = input("Quality? [-1/0/1/e] ") - if quality == 'e': + if quality == "e": break file[row][curEnt] = quality else: - print('File does not exist') + print("File does not exist") - print('Writing file sinfo.csv') - outfile = open('sinfo.csv', 'wb') + print("Writing file sinfo.csv") + outfile = open("sinfo.csv", "wb") csvwriter = csv.writer(outfile) csvwriter.writerows(file) - print('Ending') + print("Ending") -if __name__ == '__main__': +if __name__ == "__main__": main() sys.exit(0) diff --git a/mriqc/bin/mriqc_clf.py b/mriqc/bin/mriqc_clf.py index 3f490a395..ab00a22c5 100644 --- a/mriqc/bin/mriqc_clf.py +++ b/mriqc/bin/mriqc_clf.py @@ -7,17 +7,18 @@ from pkg_resources import resource_filename as pkgrf import matplotlib -matplotlib.use('Agg') + +matplotlib.use("Agg") try: from sklearn.metrics.base import UndefinedMetricWarning except ImportError: from sklearn.exceptions import UndefinedMetricWarning -LOG_FORMAT = '%(asctime)s %(name)s:%(levelname)s %(message)s' +LOG_FORMAT = "%(asctime)s %(name)s:%(levelname)s %(message)s" warnings.simplefilter("once", UndefinedMetricWarning) -LOGGER = logging.getLogger('mriqc.classifier') +LOGGER = logging.getLogger("mriqc.classifier") _handler = logging.StreamHandler(stream=sys.stdout) _handler.setFormatter(logging.Formatter(fmt=LOG_FORMAT, datefmt="%y%m%d-%H:%M:%S")) LOGGER.addHandler(_handler) @@ -27,60 +28,122 @@ def warn_redirect(message, category, filename, lineno, file=None, line=None): if category not in cached_warnings: - LOGGER.debug('captured warning (%s): %s', category, message) + LOGGER.debug("captured warning (%s): %s", category, message) cached_warnings.append(category) def get_parser(): from argparse import ArgumentParser from argparse import RawTextHelpFormatter - parser = ArgumentParser(description='MRIQC model selection and held-out evaluation', - formatter_class=RawTextHelpFormatter) + + parser = ArgumentParser( + description="MRIQC model selection and held-out evaluation", + formatter_class=RawTextHelpFormatter, + ) g_clf = parser.add_mutually_exclusive_group() - g_clf.add_argument('--train', nargs='*', - help='training data tables, X and Y, leave empty for ABIDE.') - g_clf.add_argument('--load-classifier', nargs="?", type=str, default='', - help='load a previously saved classifier') - - parser.add_argument('--test', nargs='*', - help='test data tables, X and Y, leave empty for DS030.') - parser.add_argument('-X', '--evaluation-data', help='classify this CSV table of IQMs') - - parser.add_argument('--train-balanced-leaveout', action='store_true', default=False, - help='leave out a balanced, random, sample of training examples') - parser.add_argument('--multiclass', '--ms', action='store_true', default=False, - help='do not binarize labels') - - g_input = parser.add_argument_group('Options') - g_input.add_argument('-P', '--parameters', action='store') - g_input.add_argument('-M', '--model', action='store', default='rfc', - choices=['rfc', 'xgb', 'svc_lin', 'svc_rbf'], - help='model under test') - g_input.add_argument('--nested_cv', action='store_true', default=False, - help='run nested cross-validation before held-out') - g_input.add_argument('--nested_cv_kfold', action='store_true', default=False, - help='run nested cross-validation before held-out, ' - 'using 10-fold split in the outer loop') - g_input.add_argument('--perm', action='store', default=0, type=int, - help='permutation test: number of permutations') - - g_input.add_argument('-S', '--scorer', action='store', default='roc_auc') - g_input.add_argument('--cv', action='store', default='loso', - choices=['kfold', 'loso', 'balanced-kfold', 'batch']) - g_input.add_argument('--debug', action='store_true', default=False) - - g_input.add_argument('--log-file', nargs="?", action='store', default='', - help='write log to this file, leave empty for a default log name') - - g_input.add_argument("-v", "--verbose", dest="verbose_count", - action="count", default=0, - help="increases log verbosity for each occurence.") - g_input.add_argument('--njobs', action='store', default=-1, type=int, - help='number of jobs') - - g_input.add_argument('-t', '--threshold', action='store', default=0.5, type=float, - help='decision threshold of the classifier') + g_clf.add_argument( + "--train", + nargs="*", + help="training data tables, X and Y, leave empty for ABIDE.", + ) + g_clf.add_argument( + "--load-classifier", + nargs="?", + type=str, + default="", + help="load a previously saved classifier", + ) + + parser.add_argument( + "--test", nargs="*", help="test data tables, X and Y, leave empty for DS030." + ) + parser.add_argument( + "-X", "--evaluation-data", help="classify this CSV table of IQMs" + ) + + parser.add_argument( + "--train-balanced-leaveout", + action="store_true", + default=False, + help="leave out a balanced, random, sample of training examples", + ) + parser.add_argument( + "--multiclass", + "--ms", + action="store_true", + default=False, + help="do not binarize labels", + ) + + g_input = parser.add_argument_group("Options") + g_input.add_argument("-P", "--parameters", action="store") + g_input.add_argument( + "-M", + "--model", + action="store", + default="rfc", + choices=["rfc", "xgb", "svc_lin", "svc_rbf"], + help="model under test", + ) + g_input.add_argument( + "--nested_cv", + action="store_true", + default=False, + help="run nested cross-validation before held-out", + ) + g_input.add_argument( + "--nested_cv_kfold", + action="store_true", + default=False, + help="run nested cross-validation before held-out, " + "using 10-fold split in the outer loop", + ) + g_input.add_argument( + "--perm", + action="store", + default=0, + type=int, + help="permutation test: number of permutations", + ) + + g_input.add_argument("-S", "--scorer", action="store", default="roc_auc") + g_input.add_argument( + "--cv", + action="store", + default="loso", + choices=["kfold", "loso", "balanced-kfold", "batch"], + ) + g_input.add_argument("--debug", action="store_true", default=False) + + g_input.add_argument( + "--log-file", + nargs="?", + action="store", + default="", + help="write log to this file, leave empty for a default log name", + ) + + g_input.add_argument( + "-v", + "--verbose", + dest="verbose_count", + action="count", + default=0, + help="increases log verbosity for each occurence.", + ) + g_input.add_argument( + "--njobs", action="store", default=-1, type=int, help="number of jobs" + ) + + g_input.add_argument( + "-t", + "--threshold", + action="store", + default=0.5, + type=float, + help="decision threshold of the classifier", + ) return parser @@ -102,19 +165,21 @@ def main(): LOGGER.setLevel(log_level) - base_name = 'mclf_run-%s_mod-%s_ver-%s_class-%d_cv-%s' % ( - datetime.now().strftime('%Y%m%d-%H%M%S'), opts.model, - re.sub(r'[\+_@]', '.', __version__), - 3 if opts.multiclass else 2, opts.cv, + base_name = "mclf_run-%s_mod-%s_ver-%s_class-%d_cv-%s" % ( + datetime.now().strftime("%Y%m%d-%H%M%S"), + opts.model, + re.sub(r"[\+_@]", ".", __version__), + 3 if opts.multiclass else 2, + opts.cv, ) if opts.nested_cv_kfold: - base_name += '_ncv-kfold' + base_name += "_ncv-kfold" elif opts.nested_cv: - base_name += '_ncv-loso' + base_name += "_ncv-loso" if opts.log_file is None or len(opts.log_file) > 0: - log_file = opts.log_file if opts.log_file else base_name + '.log' + log_file = opts.log_file if opts.log_file else base_name + ".log" fhl = logging.FileHandler(log_file) fhl.setFormatter(fmt=logging.Formatter(LOG_FORMAT)) fhl.setLevel(log_level) @@ -124,7 +189,7 @@ def main(): if opts.train is not None: # Initialize model selection helper - train_path = _parse_set(opts.train, default='abide') + train_path = _parse_set(opts.train, default="abide") cvhelper = CVHelper( X=train_path[0], Y=train_path[1], @@ -143,8 +208,8 @@ def main(): permutation_test=opts.perm, ) - if opts.cv == 'batch' or opts.perm: - test_path = _parse_set(opts.test, default='ds030') + if opts.cv == "batch" or opts.perm: + test_path = _parse_set(opts.test, default="ds030") # Do not set x_test unless we are going to run batch exp. cvhelper.setXtest(test_path[0], test_path[1]) @@ -152,57 +217,66 @@ def main(): cvhelper.fit() # Pickle if required - cvhelper.save(suffix='data-train_estimator') + cvhelper.save(suffix="data-train_estimator") # If no training set is given, need a classifier else: load_classifier = opts.load_classifier if load_classifier is None: load_classifier = pkgrf( - 'mriqc', - 'data/mclf_run-20170724-191452_mod-rfc_ver-0.9.7-rc8_class-2_cv-' - 'loso_data-all_estimator.pklz') + "mriqc", + "data/mclf_run-20170724-191452_mod-rfc_ver-0.9.7-rc8_class-2_cv-" + "loso_data-all_estimator.pklz", + ) if not isfile(load_classifier): - msg = 'was not provided' - if load_classifier != '': + msg = "was not provided" + if load_classifier != "": msg = '("%s") was not found' % load_classifier raise RuntimeError( - 'No training samples were given, and the --load-classifier ' - 'option %s.' % msg) + "No training samples were given, and the --load-classifier " + "option %s." % msg + ) - cvhelper = CVHelper(load_clf=load_classifier, n_jobs=opts.njobs, - rate_label=['rater_1'], basename=base_name) + cvhelper = CVHelper( + load_clf=load_classifier, + n_jobs=opts.njobs, + rate_label=["rater_1"], + basename=base_name, + ) clf_loaded = True - test_path = _parse_set(opts.test, default='ds030') - if test_path and opts.cv != 'batch': + test_path = _parse_set(opts.test, default="ds030") + if test_path and opts.cv != "batch": # Set held-out data cvhelper.setXtest(test_path[0], test_path[1]) # Evaluate - cvhelper.evaluate(matrix=True, scoring=[opts.scorer, 'accuracy'], - save_pred=True) + cvhelper.evaluate( + matrix=True, scoring=[opts.scorer, "accuracy"], save_pred=True + ) # Pickle if required if not clf_loaded: cvhelper.fit_full() - cvhelper.save(suffix='data-all_estimator') + cvhelper.save(suffix="data-all_estimator") if opts.evaluation_data: - cvhelper.predict_dataset(opts.evaluation_data, save_pred=True, - thres=opts.threshold) + cvhelper.predict_dataset( + opts.evaluation_data, save_pred=True, thres=opts.threshold + ) - LOGGER.info('Results saved as %s', abspath(cvhelper._base_name + '*')) + LOGGER.info("Results saved as %s", abspath(cvhelper._base_name + "*")) def _parse_set(arg, default): if arg is not None and len(arg) == 0: - return [pkgrf('mriqc', 'data/csv/%s' % name) for name in ( - 'x_%s.csv' % default, - 'y_%s.csv' % default)] + return [ + pkgrf("mriqc", "data/csv/%s" % name) + for name in ("x_%s.csv" % default, "y_%s.csv" % default) + ] if arg is not None and len(arg) not in (0, 2): - raise RuntimeError('Wrong number of parameters.') + raise RuntimeError("Wrong number of parameters.") if arg is None: return None @@ -210,13 +284,17 @@ def _parse_set(arg, default): if len(arg) == 2: train_exists = [isfile(fname) for fname in arg] if len(train_exists) > 0 and not all(train_exists): - errors = ['file "%s" not found' % fname - for fexists, fname in zip(train_exists, arg) - if not fexists] - raise RuntimeError('Errors (%d) loading training set: %s.' % ( - len(errors), ', '.join(errors))) + errors = [ + 'file "%s" not found' % fname + for fexists, fname in zip(train_exists, arg) + if not fexists + ] + raise RuntimeError( + "Errors (%d) loading training set: %s." + % (len(errors), ", ".join(errors)) + ) return arg -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/mriqc_plot.py b/mriqc/bin/mriqc_plot.py index 22d1a473a..861b822b3 100644 --- a/mriqc/bin/mriqc_plot.py +++ b/mriqc/bin/mriqc_plot.py @@ -21,44 +21,57 @@ def main(): """Entry point""" - parser = ArgumentParser(description='MRI Quality Control', - formatter_class=RawTextHelpFormatter) - - g_input = parser.add_argument_group('Inputs') - g_input.add_argument('-d', '--data-type', action='store', nargs='*', - choices=['anat', 'func'], default=['anat', 'func']) - g_input.add_argument('-v', '--version', action='store_true', default=False, - help='Show current mriqc version') - - g_input.add_argument('--nthreads', action='store', default=0, - type=int, help='number of threads') - - g_outputs = parser.add_argument_group('Outputs') - g_outputs.add_argument('-o', '--output-dir', action='store') - g_outputs.add_argument('-w', '--work-dir', action='store', - default=op.join(os.getcwd(), 'work')) + parser = ArgumentParser( + description="MRI Quality Control", formatter_class=RawTextHelpFormatter + ) + + g_input = parser.add_argument_group("Inputs") + g_input.add_argument( + "-d", + "--data-type", + action="store", + nargs="*", + choices=["anat", "func"], + default=["anat", "func"], + ) + g_input.add_argument( + "-v", + "--version", + action="store_true", + default=False, + help="Show current mriqc version", + ) + + g_input.add_argument( + "--nthreads", action="store", default=0, type=int, help="number of threads" + ) + + g_outputs = parser.add_argument_group("Outputs") + g_outputs.add_argument("-o", "--output-dir", action="store") + g_outputs.add_argument( + "-w", "--work-dir", action="store", default=op.join(os.getcwd(), "work") + ) opts = parser.parse_args() if opts.version: - print('mriqc version ' + __version__) + print("mriqc version " + __version__) exit(0) - settings = {'output_dir': os.getcwd(), - 'nthreads': opts.nthreads} + settings = {"output_dir": os.getcwd(), "nthreads": opts.nthreads} if opts.output_dir: - settings['output_dir'] = op.abspath(opts.output_dir) + settings["output_dir"] = op.abspath(opts.output_dir) - if not op.exists(settings['output_dir']): - os.makedirs(settings['output_dir']) + if not op.exists(settings["output_dir"]): + os.makedirs(settings["output_dir"]) - settings['work_dir'] = op.abspath(opts.work_dir) - if not op.exists(settings['work_dir']): - raise RuntimeError('Work directory of a previous MRIQC run was not found.') + settings["work_dir"] = op.abspath(opts.work_dir) + if not op.exists(settings["work_dir"]): + raise RuntimeError("Work directory of a previous MRIQC run was not found.") for dtype in opts.data_type: workflow_report(dtype, settings) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/mriqcwebapi_test.py b/mriqc/bin/mriqcwebapi_test.py index 3a930eca7..8f2a00f0a 100644 --- a/mriqc/bin/mriqcwebapi_test.py +++ b/mriqc/bin/mriqcwebapi_test.py @@ -9,15 +9,28 @@ def get_parser(): from argparse import ArgumentParser from argparse import RawTextHelpFormatter - parser = ArgumentParser(description='MRIQCWebAPI: Check entries', - formatter_class=RawTextHelpFormatter) - parser.add_argument('modality', action='store', choices=['T1w', 'bold'], - help='number of expected items in the database') - parser.add_argument('expected', action='store', type=int, - help='number of expected items in the database') + parser = ArgumentParser( + description="MRIQCWebAPI: Check entries", formatter_class=RawTextHelpFormatter + ) parser.add_argument( - '--webapi-url', action='store', default='https://mriqc.nimh.nih.gov/api/v1/T1w', type=str, - help='IP address where the MRIQC WebAPI is listening') + "modality", + action="store", + choices=["T1w", "bold"], + help="number of expected items in the database", + ) + parser.add_argument( + "expected", + action="store", + type=int, + help="number of expected items in the database", + ) + parser.add_argument( + "--webapi-url", + action="store", + default="https://mriqc.nimh.nih.gov/api/v1/T1w", + type=str, + help="IP address where the MRIQC WebAPI is listening", + ) return parser @@ -29,11 +42,11 @@ def main(): # Run parser MRIQC_LOG = logging.getLogger(__name__) opts = get_parser().parse_args() - MRIQC_LOG.info('Sending GET to %s', opts.webapi_url) + MRIQC_LOG.info("Sending GET to %s", opts.webapi_url) resp = get(opts.webapi_url).json() - MRIQC_LOG.info('There are %d records in database', resp['_meta']['total']) - assert opts.expected == resp['_meta']['total'] + MRIQC_LOG.info("There are %d records in database", resp["_meta"]["total"]) + assert opts.expected == resp["_meta"]["total"] -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/nib_hash.py b/mriqc/bin/nib_hash.py index 16a26a958..74e482ee6 100644 --- a/mriqc/bin/nib_hash.py +++ b/mriqc/bin/nib_hash.py @@ -17,9 +17,12 @@ def get_parser(): """ A trivial parser """ from argparse import ArgumentParser, RawTextHelpFormatter - parser = ArgumentParser(description='compare two pandas dataframes', - formatter_class=RawTextHelpFormatter) - parser.add_argument('input_file', action='store', help='input nifti file') + + parser = ArgumentParser( + description="compare two pandas dataframes", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("input_file", action="store", help="input nifti file") return parser @@ -34,8 +37,8 @@ def main(): """Entry point""" fname = get_parser().parse_args().input_file sha = get_hash(fname) - print('%s %s' % (sha, fname)) + print("%s %s" % (sha, fname)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/bin/subject_wrangler.py b/mriqc/bin/subject_wrangler.py index 97d374134..e293a3be5 100644 --- a/mriqc/bin/subject_wrangler.py +++ b/mriqc/bin/subject_wrangler.py @@ -1,19 +1,13 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Author: oesteban -# @Date: 2015-11-19 16:44:27 -# @Last Modified by: oesteban -# @Last Modified time: 2018-03-12 11:51:32 - -""" -BIDS-Apps subject wrangler - -""" +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""BIDS-Apps subject wrangler.""" from builtins import range # pylint: disable=W0622 import os.path as op import glob from random import shuffle + # from lockfile import LockFile from argparse import ArgumentParser @@ -25,50 +19,83 @@ def main(): """Entry point""" - parser = ArgumentParser(formatter_class=RawTextHelpFormatter, description=dedent("""\ + parser = ArgumentParser( + formatter_class=RawTextHelpFormatter, + description=dedent( + """\ BIDS-Apps participants wrangler tool ------------------------------------ This command arranges the participant labels in groups for computation, and checks that the \ requested participants have the corresponding folder in the bids_dir.\ -""")) - - parser.add_argument('-v', '--version', action='version', - version='mriqc v{}'.format(__version__)) - - parser.add_argument('bids_dir', action='store', - help='The directory with the input dataset ' - 'formatted according to the BIDS standard.') - parser.add_argument('output_dir', action='store', - help='The directory where the output files ' - 'should be stored. If you are running group level analysis ' - 'this folder should be prepopulated with the results of the' - 'participant level analysis.') - parser.add_argument('--participant_label', '--subject_list', '-S', action='store', - help='The label(s) of the participant(s) that should be analyzed. ' - 'The label corresponds to sub- from the ' - 'BIDS spec (so it does not include "sub-"). If this parameter ' - 'is not provided all subjects should be analyzed. Multiple ' - 'participants can be specified with a space separated list.', - nargs="*") - parser.add_argument('--group-size', default=1, action='store', type=int, - help='parallelize participants in groups') - parser.add_argument('--no-randomize', default=False, action='store_true', - help='do not randomize participants list before grouping') - parser.add_argument('--log-groups', default=False, action='store_true', - help='append logging output') - parser.add_argument('--multiple-workdir', default=False, action='store_true', - help='split work directories by jobs') - parser.add_argument('--bids-app-name', default='mriqc', action='store', - help='BIDS app to call') - parser.add_argument('--args', default='', action='store', help='append arguments') +""" + ), + ) + + parser.add_argument( + "-v", "--version", action="version", version="mriqc v{}".format(__version__) + ) + + parser.add_argument( + "bids_dir", + action="store", + help="The directory with the input dataset " + "formatted according to the BIDS standard.", + ) + parser.add_argument( + "output_dir", + action="store", + help="The directory where the output files " + "should be stored. If you are running group level analysis " + "this folder should be prepopulated with the results of the" + "participant level analysis.", + ) + parser.add_argument( + "--participant_label", + "--subject_list", + "-S", + action="store", + help="The label(s) of the participant(s) that should be analyzed. " + "The label corresponds to sub- from the " + 'BIDS spec (so it does not include "sub-"). If this parameter ' + "is not provided all subjects should be analyzed. Multiple " + "participants can be specified with a space separated list.", + nargs="*", + ) + parser.add_argument( + "--group-size", + default=1, + action="store", + type=int, + help="parallelize participants in groups", + ) + parser.add_argument( + "--no-randomize", + default=False, + action="store_true", + help="do not randomize participants list before grouping", + ) + parser.add_argument( + "--log-groups", default=False, action="store_true", help="append logging output" + ) + parser.add_argument( + "--multiple-workdir", + default=False, + action="store_true", + help="split work directories by jobs", + ) + parser.add_argument( + "--bids-app-name", default="mriqc", action="store", help="BIDS app to call" + ) + parser.add_argument("--args", default="", action="store", help="append arguments") opts = parser.parse_args() # Build settings dict bids_dir = op.abspath(opts.bids_dir) - all_subjects = sorted([op.basename(subj)[4:] - for subj in glob.glob(op.join(bids_dir, 'sub-*'))]) + all_subjects = sorted( + [op.basename(subj)[4:] for subj in glob.glob(op.join(bids_dir, "sub-*"))] + ) subject_list = opts.participant_label if subject_list is None or not subject_list: @@ -76,14 +103,16 @@ def main(): else: # remove sub- prefix, get unique for i, subj in enumerate(subject_list): - subject_list[i] = subj[4:] if subj.startswith('sub-') else subj + subject_list[i] = subj[4:] if subj.startswith("sub-") else subj subject_list = sorted(list(set(subject_list))) if list(set(subject_list) - set(all_subjects)): non_exist = list(set(subject_list) - set(all_subjects)) - raise RuntimeError('Participant label(s) not found in the ' - 'BIDS root directory: {}'.format(' '.join(non_exist))) + raise RuntimeError( + "Participant label(s) not found in the " + "BIDS root directory: {}".format(" ".join(non_exist)) + ) if not opts.no_randomize: shuffle(subject_list) @@ -91,29 +120,41 @@ def main(): gsize = opts.group_size if gsize < 0: - raise RuntimeError('group size should be at least 0 ' - '(all participants assigned to same group') + raise RuntimeError( + "group size should be at least 0 " + "(all participants assigned to same group" + ) if gsize == 0: gsize = len(subject_list) - groups = [subject_list[i:i + gsize] - for i in range(0, len(subject_list), gsize)] + groups = [subject_list[i:i + gsize] for i in range(0, len(subject_list), gsize)] - log_arg = ''.format + log_arg = "".format if opts.log_groups: - log_arg = '>> log/mriqc-{:04d}.log'.format + log_arg = ">> log/mriqc-{:04d}.log".format - cmdline = ('{exec} {bids_dir} {out_dir} participant --participant_label {labels}' - '{work_dir} {arguments} {logfile}').format + cmdline = ( + "{exec} {bids_dir} {out_dir} participant --participant_label {labels}" + "{work_dir} {arguments} {logfile}" + ).format for i, part_group in enumerate(groups): - workdir = '' + workdir = "" if opts.multiple_workdir: - workdir = ' -w work/sjob-{:04d}'.format(i) - print(cmdline(**{ - 'exec': opts.bids_app_name, 'bids_dir': bids_dir, 'out_dir': opts.output_dir, - 'labels': ' '.join(part_group), 'work_dir': workdir, 'arguments': opts.args, - 'logfile': log_arg(i)})) - - -if __name__ == '__main__': + workdir = " -w work/sjob-{:04d}".format(i) + print( + cmdline( + **{ + "exec": opts.bids_app_name, + "bids_dir": bids_dir, + "out_dir": opts.output_dir, + "labels": " ".join(part_group), + "work_dir": workdir, + "arguments": opts.args, + "logfile": log_arg(i), + } + ) + ) + + +if __name__ == "__main__": main() diff --git a/mriqc/classifier/__init__.py b/mriqc/classifier/__init__.py index 864de2366..03352dc33 100644 --- a/mriqc/classifier/__init__.py +++ b/mriqc/classifier/__init__.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """ diff --git a/mriqc/classifier/data.py b/mriqc/classifier/data.py index 8ec903525..fe7693b14 100644 --- a/mriqc/classifier/data.py +++ b/mriqc/classifier/data.py @@ -1,8 +1,5 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Author: oesteban -# @Date: 2015-11-19 16:44:27 - +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: """ =================== Data handler module @@ -22,23 +19,27 @@ from ..utils.misc import BIDS_COMP -def get_groups(X, label='site'): +def get_groups(X, label="site"): """Generate the index of sites""" groups = X[label].values.ravel().tolist() gnames = sorted(list(set(groups))) return [gnames.index(g) for g in groups], gnames -def combine_datasets(inputs, rating_label='rater_1'): +def combine_datasets(inputs, rating_label="rater_1"): mdata = [] for dataset_x, dataset_y, sitename in inputs: sitedata, _ = read_dataset( - dataset_x, dataset_y, rate_label=rating_label, - binarize=True, site_name=sitename) - sitedata['database'] = [sitename] * len(sitedata) + dataset_x, + dataset_y, + rate_label=rating_label, + binarize=True, + site_name=sitename, + ) + sitedata["database"] = [sitename] * len(sitedata) - if 'site' not in sitedata.columns.ravel().tolist(): - sitedata['site'] = [sitename] * len(sitedata) + if "site" not in sitedata.columns.ravel().tolist(): + sitedata["site"] = [sitename] * len(sitedata) mdata.append(sitedata) @@ -50,7 +51,7 @@ def combine_datasets(inputs, rating_label='rater_1'): bids_comps_present = list(set(mdata.columns.ravel().tolist()) & set(bids_comps)) bids_comps_present = [bit for bit in bids_comps if bit in bids_comps_present] - ordered_cols = bids_comps_present + ['database', 'site', 'rater_1'] + ordered_cols = bids_comps_present + ["database", "site", "rater_1"] ordered_cols += sorted(list(set(all_cols) - set(ordered_cols))) return mdata[ordered_cols] @@ -66,16 +67,17 @@ def read_iqms(feat_file): """ Reads in the features """ feat_file = Path(feat_file) - if feat_file.suffix == '.csv': + if feat_file.suffix == ".csv": bids_comps = list(BIDS_COMP.keys()) - x_df = pd.read_csv(feat_file, index_col=False, - dtype={col: str for col in bids_comps}) + x_df = pd.read_csv( + feat_file, index_col=False, dtype={col: str for col in bids_comps} + ) # Find present bids bits and sort by them bids_comps_present = list(set(x_df.columns.ravel().tolist()) & set(bids_comps)) bids_comps_present = [bit for bit in bids_comps if bit in bids_comps_present] x_df = x_df.sort_values(by=bids_comps_present) # Remove sub- prefix in subject_id - x_df.subject_id = x_df.subject_id.str.lstrip('sub-') + x_df.subject_id = x_df.subject_id.str.lstrip("sub-") # Remove columns that are not IQMs feat_names = list(x_df._get_numeric_data().columns.ravel()) @@ -85,26 +87,31 @@ def read_iqms(feat_file): except ValueError: pass else: - bids_comps_present = ['subject_id'] - x_df = pd.read_csv(feat_file, index_col=False, sep='\t', - dtype={'bids_name': str}) - x_df = x_df.sort_values(by=['bids_name']) - x_df['subject_id'] = x_df.bids_name.str.lstrip('sub-') - x_df = x_df.drop(columns=['bids_name']) - x_df.subject_id = ['_'.join(v.split('_')[:-1]) - for v in x_df.subject_id.ravel()] + bids_comps_present = ["subject_id"] + x_df = pd.read_csv( + feat_file, index_col=False, sep="\t", dtype={"bids_name": str} + ) + x_df = x_df.sort_values(by=["bids_name"]) + x_df["subject_id"] = x_df.bids_name.str.lstrip("sub-") + x_df = x_df.drop(columns=["bids_name"]) + x_df.subject_id = ["_".join(v.split("_")[:-1]) for v in x_df.subject_id.ravel()] feat_names = list(x_df._get_numeric_data().columns.ravel()) for col in feat_names: - if col.startswith(('size_', 'spacing_', 'Unnamed')): + if col.startswith(("size_", "spacing_", "Unnamed")): feat_names.remove(col) return x_df, feat_names, bids_comps_present -def read_labels(label_file, rate_label='rater_1', binarize=True, - site_name=None, rate_selection='random', - collapse=True): +def read_labels( + label_file, + rate_label="rater_1", + binarize=True, + site_name=None, + rate_selection="random", + collapse=True, +): """ Reads in the labels. Massage labels table to have the appropriate format @@ -115,63 +122,79 @@ def read_labels(label_file, rate_label='rater_1', binarize=True, output_labels = rate_label bids_comps = list(BIDS_COMP.keys()) - y_df = pd.read_csv(label_file, index_col=False, - dtype={col: str for col in bids_comps}) + y_df = pd.read_csv( + label_file, index_col=False, dtype={col: str for col in bids_comps} + ) # Find present bids bits and sort by them bids_comps_present = get_bids_cols(y_df) y_df = y_df.sort_values(by=bids_comps_present) - y_df.subject_id = y_df.subject_id.str.lstrip('sub-') - y_df[rate_label] = y_df[rate_label].apply(pd.to_numeric, errors='raise') + y_df.subject_id = y_df.subject_id.str.lstrip("sub-") + y_df[rate_label] = y_df[rate_label].apply(pd.to_numeric, errors="raise") if len(rate_label) == 2: np.random.seed(42) ratermask_1 = ~np.isnan(y_df[[rate_label[0]]].values.ravel()) ratermask_2 = ~np.isnan(y_df[[rate_label[1]]].values.ravel()) - all_rated = (ratermask_1 & ratermask_2) + all_rated = ratermask_1 & ratermask_2 mergey = np.array(y_df[[rate_label[0]]].values.ravel().tolist()) mergey[ratermask_2] = y_df[[rate_label[1]]].values.ravel()[ratermask_2] - subsmpl = np.random.choice(np.where(all_rated)[0], int(0.5 * np.sum(all_rated)), - replace=False) + subsmpl = np.random.choice( + np.where(all_rated)[0], int(0.5 * np.sum(all_rated)), replace=False + ) all_rated[subsmpl] = False mergey[all_rated] = y_df[[rate_label[0]]].values.ravel()[all_rated] - y_df['merged_ratings'] = mergey.astype(int) + y_df["merged_ratings"] = mergey.astype(int) # Set default name if collapse: - cols = [('indv_%s' % c) if c.startswith('rater') else - c for c in y_df.columns.ravel().tolist()] - cols[y_df.columns.get_loc('merged_ratings')] = rate_label[0] + cols = [ + ("indv_%s" % c) if c.startswith("rater") else c + for c in y_df.columns.ravel().tolist() + ] + cols[y_df.columns.get_loc("merged_ratings")] = rate_label[0] y_df.columns = cols output_labels = [rate_label[0]] else: output_labels = rate_label - output_labels.insert(0, 'merged_ratings') + output_labels.insert(0, "merged_ratings") if binarize: mask = y_df[output_labels[0]] >= 0 y_df.loc[mask, output_labels[0]] = 0 y_df.loc[~mask, output_labels[0]] = 1 - if 'site' in y_df.columns.ravel().tolist(): - output_labels.insert(0, 'site') + if "site" in y_df.columns.ravel().tolist(): + output_labels.insert(0, "site") elif site_name is not None: - y_df['site'] = [site_name] * len(y_df) - output_labels.insert(0, 'site') + y_df["site"] = [site_name] * len(y_df) + output_labels.insert(0, "site") return y_df[bids_comps_present + output_labels] -def read_dataset(feat_file, label_file, merged_name=None, - binarize=True, site_name=None, rate_label='rater_1', - rate_selection='random'): +def read_dataset( + feat_file, + label_file, + merged_name=None, + binarize=True, + site_name=None, + rate_label="rater_1", + rate_selection="random", +): """ Reads in the features and labels """ x_df, feat_names, _ = read_iqms(feat_file) - y_df = read_labels(label_file, rate_label, binarize, collapse=True, - site_name=site_name, rate_selection=rate_selection) + y_df = read_labels( + label_file, + rate_label, + binarize, + collapse=True, + site_name=site_name, + rate_selection=rate_selection, + ) if isinstance(rate_label, (list, tuple)): rate_label = rate_label[0] @@ -183,24 +206,24 @@ def read_dataset(feat_file, label_file, merged_name=None, bids_comps_y = [bit for bit in bids_comps if bit in bids_comps_y] if bids_comps_x != bids_comps_y: - raise RuntimeError('Labels and features cannot be merged') + raise RuntimeError("Labels and features cannot be merged") - x_df['bids_ids'] = x_df.subject_id.values.copy() - y_df['bids_ids'] = y_df.subject_id.values.copy() + x_df["bids_ids"] = x_df.subject_id.values.copy() + y_df["bids_ids"] = y_df.subject_id.values.copy() for comp in bids_comps_x[1:]: - x_df['bids_ids'] = x_df.bids_ids.str.cat(x_df.loc[:, comp].astype(str), sep='_') - y_df['bids_ids'] = y_df.bids_ids.str.cat(y_df.loc[:, comp].astype(str), sep='_') + x_df["bids_ids"] = x_df.bids_ids.str.cat(x_df.loc[:, comp].astype(str), sep="_") + y_df["bids_ids"] = y_df.bids_ids.str.cat(y_df.loc[:, comp].astype(str), sep="_") # Remove failed cases from Y, append new columns to X - y_df = y_df[y_df['bids_ids'].isin(list(x_df.bids_ids.values.ravel()))] + y_df = y_df[y_df["bids_ids"].isin(list(x_df.bids_ids.values.ravel()))] # Drop indexing column - del x_df['bids_ids'] - del y_df['bids_ids'] + del x_df["bids_ids"] + del y_df["bids_ids"] # Merge Y dataframe into X - x_df = pd.merge(x_df, y_df, on=bids_comps_x, how='left') + x_df = pd.merge(x_df, y_df, on=bids_comps_x, how="left") if merged_name is not None: x_df.to_csv(merged_name, index=False) @@ -209,7 +232,7 @@ def read_dataset(feat_file, label_file, merged_name=None, nan_labels = x_df[x_df[rate_label].isnull()].index.ravel().tolist() if nan_labels: config.loggers.interface.info( - f'Dropping {len(nan_labels)} samples for having non-numerical labels,' + f"Dropping {len(nan_labels)} samples for having non-numerical labels," ) x_df = x_df.drop(nan_labels) @@ -226,15 +249,16 @@ def read_dataset(feat_file, label_file, merged_name=None, ldist.append(int(np.sum(x_df[rate_label] == l))) config.loggers.interface.info( - 'Ratings distribution: %s (%s, %s)', - '/'.join(['%d' % x for x in ldist]), - '/'.join(['%.2f%%' % (100 * x / nsamples) for x in ldist]), - 'accept/exclude' if len(ldist) == 2 else 'exclude/doubtful/accept') + "Ratings distribution: %s (%s, %s)", + "/".join(["%d" % x for x in ldist]), + "/".join(["%.2f%%" % (100 * x / nsamples) for x in ldist]), + "accept/exclude" if len(ldist) == 2 else "exclude/doubtful/accept", + ) return x_df, feat_names -def balanced_leaveout(dataframe, site_column='site', rate_label='rater_1'): +def balanced_leaveout(dataframe, site_column="site", rate_label="rater_1"): sites = list(set(dataframe[[site_column]].values.ravel())) pos_draw = [] neg_draw = [] @@ -254,12 +278,11 @@ def balanced_leaveout(dataframe, site_column='site', rate_label='rater_1'): return dataframe, left_out -def zscore_dataset(dataframe, excl_columns=None, by='site', - njobs=-1): +def zscore_dataset(dataframe, excl_columns=None, by="site", njobs=-1): """ Returns a dataset zscored by the column given as argument """ from multiprocessing import Pool, cpu_count - config.loggers.interface.info('z-scoring dataset ...') + config.loggers.interface.info("z-scoring dataset ...") if njobs <= 0: njobs = cpu_count() @@ -294,7 +317,8 @@ def zscore_dataset(dataframe, excl_columns=None, by='site', if nan_columns: config.loggers.interface.warning( - f'Columns {", ".join(nan_columns)} contain NaNs after z-scoring.') + f'Columns {", ".join(nan_columns)} contain NaNs after z-scoring.' + ) zs_df[nan_columns] = dataframe[nan_columns].values return zs_df @@ -303,6 +327,6 @@ def zscore_dataset(dataframe, excl_columns=None, by='site', def zscore_site(args): """ z-scores only one site """ from scipy.stats import zscore + dataframe, columns, site = args - return zscore(dataframe.loc[dataframe.site == site, columns].values, - ddof=1, axis=0) + return zscore(dataframe.loc[dataframe.site == site, columns].values, ddof=1, axis=0) diff --git a/mriqc/classifier/helper.py b/mriqc/classifier/helper.py index 941d9b1d6..19c8d7020 100644 --- a/mriqc/classifier/helper.py +++ b/mriqc/classifier/helper.py @@ -1,8 +1,5 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Author: oesteban -# @Date: 2015-11-19 16:44:27 - +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: """ Cross-validation helper ^^^^^^^^^^^^^^^^^^^^^^^ @@ -21,11 +18,16 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import LabelBinarizer from sklearn.metrics.scorer import check_scoring -from sklearn.model_selection import (RepeatedStratifiedKFold, GridSearchCV, RandomizedSearchCV, - PredefinedSplit) +from sklearn.model_selection import ( + RepeatedStratifiedKFold, + GridSearchCV, + RandomizedSearchCV, + PredefinedSplit, +) from sklearn.ensemble import RandomForestClassifier as RFC from sklearn.svm import SVC, LinearSVC from sklearn.multiclass import OneVsRestClassifier + # xgboost from xgboost import XGBClassifier @@ -35,18 +37,46 @@ from builtins import object -LOG = logging.getLogger('mriqc.classifier') +LOG = logging.getLogger("mriqc.classifier") LOG.setLevel(logging.INFO) FEATURE_NORM = [ - 'cjv', 'cnr', 'efc', 'fber', 'fwhm_avg', 'fwhm_x', 'fwhm_y', 'fwhm_z', - 'snr_csf', 'snr_gm', 'snr_total', 'snr_wm', 'snrd_csf', 'snrd_gm', 'snrd_total', 'snrd_wm', - 'summary_csf_mad', 'summary_csf_mean', 'summary_csf_median', - 'summary_csf_p05', 'summary_csf_p95', 'summary_csf_stdv', - 'summary_gm_k', 'summary_gm_mad', 'summary_gm_mean', 'summary_gm_median', - 'summary_gm_p05', 'summary_gm_p95', 'summary_gm_stdv', - 'summary_wm_k', 'summary_wm_mad', 'summary_wm_mean', 'summary_wm_median', - 'summary_wm_p05', 'summary_wm_p95', 'summary_wm_stdv' + "cjv", + "cnr", + "efc", + "fber", + "fwhm_avg", + "fwhm_x", + "fwhm_y", + "fwhm_z", + "snr_csf", + "snr_gm", + "snr_total", + "snr_wm", + "snrd_csf", + "snrd_gm", + "snrd_total", + "snrd_wm", + "summary_csf_mad", + "summary_csf_mean", + "summary_csf_median", + "summary_csf_p05", + "summary_csf_p95", + "summary_csf_stdv", + "summary_gm_k", + "summary_gm_mad", + "summary_gm_mean", + "summary_gm_median", + "summary_gm_p05", + "summary_gm_p95", + "summary_gm_stdv", + "summary_wm_k", + "summary_wm_mad", + "summary_wm_mean", + "summary_wm_median", + "summary_wm_p05", + "summary_wm_p95", + "summary_wm_stdv", ] @@ -55,9 +85,20 @@ class CVHelperBase(object): A base helper to build cross-validation schemes """ - def __init__(self, X, Y, param_file=None, n_jobs=-1, site_label='site', - rate_label=None, rate_selection='random', - scorer='roc_auc', multiclass=False, verbosity=0, debug=False): + def __init__( + self, + X, + Y, + param_file=None, + n_jobs=-1, + site_label="site", + rate_label=None, + rate_selection="random", + scorer="roc_auc", + multiclass=False, + verbosity=0, + debug=False, + ): # Initialize some values self._param_file = param_file self.n_jobs = n_jobs @@ -67,12 +108,16 @@ def __init__(self, X, Y, param_file=None, n_jobs=-1, site_label='site', self._debug = debug if rate_label is None: - rate_label = ['rater_1', 'rater_2'] + rate_label = ["rater_1", "rater_2"] self._rate_column = rate_label[0] self._Xtrain, self._ftnames = read_dataset( - X, Y, rate_label=rate_label, rate_selection=rate_selection, - binarize=not self._multiclass) + X, + Y, + rate_label=rate_label, + rate_selection=rate_selection, + binarize=not self._multiclass, + ) self.sites = list(set(self._Xtrain[site_label].values.ravel())) self._scorer = scorer self._balanced_leaveout = True @@ -97,14 +142,30 @@ def predict(self, X, thres=0.5, return_proba=True): class CVHelper(CVHelperBase): - def __init__(self, X=None, Y=None, load_clf=None, param_file=None, n_jobs=-1, - site_label='site', rate_label=None, scorer='roc_auc', - b_leaveout=False, multiclass=False, verbosity=0, split='kfold', - debug=False, model='rfc', basename=None, nested_cv=False, - nested_cv_kfold=False, permutation_test=0): + def __init__( + self, + X=None, + Y=None, + load_clf=None, + param_file=None, + n_jobs=-1, + site_label="site", + rate_label=None, + scorer="roc_auc", + b_leaveout=False, + multiclass=False, + verbosity=0, + split="kfold", + debug=False, + model="rfc", + basename=None, + nested_cv=False, + nested_cv_kfold=False, + permutation_test=0, + ): if (X is None or Y is None) and load_clf is None: - raise RuntimeError('Either load_clf or X & Y should be supplied') + raise RuntimeError("Either load_clf or X & Y should be supplied") self._estimator = None self._Xtest = None @@ -126,9 +187,17 @@ def __init__(self, X=None, Y=None, load_clf=None, param_file=None, n_jobs=-1, self._base_name = basename[:24] else: super(CVHelper, self).__init__( - X, Y, param_file=param_file, n_jobs=n_jobs, - site_label=site_label, rate_label=rate_label, scorer=scorer, - multiclass=multiclass, verbosity=verbosity, debug=debug) + X, + Y, + param_file=param_file, + n_jobs=n_jobs, + site_label=site_label, + rate_label=rate_label, + scorer=scorer, + multiclass=multiclass, + verbosity=verbosity, + debug=debug, + ) @property def estimator(self): @@ -139,33 +208,34 @@ def Xtest(self): return self._Xtest def setXtest(self, X, Y): - self._Xtest, _ = read_dataset(X, Y, rate_label=self._rate_column, - binarize=not self._multiclass) - if 'site' not in self._Xtest.columns.ravel().tolist(): - self._Xtest['site'] = ['TestSite'] * len(self._Xtest) + self._Xtest, _ = read_dataset( + X, Y, rate_label=self._rate_column, binarize=not self._multiclass + ) + if "site" not in self._Xtest.columns.ravel().tolist(): + self._Xtest["site"] = ["TestSite"] * len(self._Xtest) def _gen_fname(self, suffix=None, ext=None): if ext is None: - ext = '' + ext = "" if suffix is None: - suffix = '' + suffix = "" - if not ext.startswith('.'): - ext = '.' + ext + if not ext.startswith("."): + ext = "." + ext - if not suffix.startswith('_'): - suffix = '_' + suffix + if not suffix.startswith("_"): + suffix = "_" + suffix return self._base_name + suffix + ext def _get_model(self): - if self._model == 'xgb': + if self._model == "xgb": return XGBClassifier() - if self._model == 'svc_rbf': + if self._model == "svc_rbf": return SVC() - if self._model == 'svc_lin': + if self._model == "svc_lin": return LinearSVC() return RFC() @@ -175,71 +245,81 @@ def fit(self): Fits the cross-validation helper """ from .sklearn import preprocessing as mcsp - from .sklearn._split import (RobustLeavePGroupsOut as LeavePGroupsOut, - RepeatedBalancedKFold, RepeatedPartiallyHeldOutKFold) + from .sklearn._split import ( + RobustLeavePGroupsOut as LeavePGroupsOut, + RepeatedBalancedKFold, + RepeatedPartiallyHeldOutKFold, + ) if self._pickled: - LOG.info('Classifier was loaded from file, cancelling fitting.') + LOG.info("Classifier was loaded from file, cancelling fitting.") return if self._leaveout: raise NotImplementedError - LOG.info('CV [Setting up pipeline] - scorer: %s', self._scorer) + LOG.info("CV [Setting up pipeline] - scorer: %s", self._scorer) - feat_sel = self._ftnames + ['site'] + feat_sel = self._ftnames + ["site"] steps = [ - ('std', mcsp.BatchRobustScaler( - by='site', columns=[ft for ft in self._ftnames if ft in FEATURE_NORM])), - ('sel_cols', mcsp.PandasAdaptor(columns=self._ftnames + ['site'])), - ('ft_sites', mcsp.SiteCorrelationSelector()), - ('ft_noise', mcsp.CustFsNoiseWinnow()), - (self._model, self._get_model()) + ( + "std", + mcsp.BatchRobustScaler( + by="site", + columns=[ft for ft in self._ftnames if ft in FEATURE_NORM], + ), + ), + ("sel_cols", mcsp.PandasAdaptor(columns=self._ftnames + ["site"])), + ("ft_sites", mcsp.SiteCorrelationSelector()), + ("ft_noise", mcsp.CustFsNoiseWinnow()), + (self._model, self._get_model()), ] if self._multiclass: # If multiclass: binarize labels and wrap classifier - steps.insert(3, ('bin', LabelBinarizer())) + steps.insert(3, ("bin", LabelBinarizer())) steps[-1] = (steps[-1][0], OneVsRestClassifier(steps[-1][1])) pipe = Pipeline(steps) # Prepare data splits for CV fit_args = {} - if self._split == 'kfold': - kf_params = {} if not self._debug else {'n_splits': 2, 'n_repeats': 1} + if self._split == "kfold": + kf_params = {} if not self._debug else {"n_splits": 2, "n_repeats": 1} splits = RepeatedStratifiedKFold(**kf_params) - elif self._split == 'loso': + elif self._split == "loso": splits = LeavePGroupsOut(n_groups=1) - elif self._split == 'balanced-kfold': - kf_params = {'n_splits': 10, 'n_repeats': 3} + elif self._split == "balanced-kfold": + kf_params = {"n_splits": 10, "n_repeats": 3} if self._debug: - kf_params = {'n_splits': 3, 'n_repeats': 1} + kf_params = {"n_splits": 3, "n_repeats": 1} splits = RepeatedBalancedKFold(**kf_params) - elif self._split == 'batch': + elif self._split == "batch": # Get test label test_site = list(set(self._Xtest.site.values.ravel().tolist()))[0] # Merge test and train self._Xtrain = pd.concat((self._Xtrain, self._Xtest), axis=0) test_mask = self._Xtrain.site.values.ravel() == test_site - kf_params = {'n_splits': 5, 'n_repeats': 1} + kf_params = {"n_splits": 5, "n_repeats": 1} if self._debug: - kf_params = {'n_splits': 3, 'n_repeats': 1} - kf_params['groups'] = test_mask.astype(int).tolist() + kf_params = {"n_splits": 3, "n_repeats": 1} + kf_params["groups"] = test_mask.astype(int).tolist() splits = RepeatedPartiallyHeldOutKFold(**kf_params) train_y = self._Xtrain[[self._rate_column]].values.ravel().tolist() grid = RandomizedSearchCV( - pipe, self._get_params_dist(), + pipe, + self._get_params_dist(), n_iter=1 if self._debug else 50, error_score=0.5, refit=True, scoring=check_scoring(pipe, scoring=self._scorer), n_jobs=self.n_jobs, cv=splits, - verbose=self._verbosity) + verbose=self._verbosity, + ) if self._nestedcv or self._nestedcv_kfold: from .sklearn._validation import cross_val_score @@ -248,90 +328,121 @@ def fit(self): if self._nestedcv_kfold: outer_cv = RepeatedStratifiedKFold(n_repeats=1, n_splits=10) - n_iter = 32 if self._model in ['svc_lin', 'xgb'] else 50 + n_iter = 32 if self._model in ["svc_lin", "xgb"] else 50 grid = RandomizedSearchCV( - pipe, self._get_params_dist(), + pipe, + self._get_params_dist(), n_iter=n_iter if not self._debug else 1, error_score=0.5, refit=True, scoring=check_scoring(pipe, scoring=self._scorer), n_jobs=self.n_jobs, cv=splits, - verbose=self._verbosity) + verbose=self._verbosity, + ) nested_score, group_order = cross_val_score( grid, X=self._Xtrain, y=train_y, cv=outer_cv, - scoring=['roc_auc', 'accuracy', 'recall'], + scoring=["roc_auc", "accuracy", "recall"], ) nested_means = np.average(nested_score, axis=0) nested_std = np.std(nested_score, axis=0) - LOG.info('Nested CV [avg] %s=%.3f (+/-%.3f), accuracy=%.3f (+/-%.3f), ' - 'recall=%.3f (+/-%.3f).', self._scorer, nested_means[0], nested_std[0], - nested_means[1], nested_std[1], nested_means[2], nested_std[2]) - LOG.info('Nested CV %s=%s.', self._scorer, - ', '.join('%.3f' % v for v in nested_score[:, 0].tolist())) - LOG.info('Nested CV accuracy=%s.', - ', '.join('%.3f' % v for v in nested_score[:, 1].tolist())) - LOG.info('Nested CV groups=%s', group_order) + LOG.info( + "Nested CV [avg] %s=%.3f (+/-%.3f), accuracy=%.3f (+/-%.3f), " + "recall=%.3f (+/-%.3f).", + self._scorer, + nested_means[0], + nested_std[0], + nested_means[1], + nested_std[1], + nested_means[2], + nested_std[2], + ) + LOG.info( + "Nested CV %s=%s.", + self._scorer, + ", ".join("%.3f" % v for v in nested_score[:, 0].tolist()), + ) + LOG.info( + "Nested CV accuracy=%s.", + ", ".join("%.3f" % v for v in nested_score[:, 1].tolist()), + ) + LOG.info("Nested CV groups=%s", group_order) else: grid = GridSearchCV( - pipe, self._get_params(), + pipe, + self._get_params(), error_score=0.5, refit=True, scoring=check_scoring(pipe, scoring=self._scorer), n_jobs=self.n_jobs, cv=splits, - verbose=self._verbosity) + verbose=self._verbosity, + ) grid.fit(self._Xtrain, train_y, **fit_args) - np.savez(os.path.abspath(self._gen_fname(suffix='cvres', ext='npz')), - cv_results=grid.cv_results_) + np.savez( + os.path.abspath(self._gen_fname(suffix="cvres", ext="npz")), + cv_results=grid.cv_results_, + ) - best_pos = np.argmin(grid.cv_results_['rank_test_score']) + best_pos = np.argmin(grid.cv_results_["rank_test_score"]) # Save estimator and get its parameters self._estimator = grid.best_estimator_ cvparams = self._estimator.get_params() - LOG.info('CV [Best model] %s=%s, mean=%.3f, std=%.3f.', - self._scorer, grid.best_score_, - grid.cv_results_['mean_test_score'][best_pos], - grid.cv_results_['std_test_score'][best_pos], - ) - LOG.log(18, 'CV [Best model] parameters\n%s', cvparams) + LOG.info( + "CV [Best model] %s=%s, mean=%.3f, std=%.3f.", + self._scorer, + grid.best_score_, + grid.cv_results_["mean_test_score"][best_pos], + grid.cv_results_["std_test_score"][best_pos], + ) + LOG.log(18, "CV [Best model] parameters\n%s", cvparams) - if cvparams.get(self._model + '__oob_score', False): - LOG.info('CV [Best model] OOB %s=%.3f', self._scorer, - self._estimator.named_steps[self._model].oob_score_) + if cvparams.get(self._model + "__oob_score", False): + LOG.info( + "CV [Best model] OOB %s=%.3f", + self._scorer, + self._estimator.named_steps[self._model].oob_score_, + ) # Report preprocessing selections - prep_msg = ' * Robust scaling (centering): %s.\n' % ( - 'enabled' if cvparams['std__with_centering'] else 'disabled') - prep_msg += ' * Robust scaling (scaling): %s.\n' % ( - 'enabled' if cvparams['std__with_scaling'] else 'disabled') - prep_msg += ' * SiteCorrelation feature selection: %s.\n' % ( - 'disabled' if cvparams['ft_sites__disable'] else 'enabled') - prep_msg += ' * Winnow feature selection: %s.\n' % ( - 'disabled' if cvparams['ft_noise__disable'] else 'enabled') + prep_msg = " * Robust scaling (centering): %s.\n" % ( + "enabled" if cvparams["std__with_centering"] else "disabled" + ) + prep_msg += " * Robust scaling (scaling): %s.\n" % ( + "enabled" if cvparams["std__with_scaling"] else "disabled" + ) + prep_msg += " * SiteCorrelation feature selection: %s.\n" % ( + "disabled" if cvparams["ft_sites__disable"] else "enabled" + ) + prep_msg += " * Winnow feature selection: %s.\n" % ( + "disabled" if cvparams["ft_noise__disable"] else "enabled" + ) selected = np.array(feat_sel).copy() - if not cvparams['ft_sites__disable']: - sitesmask = self._estimator.named_steps['ft_sites'].mask_ + if not cvparams["ft_sites__disable"]: + sitesmask = self._estimator.named_steps["ft_sites"].mask_ selected = self._Xtrain[feat_sel].columns.ravel()[sitesmask] - if not cvparams['ft_noise__disable']: - winnowmask = self._estimator.named_steps['ft_noise'].mask_ + if not cvparams["ft_noise__disable"]: + winnowmask = self._estimator.named_steps["ft_noise"].mask_ selected = selected[winnowmask] selected = selected.tolist() - if 'site' in selected: - selected.remove('site') + if "site" in selected: + selected.remove("site") - LOG.info('CV [Preprocessing]:\n%s * Features selected: %s.', - prep_msg, ', '.join(['"%s"' % f for f in selected])) + LOG.info( + "CV [Preprocessing]:\n%s * Features selected: %s.", + prep_msg, + ", ".join(['"%s"' % f for f in selected]), + ) # If leaveout, test and refit if self._leaveout: @@ -342,23 +453,24 @@ def fit(self): def _fit_leaveout(self, leaveout_x, leaveout_y): - target_names = ['accept', 'exclude'] + target_names = ["accept", "exclude"] if self._multiclass: - target_names = ['exclude', 'doubtful', 'accept'] + target_names = ["exclude", "doubtful", "accept"] - LOG.info('Testing on left-out, balanced subset ...') + LOG.info("Testing on left-out, balanced subset ...") # Predict _, pred_y = self.predict(leaveout_x) - LOG.info('Classification report:\n%s', - slm.classification_report(leaveout_y, pred_y, - target_names=target_names)) + LOG.info( + "Classification report:\n%s", + slm.classification_report(leaveout_y, pred_y, target_names=target_names), + ) score = self._score(leaveout_x, leaveout_y) - LOG.info('Performance on balanced left-out (%s=%f)', self._scorer, score) + LOG.info("Performance on balanced left-out (%s=%f)", self._scorer, score) # Rewrite clf - LOG.info('Fitting full model (train + balanced left-out) ...') + LOG.info("Fitting full model (train + balanced left-out) ...") # Features may change the robust normalization # self._estimator.rfc__warm_start = True test_yall = self._Xtrain[[self._rate_column]].values.ravel().tolist() @@ -366,13 +478,14 @@ def _fit_leaveout(self, leaveout_x, leaveout_y): test_yall = LabelBinarizer().fit_transform(test_yall) self._estimator = self._estimator.fit(self._Xtrain, test_yall) - LOG.info('Testing on left-out with full model, balanced subset ...') + LOG.info("Testing on left-out with full model, balanced subset ...") _, pred_y = self.predict(leaveout_x) - LOG.info('Classification report:\n%s', - slm.classification_report(leaveout_y, pred_y, - target_names=target_names)) + LOG.info( + "Classification report:\n%s", + slm.classification_report(leaveout_y, pred_y, target_names=target_names), + ) score = self._score(leaveout_x, leaveout_y) - LOG.info('Performance on balanced left-out (%s=%f)', self._scorer, score) + LOG.info("Performance on balanced left-out (%s=%f)", self._scorer, score) def fit_full(self): """ @@ -380,7 +493,7 @@ def fit_full(self): from the left-out dataset """ if self._estimator is None: - raise RuntimeError('Model should be fit first') + raise RuntimeError("Model should be fit first") target_names = ["accept", "exclude"] X = pd.concat([self._Xtrain, self._Xtest], axis=0) @@ -390,27 +503,27 @@ def fit_full(self): labels_y = LabelBinarizer().fit_transform(labels_y) target_names = ["exclude", "doubtful", "accept"] - LOG.info('Fitting full model ...') + LOG.info("Fitting full model ...") self._estimator = self._estimator.fit(X, labels_y) - LOG.info('Testing on left-out with full model') + LOG.info("Testing on left-out with full model") pred_y = self._estimator.predict(X) - LOG.info('Classification report:\n%s', - slm.classification_report(labels_y, pred_y, - target_names=target_names)) + LOG.info( + "Classification report:\n%s", + slm.classification_report(labels_y, pred_y, target_names=target_names), + ) score = self._score(X, labels_y) - LOG.info('Full model performance on left-out (%s=%f)', self._scorer, score) + LOG.info("Full model performance on left-out (%s=%f)", self._scorer, score) - def evaluate(self, scoring=None, matrix=False, save_roc=False, - save_pred=False): + def evaluate(self, scoring=None, matrix=False, save_roc=False, save_pred=False): """ Evaluate the internal estimator on the test data """ if scoring is None: - scoring = ['accuracy'] + scoring = ["accuracy"] - LOG.info('Testing on evaluation (left-out) dataset ...') + LOG.info("Testing on evaluation (left-out) dataset ...") test_y = self._Xtest[[self._rate_column]].values.ravel() target_names = ["accept", "exclude"] @@ -421,31 +534,36 @@ def evaluate(self, scoring=None, matrix=False, save_roc=False, prob_y, pred_y = self.predict(self._Xtest) scores = [self._score(self._Xtest, test_y, scoring=s) for s in scoring] - LOG.info('Performance on evaluation set (%s)', - ', '.join(['%s=%.3f' % (n, s) for n, s in zip(scoring, scores)])) + LOG.info( + "Performance on evaluation set (%s)", + ", ".join(["%s=%.3f" % (n, s) for n, s in zip(scoring, scores)]), + ) pred_totals = np.sum(pred_y, 0).tolist() if prob_y.shape[1] <= 2: pred_totals = [len(pred_y) - pred_totals, pred_totals] - LOG.info('Predictions: %s', ' / '.join(( - '%d (%s)' % (n, c) for n, c in zip(pred_totals, target_names)))) + LOG.info( + "Predictions: %s", + " / ".join(("%d (%s)" % (n, c) for n, c in zip(pred_totals, target_names))), + ) if matrix: + LOG.info("Confusion matrix:\n%s", slm.confusion_matrix(test_y, pred_y)) LOG.info( - 'Confusion matrix:\n%s', slm.confusion_matrix( - test_y, pred_y)) - LOG.info( - 'Classification report:\n%s', slm.classification_report( - test_y, pred_y, target_names=target_names)) + "Classification report:\n%s", + slm.classification_report(test_y, pred_y, target_names=target_names), + ) if save_pred: - self._save_pred_table(self._Xtest, prob_y, pred_y, - suffix='data-test_pred') + self._save_pred_table(self._Xtest, prob_y, pred_y, suffix="data-test_pred") if save_roc: - plot_roc_curve(self._Xtest[[self._rate_column]].values.ravel(), prob_y, - self._gen_fname(suffix='data-test_roc', ext='png')) + plot_roc_curve( + self._Xtest[[self._rate_column]].values.ravel(), + prob_y, + self._gen_fname(suffix="data-test_roc", ext="png"), + ) # Run a permutation test if self._permutation_test: @@ -457,15 +575,25 @@ def evaluate(self, scoring=None, matrix=False, save_roc=False, test_fold = [-1] * len(self._Xtrain) + [0] * len(self._Xtest) permutation_scores = permutation_test_score( - self._estimator, concatenated_x, concatenated_y, - scoring='accuracy', cv=PredefinedSplit(test_fold), - n_permutations=self._permutation_test, n_jobs=1) + self._estimator, + concatenated_x, + concatenated_y, + scoring="accuracy", + cv=PredefinedSplit(test_fold), + n_permutations=self._permutation_test, + n_jobs=1, + ) - score = scores[scoring.index('accuracy')] - pvalue = (np.sum(permutation_scores - >= score) + 1.0) / (self._permutation_test + 1) - LOG.info('Permutation test (N=%d) for accuracy score %f (pvalue=%f)', - self._permutation_test, score, pvalue) + score = scores[scoring.index("accuracy")] + pvalue = (np.sum(permutation_scores >= score) + 1.0) / ( + self._permutation_test + 1 + ) + LOG.info( + "Permutation test (N=%d) for accuracy score %f (pvalue=%f)", + self._permutation_test, + score, + pvalue, + ) return scores @@ -480,11 +608,13 @@ def predict(self, X, thres=0.5, return_proba=True): """ - if self._model == 'svc_lin': + if self._model == "svc_lin": from sklearn.base import clone from sklearn.calibration import CalibratedClassifierCV - clf = CalibratedClassifierCV(clone(self._estimator).set_param( - **self._estimator.get_param())) + + clf = CalibratedClassifierCV( + clone(self._estimator).set_param(**self._estimator.get_param()) + ) train_y = self._Xtrain[[self._rate_column]].values.ravel().tolist() self._estimator = clf.fit(self._Xtrain, train_y) @@ -502,26 +632,26 @@ def predict(self, X, thres=0.5, return_proba=True): def predict_dataset(self, data, thres=0.5, save_pred=False, site=None): from .data import read_iqms + _xeval, _, _ = read_iqms(data) if site is None: - site = 'unseen' + site = "unseen" columns = _xeval.columns.ravel().tolist() - if 'site' not in columns: - _xeval['site'] = [site] * len(_xeval) - columns.append('site') + if "site" not in columns: + _xeval["site"] = [site] * len(_xeval) + columns.append("site") # Classifier is trained with rate_1 as last column - if 'rate_1' not in columns: - _xeval['rate_1'] = [np.nan] * len(_xeval) - columns.append('rate_1') + if "rate_1" not in columns: + _xeval["rate_1"] = [np.nan] * len(_xeval) + columns.append("rate_1") prob_y, pred_y = self.predict(_xeval[columns]) if save_pred: - self._save_pred_table(_xeval, prob_y, pred_y, - suffix='data-%s_pred' % site) + self._save_pred_table(_xeval, prob_y, pred_y, suffix="data-%s_pred" % site) return pred_y def _save_pred_table(self, sample, prob_y, pred_y, suffix): @@ -529,24 +659,23 @@ def _save_pred_table(self, sample, prob_y, pred_y, suffix): predf = sample[bidts].copy() if self._multiclass: - probs = ['proba_%d' % i - for i in list(range(prob_y.shape[1]))] - predf['pred_y'] = (np.argmax(pred_y, axis=1) - 1).astype(int) + probs = ["proba_%d" % i for i in list(range(prob_y.shape[1]))] + predf["pred_y"] = (np.argmax(pred_y, axis=1) - 1).astype(int) for i, col in enumerate(probs): predf[col] = prob_y[:, i] - cols = probs + ['pred_y'] + cols = probs + ["pred_y"] else: - cols = ['prob_y', 'pred_y'] - predf['prob_y'] = prob_y[:, 1] - predf['pred_y'] = pred_y + cols = ["prob_y", "pred_y"] + predf["prob_y"] = prob_y[:, 1] + predf["pred_y"] = pred_y predf[bidts + cols].to_csv( - self._gen_fname(suffix=suffix, ext='csv'), - index=False) + self._gen_fname(suffix=suffix, ext="csv"), index=False + ) - def save(self, suffix='estimator', compress=3): + def save(self, suffix="estimator", compress=3): """ Pickle the estimator, adding the feature names http://scikit-learn.org/stable/modules/model_persistence.html @@ -555,15 +684,14 @@ def save(self, suffix='estimator', compress=3): from sklearn.externals.joblib import dump as savepkl # Store ftnames - setattr(self._estimator, '_ftnames', self._ftnames) + setattr(self._estimator, "_ftnames", self._ftnames) # Store normalization medians - setattr(self._estimator, '_batch_effect', self._batch_effect) + setattr(self._estimator, "_batch_effect", self._batch_effect) - filehandler = os.path.abspath( - self._gen_fname(suffix=suffix, ext='pklz')) + filehandler = os.path.abspath(self._gen_fname(suffix=suffix, ext="pklz")) - LOG.info('Saving classifier to: %s', filehandler) + LOG.info("Saving classifier to: %s", filehandler) savepkl(self._estimator, filehandler, compress=compress) def load(self, filehandler): @@ -573,9 +701,10 @@ def load(self, filehandler): """ from sklearn.externals.joblib import load as loadpkl + self._estimator = loadpkl(filehandler) - self._ftnames = getattr(self._estimator, '_ftnames') - self._batch_effect = getattr(self._estimator, '_batch_effect', None) + self._ftnames = getattr(self._estimator, "_ftnames") + self._batch_effect = getattr(self._estimator, "_batch_effect", None) self._pickled = True def _score(self, X, y, scoring=None, clf=None): @@ -593,41 +722,46 @@ def _get_params(self): # Some baseline parameters baseparam = { - 'std__by': ['site'], - 'std__columns': [[ft for ft in self._ftnames if ft in FEATURE_NORM]], - 'sel_cols__columns': [self._ftnames + ['site']], + "std__by": ["site"], + "std__columns": [[ft for ft in self._ftnames if ft in FEATURE_NORM]], + "sel_cols__columns": [self._ftnames + ["site"]], } # Load in classifier parameters clfparams = _load_parameters( - (pkgrf('mriqc', 'data/classifier_settings.yml') - if self._param_file is None else self._param_file) + ( + pkgrf("mriqc", "data/classifier_settings.yml") + if self._param_file is None + else self._param_file + ) ) # Read preprocessing parameters - if 'preproc' in clfparams: + if "preproc" in clfparams: preparams = [] - for el in clfparams['preproc']: + for el in clfparams["preproc"]: pcombination = {} for pref, subel in list(el.items()): for k, v in list(subel.items()): - pcombination[pref + '__' + k] = v + pcombination[pref + "__" + k] = v preparams.append(pcombination) else: - preparams = [{ - 'std__with_centering': [True], - 'std__with_scaling': [True], - 'ft_sites__disable': [False], - 'ft_noise__disable': [False], - }] + preparams = [ + { + "std__with_centering": [True], + "std__with_scaling": [True], + "ft_sites__disable": [False], + "ft_noise__disable": [False], + } + ] # Set base parameters preparams = [{**baseparam, **prep} for prep in preparams] # Extract this model parameters - prefix = self._model + '__' + prefix = self._model + "__" if self._multiclass: - prefix += 'estimator__' + prefix += "estimator__" modparams = {prefix + k: v for k, v in list(clfparams[self._model][0].items())} # Merge model parameters + preprocessing @@ -641,33 +775,36 @@ def _get_params(self): def _get_params_dist(self): preparams = { - 'std__by': ['site'], - 'std__with_centering': [True, False], - 'std__with_scaling': [True, False], - 'std__columns': [[ft for ft in self._ftnames if ft in FEATURE_NORM]], - 'sel_cols__columns': [self._ftnames + ['site']], - 'ft_sites__disable': [False, True], - 'ft_noise__disable': [False, True], + "std__by": ["site"], + "std__with_centering": [True, False], + "std__with_scaling": [True, False], + "std__columns": [[ft for ft in self._ftnames if ft in FEATURE_NORM]], + "sel_cols__columns": [self._ftnames + ["site"]], + "ft_sites__disable": [False, True], + "ft_noise__disable": [False, True], } - prefix = self._model + '__' + prefix = self._model + "__" if self._multiclass: - prefix += 'estimator__' + prefix += "estimator__" clfparams = _load_parameters( - (pkgrf('mriqc', 'data/model_selection.yml') - if self._param_file is None else self._param_file) + ( + pkgrf("mriqc", "data/model_selection.yml") + if self._param_file is None + else self._param_file + ) ) modparams = {prefix + k: v for k, v in list(clfparams[self._model][0].items())} if self._debug: preparams = { - 'std__by': ['site'], - 'std__with_centering': [True], - 'std__with_scaling': [True], - 'std__columns': [[ft for ft in self._ftnames if ft in FEATURE_NORM]], - 'sel_cols__columns': [self._ftnames + ['site']], - 'ft_sites__disable': [True], - 'ft_noise__disable': [True], + "std__by": ["site"], + "std__with_centering": [True], + "std__with_scaling": [True], + "std__columns": [[ft for ft in self._ftnames if ft in FEATURE_NORM]], + "sel_cols__columns": [self._ftnames + ["site"]], + "ft_sites__disable": [True], + "ft_noise__disable": [True], } modparams = {k: [v[0]] for k, v in list(modparams.items())} @@ -678,6 +815,7 @@ def _load_parameters(param_file): """Load parameters from file""" import yaml from io import open + with open(param_file) as paramfile: parameters = yaml.load(paramfile) return parameters diff --git a/mriqc/classifier/sklearn/__init__.py b/mriqc/classifier/sklearn/__init__.py index 0c8e0f7ba..af85ec7a6 100644 --- a/mriqc/classifier/sklearn/__init__.py +++ b/mriqc/classifier/sklearn/__init__.py @@ -6,7 +6,7 @@ from ._split import RobustLeavePGroupsOut __all__ = [ - 'ModelParameterGrid', - 'ModelAndGridSearchCV', - 'RobustLeavePGroupsOut', + "ModelParameterGrid", + "ModelAndGridSearchCV", + "RobustLeavePGroupsOut", ] diff --git a/mriqc/classifier/sklearn/_split.py b/mriqc/classifier/sklearn/_split.py index ba393680c..f8508625e 100644 --- a/mriqc/classifier/sklearn/_split.py +++ b/mriqc/classifier/sklearn/_split.py @@ -1,10 +1,11 @@ import numpy as np from sklearn.utils import indexable -from sklearn.model_selection import (LeavePGroupsOut, StratifiedKFold) +from sklearn.model_selection import LeavePGroupsOut, StratifiedKFold from sklearn.model_selection._split import _RepeatedSplits import logging -LOG = logging.getLogger('mriqc.classifier') + +LOG = logging.getLogger("mriqc.classifier") class RobustLeavePGroupsOut(LeavePGroupsOut): @@ -28,11 +29,13 @@ def split(self, X, y=None, groups=None): if groups is None: from ..data import get_groups + groups, _ = get_groups(X) self._groups = groups - self._splits = list(super(RobustLeavePGroupsOut, self).split( - X, y=y, groups=groups)) + self._splits = list( + super(RobustLeavePGroupsOut, self).split(X, y=y, groups=groups) + ) rmfold = [] for i, (_, test_idx) in enumerate(self._splits): @@ -40,10 +43,14 @@ def split(self, X, y=None, groups=None): rmfold.append(i) if rmfold: - self._splits = [split for i, split in enumerate(self._splits) - if i not in rmfold] - LOG.warning('Some splits (%d) were dropped because one or more classes' - ' are totally missing', len(rmfold)) + self._splits = [ + split for i, split in enumerate(self._splits) if i not in rmfold + ] + LOG.warning( + "Some splits (%d) were dropped because one or more classes" + " are totally missing", + len(rmfold), + ) return self._splits @@ -73,8 +80,7 @@ def split(self, X, y, groups=None): for cls in classes_y: cls_index = test_index[split_y == cls] if len(cls_index) > min_y: - cls_index = np.random.choice( - cls_index, size=min_y, replace=False) + cls_index = np.random.choice(cls_index, size=min_y, replace=False) new_index[cls * min_y:(cls + 1) * min_y] = cls_index yield train_index, new_index @@ -87,7 +93,8 @@ class RepeatedBalancedKFold(_RepeatedSplits): def __init__(self, n_splits=5, n_repeats=10, random_state=None): super(RepeatedBalancedKFold, self).__init__( - BalancedKFold, n_repeats, random_state, n_splits=n_splits) + BalancedKFold, n_repeats, random_state, n_splits=n_splits + ) class PartiallyHeldOutKFold(StratifiedKFold): @@ -95,11 +102,13 @@ class PartiallyHeldOutKFold(StratifiedKFold): A K-Fold split on the test set where the train splits are augmented with the original train set (in whole). """ + def __init__(self, n_splits=3, shuffle=False, random_state=None, groups=None): self._splits = None self._groups = groups super(PartiallyHeldOutKFold, self).__init__( - n_splits=n_splits, shuffle=shuffle, random_state=random_state) + n_splits=n_splits, shuffle=shuffle, random_state=random_state + ) def split(self, X, y, groups=None): if groups is None: @@ -117,8 +126,7 @@ def split(self, X, y, groups=None): test_x = X[test_idx, :] test_y = np.array(y)[test_idx] - split = super(PartiallyHeldOutKFold, self).split( - test_x, test_y) + split = super(PartiallyHeldOutKFold, self).split(test_x, test_y) offset = test_idx[0] for test_train, test_test in split: @@ -133,5 +141,9 @@ class RepeatedPartiallyHeldOutKFold(_RepeatedSplits): def __init__(self, n_splits=5, n_repeats=10, random_state=None, groups=None): super(RepeatedPartiallyHeldOutKFold, self).__init__( - PartiallyHeldOutKFold, n_repeats, random_state, n_splits=n_splits, - groups=groups) + PartiallyHeldOutKFold, + n_repeats, + random_state, + n_splits=n_splits, + groups=groups, + ) diff --git a/mriqc/classifier/sklearn/_validation.py b/mriqc/classifier/sklearn/_validation.py index dce93a7ce..173ea8bdc 100644 --- a/mriqc/classifier/sklearn/_validation.py +++ b/mriqc/classifier/sklearn/_validation.py @@ -9,6 +9,7 @@ import time import numpy as np + # import scipy.sparse as sp from sklearn.base import is_classifier, clone @@ -20,15 +21,26 @@ from sklearn.exceptions import FitFailedWarning from sklearn.model_selection._split import check_cv from sklearn.model_selection._validation import _index_param_value + # from sklearn.preprocessing import LabelEncoder import logging -LOG = logging.getLogger('mriqc.classifier') - -def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, - n_jobs=1, verbose=0, fit_params=None, - pre_dispatch='2*n_jobs'): +LOG = logging.getLogger("mriqc.classifier") + + +def cross_val_score( + estimator, + X, + y=None, + groups=None, + scoring=None, + cv=None, + n_jobs=1, + verbose=0, + fit_params=None, + pre_dispatch="2*n_jobs", +): """ Evaluate a score by cross-validation """ @@ -42,38 +54,51 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, scorer = [check_scoring(estimator, scoring=s) for s in scoring] # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. - parallel = Parallel(n_jobs=n_jobs, verbose=verbose, - pre_dispatch=pre_dispatch) - scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, - train, test, verbose, None, - fit_params) - for train, test in splits) + parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) + scores = parallel( + delayed(_fit_and_score)( + clone(estimator), X, y, scorer, train, test, verbose, None, fit_params + ) + for train, test in splits + ) group_order = [] - if hasattr(cv, 'groups'): + if hasattr(cv, "groups"): group_order = [np.array(cv.groups)[test].tolist()[0] for _, test in splits] return np.squeeze(np.array(scores)), group_order -def _fit_and_score(estimator, X, y, scorer, train, test, verbose, - parameters, fit_params, return_train_score=False, - return_parameters=False, return_n_test_samples=False, - return_times=False, error_score='raise'): +def _fit_and_score( + estimator, + X, + y, + scorer, + train, + test, + verbose, + parameters, + fit_params, + return_train_score=False, + return_parameters=False, + return_n_test_samples=False, + return_times=False, + error_score="raise", +): """ Fit estimator and compute scores for a given dataset split. """ if verbose > 1: if parameters is None: - msg = '' + msg = "" else: - msg = '%s' % (', '.join('%s=%s' % (k, v) - for k, v in parameters.items())) - LOG.info("[CV] %s %s", msg, (64 - len(msg)) * '.') + msg = "%s" % (", ".join("%s=%s" % (k, v) for k, v in parameters.items())) + LOG.info("[CV] %s %s", msg, (64 - len(msg)) * ".") # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} - fit_params = dict([(k, _index_param_value(X, v, train)) - for k, v in fit_params.items()]) + fit_params = dict( + [(k, _index_param_value(X, v, train)) for k, v in fit_params.items()] + ) if parameters is not None: estimator.set_params(**parameters) @@ -93,34 +118,38 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, # Note fit time as time until error fit_time = time.time() - start_time score_time = 0.0 - if error_score == 'raise': + if error_score == "raise": raise elif isinstance(error_score, numbers.Number): test_score = error_score if return_train_score: train_score = error_score - warnings.warn("Classifier fit failed. The score on this train-test" - " partition for these parameters will be set to %f. " - "Details: \n%r" % (error_score, e), FitFailedWarning) + warnings.warn( + "Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r" % (error_score, e), + FitFailedWarning, + ) else: - raise ValueError("error_score must be the string 'raise' or a" - " numeric value. (Hint: if using 'raise', please" - " make sure that it has been spelled correctly.)") + raise ValueError( + "error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)" + ) else: fit_time = time.time() - start_time test_score = [_score(estimator, X_test, y_test, s) for s in scorer] score_time = time.time() - start_time - fit_time if return_train_score: - train_score = [_score(estimator, X_train, y_train, s) - for s in scorer] + train_score = [_score(estimator, X_train, y_train, s) for s in scorer] if verbose > 2: - msg += ", score=".join(('%f' % ts for ts in test_score)) + msg += ", score=".join(("%f" % ts for ts in test_score)) if verbose > 1: total_time = score_time + fit_time end_msg = "%s, total=%s" % (msg, logger.short_format_time(total_time)) - LOG.info("[CV] %s %s", (64 - len(end_msg)) * '.', end_msg) + LOG.info("[CV] %s %s", (64 - len(end_msg)) * ".", end_msg) ret = [train_score, test_score] if return_train_score else [test_score] @@ -139,7 +168,7 @@ def _score(estimator, X_test, y_test, scorer): score = scorer(estimator, X_test) else: score = scorer(estimator, X_test, y_test) - if hasattr(score, 'item'): + if hasattr(score, "item"): try: # e.g. unwrap memmapped scalars score = score.item() @@ -147,14 +176,25 @@ def _score(estimator, X_test, y_test, scorer): # non-scalar? pass if not isinstance(score, numbers.Number): - raise ValueError("scoring must return a number, got %s (%s) instead." - % (str(score), type(score))) + raise ValueError( + "scoring must return a number, got %s (%s) instead." + % (str(score), type(score)) + ) return score -def permutation_test_score(estimator, X, y, groups=None, cv=None, - n_permutations=100, n_jobs=1, random_state=0, - verbose=0, scoring=None): +def permutation_test_score( + estimator, + X, + y, + groups=None, + cv=None, + n_permutations=100, + n_jobs=1, + random_state=0, + verbose=0, + scoring=None, +): """ Evaluate the significance of a cross-validated score with permutations, as in test 1 of [Ojala2010]_. @@ -179,9 +219,10 @@ def permutation_test_score(estimator, X, y, groups=None, cv=None, # independent, and that it is pickle-able. permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(_permutation_test_score)( - clone(estimator), X, _shuffle(y, groups, random_state), - groups, cv, scorer) - for _ in range(n_permutations)) + clone(estimator), X, _shuffle(y, groups, random_state), groups, cv, scorer + ) + for _ in range(n_permutations) + ) permutation_scores = np.array(permutation_scores) return permutation_scores @@ -204,6 +245,6 @@ def _shuffle(y, groups, random_state): else: indices = np.arange(len(groups)) for group in np.unique(groups): - this_mask = (groups == group) + this_mask = groups == group indices[this_mask] = random_state.permutation(indices[this_mask]) return safe_indexing(y, indices) diff --git a/mriqc/classifier/sklearn/cv_nested.py b/mriqc/classifier/sklearn/cv_nested.py index 8dbcf8ad2..a0575b902 100644 --- a/mriqc/classifier/sklearn/cv_nested.py +++ b/mriqc/classifier/sklearn/cv_nested.py @@ -23,23 +23,34 @@ from sklearn.model_selection._split import check_cv from sklearn.model_selection._search import ( - BaseSearchCV, check_scoring, indexable, - Parallel, delayed, defaultdict, rankdata + BaseSearchCV, + check_scoring, + indexable, + Parallel, + delayed, + defaultdict, + rankdata, ) from sklearn.model_selection._validation import ( - _score, _num_samples, _index_param_value, _safe_split, - FitFailedWarning, logger) + _score, + _num_samples, + _index_param_value, + _safe_split, + FitFailedWarning, + logger, +) import logging from .parameters import ModelParameterGrid from builtins import object, zip + try: from sklearn.utils.fixes import MaskedArray except ImportError: from numpy.ma import MaskedArray -LOG = logging.getLogger('mriqc.classifier') +LOG = logging.getLogger("mriqc.classifier") class ModelAndGridSearchCV(BaseSearchCV): @@ -47,15 +58,33 @@ class ModelAndGridSearchCV(BaseSearchCV): Adds model selection to the GridSearchCV """ - def __init__(self, param_grid, scoring=None, fit_params=None, - n_jobs=1, iid=True, refit=True, cv=None, verbose=0, - pre_dispatch='2*n_jobs', error_score='raise', - return_train_score=True): + def __init__( + self, + param_grid, + scoring=None, + fit_params=None, + n_jobs=1, + iid=True, + refit=True, + cv=None, + verbose=0, + pre_dispatch="2*n_jobs", + error_score="raise", + return_train_score=True, + ): super(ModelAndGridSearchCV, self).__init__( - estimator=None, scoring=scoring, fit_params=fit_params, - n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, - pre_dispatch=pre_dispatch, error_score=error_score, - return_train_score=return_train_score) + estimator=None, + scoring=scoring, + fit_params=fit_params, + n_jobs=n_jobs, + iid=iid, + refit=refit, + cv=cv, + verbose=verbose, + pre_dispatch=pre_dispatch, + error_score=error_score, + return_train_score=return_train_score, + ) self.param_grid = param_grid self.best_model_ = None # _check_param_grid(param_grid) @@ -75,32 +104,53 @@ def _fit(self, X, y, groups, parameter_iterable): if self.verbose > 0 and isinstance(parameter_iterable, Sized): n_candidates = len(parameter_iterable) - LOG.info("Fitting %d folds for each of %d candidates, totalling" - " %d fits", n_splits, n_candidates, n_candidates * n_splits) + LOG.info( + "Fitting %d folds for each of %d candidates, totalling" " %d fits", + n_splits, + n_candidates, + n_candidates * n_splits, + ) pre_dispatch = self.pre_dispatch cv_iter = list(cv.split(X, y, groups)) out = Parallel( - n_jobs=self.n_jobs, verbose=self.verbose, - pre_dispatch=pre_dispatch - )(delayed(_model_fit_and_score)( - estimator, X, y, self.scoring, train, test, self.verbose, parameters, - fit_params=self.fit_params, - return_train_score=self.return_train_score, - return_n_test_samples=True, - return_times=True, return_parameters=True, - error_score=self.error_score) - for estimator, parameters in parameter_iterable - for train, test in cv_iter) + n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch + )( + delayed(_model_fit_and_score)( + estimator, + X, + y, + self.scoring, + train, + test, + self.verbose, + parameters, + fit_params=self.fit_params, + return_train_score=self.return_train_score, + return_n_test_samples=True, + return_times=True, + return_parameters=True, + error_score=self.error_score, + ) + for estimator, parameters in parameter_iterable + for train, test in cv_iter + ) # if one choose to see train score, "out" will contain train score info if self.return_train_score: - (train_scores, test_scores, test_sample_counts, - fit_time, score_time, parameters) = zip(*out) + ( + train_scores, + test_scores, + test_sample_counts, + fit_time, + score_time, + parameters, + ) = zip(*out) else: - (test_scores, test_sample_counts, - fit_time, score_time, parameters) = zip(*out) + (test_scores, test_sample_counts, fit_time, score_time, parameters) = zip( + *out + ) candidate_params = parameters[::n_splits] n_candidates = len(candidate_params) @@ -109,36 +159,41 @@ def _fit(self, X, y, groups, parameter_iterable): def _store(key_name, array, weights=None, splits=False, rank=False): """A small helper to store the scores/times to the cv_results_""" - array = np.array(array, dtype=np.float64).reshape(n_candidates, - n_splits) + array = np.array(array, dtype=np.float64).reshape(n_candidates, n_splits) if splits: for split_i in range(n_splits): - results["split%d_%s" - % (split_i, key_name)] = array[:, split_i] + results["split%d_%s" % (split_i, key_name)] = array[:, split_i] array_means = np.average(array, axis=1, weights=weights) - results['mean_%s' % key_name] = array_means + results["mean_%s" % key_name] = array_means # Weighted std is not directly available in numpy - array_stds = np.sqrt(np.average(( - array - array_means[:, np.newaxis]) ** 2, - axis=1, weights=weights)) - results['std_%s' % key_name] = array_stds + array_stds = np.sqrt( + np.average( + (array - array_means[:, np.newaxis]) ** 2, axis=1, weights=weights + ) + ) + results["std_%s" % key_name] = array_stds if rank: results["rank_%s" % key_name] = np.asarray( - rankdata(-array_means, method='min'), dtype=np.int32) + rankdata(-array_means, method="min"), dtype=np.int32 + ) # Computed the (weighted) mean and std for test scores alone # NOTE test_sample counts (weights) remain the same for all candidates - test_sample_counts = np.array(test_sample_counts[:n_splits], - dtype=np.int) - - _store('test_score', test_scores, splits=True, rank=True, - weights=test_sample_counts if self.iid else None) + test_sample_counts = np.array(test_sample_counts[:n_splits], dtype=np.int) + + _store( + "test_score", + test_scores, + splits=True, + rank=True, + weights=test_sample_counts if self.iid else None, + ) if self.return_train_score: - _store('train_score', train_scores, splits=True) - _store('fit_time', fit_time) - _store('score_time', score_time) + _store("train_score", train_scores, splits=True) + _store("fit_time", fit_time) + _store("score_time", score_time) best_index = np.flatnonzero(results["rank_test_score"] == 1)[0] best_parameters = candidate_params[best_index][1] @@ -146,10 +201,9 @@ def _store(key_name, array, weights=None, splits=False, rank=False): # Use one MaskedArray and mask all the places where the param is not # applicable for that candidate. Use defaultdict as each candidate may # not contain all the params - param_results = defaultdict(partial(MaskedArray, - np.empty(n_candidates,), - mask=True, - dtype=object)) + param_results = defaultdict( + partial(MaskedArray, np.empty(n_candidates,), mask=True, dtype=object) + ) for cand_i, params in enumerate(candidate_params): _, param_values = params for name, value in param_values.items(): @@ -161,7 +215,7 @@ def _store(key_name, array, weights=None, splits=False, rank=False): results.update(param_results) # Store a list of param dicts at the key 'params' - results['params'] = candidate_params + results["params"] = candidate_params self.cv_results_ = results self.best_index_ = best_index @@ -180,26 +234,38 @@ def _store(key_name, array, weights=None, splits=False, rank=False): return self -def _model_fit_and_score(estimator_str, X, y, scorer, train, test, verbose, - parameters, fit_params, return_train_score=False, - return_parameters=False, return_n_test_samples=False, - return_times=False, error_score='raise'): +def _model_fit_and_score( + estimator_str, + X, + y, + scorer, + train, + test, + verbose, + parameters, + fit_params, + return_train_score=False, + return_parameters=False, + return_n_test_samples=False, + return_times=False, + error_score="raise", +): """ """ if verbose > 1: - msg = '[CV model=%s]' % estimator_str.upper() + msg = "[CV model=%s]" % estimator_str.upper() if parameters is not None: - msg += ' %s' % (', '.join('%s=%s' % (k, v) - for k, v in parameters.items())) - LOG.info("%s %s", msg, (89 - len(msg)) * '.') + msg += " %s" % (", ".join("%s=%s" % (k, v) for k, v in parameters.items())) + LOG.info("%s %s", msg, (89 - len(msg)) * ".") estimator = _clf_build(estimator_str) # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} - fit_params = dict([(k, _index_param_value(X, v, train)) - for k, v in fit_params.items()]) + fit_params = dict( + [(k, _index_param_value(X, v, train)) for k, v in fit_params.items()] + ) if parameters is not None: estimator.set_params(**parameters) @@ -219,19 +285,24 @@ def _model_fit_and_score(estimator_str, X, y, scorer, train, test, verbose, # Note fit time as time until error fit_time = time.time() - start_time score_time = 0.0 - if error_score == 'raise': + if error_score == "raise": raise elif isinstance(error_score, numbers.Number): test_score = error_score if return_train_score: train_score = error_score - warnings.warn("Classifier fit failed. The score on this train-test" - " partition for these parameters will be set to %f. " - "Details: \n%r" % (error_score, e), FitFailedWarning) + warnings.warn( + "Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r" % (error_score, e), + FitFailedWarning, + ) else: - raise ValueError("error_score must be the string 'raise' or a" - " numeric value. (Hint: if using 'raise', please" - " make sure that it has been spelled correctly.)") + raise ValueError( + "error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)" + ) else: fit_time = time.time() - start_time @@ -260,9 +331,19 @@ def _model_fit_and_score(estimator_str, X, y, scorer, train, test, verbose, def nested_fit_and_score( - estimator, X, y, scorer, train, test, verbose=1, - parameters=None, fit_params=None, return_train_score=False, - return_times=False, error_score='raise'): + estimator, + X, + y, + scorer, + train, + test, + verbose=1, + parameters=None, + fit_params=None, + return_train_score=False, + return_times=False, + error_score="raise", +): """ """ @@ -270,8 +351,9 @@ def nested_fit_and_score( # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} - fit_params = dict([(k, _index_param_value(X, v, train)) - for k, v in fit_params.items()]) + fit_params = dict( + [(k, _index_param_value(X, v, train)) for k, v in fit_params.items()] + ) if parameters is not None: estimator.set_params(**parameters) @@ -282,9 +364,15 @@ def nested_fit_and_score( X_test, y_test = _safe_split(estimator, X, y, test, train) if verbose > 1: - LOG.info('CV iteration: Xtrain=%d, Ytrain=%d/%d -- Xtest=%d, Ytest=%d/%d.', - len(X_train), len(X_train) - sum(y_train), sum(y_train), - len(X_test), len(X_test) - sum(y_test), sum(y_test)) + LOG.info( + "CV iteration: Xtrain=%d, Ytrain=%d/%d -- Xtest=%d, Ytest=%d/%d.", + len(X_train), + len(X_train) - sum(y_train), + sum(y_train), + len(X_test), + len(X_test) - sum(y_test), + sum(y_test), + ) try: if y_train is None: @@ -296,19 +384,25 @@ def nested_fit_and_score( # Note fit time as time until error fit_time = time.time() - start_time score_time = 0.0 - if error_score == 'raise': + if error_score == "raise": raise elif isinstance(error_score, numbers.Number): test_score = error_score if return_train_score: train_score = error_score - LOG.warning("Classifier fit failed. The score on this train-test" - " partition for these parameters will be set to %f. " - "Details: \n%r", error_score, e) + LOG.warning( + "Classifier fit failed. The score on this train-test" + " partition for these parameters will be set to %f. " + "Details: \n%r", + error_score, + e, + ) else: - raise ValueError("error_score must be the string 'raise' or a" - " numeric value. (Hint: if using 'raise', please" - " make sure that it has been spelled correctly.)") + raise ValueError( + "error_score must be the string 'raise' or a" + " numeric value. (Hint: if using 'raise', please" + " make sure that it has been spelled correctly.)" + ) else: fit_time = time.time() - start_time @@ -319,33 +413,41 @@ def nested_fit_and_score( test_score = _score(estimator, X_test, y_test, scorer) score_time = time.time() - start_time - fit_time else: - LOG.warning('Test set has no positive labels, scoring has been skipped ' - 'in this loop.') + LOG.warning( + "Test set has no positive labels, scoring has been skipped " + "in this loop." + ) if return_train_score: train_score = _score(estimator, X_train, y_train, scorer) - acc_score = _score(estimator, X_test, y_test, - check_scoring(estimator, scoring='accuracy')) + acc_score = _score( + estimator, X_test, y_test, check_scoring(estimator, scoring="accuracy") + ) if verbose > 0: total_time = score_time + fit_time if test_score is not None: - LOG.info('Iteration took %s, score=%f, accuracy=%f.', - short_format_time(total_time), test_score, acc_score) + LOG.info( + "Iteration took %s, score=%f, accuracy=%f.", + short_format_time(total_time), + test_score, + acc_score, + ) else: - LOG.info('Iteration took %s, score=None, accuracy=%f.', - short_format_time(total_time), acc_score) + LOG.info( + "Iteration took %s, score=None, accuracy=%f.", + short_format_time(total_time), + acc_score, + ) - ret = { - 'test': {'score': test_score, 'accuracy': acc_score} - } + ret = {"test": {"score": test_score, "accuracy": acc_score}} if return_train_score: - ret['train'] = {'score': train_score} + ret["train"] = {"score": train_score} if return_times: - ret['times'] = [fit_time, score_time] + ret["times"] = [fit_time, score_time] return ret, estimator @@ -353,9 +455,10 @@ def nested_fit_and_score( def _clf_build(clf_type): from sklearn import svm from sklearn.ensemble import RandomForestClassifier as RFC - if clf_type == 'svc_linear': + + if clf_type == "svc_linear": return svm.LinearSVC(C=1) - elif clf_type == 'svc_rbf': + elif clf_type == "svc_rbf": return svm.SVC(C=1) - elif clf_type == 'rfc': + elif clf_type == "rfc": return RFC() diff --git a/mriqc/classifier/sklearn/parameters.py b/mriqc/classifier/sklearn/parameters.py index 0b94db698..fa9e8b64b 100644 --- a/mriqc/classifier/sklearn/parameters.py +++ b/mriqc/classifier/sklearn/parameters.py @@ -24,8 +24,7 @@ def _len(indict): product = partial(reduce, operator.mul) - return sum(product(len(v) for v in p.values()) if p else 1 - for p in indict) + return sum(product(len(v) for v in p.values()) if p else 1 for p in indict) class ModelParameterGrid(object): @@ -105,7 +104,9 @@ def __iter__(self): def __len__(self): """Number of points on the grid.""" # Product function that can handle iterables (np.product can't). - return sum(_len(points) for p in self.param_grid for estim, points in list(p.items())) + return sum( + _len(points) for p in self.param_grid for estim, points in list(p.items()) + ) def __getitem__(self, ind): """Get the parameters that would be ``ind``th in iteration @@ -145,4 +146,4 @@ def __getitem__(self, ind): out[key] = v_list[offset] return (estimator, out) - raise IndexError('ModelParameterGrid index out of range') + raise IndexError("ModelParameterGrid index out of range") diff --git a/mriqc/classifier/sklearn/preprocessing.py b/mriqc/classifier/sklearn/preprocessing.py index 982d5e70a..cdd6b0838 100644 --- a/mriqc/classifier/sklearn/preprocessing.py +++ b/mriqc/classifier/sklearn/preprocessing.py @@ -16,7 +16,8 @@ from sklearn.preprocessing import LabelBinarizer import logging -LOG = logging.getLogger('mriqc.classifier') + +LOG = logging.getLogger("mriqc.classifier") class PandasAdaptor(BaseEstimator, TransformerMixin): @@ -87,8 +88,7 @@ def transform(self, X, y=None): columns = self._numeric_cols(X) col_order = X.columns - scaled_x = pd.DataFrame(self._scaler.transform( - X[columns]), columns=columns) + scaled_x = pd.DataFrame(self._scaler.transform(X[columns]), columns=columns) unscaled_x = X.ix[:, ~X.columns.isin(columns)] return pd.concat([unscaled_x, scaled_x], axis=1)[col_order] @@ -107,7 +107,7 @@ class GroupsScaler(BaseEstimator, TransformerMixin): """ - def __init__(self, scaler, by='site'): + def __init__(self, scaler, by="site"): self.by = by self._base_scaler = scaler self._scalers = {} @@ -125,16 +125,14 @@ def fit(self, X, y=None): # Convert groups to IDs glist = list(set(groups)) - self._groups = np.array([glist.index(group) - for group in groups]) + self._groups = np.array([glist.index(group) for group in groups]) for gid, batch in enumerate(list(set(groups))): scaler = clone(self._base_scaler) mask = self._groups == gid if not np.any(mask): continue - self._scalers[batch] = scaler.fit( - X.ix[mask, self._colmask], y) + self._scalers[batch] = scaler.fit(X.ix[mask, self._colmask], y) return self @@ -142,7 +140,7 @@ def transform(self, X, y=None): if self.by in X.columns.ravel().tolist(): groups = X[[self.by]].values.ravel().tolist() else: - groups = ['Unknown'] * X.shape[0] + groups = ["Unknown"] * X.shape[0] glist = list(set(groups)) groups = np.array([glist.index(group) for group in groups]) @@ -154,15 +152,15 @@ def transform(self, X, y=None): continue scaler = self._scalers[batch] new_x.ix[mask, self._colmask] = scaler.transform( - X.ix[mask, self._colmask]) + X.ix[mask, self._colmask] + ) else: colmask = self._colmask if self.by in self._colnames and len(colmask) == len(self._colnames): del colmask[self._colnames.index(self.by)] scaler = clone(self._base_scaler) - new_x.ix[:, colmask] = scaler.fit_transform( - X.ix[:, colmask]) + new_x.ix[:, colmask] = scaler.fit_transform(X.ix[:, colmask]) return new_x @@ -180,7 +178,7 @@ class BatchScaler(GroupsScaler, TransformerMixin): """ - def __init__(self, scaler, by='site', columns=None): + def __init__(self, scaler, by="site", columns=None): super(BatchScaler, self).__init__(scaler, by=by) self.columns = columns self.ftmask_ = None @@ -206,24 +204,37 @@ def transform(self, X, y=None): columns = self.columns if self.by not in columns: - new_x[self.by] = ['Unknown'] * new_x.shape[0] + new_x[self.by] = ["Unknown"] * new_x.shape[0] new_x.ix[:, self.ftmask_] = super(BatchScaler, self).transform( - new_x[new_x.columns[self.ftmask_]], y) + new_x[new_x.columns[self.ftmask_]], y + ) return new_x class BatchRobustScaler(BatchScaler, TransformerMixin): - def __init__(self, by='site', columns=None, with_centering=True, with_scaling=True, - quantile_range=(25.0, 75.0), copy=True): + def __init__( + self, + by="site", + columns=None, + with_centering=True, + with_scaling=True, + quantile_range=(25.0, 75.0), + copy=True, + ): self.with_centering = with_centering self.with_scaling = with_scaling self.quantile_range = quantile_range self.copy = True super(BatchRobustScaler, self).__init__( - RobustScaler(with_centering=with_centering, with_scaling=with_scaling, - quantile_range=quantile_range), - by=by, columns=columns) + RobustScaler( + with_centering=with_centering, + with_scaling=with_scaling, + quantile_range=quantile_range, + ), + by=by, + columns=columns, + ) class CustFsNoiseWinnow(BaseEstimator, TransformerMixin): @@ -290,12 +301,12 @@ def fit(self, X, y, n_jobs=1): if clf_flag: clf = ExtraTreesClassifier( n_estimators=n_estimators, - criterion='gini', + criterion="gini", max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, - max_features='sqrt', + max_features="sqrt", max_leaf_nodes=None, min_impurity_decrease=1e-07, bootstrap=True, @@ -304,28 +315,41 @@ def fit(self, X, y, n_jobs=1): random_state=None, verbose=0, warm_start=False, - class_weight='balanced' + class_weight="balanced", ) else: clf = ExtraTreesRegressor( - n_estimators=n_estimators, criterion='mse', max_depth=None, - min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, - max_features='auto', max_leaf_nodes=None, min_impurity_decrease=1e-07, - bootstrap=False, oob_score=False, n_jobs=1, random_state=None, verbose=0, - warm_start=False) + n_estimators=n_estimators, + criterion="mse", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="auto", + max_leaf_nodes=None, + min_impurity_decrease=1e-07, + bootstrap=False, + oob_score=False, + n_jobs=1, + random_state=None, + verbose=0, + warm_start=False, + ) clf.fit(X[:, idx_keep], y) - LOG.debug('done fitting once') + LOG.debug("done fitting once") importances = clf.feature_importances_ k = 1 if np.all(importances[0:-1] > k * importances[-1]): - LOG.log(19, 'All features (%d) are better than noise', len(idx_keep) - 1) + LOG.log( + 19, "All features (%d) are better than noise", len(idx_keep) - 1 + ) # all features better than noise # comment out to force counter renditions of winnowing # noise_flag = False elif np.all(k * importances[-1] > importances[0:-1]): - LOG.warning('No features are better than noise') + LOG.warning("No features are better than noise") # noise better than all features aka no feature better than noise # Leave as separate if clause in case want to do something different than # when all feat > noise. Comment out to force counter renditions of winnowing @@ -348,8 +372,11 @@ def fit(self, X, y, n_jobs=1): self.importances_snr_ = importances[:-1] / importances[-1] self.idx_keep_ = idx_keep[:-1] self.mask_[self.idx_keep_] = True - LOG.info('Feature selection: %d of %d features better than noise feature', - self.mask_.astype(int).sum(), len(self.mask_)) + LOG.info( + "Feature selection: %d of %d features better than noise feature", + self.mask_.astype(int).sum(), + len(self.mask_), + ) return self def fit_transform(self, X, y=None): @@ -380,7 +407,8 @@ def transform(self, X, y=None): """ from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted - check_is_fitted(self, ['mask_'], all_or_any=all) + + check_is_fitted(self, ["mask_"], all_or_any=all) X = check_array(X) return X[:, self.mask_] @@ -392,8 +420,9 @@ class SiteCorrelationSelector(BaseEstimator, TransformerMixin): """ - def __init__(self, target_auc=0.6, disable=False, - max_iter=None, max_remove=0.7, site_col=-1): + def __init__( + self, target_auc=0.6, disable=False, max_iter=None, max_remove=0.7, site_col=-1 + ): self.disable = disable self.target_auc = target_auc self.mask_ = None @@ -435,7 +464,8 @@ def fit(self, X, y, n_jobs=1): y_input = LabelBinarizer().fit_transform(sites) X_train, X_test, y_train, y_test = train_test_split( - X_input, y_input, test_size=0.33, random_state=42) + X_input, y_input, test_size=0.33, random_state=42 + ) max_remove = n_feature - 5 if self.max_remove < 1.0: @@ -448,12 +478,12 @@ def fit(self, X, y, n_jobs=1): while True: clf = ExtraTreesClassifier( n_estimators=1000, - criterion='gini', + criterion="gini", max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, - max_features='sqrt', + max_features="sqrt", max_leaf_nodes=None, min_impurity_decrease=1e-07, bootstrap=True, @@ -462,13 +492,15 @@ def fit(self, X, y, n_jobs=1): random_state=None, verbose=0, warm_start=False, - class_weight='balanced' + class_weight="balanced", ).fit(X_train[:, self.mask_], y_train) score = roc_auc_score( - y_test, clf.predict(X_test[:, self.mask_]), - average='macro', - sample_weight=None) + y_test, + clf.predict(X_test[:, self.mask_]), + average="macro", + sample_weight=None, + ) if score < self.target_auc: break @@ -488,8 +520,9 @@ def fit(self, X, y, n_jobs=1): i += 1 - LOG.info('Feature selection: kept %d of %d features', - np.sum(self.mask_), n_feature) + LOG.info( + "Feature selection: kept %d of %d features", np.sum(self.mask_), n_feature + ) return self def fit_transform(self, X, y=None, n_jobs=1): @@ -520,8 +553,9 @@ def transform(self, X, y=None): """ from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted - check_is_fitted(self, ['mask_'], all_or_any=all) - if hasattr(X, 'columns'): + + check_is_fitted(self, ["mask_"], all_or_any=all) + if hasattr(X, "columns"): X = X.values X = check_array(X[:, self.mask_]) return X @@ -534,24 +568,25 @@ def _generate_noise(n_sample, y, clf_flag=True): if classification """ if clf_flag: - noise_feature = np.random.normal( - loc=0, scale=10.0, size=(n_sample, 1)) + noise_feature = np.random.normal(loc=0, scale=10.0, size=(n_sample, 1)) noise_score = roc_auc_score( - y, noise_feature, average='macro', sample_weight=None) + y, noise_feature, average="macro", sample_weight=None + ) while (noise_score > 0.6) or (noise_score < 0.4): - noise_feature = np.random.normal( - loc=0, scale=10.0, size=(n_sample, 1)) + noise_feature = np.random.normal(loc=0, scale=10.0, size=(n_sample, 1)) noise_score = roc_auc_score( - y, noise_feature, average='macro', sample_weight=None) + y, noise_feature, average="macro", sample_weight=None + ) else: - noise_feature = np.random.normal( - loc=0, scale=10.0, size=(n_sample, 1)) - while np.abs(np.corrcoef(noise_feature, y[:, np.newaxis], rowvar=0)[0][1]) > 0.05: - noise_feature = np.random.normal( - loc=0, scale=10.0, size=(n_sample, 1)) + noise_feature = np.random.normal(loc=0, scale=10.0, size=(n_sample, 1)) + while ( + np.abs(np.corrcoef(noise_feature, y[:, np.newaxis], rowvar=0)[0][1]) > 0.05 + ): + noise_feature = np.random.normal(loc=0, scale=10.0, size=(n_sample, 1)) return noise_feature + # DEPRECATED CODE # def find_gmed(dataframe, by='site', excl_columns=None): # sites = list(set(dataframe[[by]].values.ravel().tolist())) diff --git a/mriqc/cli/parser.py b/mriqc/cli/parser.py index 9fa714576..335cb5a5d 100644 --- a/mriqc/cli/parser.py +++ b/mriqc/cli/parser.py @@ -122,7 +122,7 @@ def _bids_filter(value): action="store", type=int, nargs="*", - help="filter input dataset by run id " "(only integer run ids are valid)", + help="filter input dataset by run id (only integer run ids are valid)", ) g_bids.add_argument( "--task-id", @@ -468,23 +468,27 @@ def parse_args(args=None, namespace=None): } config.workflow.inputs = { mod: files - for mod, files in collect_bids_data(config.execution.layout, **bids_filters).items() + for mod, files in collect_bids_data( + config.execution.layout, **bids_filters + ).items() if files } # Check the query is not empty if not list(config.workflow.inputs.values()): - _j = '\n *' - parser.error(f"""\ + _j = "\n *" + parser.error( + f"""\ Querying BIDS dataset at <{config.execution.bids_dir}> got an empty result. Please, check out your currently set filters: -{_j.join([''] + [': '.join((k, str(v))) for k, v in bids_filters.items()])}""") +{_j.join([''] + [': '.join((k, str(v))) for k, v in bids_filters.items()])}""" + ) # Check no DWI or others are sneaked into MRIQC - unknown_mods = set(config.workflow.inputs.keys()) - set(('T1w', 'T2w', 'bold')) + unknown_mods = set(config.workflow.inputs.keys()) - set(("T1w", "T2w", "bold")) if unknown_mods: parser.error( - 'MRIQC is unable to process the following modalities: ' + "MRIQC is unable to process the following modalities: " f'{", ".join(unknown_mods)}.' ) @@ -496,6 +500,7 @@ def parse_args(args=None, namespace=None): def _get_biggest_file_size_gb(files): import os + max_size = 0 for file in files: size = os.path.getsize(file) / (1024 ** 3) diff --git a/mriqc/cli/run.py b/mriqc/cli/run.py index ecc343a8d..489f0c5a2 100644 --- a/mriqc/cli/run.py +++ b/mriqc/cli/run.py @@ -22,12 +22,15 @@ def main(): # Set up participant level if "participant" in config.workflow.analysis_level: - config.loggers.cli.log(25, f""" + config.loggers.cli.log( + 25, + f""" Running MRIQC version {config.environment.version}: * BIDS dataset path: {config.execution.bids_dir}. * Output folder: {config.execution.output_dir}. * Analysis levels: {config.workflow.analysis_level}. -""") +""", + ) # CRITICAL Call build_workflow(config_file, retval) in a subprocess. # Because Python on Linux does not ever free virtual memory (VM), running the # workflow construction jailed within a process preempts excessive VM buildup. diff --git a/mriqc/cli/version.py b/mriqc/cli/version.py index 6b92193ea..ced4fd681 100644 --- a/mriqc/cli/version.py +++ b/mriqc/cli/version.py @@ -8,7 +8,7 @@ from .. import __version__ RELEASE_EXPIRY_DAYS = 14 -DATE_FMT = '%Y%m%d' +DATE_FMT = "%Y%m%d" def check_latest(): @@ -18,11 +18,11 @@ def check_latest(): latest = None date = None outdated = None - cachefile = Path.home() / '.cache' / 'mriqc' / 'latest' + cachefile = Path.home() / ".cache" / "mriqc" / "latest" cachefile.parent.mkdir(parents=True, exist_ok=True) try: - latest, date = cachefile.read_text().split('|') + latest, date = cachefile.read_text().split("|") except Exception: pass else: @@ -37,12 +37,12 @@ def check_latest(): if latest is None or outdated is True: try: - response = requests.get(url='https://pypi.org/pypi/mriqc/json', timeout=1.0) + response = requests.get(url="https://pypi.org/pypi/mriqc/json", timeout=1.0) except Exception: response = None if response and response.status_code == 200: - versions = [Version(rel) for rel in response.json()['releases'].keys()] + versions = [Version(rel) for rel in response.json()["releases"].keys()] versions = [rel for rel in versions if not rel.is_prerelease] if versions: latest = sorted(versions)[-1] @@ -51,7 +51,9 @@ def check_latest(): if latest is not None: try: - cachefile.write_text('|'.join(('%s' % latest, datetime.now().strftime(DATE_FMT)))) + cachefile.write_text( + "|".join(("%s" % latest, datetime.now().strftime(DATE_FMT))) + ) except Exception: pass @@ -63,13 +65,16 @@ def is_flagged(): # https://raw.githubusercontent.com/poldracklab/mriqc/master/.versions.json flagged = tuple() try: - response = requests.get(url="""\ -https://raw.githubusercontent.com/poldracklab/mriqc/master/.versions.json""", timeout=1.0) + response = requests.get( + url="""\ +https://raw.githubusercontent.com/poldracklab/mriqc/master/.versions.json""", + timeout=1.0, + ) except Exception: response = None if response and response.status_code == 200: - flagged = response.json().get('flagged', {}) or {} + flagged = response.json().get("flagged", {}) or {} if __version__ in flagged: return True, flagged[__version__] diff --git a/mriqc/cli/workflow.py b/mriqc/cli/workflow.py index f01bcdc53..9dc6a406e 100644 --- a/mriqc/cli/workflow.py +++ b/mriqc/cli/workflow.py @@ -15,8 +15,8 @@ def build_workflow(config_file, retval): from ..workflows.core import init_mriqc_wf config.load(config_file) - retval['return_code'] = 1 - retval['workflow'] = None + retval["return_code"] = 1 + retval["workflow"] = None retval["workflow"] = init_mriqc_wf() retval["return_code"] = int(retval["workflow"] is None) diff --git a/mriqc/config.py b/mriqc/config.py index bd310ce88..912bf7bae 100644 --- a/mriqc/config.py +++ b/mriqc/config.py @@ -71,6 +71,7 @@ try: from multiprocessing import set_start_method + set_start_method("forkserver") except RuntimeError: pass # context has been already set @@ -90,11 +91,13 @@ if not hasattr(sys, "_is_pytest_session"): sys._is_pytest_session = False # Trick to avoid sklearn's FutureWarnings # Disable all warnings in main and children processes only on production versions -if not any(( - "+" in __version__, - __version__.endswith(".dirty"), - os.getenv("MRIQC_DEV", "0").lower() in ("1", "on", "true", "y", "yes") -)): +if not any( + ( + "+" in __version__, + __version__.endswith(".dirty"), + os.getenv("MRIQC_DEV", "0").lower() in ("1", "on", "true", "y", "yes"), + ) +): os.environ["PYTHONWARNINGS"] = "ignore" logging.addLevelName(25, "IMPORTANT") # Add a new level between INFO and WARNING @@ -474,15 +477,19 @@ def init(cls): """ from nipype import config as ncfg + if not cls._init: from nipype import logging as nlogging + cls.workflow = nlogging.getLogger("nipype.workflow") cls.interface = nlogging.getLogger("nipype.interface") cls.utils = nlogging.getLogger("nipype.utils") if not len(cls.cli.handlers): _handler = logging.StreamHandler(stream=sys.stdout) - _handler.setFormatter(logging.Formatter(fmt=cls._fmt, datefmt=cls._datefmt)) + _handler.setFormatter( + logging.Formatter(fmt=cls._fmt, datefmt=cls._datefmt) + ) cls.cli.addHandler(_handler) cls._init = True diff --git a/mriqc/data/__init__.py b/mriqc/data/__init__.py index 359eb3a79..ab080f73a 100644 --- a/mriqc/data/__init__.py +++ b/mriqc/data/__init__.py @@ -3,8 +3,4 @@ """MRIQC resources.""" from .config import Template, IndividualTemplate, GroupTemplate -__all__ = [ - 'Template', - 'IndividualTemplate', - 'GroupTemplate' -] +__all__ = ["Template", "IndividualTemplate", "GroupTemplate"] diff --git a/mriqc/data/config.py b/mriqc/data/config.py index 185002a4d..2504edef9 100644 --- a/mriqc/data/config.py +++ b/mriqc/data/config.py @@ -1,10 +1,6 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -""" -Utilities: Jinja2 templates -""" +"""Utilities: Jinja2 templates.""" from io import open # pylint: disable=W0622 import jinja2 @@ -16,11 +12,14 @@ class Template(object): Utility class for generating a config file from a jinja template. https://github.com/oesteban/endofday/blob/f2e79c625d648ef45b08cc1f11fd0bd84342d604/endofday/core/template.py """ + def __init__(self, template_str): self.template_str = template_str self.env = jinja2.Environment( - loader=jinja2.FileSystemLoader(searchpath='/'), - trim_blocks=True, lstrip_blocks=True) + loader=jinja2.FileSystemLoader(searchpath="/"), + trim_blocks=True, + lstrip_blocks=True, + ) def compile(self, configs): """Generates a string with the replacements""" @@ -30,7 +29,7 @@ def compile(self, configs): def generate_conf(self, configs, path): """Saves the oucome after replacement on the template to file""" output = self.compile(configs) - with open(path, 'w+') as output_file: + with open(path, "w+") as output_file: output_file.write(output) @@ -38,11 +37,13 @@ class IndividualTemplate(Template): """Specific template for the individual report""" def __init__(self): - super(IndividualTemplate, self).__init__(pkgrf('mriqc', 'data/reports/individual.html')) + super(IndividualTemplate, self).__init__( + pkgrf("mriqc", "data/reports/individual.html") + ) class GroupTemplate(Template): """Specific template for the individual report""" def __init__(self): - super(GroupTemplate, self).__init__(pkgrf('mriqc', 'data/reports/group.html')) + super(GroupTemplate, self).__init__(pkgrf("mriqc", "data/reports/group.html")) diff --git a/mriqc/data/csv/raters_merge.py b/mriqc/data/csv/raters_merge.py index 979a1a891..a7be864bc 100644 --- a/mriqc/data/csv/raters_merge.py +++ b/mriqc/data/csv/raters_merge.py @@ -11,39 +11,47 @@ def get_parser(): """Entry point""" from argparse import ArgumentParser from argparse import RawTextHelpFormatter - parser = ArgumentParser(description='Merge ratings from two raters', - formatter_class=RawTextHelpFormatter) - g_input = parser.add_argument_group('Inputs') - g_input.add_argument('rater_1', action='store') - g_input.add_argument('rater_2', action='store') - g_input.add_argument('--mapping-file', action='store') - - g_outputs = parser.add_argument_group('Outputs') - g_outputs.add_argument('-o', '--output', action='store', default='merged.csv') + + parser = ArgumentParser( + description="Merge ratings from two raters", + formatter_class=RawTextHelpFormatter, + ) + g_input = parser.add_argument_group("Inputs") + g_input.add_argument("rater_1", action="store") + g_input.add_argument("rater_2", action="store") + g_input.add_argument("--mapping-file", action="store") + + g_outputs = parser.add_argument_group("Outputs") + g_outputs.add_argument("-o", "--output", action="store", default="merged.csv") return parser def main(): opts = get_parser().parse_args() - rater_1 = pd.read_csv(opts.rater_1)[['participant_id', 'check-1']] - rater_2 = pd.read_csv(opts.rater_2)[['participant_id', 'check-1']] + rater_1 = pd.read_csv(opts.rater_1)[["participant_id", "check-1"]] + rater_2 = pd.read_csv(opts.rater_2)[["participant_id", "check-1"]] - rater_1.columns = ['participant_id', 'rater_1'] - rater_2.columns = ['participant_id', 'rater_2'] - merged = pd.merge(rater_1, rater_2, on='participant_id', how='outer') + rater_1.columns = ["participant_id", "rater_1"] + rater_2.columns = ["participant_id", "rater_2"] + merged = pd.merge(rater_1, rater_2, on="participant_id", how="outer") - idcol = 'participant_id' + idcol = "participant_id" if opts.mapping_file: - idcol = 'subject_id' + idcol = "subject_id" name_mapping = pd.read_csv( - opts.mapping_file, sep=' ', header=None, usecols=[0, 1]) - name_mapping.columns = ['subject_id', 'participant_id'] - name_mapping['participant_id'] = name_mapping.participant_id.astype(str) + '.gif' - merged = pd.merge(name_mapping, merged, on='participant_id', how='outer') + opts.mapping_file, sep=" ", header=None, usecols=[0, 1] + ) + name_mapping.columns = ["subject_id", "participant_id"] + name_mapping["participant_id"] = ( + name_mapping.participant_id.astype(str) + ".gif" + ) + merged = pd.merge(name_mapping, merged, on="participant_id", how="outer") - merged[[idcol, 'rater_1', 'rater_2']].sort_values(by=idcol).to_csv(opts.output, index=False) + merged[[idcol, "rater_1", "rater_2"]].sort_values(by=idcol).to_csv( + opts.output, index=False + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mriqc/interfaces/__init__.py b/mriqc/interfaces/__init__.py index 309a30999..dcf735235 100644 --- a/mriqc/interfaces/__init__.py +++ b/mriqc/interfaces/__init__.py @@ -1,11 +1,8 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -""" mriqc nipype interfaces """ +"""mriqc nipype interfaces """ -from .anatomical import \ - StructuralQC, ArtifactMask, ComputeQI2, Harmonize, RotationMask +from .anatomical import StructuralQC, ArtifactMask, ComputeQI2, Harmonize, RotationMask from .functional import FunctionalQC, Spikes from .bids import IQMFileSink from .viz import PlotMosaic, PlotContours, PlotSpikes @@ -14,18 +11,18 @@ __all__ = [ - 'ArtifactMask', - 'ComputeQI2', - 'ConformImage', - 'EnsureSize', - 'FunctionalQC', - 'Harmonize', - 'IQMFileSink', - 'PlotContours', - 'PlotMosaic', - 'PlotSpikes', - 'RotationMask', - 'Spikes', - 'StructuralQC', - 'UploadIQMs', + "ArtifactMask", + "ComputeQI2", + "ConformImage", + "EnsureSize", + "FunctionalQC", + "Harmonize", + "IQMFileSink", + "PlotContours", + "PlotMosaic", + "PlotSpikes", + "RotationMask", + "Spikes", + "StructuralQC", + "UploadIQMs", ] diff --git a/mriqc/interfaces/anatomical.py b/mriqc/interfaces/anatomical.py index 9d38576db..c74fda00f 100644 --- a/mriqc/interfaces/anatomical.py +++ b/mriqc/interfaces/anatomical.py @@ -10,41 +10,59 @@ from nipype.utils.filemanip import fname_presuffix from nipype.interfaces.base import ( - traits, TraitedSpec, File, isdefined, InputMultiPath, BaseInterfaceInputSpec, - SimpleInterface + traits, + TraitedSpec, + File, + isdefined, + InputMultiPath, + BaseInterfaceInputSpec, + SimpleInterface, ) from ..utils.misc import _flatten_dict -from ..qc.anatomical import (snr, snr_dietrich, cnr, fber, efc, art_qi1, - art_qi2, volume_fraction, rpve, summary_stats, - cjv, wm2max) +from ..qc.anatomical import ( + snr, + snr_dietrich, + cnr, + fber, + efc, + art_qi1, + art_qi2, + volume_fraction, + rpve, + summary_stats, + cjv, + wm2max, +) class StructuralQCInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='file to be plotted') - in_noinu = File(exists=True, mandatory=True, desc='image after INU correction') - in_segm = File(exists=True, mandatory=True, desc='segmentation file from FSL FAST') - in_bias = File(exists=True, mandatory=True, desc='bias file') - head_msk = File(exists=True, mandatory=True, desc='head mask') - air_msk = File(exists=True, mandatory=True, desc='air mask') - rot_msk = File(exists=True, mandatory=True, desc='rotation mask') - artifact_msk = File(exists=True, mandatory=True, desc='air mask') - in_pvms = InputMultiPath(File(exists=True), mandatory=True, - desc='partial volume maps from FSL FAST') - in_tpms = InputMultiPath(File(), desc='tissue probability maps from FSL FAST') - mni_tpms = InputMultiPath(File(), desc='tissue probability maps from FSL FAST') - in_fwhm = traits.List(traits.Float, mandatory=True, - desc='smoothness estimated with AFNI') + in_file = File(exists=True, mandatory=True, desc="file to be plotted") + in_noinu = File(exists=True, mandatory=True, desc="image after INU correction") + in_segm = File(exists=True, mandatory=True, desc="segmentation file from FSL FAST") + in_bias = File(exists=True, mandatory=True, desc="bias file") + head_msk = File(exists=True, mandatory=True, desc="head mask") + air_msk = File(exists=True, mandatory=True, desc="air mask") + rot_msk = File(exists=True, mandatory=True, desc="rotation mask") + artifact_msk = File(exists=True, mandatory=True, desc="air mask") + in_pvms = InputMultiPath( + File(exists=True), mandatory=True, desc="partial volume maps from FSL FAST" + ) + in_tpms = InputMultiPath(File(), desc="tissue probability maps from FSL FAST") + mni_tpms = InputMultiPath(File(), desc="tissue probability maps from FSL FAST") + in_fwhm = traits.List( + traits.Float, mandatory=True, desc="smoothness estimated with AFNI" + ) class StructuralQCOutputSpec(TraitedSpec): - summary = traits.Dict(desc='summary statistics per tissue') - icvs = traits.Dict(desc='intracranial volume (ICV) fractions') - rpve = traits.Dict(desc='partial volume fractions') - size = traits.Dict(desc='image sizes') - spacing = traits.Dict(desc='image sizes') - fwhm = traits.Dict(desc='full width half-maximum measure') - inu = traits.Dict(desc='summary statistics of the bias field') + summary = traits.Dict(desc="summary statistics per tissue") + icvs = traits.Dict(desc="intracranial volume (ICV) fractions") + rpve = traits.Dict(desc="partial volume fractions") + size = traits.Dict(desc="image sizes") + spacing = traits.Dict(desc="image sizes") + fwhm = traits.Dict(desc="full width half-maximum measure") + inu = traits.Dict(desc="summary statistics of the bias field") snr = traits.Dict snrd = traits.Dict cnr = traits.Float @@ -53,8 +71,8 @@ class StructuralQCOutputSpec(TraitedSpec): qi_1 = traits.Float wm2max = traits.Float cjv = traits.Float - out_qc = traits.Dict(desc='output flattened dictionary with all measures') - out_noisefit = File(exists=True, desc='plot of background noise and chi fitting') + out_qc = traits.Dict(desc="output flattened dictionary with all measures") + out_noisefit = File(exists=True, desc="plot of background noise and chi fitting") tpm_overlap = traits.Dict @@ -64,13 +82,13 @@ class StructuralQC(SimpleInterface): structural image given as input """ + input_spec = StructuralQCInputSpec output_spec = StructuralQCOutputSpec def _run_interface(self, runtime): # pylint: disable=R0914,E1101 imnii = nb.load(self.inputs.in_noinu) - erode = np.all(np.array(imnii.header.get_zooms()[:3], - dtype=np.float32) < 1.9) + erode = np.all(np.array(imnii.header.get_zooms()[:3], dtype=np.float32) < 1.9) # Load image corrected for INU inudata = np.nan_to_num(imnii.get_data()) @@ -93,111 +111,127 @@ def _run_interface(self, runtime): # pylint: disable=R0914,E1101 # Summary stats stats = summary_stats(inudata, pvmdata, airdata, erode=erode) - self._results['summary'] = stats + self._results["summary"] = stats # SNR snrvals = [] - self._results['snr'] = {} - for tlabel in ['csf', 'wm', 'gm']: - snrvals.append(snr(stats[tlabel]['median'], stats[tlabel]['stdv'], stats[tlabel]['n'])) - self._results['snr'][tlabel] = snrvals[-1] - self._results['snr']['total'] = float(np.mean(snrvals)) + self._results["snr"] = {} + for tlabel in ["csf", "wm", "gm"]: + snrvals.append( + snr(stats[tlabel]["median"], stats[tlabel]["stdv"], stats[tlabel]["n"]) + ) + self._results["snr"][tlabel] = snrvals[-1] + self._results["snr"]["total"] = float(np.mean(snrvals)) snrvals = [] - self._results['snrd'] = { - tlabel: snr_dietrich(stats[tlabel]['median'], stats['bg']['mad']) - for tlabel in ['csf', 'wm', 'gm']} - self._results['snrd']['total'] = float( - np.mean([val for _, val in list(self._results['snrd'].items())])) + self._results["snrd"] = { + tlabel: snr_dietrich(stats[tlabel]["median"], stats["bg"]["mad"]) + for tlabel in ["csf", "wm", "gm"] + } + self._results["snrd"]["total"] = float( + np.mean([val for _, val in list(self._results["snrd"].items())]) + ) # CNR - self._results['cnr'] = cnr( - stats['wm']['median'], stats['gm']['median'], - sqrt(sum(stats[k]['stdv'] ** 2 for k in ['bg', 'gm', 'wm'])) + self._results["cnr"] = cnr( + stats["wm"]["median"], + stats["gm"]["median"], + sqrt(sum(stats[k]["stdv"] ** 2 for k in ["bg", "gm", "wm"])), ) # FBER - self._results['fber'] = fber(inudata, headdata, rotdata) + self._results["fber"] = fber(inudata, headdata, rotdata) # EFC - self._results['efc'] = efc(inudata, rotdata) + self._results["efc"] = efc(inudata, rotdata) # M2WM - self._results['wm2max'] = wm2max(inudata, stats['wm']['median']) + self._results["wm2max"] = wm2max(inudata, stats["wm"]["median"]) # Artifacts - self._results['qi_1'] = art_qi1(airdata, artdata) + self._results["qi_1"] = art_qi1(airdata, artdata) # CJV - self._results['cjv'] = cjv( + self._results["cjv"] = cjv( # mu_wm, mu_gm, sigma_wm, sigma_gm - stats['wm']['median'], - stats['gm']['median'], - stats['wm']['mad'], - stats['gm']['mad'] + stats["wm"]["median"], + stats["gm"]["median"], + stats["wm"]["mad"], + stats["gm"]["mad"], ) # FWHM - fwhm = np.array(self.inputs.in_fwhm[:3]) / np.array(imnii.header.get_zooms()[:3]) - self._results['fwhm'] = { - 'x': float(fwhm[0]), 'y': float(fwhm[1]), 'z': float(fwhm[2]), - 'avg': float(np.average(fwhm))} + fwhm = np.array(self.inputs.in_fwhm[:3]) / np.array( + imnii.header.get_zooms()[:3] + ) + self._results["fwhm"] = { + "x": float(fwhm[0]), + "y": float(fwhm[1]), + "z": float(fwhm[2]), + "avg": float(np.average(fwhm)), + } # ICVs - self._results['icvs'] = volume_fraction(pvmdata) + self._results["icvs"] = volume_fraction(pvmdata) # RPVE - self._results['rpve'] = rpve(pvmdata, segdata) + self._results["rpve"] = rpve(pvmdata, segdata) # Image specs - self._results['size'] = {'x': int(inudata.shape[0]), - 'y': int(inudata.shape[1]), - 'z': int(inudata.shape[2])} - self._results['spacing'] = { - i: float(v) for i, v in zip( - ['x', 'y', 'z'], imnii.header.get_zooms()[:3])} + self._results["size"] = { + "x": int(inudata.shape[0]), + "y": int(inudata.shape[1]), + "z": int(inudata.shape[2]), + } + self._results["spacing"] = { + i: float(v) for i, v in zip(["x", "y", "z"], imnii.header.get_zooms()[:3]) + } try: - self._results['size']['t'] = int(inudata.shape[3]) + self._results["size"]["t"] = int(inudata.shape[3]) except IndexError: pass try: - self._results['spacing']['tr'] = float(imnii.header.get_zooms()[3]) + self._results["spacing"]["tr"] = float(imnii.header.get_zooms()[3]) except IndexError: pass # Bias bias = nb.load(self.inputs.in_bias).get_data()[segdata > 0] - self._results['inu'] = { - 'range': float(np.abs(np.percentile(bias, 95.) - np.percentile(bias, 5.))), - 'med': float(np.median(bias))} # pylint: disable=E1101 + self._results["inu"] = { + "range": float( + np.abs(np.percentile(bias, 95.0) - np.percentile(bias, 5.0)) + ), + "med": float(np.median(bias)), + } # pylint: disable=E1101 mni_tpms = [nb.load(tpm).get_data() for tpm in self.inputs.mni_tpms] in_tpms = [nb.load(tpm).get_data() for tpm in self.inputs.in_pvms] overlap = fuzzy_jaccard(in_tpms, mni_tpms) - self._results['tpm_overlap'] = { - 'csf': overlap[0], - 'gm': overlap[1], - 'wm': overlap[2] + self._results["tpm_overlap"] = { + "csf": overlap[0], + "gm": overlap[1], + "wm": overlap[2], } # Flatten the dictionary - self._results['out_qc'] = _flatten_dict(self._results) + self._results["out_qc"] = _flatten_dict(self._results) return runtime class ArtifactMaskInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='File to be plotted') - head_mask = File(exists=True, mandatory=True, desc='head mask') - rot_mask = File(exists=True, desc='a rotation mask') - nasion_post_mask = File(exists=True, mandatory=True, - desc='nasion to posterior of cerebellum mask') + in_file = File(exists=True, mandatory=True, desc="File to be plotted") + head_mask = File(exists=True, mandatory=True, desc="head mask") + rot_mask = File(exists=True, desc="a rotation mask") + nasion_post_mask = File( + exists=True, mandatory=True, desc="nasion to posterior of cerebellum mask" + ) class ArtifactMaskOutputSpec(TraitedSpec): out_hat_msk = File(exists=True, desc='output "hat" mask') - out_art_msk = File(exists=True, desc='output artifacts mask') + out_art_msk = File(exists=True, desc="output artifacts mask") out_air_msk = File(exists=True, desc='output "hat" mask, without artifacts') @@ -205,6 +239,7 @@ class ArtifactMask(SimpleInterface): """ Computes the artifact mask using the method described in [Mortamet2009]_. """ + input_spec = ArtifactMaskInputSpec output_spec = ArtifactMaskOutputSpec @@ -239,42 +274,46 @@ def _run_interface(self, runtime): qi1_img = artifact_mask(imdata, airdata, dist) fname, ext = op.splitext(op.basename(self.inputs.in_file)) - if ext == '.gz': + if ext == ".gz": fname, ext2 = op.splitext(fname) ext = ext2 + ext - self._results['out_hat_msk'] = op.abspath('{}_hat{}'.format(fname, ext)) - self._results['out_art_msk'] = op.abspath('{}_art{}'.format(fname, ext)) - self._results['out_air_msk'] = op.abspath('{}_air{}'.format(fname, ext)) + self._results["out_hat_msk"] = op.abspath("{}_hat{}".format(fname, ext)) + self._results["out_art_msk"] = op.abspath("{}_art{}".format(fname, ext)) + self._results["out_air_msk"] = op.abspath("{}_air{}".format(fname, ext)) hdr = imnii.header.copy() hdr.set_data_dtype(np.uint8) nb.Nifti1Image(qi1_img, imnii.affine, hdr).to_filename( - self._results['out_art_msk']) + self._results["out_art_msk"] + ) nb.Nifti1Image(airdata, imnii.affine, hdr).to_filename( - self._results['out_hat_msk']) + self._results["out_hat_msk"] + ) airdata[qi1_img > 0] = 0 nb.Nifti1Image(airdata, imnii.affine, hdr).to_filename( - self._results['out_air_msk']) + self._results["out_air_msk"] + ) return runtime class ComputeQI2InputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='File to be plotted') - air_msk = File(exists=True, mandatory=True, desc='air (without artifacts) mask') + in_file = File(exists=True, mandatory=True, desc="File to be plotted") + air_msk = File(exists=True, mandatory=True, desc="air (without artifacts) mask") class ComputeQI2OutputSpec(TraitedSpec): - qi2 = traits.Float(desc='computed QI2 value') - out_file = File(desc='output plot: noise fit') + qi2 = traits.Float(desc="computed QI2 value") + out_file = File(desc="output plot: noise fit") class ComputeQI2(SimpleInterface): """ Computes the artifact mask using the method described in [Mortamet2009]_. """ + input_spec = ComputeQI2InputSpec output_spec = ComputeQI2OutputSpec @@ -282,25 +321,28 @@ def _run_interface(self, runtime): imdata = nb.load(self.inputs.in_file).get_data() airdata = nb.load(self.inputs.air_msk).get_data() qi2, out_file = art_qi2(imdata, airdata) - self._results['qi2'] = qi2 - self._results['out_file'] = out_file + self._results["qi2"] = qi2 + self._results["out_file"] = out_file return runtime class HarmonizeInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='input data (after bias correction)') - wm_mask = File(exists=True, mandatory=True, desc='white-matter mask') - erodemsk = traits.Bool(True, usedefault=True, desc='erode mask') + in_file = File( + exists=True, mandatory=True, desc="input data (after bias correction)" + ) + wm_mask = File(exists=True, mandatory=True, desc="white-matter mask") + erodemsk = traits.Bool(True, usedefault=True, desc="erode mask") class HarmonizeOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='input data (after intensity harmonization)') + out_file = File(exists=True, desc="input data (after intensity harmonization)") class Harmonize(SimpleInterface): """ Computes the artifact mask using the method described in [Mortamet2009]_. """ + input_spec = HarmonizeInputSpec output_spec = HarmonizeOutputSpec @@ -321,28 +363,29 @@ def _run_interface(self, runtime): data = in_file.get_data() data *= 1000.0 / np.median(data[wm_mask > 0]) - out_file = fname_presuffix(self.inputs.in_file, - suffix='_harmonized', newpath='.') - in_file.__class__(data, in_file.affine, in_file.header).to_filename( - out_file) + out_file = fname_presuffix( + self.inputs.in_file, suffix="_harmonized", newpath="." + ) + in_file.__class__(data, in_file.affine, in_file.header).to_filename(out_file) - self._results['out_file'] = out_file + self._results["out_file"] = out_file return runtime class RotationMaskInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='input data') + in_file = File(exists=True, mandatory=True, desc="input data") class RotationMaskOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='rotation mask (if any)') + out_file = File(exists=True, desc="rotation mask (if any)") class RotationMask(SimpleInterface): """ Computes the artifact mask using the method described in [Mortamet2009]_. """ + input_spec = RotationMaskInputSpec output_spec = RotationMaskOutputSpec @@ -352,12 +395,11 @@ def _run_interface(self, runtime): mask = data <= 0 # Pad one pixel to control behavior on borders of binary_opening - mask = np.pad(mask, pad_width=(1,), mode='constant', constant_values=1) + mask = np.pad(mask, pad_width=(1,), mode="constant", constant_values=1) # Remove noise struc = nd.generate_binary_structure(3, 2) - mask = nd.binary_opening(mask, structure=struc).astype( - np.uint8) + mask = nd.binary_opening(mask, structure=struc).astype(np.uint8) # Remove small objects label_im, nb_labels = nd.label(mask) @@ -377,20 +419,19 @@ def _run_interface(self, runtime): out_img = in_file.__class__(mask, in_file.affine, in_file.header) out_img.header.set_data_dtype(np.uint8) - out_file = fname_presuffix(self.inputs.in_file, - suffix='_rotmask', newpath='.') + out_file = fname_presuffix(self.inputs.in_file, suffix="_rotmask", newpath=".") out_img.to_filename(out_file) - self._results['out_file'] = out_file + self._results["out_file"] = out_file return runtime -def artifact_mask(imdata, airdata, distance, zscore=10.): +def artifact_mask(imdata, airdata, distance, zscore=10.0): """Computes a mask of artifacts found in the air region""" from statsmodels.robust.scale import mad if not np.issubdtype(airdata.dtype, np.integer): - airdata[airdata < .95] = 0 - airdata[airdata > 0.] = 1 + airdata[airdata < 0.95] = 0 + airdata[airdata > 0.0] = 1 bg_img = imdata * airdata if np.sum((bg_img > 0).astype(np.uint8)) < 100: @@ -407,7 +448,7 @@ def artifact_mask(imdata, airdata, distance, zscore=10.): # contributing artifacts. qi1_img = np.zeros_like(bg_img) qi1_img[bg_img > zscore] = 1 - qi1_img[distance < .10] = 0 + qi1_img[distance < 0.10] = 0 # Create a structural element to be used in an opening operation. struc = nd.generate_binary_structure(3, 1) diff --git a/mriqc/interfaces/bids.py b/mriqc/interfaces/bids.py index bb95dffa3..573eda45e 100644 --- a/mriqc/interfaces/bids.py +++ b/mriqc/interfaces/bids.py @@ -4,28 +4,35 @@ import re import simplejson as json from nipype.interfaces.base import ( - traits, isdefined, TraitedSpec, DynamicTraitedSpec, BaseInterfaceInputSpec, - File, Undefined, Str, SimpleInterface + traits, + isdefined, + TraitedSpec, + DynamicTraitedSpec, + BaseInterfaceInputSpec, + File, + Undefined, + Str, + SimpleInterface, ) from .. import config from ..utils.misc import BIDS_COMP class IQMFileSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec): - in_file = Str(mandatory=True, desc='path of input file') - subject_id = Str(mandatory=True, desc='the subject id') - modality = Str(mandatory=True, desc='the qc type') + in_file = Str(mandatory=True, desc="path of input file") + subject_id = Str(mandatory=True, desc="the subject id") + modality = Str(mandatory=True, desc="the qc type") session_id = traits.Either(None, Str, usedefault=True) task_id = traits.Either(None, Str, usedefault=True) acq_id = traits.Either(None, Str, usedefault=True) rec_id = traits.Either(None, Str, usedefault=True) run_id = traits.Either(None, traits.Int, usedefault=True) - dataset = Str(desc='dataset identifier') + dataset = Str(desc="dataset identifier") metadata = traits.Dict() provenance = traits.Dict() - root = traits.Dict(desc='output root dictionary') - out_dir = File(desc='the output directory') + root = traits.Dict(desc="output root dictionary") + out_dir = File(desc="the output directory") _outputs = traits.Dict(value={}, usedefault=True) def __setattr__(self, key, value): @@ -40,13 +47,13 @@ def __setattr__(self, key, value): class IQMFileSinkOutputSpec(TraitedSpec): - out_file = File(desc='the output JSON file containing the IQMs') + out_file = File(desc="the output JSON file containing the IQMs") class IQMFileSink(SimpleInterface): input_spec = IQMFileSinkInputSpec output_spec = IQMFileSinkOutputSpec - expr = re.compile('^root[0-9]+$') + expr = re.compile("^root[0-9]+$") def __init__(self, fields=None, force_run=True, **inputs): super(IQMFileSink, self).__init__(**inputs) @@ -78,17 +85,16 @@ def _gen_outfile(self): # Crawl back to the BIDS root path = Path(self.inputs.in_file) for i in range(1, 4): - if str(path.parents[i].name).startswith('sub-'): + if str(path.parents[i].name).startswith("sub-"): bids_root = path.parents[i + 1] break in_file = str(path.relative_to(bids_root)) # Build path and ensure directory exists - bids_path = out_dir / in_file.replace( - ''.join(Path(in_file).suffixes), '.json') + bids_path = out_dir / in_file.replace("".join(Path(in_file).suffixes), ".json") bids_path.parent.mkdir(parents=True, exist_ok=True) - self._results['out_file'] = str(bids_path) - return self._results['out_file'] + self._results["out_file"] = str(bids_path) + return self._results["out_file"] def _run_interface(self, runtime): out_file = self._gen_outfile() @@ -98,7 +104,7 @@ def _run_interface(self, runtime): root_adds = [] for key, val in list(self.inputs._outputs.items()): - if not isdefined(val) or key == 'trait_added': + if not isdefined(val) or key == "trait_added": continue if not self.expr.match(key) is None: @@ -115,7 +121,10 @@ def _run_interface(self, runtime): else: config.loggers.interface.warning( 'Output "%s" is not a dictionary (value="%s"), ' - 'discarding output.', root_key, str(val)) + "discarding output.", + root_key, + str(val), + ) # Fill in the "bids_meta" key id_dict = {} @@ -123,17 +132,17 @@ def _run_interface(self, runtime): comp_val = getattr(self.inputs, comp, None) if isdefined(comp_val) and comp_val is not None: id_dict[comp] = comp_val - id_dict['modality'] = self.inputs.modality + id_dict["modality"] = self.inputs.modality if isdefined(self.inputs.metadata) and self.inputs.metadata: id_dict.update(self.inputs.metadata) - if self._out_dict.get('bids_meta') is None: - self._out_dict['bids_meta'] = {} - self._out_dict['bids_meta'].update(id_dict) + if self._out_dict.get("bids_meta") is None: + self._out_dict["bids_meta"] = {} + self._out_dict["bids_meta"].update(id_dict) if isdefined(self.inputs.dataset): - self._out_dict['bids_meta']['dataset'] = self.inputs.dataset + self._out_dict["bids_meta"]["dataset"] = self.inputs.dataset # Fill in the "provenance" key # Predict QA from IQMs and add to metadata @@ -141,20 +150,21 @@ def _run_interface(self, runtime): if isdefined(self.inputs.provenance) and self.inputs.provenance: prov_dict.update(self.inputs.provenance) - if self._out_dict.get('provenance') is None: - self._out_dict['provenance'] = {} - self._out_dict['provenance'].update(prov_dict) + if self._out_dict.get("provenance") is None: + self._out_dict["provenance"] = {} + self._out_dict["provenance"].update(prov_dict) - with open(out_file, 'w') as f: - f.write(json.dumps(self._out_dict, sort_keys=True, indent=2, - ensure_ascii=False)) + with open(out_file, "w") as f: + f.write( + json.dumps(self._out_dict, sort_keys=True, indent=2, ensure_ascii=False) + ) return runtime def _process_name(name, val): - if '.' in name: - newkeys = name.split('.') + if "." in name: + newkeys = name.split(".") name = newkeys.pop(0) nested_dict = {newkeys.pop(): val} diff --git a/mriqc/interfaces/common.py b/mriqc/interfaces/common.py index 268d3bbaa..dea5dc3cd 100644 --- a/mriqc/interfaces/common.py +++ b/mriqc/interfaces/common.py @@ -8,23 +8,25 @@ import nibabel as nb from nipype.interfaces.base import ( - traits, TraitedSpec, BaseInterfaceInputSpec, File, isdefined, - SimpleInterface + traits, + TraitedSpec, + BaseInterfaceInputSpec, + File, + isdefined, + SimpleInterface, ) from nipype.interfaces.ants import ApplyTransforms from .. import config class ConformImageInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='input image') - check_ras = traits.Bool(True, usedefault=True, - desc='check that orientation is RAS') - check_dtype = traits.Bool(True, usedefault=True, - desc='check data type') + in_file = File(exists=True, mandatory=True, desc="input image") + check_ras = traits.Bool(True, usedefault=True, desc="check that orientation is RAS") + check_dtype = traits.Bool(True, usedefault=True, desc="check data type") class ConformImageOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='output conformed file') + out_file = File(exists=True, desc="output conformed file") class ConformImage(SimpleInterface): @@ -86,6 +88,7 @@ class ConformImage(SimpleInterface): NIFTI_TYPE_COMPLEX256 2048 /! 256 bit complex = 2 128 bit floats / """ + input_spec = ConformImageInputSpec output_spec = ConformImageOutputSpec @@ -98,12 +101,14 @@ def _run_interface(self, runtime): if self.inputs.check_dtype: changed = True - datatype = int(hdr['datatype']) + datatype = int(hdr["datatype"]) if datatype == 1: config.loggers.interface.warning( 'Input image %s has a suspicious data type "%s"', - self.inputs.in_file, hdr.get_data_dtype()) + self.inputs.in_file, + hdr.get_data_dtype(), + ) # signed char and bool to uint8 if datatype == 1 or datatype == 2 or datatype == 256: @@ -125,30 +130,28 @@ def _run_interface(self, runtime): if changed: hdr.set_data_dtype(dtype) - nii = nb.Nifti1Image(nii.get_data().astype(dtype), - nii.affine, hdr) + nii = nb.Nifti1Image(nii.get_data().astype(dtype), nii.affine, hdr) # Generate name out_file, ext = op.splitext(op.basename(self.inputs.in_file)) - if ext == '.gz': + if ext == ".gz": out_file, ext2 = op.splitext(out_file) ext = ext2 + ext - self._results['out_file'] = op.abspath('{}_conformed{}'.format(out_file, ext)) - nii.to_filename(self._results['out_file']) + self._results["out_file"] = op.abspath("{}_conformed{}".format(out_file, ext)) + nii.to_filename(self._results["out_file"]) return runtime class EnsureSizeInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, copyfile=False, mandatory=True, desc='input image') - in_mask = File(exists=True, copyfile=False, desc='input mask') - pixel_size = traits.Float(2.0, usedefault=True, - desc='desired pixel size (mm)') + in_file = File(exists=True, copyfile=False, mandatory=True, desc="input image") + in_mask = File(exists=True, copyfile=False, desc="input mask") + pixel_size = traits.Float(2.0, usedefault=True, desc="desired pixel size (mm)") class EnsureSizeOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='output image') - out_mask = File(exists=True, desc='output mask') + out_file = File(exists=True, desc="output image") + out_mask = File(exists=True, desc="output mask") class EnsureSize(SimpleInterface): @@ -157,6 +160,7 @@ class EnsureSize(SimpleInterface): have `pixel_size` """ + input_spec = EnsureSizeInputSpec output_spec = EnsureSizeOutputSpec @@ -165,17 +169,23 @@ def _run_interface(self, runtime): zooms = nii.header.get_zooms() size_diff = np.array(zooms[:3]) - (self.inputs.pixel_size - 0.1) if np.all(size_diff >= -1e-3): - config.loggers.interface.info('Voxel size is large enough') - self._results['out_file'] = self.inputs.in_file + config.loggers.interface.info("Voxel size is large enough") + self._results["out_file"] = self.inputs.in_file if isdefined(self.inputs.in_mask): - self._results['out_mask'] = self.inputs.in_mask + self._results["out_mask"] = self.inputs.in_mask return runtime config.loggers.interface.info( - 'One or more voxel dimensions (%f, %f, %f) are smaller than ' - 'the requested voxel size (%f) - diff=(%f, %f, %f)', zooms[0], - zooms[1], zooms[2], self.inputs.pixel_size, size_diff[0], - size_diff[1], size_diff[2]) + "One or more voxel dimensions (%f, %f, %f) are smaller than " + "the requested voxel size (%f) - diff=(%f, %f, %f)", + zooms[0], + zooms[1], + zooms[2], + self.inputs.pixel_size, + size_diff[0], + size_diff[1], + size_diff[2], + ) # Figure out new matrix # 1) Get base affine @@ -195,7 +205,9 @@ def _run_interface(self, runtime): new_size = np.array(extent_mm / self.inputs.pixel_size, dtype=int) # 5) Initialize new base affine - new_base = aff_base[:3, :3] * np.abs(aff_base_inv[:3, :3]) * self.inputs.pixel_size + new_base = ( + aff_base[:3, :3] * np.abs(aff_base_inv[:3, :3]) * self.inputs.pixel_size + ) # 6) Find new center new_center_idx = (new_size - 1) * 0.5 @@ -210,47 +222,49 @@ def _run_interface(self, runtime): # 8) Generate new reference image hdr = nii.header.copy() hdr.set_data_shape(new_size) - ref_file = 'resample_ref.nii.gz' - nb.Nifti1Image(np.zeros(new_size, dtype=nii.get_data_dtype()), - new_affine, hdr).to_filename(ref_file) + ref_file = "resample_ref.nii.gz" + nb.Nifti1Image( + np.zeros(new_size, dtype=nii.get_data_dtype()), new_affine, hdr + ).to_filename(ref_file) out_prefix, ext = op.splitext(op.basename(self.inputs.in_file)) - if ext == '.gz': + if ext == ".gz": out_prefix, ext2 = op.splitext(out_prefix) ext = ext2 + ext - out_file = op.abspath('%s_resampled%s' % (out_prefix, ext)) + out_file = op.abspath("%s_resampled%s" % (out_prefix, ext)) # 9) Resample new image ApplyTransforms( dimension=3, input_image=self.inputs.in_file, reference_image=ref_file, - interpolation='LanczosWindowedSinc', - transforms=[pkgrf('mriqc', 'data/itk_identity.tfm')], + interpolation="LanczosWindowedSinc", + transforms=[pkgrf("mriqc", "data/itk_identity.tfm")], output_image=out_file, ).run() - self._results['out_file'] = out_file + self._results["out_file"] = out_file if isdefined(self.inputs.in_mask): hdr = nii.header.copy() hdr.set_data_shape(new_size) hdr.set_data_dtype(np.uint8) - ref_mask = 'mask_ref.nii.gz' - nb.Nifti1Image(np.zeros(new_size, dtype=np.uint8), - new_affine, hdr).to_filename(ref_mask) + ref_mask = "mask_ref.nii.gz" + nb.Nifti1Image( + np.zeros(new_size, dtype=np.uint8), new_affine, hdr + ).to_filename(ref_mask) - out_mask = op.abspath('%s_resmask%s' % (out_prefix, ext)) + out_mask = op.abspath("%s_resmask%s" % (out_prefix, ext)) ApplyTransforms( dimension=3, input_image=self.inputs.in_mask, reference_image=ref_mask, - interpolation='NearestNeighbor', - transforms=[pkgrf('mriqc', 'data/itk_identity.tfm')], + interpolation="NearestNeighbor", + transforms=[pkgrf("mriqc", "data/itk_identity.tfm")], output_image=out_mask, ).run() - self._results['out_mask'] = out_mask + self._results["out_mask"] = out_mask return runtime diff --git a/mriqc/interfaces/functional.py b/mriqc/interfaces/functional.py index 8cbd3db7b..1e6e486f3 100644 --- a/mriqc/interfaces/functional.py +++ b/mriqc/interfaces/functional.py @@ -7,8 +7,12 @@ from builtins import zip from nipype.interfaces.base import ( - traits, TraitedSpec, File, isdefined, BaseInterfaceInputSpec, - SimpleInterface + traits, + TraitedSpec, + File, + isdefined, + BaseInterfaceInputSpec, + SimpleInterface, ) from ..utils.misc import _flatten_dict @@ -17,17 +21,29 @@ class FunctionalQCInputSpec(BaseInterfaceInputSpec): - in_epi = File(exists=True, mandatory=True, desc='input EPI file') - in_hmc = File(exists=True, mandatory=True, desc='input motion corrected file') - in_tsnr = File(exists=True, mandatory=True, desc='input tSNR volume') - in_mask = File(exists=True, mandatory=True, desc='input mask') - direction = traits.Enum('all', 'x', 'y', '-x', '-y', usedefault=True, - desc='direction for GSR computation') - in_fd = File(exists=True, mandatory=True, desc='motion parameters for FD computation') - fd_thres = traits.Float(0.2, usedefault=True, desc='motion threshold for FD computation') - in_dvars = File(exists=True, mandatory=True, desc='input file containing DVARS') - in_fwhm = traits.List(traits.Float, mandatory=True, - desc='smoothness estimated with AFNI') + in_epi = File(exists=True, mandatory=True, desc="input EPI file") + in_hmc = File(exists=True, mandatory=True, desc="input motion corrected file") + in_tsnr = File(exists=True, mandatory=True, desc="input tSNR volume") + in_mask = File(exists=True, mandatory=True, desc="input mask") + direction = traits.Enum( + "all", + "x", + "y", + "-x", + "-y", + usedefault=True, + desc="direction for GSR computation", + ) + in_fd = File( + exists=True, mandatory=True, desc="motion parameters for FD computation" + ) + fd_thres = traits.Float( + 0.2, usedefault=True, desc="motion threshold for FD computation" + ) + in_dvars = File(exists=True, mandatory=True, desc="input file containing DVARS") + in_fwhm = traits.List( + traits.Float, mandatory=True, desc="smoothness estimated with AFNI" + ) class FunctionalQCOutputSpec(TraitedSpec): @@ -38,12 +54,12 @@ class FunctionalQCOutputSpec(TraitedSpec): tsnr = traits.Float dvars = traits.Dict fd = traits.Dict - fwhm = traits.Dict(desc='full width half-maximum measure') + fwhm = traits.Dict(desc="full width half-maximum measure") size = traits.Dict spacing = traits.Dict summary = traits.Dict - out_qc = traits.Dict(desc='output flattened dictionary with all measures') + out_qc = traits.Dict(desc="output flattened dictionary with all measures") class FunctionalQC(SimpleInterface): @@ -52,6 +68,7 @@ class FunctionalQC(SimpleInterface): structural image given as input """ + input_spec = FunctionalQCInputSpec output_spec = FunctionalQCOutputSpec @@ -77,93 +94,107 @@ def _run_interface(self, runtime): # Summary stats stats = summary_stats(epidata, mskdata, erode=True) - self._results['summary'] = stats + self._results["summary"] = stats # SNR - self._results['snr'] = snr(stats['fg']['median'], stats['fg']['stdv'], stats['fg']['n']) + self._results["snr"] = snr( + stats["fg"]["median"], stats["fg"]["stdv"], stats["fg"]["n"] + ) # FBER - self._results['fber'] = fber(epidata, mskdata) + self._results["fber"] = fber(epidata, mskdata) # EFC - self._results['efc'] = efc(epidata) + self._results["efc"] = efc(epidata) # GSR - self._results['gsr'] = {} - if self.inputs.direction == 'all': - epidir = ['x', 'y'] + self._results["gsr"] = {} + if self.inputs.direction == "all": + epidir = ["x", "y"] else: epidir = [self.inputs.direction] for axis in epidir: - self._results['gsr'][axis] = gsr(epidata, mskdata, direction=axis) + self._results["gsr"][axis] = gsr(epidata, mskdata, direction=axis) # DVARS - dvars_avg = np.loadtxt(self.inputs.in_dvars, skiprows=1, - usecols=list(range(3))).mean(axis=0) - dvars_col = ['std', 'nstd', 'vstd'] - self._results['dvars'] = { + dvars_avg = np.loadtxt( + self.inputs.in_dvars, skiprows=1, usecols=list(range(3)) + ).mean(axis=0) + dvars_col = ["std", "nstd", "vstd"] + self._results["dvars"] = { key: float(val) for key, val in zip(dvars_col, dvars_avg) } # tSNR tsnr_data = nb.load(self.inputs.in_tsnr).get_data() - self._results['tsnr'] = float(np.median(tsnr_data[mskdata > 0])) + self._results["tsnr"] = float(np.median(tsnr_data[mskdata > 0])) # FD fd_data = np.loadtxt(self.inputs.in_fd, skiprows=1) num_fd = np.float((fd_data > self.inputs.fd_thres).sum()) - self._results['fd'] = { - 'mean': float(fd_data.mean()), - 'num': int(num_fd), - 'perc': float(num_fd * 100 / (len(fd_data) + 1)) + self._results["fd"] = { + "mean": float(fd_data.mean()), + "num": int(num_fd), + "perc": float(num_fd * 100 / (len(fd_data) + 1)), } # FWHM - fwhm = np.array(self.inputs.in_fwhm[:3]) / np.array(hmcnii.header.get_zooms()[:3]) - self._results['fwhm'] = { - 'x': float(fwhm[0]), 'y': float(fwhm[1]), 'z': float(fwhm[2]), - 'avg': float(np.average(fwhm))} + fwhm = np.array(self.inputs.in_fwhm[:3]) / np.array( + hmcnii.header.get_zooms()[:3] + ) + self._results["fwhm"] = { + "x": float(fwhm[0]), + "y": float(fwhm[1]), + "z": float(fwhm[2]), + "avg": float(np.average(fwhm)), + } # Image specs - self._results['size'] = {'x': int(hmcdata.shape[0]), - 'y': int(hmcdata.shape[1]), - 'z': int(hmcdata.shape[2])} - self._results['spacing'] = { - i: float(v) for i, v in zip(['x', 'y', 'z'], - hmcnii.header.get_zooms()[:3])} + self._results["size"] = { + "x": int(hmcdata.shape[0]), + "y": int(hmcdata.shape[1]), + "z": int(hmcdata.shape[2]), + } + self._results["spacing"] = { + i: float(v) for i, v in zip(["x", "y", "z"], hmcnii.header.get_zooms()[:3]) + } try: - self._results['size']['t'] = int(hmcdata.shape[3]) + self._results["size"]["t"] = int(hmcdata.shape[3]) except IndexError: pass try: - self._results['spacing']['tr'] = float(hmcnii.header.get_zooms()[3]) + self._results["spacing"]["tr"] = float(hmcnii.header.get_zooms()[3]) except IndexError: pass - self._results['out_qc'] = _flatten_dict(self._results) + self._results["out_qc"] = _flatten_dict(self._results) return runtime class SpikesInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, desc='input fMRI dataset') - in_mask = File(exists=True, desc='brain mask') - invert_mask = traits.Bool(False, usedefault=True, desc='invert mask') - no_zscore = traits.Bool(False, usedefault=True, desc='do not zscore') - detrend = traits.Bool(True, usedefault=True, desc='do detrend') - spike_thresh = traits.Float(6., usedefault=True, - desc='z-score to call one timepoint of one axial slice a spike') - skip_frames = traits.Int(0, usedefault=True, - desc='number of frames to skip in the beginning of the time series') - out_tsz = File('spikes_tsz.txt', usedefault=True, desc='output file name') - out_spikes = File( - 'spikes_idx.txt', usedefault=True, desc='output file name') + in_file = File(exists=True, mandatory=True, desc="input fMRI dataset") + in_mask = File(exists=True, desc="brain mask") + invert_mask = traits.Bool(False, usedefault=True, desc="invert mask") + no_zscore = traits.Bool(False, usedefault=True, desc="do not zscore") + detrend = traits.Bool(True, usedefault=True, desc="do detrend") + spike_thresh = traits.Float( + 6.0, + usedefault=True, + desc="z-score to call one timepoint of one axial slice a spike", + ) + skip_frames = traits.Int( + 0, + usedefault=True, + desc="number of frames to skip in the beginning of the time series", + ) + out_tsz = File("spikes_tsz.txt", usedefault=True, desc="output file name") + out_spikes = File("spikes_idx.txt", usedefault=True, desc="output file name") class SpikesOutputSpec(TraitedSpec): - out_tsz = File( - desc='slice-wise z-scored timeseries (Z x N), inside brainmask') - out_spikes = File(desc='indices of spikes') - num_spikes = traits.Int(desc='number of spikes found (total)') + out_tsz = File(desc="slice-wise z-scored timeseries (Z x N), inside brainmask") + out_spikes = File(desc="indices of spikes") + num_spikes = traits.Int(desc="number of spikes found (total)") class Spikes(SimpleInterface): @@ -173,6 +204,7 @@ class Spikes(SimpleInterface): https://github.com/cni/nims/blob/master/nimsproc/qa_report.py """ + input_spec = SpikesInputSpec output_spec = SpikesOutputSpec @@ -186,16 +218,20 @@ def _run_interface(self, runtime): if self.inputs.detrend: from nilearn.signal import clean + data = func_data.reshape(-1, ntsteps) clean_data = clean(data[:, nskip:].T, t_r=tr, standardize=False).T new_shape = ( - func_shape[0], func_shape[1], func_shape[2], clean_data.shape[-1]) + func_shape[0], + func_shape[1], + func_shape[2], + clean_data.shape[-1], + ) func_data = np.zeros(func_shape) func_data[..., nskip:] = clean_data.reshape(new_shape) if not isdefined(self.inputs.in_mask): - _, mask_data, _ = auto_mask( - func_data, nskip=self.inputs.skip_frames) + _, mask_data, _ = auto_mask(func_data, nskip=self.inputs.skip_frames) else: mask_data = nb.load(self.inputs.in_mask).get_data() mask_data[..., :nskip] = 0 @@ -204,25 +240,24 @@ def _run_interface(self, runtime): if not self.inputs.invert_mask: brain = np.ma.array(func_data, mask=(mask_data != 1)) else: - mask_data[..., :self.inputs.skip_frames] = 1 + mask_data[..., : self.inputs.skip_frames] = 1 brain = np.ma.array(func_data, mask=(mask_data == 1)) if self.inputs.no_zscore: ts_z = find_peaks(brain) total_spikes = [] else: - total_spikes, ts_z = find_spikes( - brain, self.inputs.spike_thresh) + total_spikes, ts_z = find_spikes(brain, self.inputs.spike_thresh) total_spikes = list(set(total_spikes)) out_tsz = op.abspath(self.inputs.out_tsz) - self._results['out_tsz'] = out_tsz + self._results["out_tsz"] = out_tsz np.savetxt(out_tsz, ts_z) out_spikes = op.abspath(self.inputs.out_spikes) - self._results['out_spikes'] = out_spikes + self._results["out_spikes"] = out_spikes np.savetxt(out_spikes, total_spikes) - self._results['num_spikes'] = len(total_spikes) + self._results["num_spikes"] = len(total_spikes) return runtime @@ -251,12 +286,16 @@ def find_spikes(data, spike_thresh): def auto_mask(data, raw_d=None, nskip=3, mask_bad_end_vols=False): from dipy.segment.mask import median_otsu + mn = data[:, :, :, nskip:].mean(3) _, mask = median_otsu(mn, 3, 2) # oesteban: masked_data was not used - mask = np.concatenate(( - np.tile(True, (data.shape[0], data.shape[1], data.shape[2], nskip)), - np.tile(np.expand_dims(mask == 0, 3), (1, 1, 1, data.shape[3] - nskip))), - axis=3) + mask = np.concatenate( + ( + np.tile(True, (data.shape[0], data.shape[1], data.shape[2], nskip)), + np.tile(np.expand_dims(mask == 0, 3), (1, 1, 1, data.shape[3] - nskip)), + ), + axis=3, + ) mask_vols = np.zeros((mask.shape[-1]), dtype=int) if mask_bad_end_vols: # Some runs have corrupt volumes at the end (e.g., mux scans that are stopped @@ -286,7 +325,6 @@ def auto_mask(data, raw_d=None, nskip=3, mask_bad_end_vols=False): def _robust_zscore(data): - return (( - data - np.atleast_2d(np.median(data, axis=1)).T) - / np.atleast_2d(data.std(axis=1)).T - ) + return (data - np.atleast_2d(np.median(data, axis=1)).T) / np.atleast_2d( + data.std(axis=1) + ).T diff --git a/mriqc/interfaces/reports.py b/mriqc/interfaces/reports.py index f4bf9c7a7..a8d8b9d8d 100644 --- a/mriqc/interfaces/reports.py +++ b/mriqc/interfaces/reports.py @@ -4,8 +4,13 @@ import numpy as np import nibabel as nb from nipype.interfaces.base import ( - traits, TraitedSpec, File, isdefined, InputMultiObject, - BaseInterfaceInputSpec, SimpleInterface + traits, + TraitedSpec, + File, + isdefined, + InputMultiObject, + BaseInterfaceInputSpec, + SimpleInterface, ) from .. import config from ..reports.individual import individual_html @@ -55,31 +60,34 @@ class AddProvenance(SimpleInterface): def _run_interface(self, runtime): from nipype.utils.filemanip import hash_infile + self._results["out_prov"] = { - 'md5sum': hash_infile(self.inputs.in_file), - 'version': config.environment.version, - 'software': 'mriqc', - 'webapi_url': config.execution.webapi_url, - 'webapi_port': config.execution.webapi_port, - 'settings': { - 'testing': config.execution.debug, - }, + "md5sum": hash_infile(self.inputs.in_file), + "version": config.environment.version, + "software": "mriqc", + "webapi_url": config.execution.webapi_url, + "webapi_port": config.execution.webapi_port, + "settings": {"testing": config.execution.debug, }, } if self.inputs.modality in ("T1w", "T2w"): - air_msk_size = np.asanyarray(nb.load(self.inputs.air_msk).dataobj).astype( - bool).sum() - rot_msk_size = np.asanyarray(nb.load(self.inputs.rot_msk).dataobj).astype( - bool).sum() + air_msk_size = ( + np.asanyarray(nb.load(self.inputs.air_msk).dataobj).astype(bool).sum() + ) + rot_msk_size = ( + np.asanyarray(nb.load(self.inputs.rot_msk).dataobj).astype(bool).sum() + ) self._results["out_prov"]["warnings"] = { - 'small_air_mask': bool(air_msk_size < 5e5), - 'large_rot_frame': bool(rot_msk_size > 500), + "small_air_mask": bool(air_msk_size < 5e5), + "large_rot_frame": bool(rot_msk_size > 500), } if self.inputs.modality == "bold": - self._results["out_prov"]["settings"].update({ - 'fd_thres': config.workflow.fd_thres, - 'hmc_fsl': config.workflow.hmc_fsl, - }) + self._results["out_prov"]["settings"].update( + { + "fd_thres": config.workflow.fd_thres, + "hmc_fsl": config.workflow.hmc_fsl, + } + ) return runtime diff --git a/mriqc/interfaces/transitional.py b/mriqc/interfaces/transitional.py index e04b2088e..e9e2d6b72 100644 --- a/mriqc/interfaces/transitional.py +++ b/mriqc/interfaces/transitional.py @@ -2,33 +2,41 @@ # vi: set ft=python sts=4 ts=4 sw=4 et: from nipype.interfaces.base import ( - File, traits, CommandLine, TraitedSpec, CommandLineInputSpec + File, + traits, + CommandLine, + TraitedSpec, + CommandLineInputSpec, ) class GCORInputSpec(CommandLineInputSpec): in_file = File( - desc='input dataset to compute the GCOR over', - argstr='-input %s', + desc="input dataset to compute the GCOR over", + argstr="-input %s", position=-1, mandatory=True, exists=True, - copyfile=False) + copyfile=False, + ) mask = File( - desc='mask dataset, for restricting the computation', - argstr='-mask %s', + desc="mask dataset, for restricting the computation", + argstr="-mask %s", exists=True, - copyfile=False) + copyfile=False, + ) - nfirst = traits.Int(0, argstr='-nfirst %d', - desc='specify number of initial TRs to ignore') - no_demean = traits.Bool(False, argstr='-no_demean', - desc='do not (need to) demean as first step') + nfirst = traits.Int( + 0, argstr="-nfirst %d", desc="specify number of initial TRs to ignore" + ) + no_demean = traits.Bool( + False, argstr="-no_demean", desc="do not (need to) demean as first step" + ) class GCOROutputSpec(TraitedSpec): - out = traits.Float(desc='global correlation value') + out = traits.Float(desc="global correlation value") class GCOR(CommandLine): @@ -50,17 +58,20 @@ class GCOR(CommandLine): """ - _cmd = '@compute_gcor' + _cmd = "@compute_gcor" input_spec = GCORInputSpec output_spec = GCOROutputSpec def _run_interface(self, runtime): runtime = super(GCOR, self)._run_interface(runtime) - gcor_line = [line.strip() for line in runtime.stdout.split('\n') - if line.strip().startswith('GCOR = ')][-1] - setattr(self, '_gcor', float(gcor_line[len('GCOR = '):])) + gcor_line = [ + line.strip() + for line in runtime.stdout.split("\n") + if line.strip().startswith("GCOR = ") + ][-1] + setattr(self, "_gcor", float(gcor_line[len("GCOR = "):])) return runtime def _list_outputs(self): - return {'out': getattr(self, '_gcor')} + return {"out": getattr(self, "_gcor")} diff --git a/mriqc/interfaces/viz.py b/mriqc/interfaces/viz.py index 788bf28da..f7a5f034a 100644 --- a/mriqc/interfaces/viz.py +++ b/mriqc/interfaces/viz.py @@ -1,47 +1,57 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# -# @Author: oesteban -# @Date: 2016-01-05 11:29:40 -# @Email: code@oscaresteban.es -# @Last modified by: oesteban -""" Visualization interfaces """ +"""Visualization interfaces.""" from pathlib import Path import numpy as np from nipype.interfaces.base import ( - traits, TraitedSpec, File, BaseInterfaceInputSpec, isdefined, - SimpleInterface) + traits, + TraitedSpec, + File, + BaseInterfaceInputSpec, + isdefined, + SimpleInterface, +) from io import open # pylint: disable=W0622 -from ..viz.utils import (plot_mosaic, plot_segmentation, plot_spikes) +from ..viz.utils import plot_mosaic, plot_segmentation, plot_spikes class PlotContoursInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, - desc='File to be plotted') - in_contours = File(exists=True, mandatory=True, - desc='file to pick the contours from') - cut_coords = traits.Int(8, usedefault=True, desc='number of slices') - levels = traits.List([.5], traits.Float, usedefault=True, - desc='add a contour per level') - colors = traits.List(['r'], traits.Str, usedefault=True, - desc='colors to be used for contours') - display_mode = traits.Enum('ortho', 'x', 'y', 'z', 'yx', 'xz', 'yz', usedefault=True, - desc='visualization mode') - saturate = traits.Bool(False, usedefault=True, desc='saturate background') - out_file = traits.File(exists=False, desc='output file name') - vmin = traits.Float(desc='minimum intensity') - vmax = traits.Float(desc='maximum intensity') + in_file = File(exists=True, mandatory=True, desc="File to be plotted") + in_contours = File( + exists=True, mandatory=True, desc="file to pick the contours from" + ) + cut_coords = traits.Int(8, usedefault=True, desc="number of slices") + levels = traits.List( + [0.5], traits.Float, usedefault=True, desc="add a contour per level" + ) + colors = traits.List( + ["r"], traits.Str, usedefault=True, desc="colors to be used for contours" + ) + display_mode = traits.Enum( + "ortho", + "x", + "y", + "z", + "yx", + "xz", + "yz", + usedefault=True, + desc="visualization mode", + ) + saturate = traits.Bool(False, usedefault=True, desc="saturate background") + out_file = traits.File(exists=False, desc="output file name") + vmin = traits.Float(desc="minimum intensity") + vmax = traits.Float(desc="maximum intensity") class PlotContoursOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='output svg file') + out_file = File(exists=True, desc="output svg file") class PlotContours(SimpleInterface): """ Plot contours """ + input_spec = PlotContoursInputSpec output_spec = PlotContoursOutputSpec @@ -51,10 +61,9 @@ def _run_interface(self, runtime): if isdefined(self.inputs.out_file): in_file_ref = Path(self.inputs.out_file) - fname = in_file_ref.name.rstrip( - ''.join(in_file_ref.suffixes)) - out_file = (Path(runtime.cwd) / ('plot_%s_contours.svg' % fname)).resolve() - self._results['out_file'] = str(out_file) + fname = in_file_ref.name.rstrip("".join(in_file_ref.suffixes)) + out_file = (Path(runtime.cwd) / ("plot_%s_contours.svg" % fname)).resolve() + self._results["out_file"] = str(out_file) vmax = None if not isdefined(self.inputs.vmax) else self.inputs.vmax vmin = None if not isdefined(self.inputs.vmin) else self.inputs.vmin @@ -68,31 +77,32 @@ def _run_interface(self, runtime): levels=self.inputs.levels, colors=self.inputs.colors, saturate=self.inputs.saturate, - vmin=vmin, vmax=vmax) + vmin=vmin, + vmax=vmax, + ) return runtime class PlotBaseInputSpec(BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, - desc='File to be plotted') - title = traits.Str(desc='a title string for the plot') - annotate = traits.Bool(True, usedefault=True, desc='annotate left/right') + in_file = File(exists=True, mandatory=True, desc="File to be plotted") + title = traits.Str(desc="a title string for the plot") + annotate = traits.Bool(True, usedefault=True, desc="annotate left/right") figsize = traits.Tuple( - (11.69, 8.27), traits.Float, traits.Float, usedefault=True, - desc='Figure size') - dpi = traits.Int(300, usedefault=True, desc='Desired DPI of figure') - out_file = File('mosaic.svg', usedefault=True, desc='output file name') - cmap = traits.Str('Greys_r', usedefault=True) + (11.69, 8.27), traits.Float, traits.Float, usedefault=True, desc="Figure size" + ) + dpi = traits.Int(300, usedefault=True, desc="Desired DPI of figure") + out_file = File("mosaic.svg", usedefault=True, desc="output file name") + cmap = traits.Str("Greys_r", usedefault=True) class PlotMosaicInputSpec(PlotBaseInputSpec): - bbox_mask_file = File(exists=True, desc='brain mask') - only_noise = traits.Bool(False, desc='plot only noise') + bbox_mask_file = File(exists=True, desc="brain mask") + only_noise = traits.Bool(False, desc="plot only noise") class PlotMosaicOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='output pdf file') + out_file = File(exists=True, desc="output pdf file") class PlotMosaic(SimpleInterface): @@ -100,6 +110,7 @@ class PlotMosaic(SimpleInterface): """ Plots slices of a 3D volume into a pdf file """ + input_spec = PlotMosaicInputSpec output_spec = PlotMosaicOutputSpec @@ -119,40 +130,44 @@ def _run_interface(self, runtime): only_plot_noise=self.inputs.only_noise, bbox_mask_file=mask, cmap=self.inputs.cmap, - annotate=self.inputs.annotate) - self._results['out_file'] = str((Path(runtime.cwd) / self.inputs.out_file).resolve()) + annotate=self.inputs.annotate, + ) + self._results["out_file"] = str( + (Path(runtime.cwd) / self.inputs.out_file).resolve() + ) return runtime class PlotSpikesInputSpec(PlotBaseInputSpec): - in_spikes = File(exists=True, mandatory=True, desc='tsv file of spikes') - in_fft = File(exists=True, mandatory=True, desc='nifti file with the 4D FFT') + in_spikes = File(exists=True, mandatory=True, desc="tsv file of spikes") + in_fft = File(exists=True, mandatory=True, desc="nifti file with the 4D FFT") class PlotSpikesOutputSpec(TraitedSpec): - out_file = File(exists=True, desc='output svg file') + out_file = File(exists=True, desc="output svg file") class PlotSpikes(SimpleInterface): """ Plot slices of a dataset with spikes """ + input_spec = PlotSpikesInputSpec output_spec = PlotSpikesOutputSpec def _run_interface(self, runtime): out_file = str((Path(runtime.cwd) / self.inputs.out_file).resolve()) - self._results['out_file'] = out_file + self._results["out_file"] = out_file spikes_list = np.loadtxt(self.inputs.in_spikes, dtype=int).tolist() # No spikes if not spikes_list: - with open(out_file, 'w') as f: - f.write('

No high-frequency spikes were found in this dataset

') + with open(out_file, "w") as f: + f.write("

No high-frequency spikes were found in this dataset

") return runtime spikes_list = [tuple(i) for i in np.atleast_2d(spikes_list).tolist()] plot_spikes( - self.inputs.in_file, self.inputs.in_fft, spikes_list, - out_file=out_file) + self.inputs.in_file, self.inputs.in_fft, spikes_list, out_file=out_file + ) return runtime diff --git a/mriqc/interfaces/webapi.py b/mriqc/interfaces/webapi.py index 63cbac0a0..cb0f9f3fc 100644 --- a/mriqc/interfaces/webapi.py +++ b/mriqc/interfaces/webapi.py @@ -1,108 +1,112 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: from nipype.interfaces.base import ( - Bunch, traits, isdefined, TraitedSpec, BaseInterfaceInputSpec, File, Str, - SimpleInterface + Bunch, + traits, + isdefined, + TraitedSpec, + BaseInterfaceInputSpec, + File, + Str, + SimpleInterface, ) from urllib.parse import urlparse from .. import config -SECRET_KEY = '' +SECRET_KEY = "" # metadata whitelist META_WHITELIST = [ - 'AccelNumReferenceLines', - 'AccelerationFactorPE', - 'AcquisitionMatrix', - 'CogAtlasID', - 'CogPOID', - 'CoilCombinationMethod', - 'ContrastBolusIngredient', - 'ConversionSoftware', - 'ConversionSoftwareVersion', - 'DelayTime', - 'DeviceSerialNumber', - 'EchoTime', - 'EchoTrainLength', - 'EffectiveEchoSpacing', - 'FlipAngle', - 'GradientSetType', - 'HardcopyDeviceSoftwareVersion', - 'ImageType', - 'ImagingFrequency', - 'InPlanePhaseEncodingDirection', - 'InstitutionAddress', - 'InstitutionName', - 'Instructions', - 'InversionTime', - 'MRAcquisitionType', - 'MRTransmitCoilSequence', - 'MagneticFieldStrength', - 'Manufacturer', - 'ManufacturersModelName', - 'MatrixCoilMode', - 'MultibandAccelerationFactor', - 'NumberOfAverages', - 'NumberOfPhaseEncodingSteps', - 'NumberOfVolumesDiscardedByScanner', - 'NumberOfVolumesDiscardedByUser', - 'NumberShots', - 'ParallelAcquisitionTechnique', - 'ParallelReductionFactorInPlane', - 'PartialFourier', - 'PartialFourierDirection', - 'PatientPosition', - 'PercentPhaseFieldOfView', - 'PercentSampling', - 'PhaseEncodingDirection', - 'PixelBandwidth', - 'ProtocolName', - 'PulseSequenceDetails', - 'PulseSequenceType', - 'ReceiveCoilName', - 'RepetitionTime', - 'ScanOptions', - 'ScanningSequence', - 'SequenceName', - 'SequenceVariant', - 'SliceEncodingDirection', - 'SoftwareVersions', - 'TaskDescription', - 'TaskName', - 'TotalReadoutTime', - 'TotalScanTimeSec', - 'TransmitCoilName', - 'VariableFlipAngleFlag', - 'acq_id', - 'modality', - 'run_id', - 'subject_id', - 'task_id', - 'session_id', + "AccelNumReferenceLines", + "AccelerationFactorPE", + "AcquisitionMatrix", + "CogAtlasID", + "CogPOID", + "CoilCombinationMethod", + "ContrastBolusIngredient", + "ConversionSoftware", + "ConversionSoftwareVersion", + "DelayTime", + "DeviceSerialNumber", + "EchoTime", + "EchoTrainLength", + "EffectiveEchoSpacing", + "FlipAngle", + "GradientSetType", + "HardcopyDeviceSoftwareVersion", + "ImageType", + "ImagingFrequency", + "InPlanePhaseEncodingDirection", + "InstitutionAddress", + "InstitutionName", + "Instructions", + "InversionTime", + "MRAcquisitionType", + "MRTransmitCoilSequence", + "MagneticFieldStrength", + "Manufacturer", + "ManufacturersModelName", + "MatrixCoilMode", + "MultibandAccelerationFactor", + "NumberOfAverages", + "NumberOfPhaseEncodingSteps", + "NumberOfVolumesDiscardedByScanner", + "NumberOfVolumesDiscardedByUser", + "NumberShots", + "ParallelAcquisitionTechnique", + "ParallelReductionFactorInPlane", + "PartialFourier", + "PartialFourierDirection", + "PatientPosition", + "PercentPhaseFieldOfView", + "PercentSampling", + "PhaseEncodingDirection", + "PixelBandwidth", + "ProtocolName", + "PulseSequenceDetails", + "PulseSequenceType", + "ReceiveCoilName", + "RepetitionTime", + "ScanOptions", + "ScanningSequence", + "SequenceName", + "SequenceVariant", + "SliceEncodingDirection", + "SoftwareVersions", + "TaskDescription", + "TaskName", + "TotalReadoutTime", + "TotalScanTimeSec", + "TransmitCoilName", + "VariableFlipAngleFlag", + "acq_id", + "modality", + "run_id", + "subject_id", + "task_id", + "session_id", ] -PROV_WHITELIST = [ - 'version', - 'md5sum', - 'software', - 'settings' -] +PROV_WHITELIST = ["version", "md5sum", "software", "settings"] -HASH_BIDS = ['subject_id', 'session_id'] +HASH_BIDS = ["subject_id", "session_id"] class UploadIQMsInputSpec(BaseInterfaceInputSpec): - in_iqms = File(exists=True, mandatory=True, desc='the input IQMs-JSON file') - url = Str(mandatory=True, desc='URL (protocol and name) listening') - port = traits.Int(desc='MRIQCWebAPI service port') - path = Str(desc='MRIQCWebAPI endpoint root path') - email = Str(desc='set sender email') - strict = traits.Bool(False, usedefault=True, - desc='crash if upload was not succesfull') + in_iqms = File(exists=True, mandatory=True, desc="the input IQMs-JSON file") + url = Str(mandatory=True, desc="URL (protocol and name) listening") + port = traits.Int(desc="MRIQCWebAPI service port") + path = Str(desc="MRIQCWebAPI endpoint root path") + email = Str(desc="set sender email") + strict = traits.Bool( + False, usedefault=True, desc="crash if upload was not succesfull" + ) class UploadIQMsOutputSpec(TraitedSpec): - api_id = traits.Either(None, traits.Str, desc="Id for report returned by the web api") + api_id = traits.Either( + None, traits.Str, desc="Id for report returned by the web api" + ) class UploadIQMs(SimpleInterface): @@ -120,13 +124,12 @@ def _run_interface(self, runtime): email = self.inputs.email rawurl = self.inputs.url - if '://' not in rawurl: - rawurl = 'http://' + if "://" not in rawurl: + rawurl = "http://" url = urlparse(rawurl) - if not url.scheme.startswith('http'): - raise RuntimeError( - 'Tried an unknown protocol "%s"' % url.scheme) + if not url.scheme.startswith("http"): + raise RuntimeError('Tried an unknown protocol "%s"' % url.scheme) port = url.port if isdefined(self.inputs.port): @@ -136,26 +139,35 @@ def _run_interface(self, runtime): if isdefined(self.inputs.path): path = self.inputs.path - self._results['api_id'] = None + self._results["api_id"] = None response = upload_qc_metrics( - self.inputs.in_iqms, url.netloc, path=path, - scheme=url.scheme, port=port, email=email) + self.inputs.in_iqms, + url.netloc, + path=path, + scheme=url.scheme, + port=port, + email=email, + ) try: - self._results['api_id'] = response.json()['_id'] + self._results["api_id"] = response.json()["_id"] except (AttributeError, KeyError, ValueError): # response did not give us an ID - errmsg = ('QC metrics upload failed to create an ID for the record ' - 'uplOADED. rEsponse from server follows: {}'.format(response.text)) + errmsg = ( + "QC metrics upload failed to create an ID for the record " + "uplOADED. rEsponse from server follows: {}".format(response.text) + ) config.loggers.interface.warning(errmsg) if response.status_code == 201: - config.loggers.interface.info('QC metrics successfully uploaded.') + config.loggers.interface.info("QC metrics successfully uploaded.") return runtime - errmsg = 'QC metrics failed to upload. Status %d: %s' % ( - response.status_code, response.text) + errmsg = "QC metrics failed to upload. Status %d: %s" % ( + response.status_code, + response.text, + ) config.loggers.interface.warning(errmsg) if self.inputs.strict: raise RuntimeError(response.text) @@ -163,8 +175,7 @@ def _run_interface(self, runtime): return runtime -def upload_qc_metrics(in_iqms, loc, path='', scheme='http', - port=None, email=None): +def upload_qc_metrics(in_iqms, loc, path="", scheme="http", port=None, email=None): """ Upload qc metrics to remote repository. @@ -185,54 +196,58 @@ def upload_qc_metrics(in_iqms, loc, path='', scheme='http', from copy import deepcopy if port is None: - port = 443 if scheme == 'https' else 80 + port = 443 if scheme == "https" else 80 in_data = loads(Path(in_iqms).read_text()) # Extract metadata and provenance - meta = in_data.pop('bids_meta') + meta = in_data.pop("bids_meta") # For compatibility with WebAPI. Shold be rolled back to int - if meta.get('run_id', None) is not None: - meta['run_id'] = '%d' % meta.get('run_id') + if meta.get("run_id", None) is not None: + meta["run_id"] = "%d" % meta.get("run_id") - prov = in_data.pop('provenance') + prov = in_data.pop("provenance") # At this point, data should contain only IQMs data = deepcopy(in_data) # Check modality - modality = meta.get('modality', 'None') - if modality not in ('T1w', 'bold', 'T2w'): - errmsg = ('Submitting to MRIQCWebAPI: image modality should be "bold", "T1w", or "T2w", ' - '(found "%s")' % modality) + modality = meta.get("modality", "None") + if modality not in ("T1w", "bold", "T2w"): + errmsg = ( + 'Submitting to MRIQCWebAPI: image modality should be "bold", "T1w", or "T2w", ' + '(found "%s")' % modality + ) return Bunch(status_code=1, text=errmsg) # Filter metadata values that aren't in whitelist - data['bids_meta'] = {k: meta[k] for k in META_WHITELIST if k in meta} + data["bids_meta"] = {k: meta[k] for k in META_WHITELIST if k in meta} # Filter provenance values that aren't in whitelist - data['provenance'] = {k: prov[k] for k in PROV_WHITELIST if k in prov} + data["provenance"] = {k: prov[k] for k in PROV_WHITELIST if k in prov} # Hash fields that may contain personal information - data['bids_meta'] = _hashfields(data['bids_meta']) + data["bids_meta"] = _hashfields(data["bids_meta"]) if email: - data['provenance']['email'] = email + data["provenance"]["email"] = email - if path and not path.endswith('/'): - path += '/' - if path.startswith('/'): + if path and not path.endswith("/"): + path += "/" + if path.startswith("/"): path = path[1:] - headers = {'Authorization': SECRET_KEY, "Content-Type": "application/json"} + headers = {"Authorization": SECRET_KEY, "Content-Type": "application/json"} - webapi_url = '{}://{}:{}/{}{}'.format(scheme, loc, port, path, modality) - config.loggers.interface.info('MRIQC Web API: submitting to <%s>', webapi_url) + webapi_url = "{}://{}:{}/{}{}".format(scheme, loc, port, path, modality) + config.loggers.interface.info("MRIQC Web API: submitting to <%s>", webapi_url) try: # if the modality is bold, call "bold" endpoint response = requests.post(webapi_url, headers=headers, data=dumps(data)) except requests.ConnectionError as err: - errmsg = 'QC metrics failed to upload due to connection error shown below:\n%s' % err + errmsg = ( + "QC metrics failed to upload due to connection error shown below:\n%s" % err + ) return Bunch(status_code=1, text=errmsg) return response diff --git a/mriqc/qc/__init__.py b/mriqc/qc/__init__.py index 2dc559630..415dcad39 100644 --- a/mriqc/qc/__init__.py +++ b/mriqc/qc/__init__.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """ diff --git a/mriqc/qc/anatomical.py b/mriqc/qc/anatomical.py index bfd8bb6ee..0a4188f65 100644 --- a/mriqc/qc/anatomical.py +++ b/mriqc/qc/anatomical.py @@ -1,9 +1,5 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# pylint: disable=no-member - r""" Measures based on noise measurements @@ -199,7 +195,7 @@ DIETRICH_FACTOR = 1.0 / sqrt(2 / (4 - pi)) -FSL_FAST_LABELS = {'csf': 1, 'gm': 2, 'wm': 3, 'bg': 0} +FSL_FAST_LABELS = {"csf": 1, "gm": 2, "wm": 3, "bg": 0} PY3 = version_info[0] > 2 @@ -248,8 +244,10 @@ def snr_dietrich(mu_fg, sigma_air): """ if sigma_air < 1.0: from .. import config + config.loggers.interface.warning( - f'SNRd - background sigma is too small ({sigma_air})') + f"SNRd - background sigma is too small ({sigma_air})" + ) sigma_air += 1.0 return float(DIETRICH_FACTOR * mu_fg / sigma_air) @@ -367,15 +365,19 @@ def efc(img, framemask=None): n_vox = np.sum(1 - framemask) # Calculate the maximum value of the EFC (which occurs any time all # voxels have the same value) - efc_max = 1.0 * n_vox * (1.0 / np.sqrt(n_vox)) * \ - np.log(1.0 / np.sqrt(n_vox)) + efc_max = 1.0 * n_vox * (1.0 / np.sqrt(n_vox)) * np.log(1.0 / np.sqrt(n_vox)) # Calculate the total image energy - b_max = np.sqrt((img[framemask == 0]**2).sum()) + b_max = np.sqrt((img[framemask == 0] ** 2).sum()) # Calculate EFC (add 1e-16 to the image data to keep log happy) - return float((1.0 / efc_max) * np.sum((img[framemask == 0] / b_max) * np.log( - (img[framemask == 0] + 1e-16) / b_max))) + return float( + (1.0 / efc_max) + * np.sum( + (img[framemask == 0] / b_max) + * np.log((img[framemask == 0] + 1e-16) / b_max) + ) + ) def wm2max(img, mu_wm): @@ -444,21 +446,21 @@ def art_qi2(img, airmask, min_voxels=int(1e3), max_voxels=int(3e5), save_plot=Tr data = data[data > 0] # Write out figure of the fitting - out_file = op.abspath('error.svg') - with open(out_file, 'w') as ofh: - ofh.write('

Background noise fitting could not be plotted.

') + out_file = op.abspath("error.svg") + with open(out_file, "w") as ofh: + ofh.write("

Background noise fitting could not be plotted.

") if len(data) < min_voxels: return 0.0, out_file - modelx = data if len(data) < max_voxels else np.random.choice( - data, size=max_voxels) + modelx = data if len(data) < max_voxels else np.random.choice(data, size=max_voxels) x_grid = np.linspace(0.0, np.percentile(data, 99), 1000) # Estimate data pdf with KDE on a random subsample - kde_skl = KernelDensity(bandwidth=0.05 * np.percentile(data, 98), - kernel='gaussian').fit(modelx[:, np.newaxis]) + kde_skl = KernelDensity( + bandwidth=0.05 * np.percentile(data, 98), kernel="gaussian" + ).fit(modelx[:, np.newaxis]) kde = np.exp(kde_skl.score_samples(x_grid[:, np.newaxis])) # Find cutoff @@ -518,14 +520,16 @@ def rpve(pvms, seg): if lid == 0: continue pvmap = pvms[lid - 1] - pvmap[pvmap < 0.] = 0. - pvmap[pvmap >= 1.] = 1. + pvmap[pvmap < 0.0] = 0.0 + pvmap[pvmap >= 1.0] = 1.0 totalvol = np.sum(pvmap > 0.0) upth = np.percentile(pvmap[pvmap > 0], 98) loth = np.percentile(pvmap[pvmap > 0], 2) pvmap[pvmap < loth] = 0 pvmap[pvmap > upth] = 0 - pvfs[k] = (pvmap[pvmap > 0.5].sum() + (1.0 - pvmap[pvmap <= 0.5]).sum()) / totalvol + pvfs[k] = ( + pvmap[pvmap > 0.5].sum() + (1.0 - pvmap[pvmap <= 0.5]).sum() + ) / totalvol return {k: float(v) for k, v in list(pvfs.items())} @@ -559,15 +563,16 @@ def summary_stats(img, pvms, airmask=None, erode=True): elif dims == 3: stats_pvms = [np.ones_like(pvms) - pvms, pvms] else: - raise RuntimeError('Incorrect image dimensions ({0:d})'.format( - np.array(pvms).ndim)) + raise RuntimeError( + "Incorrect image dimensions ({0:d})".format(np.array(pvms).ndim) + ) if airmask is not None: stats_pvms[0] = airmask labels = list(FSL_FAST_LABELS.items()) if len(stats_pvms) == 2: - labels = list(zip(['bg', 'fg'], list(range(2)))) + labels = list(zip(["bg", "fg"], list(range(2)))) output = {} for k, lid in labels: @@ -576,48 +581,48 @@ def summary_stats(img, pvms, airmask=None, erode=True): if erode: struc = nd.generate_binary_structure(3, 2) - mask = nd.binary_erosion( - mask, structure=struc).astype(np.uint8) + mask = nd.binary_erosion(mask, structure=struc).astype(np.uint8) nvox = float(mask.sum()) if nvox < 1e3: config.loggers.interface.warning( 'calculating summary stats of label "%s" in a very small ' - 'mask (%d voxels)', k, int(nvox)) - if k == 'bg': + "mask (%d voxels)", + k, + int(nvox), + ) + if k == "bg": continue output[k] = { - 'mean': float(img[mask == 1].mean()), - 'stdv': float(img[mask == 1].std()), - 'median': float(np.median(img[mask == 1])), - 'mad': float(mad(img[mask == 1])), - 'p95': float(np.percentile(img[mask == 1], 95)), - 'p05': float(np.percentile(img[mask == 1], 5)), - 'k': float(kurtosis(img[mask == 1])), - 'n': nvox, + "mean": float(img[mask == 1].mean()), + "stdv": float(img[mask == 1].std()), + "median": float(np.median(img[mask == 1])), + "mad": float(mad(img[mask == 1])), + "p95": float(np.percentile(img[mask == 1], 95)), + "p05": float(np.percentile(img[mask == 1], 5)), + "k": float(kurtosis(img[mask == 1])), + "n": nvox, } - if 'bg' not in output: - output['bg'] = { - 'mean': 0., - 'median': 0., - 'p95': 0., - 'p05': 0., - 'k': 0., - 'stdv': sqrt(sum(val['stdv']**2 - for _, val in list(output.items()))), - 'mad': sqrt(sum(val['mad']**2 - for _, val in list(output.items()))), - 'n': sum(val['n'] for _, val in list(output.items())) + if "bg" not in output: + output["bg"] = { + "mean": 0.0, + "median": 0.0, + "p95": 0.0, + "p05": 0.0, + "k": 0.0, + "stdv": sqrt(sum(val["stdv"] ** 2 for _, val in list(output.items()))), + "mad": sqrt(sum(val["mad"] ** 2 for _, val in list(output.items()))), + "n": sum(val["n"] for _, val in list(output.items())), } - if 'bg' in output and output['bg']['mad'] == 0.0 and output['bg']['stdv'] > 1.0: + if "bg" in output and output["bg"]["mad"] == 0.0 and output["bg"]["stdv"] > 1.0: config.loggers.interface.warning( - 'estimated MAD in the background was too small (MAD=%f)', - output['bg']['mad'] + "estimated MAD in the background was too small (MAD=%f)", + output["bg"]["mad"], ) - output['bg']['mad'] = output['bg']['stdv'] / DIETRICH_FACTOR + output["bg"]["mad"] = output["bg"]["stdv"] / DIETRICH_FACTOR return output @@ -631,8 +636,8 @@ def _prepare_mask(mask, label, erode=True): fgmask[fgmask != label] = 0 fgmask[fgmask == label] = 1 else: - fgmask[fgmask > .95] = 1. - fgmask[fgmask < 1.] = 0 + fgmask[fgmask > 0.95] = 1.0 + fgmask[fgmask < 1.0] = 0 if erode: # Create a structural element to be used in an opening operation. diff --git a/mriqc/qc/functional.py b/mriqc/qc/functional.py index 5264f4444..b0a85acc1 100644 --- a/mriqc/qc/functional.py +++ b/mriqc/qc/functional.py @@ -1,6 +1,5 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# pylint: disable=no-member r""" MRIQC: image quality metrics for functional MRI. @@ -197,7 +196,7 @@ import os.path as op import numpy as np -RAS_AXIS_ORDER = {'x': 0, 'y': 1, 'z': 2} +RAS_AXIS_ORDER = {"x": 0, "y": 1, "z": 2} def gsr(epi_data, mask, direction="y", ref_file=None, out_file=None): @@ -229,22 +228,22 @@ def gsr(epi_data, mask, direction="y", ref_file=None, out_file=None): """ direction = direction.lower() - if direction[-1] not in ['x', 'y', 'all']: - raise Exception("Unknown direction {}, should be one of x, -x, y, -y, all".format( - direction)) + if direction[-1] not in ["x", "y", "all"]: + raise Exception( + "Unknown direction {}, should be one of x, -x, y, -y, all".format(direction) + ) - if direction == 'all': + if direction == "all": result = [] - for newdir in ['x', 'y']: + for newdir in ["x", "y"]: ofile = None if out_file is not None: fname, ext = op.splitext(ofile) - if ext == '.gz': + if ext == ".gz": fname, ext2 = op.splitext(fname) ext = ext2 + ext - ofile = '{0}_{1}{2}'.format(fname, newdir, ext) - result += [gsr(epi_data, mask, newdir, - ref_file=ref_file, out_file=ofile)] + ofile = "{0}_{1}{2}".format(fname, newdir, ext) + result += [gsr(epi_data, mask, newdir, ref_file=ref_file, out_file=ofile)] return result # Roll data of mask through the appropriate axis diff --git a/mriqc/qc/tests/test_anatomical.py b/mriqc/qc/tests/test_anatomical.py index 50995f22b..cce468b84 100644 --- a/mriqc/qc/tests/test_anatomical.py +++ b/mriqc/qc/tests/test_anatomical.py @@ -17,13 +17,13 @@ import pytest from scipy.stats import rice from builtins import object + # from numpy.testing import allclose from ..anatomical import art_qi2 class GroundTruth(object): - - def get_data(self, sigma, noise='normal'): + def get_data(self, sigma, noise="normal"): """Generates noisy 3d data""" size = (50, 50, 50) test_data = np.ones(size) @@ -37,12 +37,14 @@ def get_data(self, sigma, noise='normal'): test_data[bgdata > 0] = bg_mean test_data[wmdata > 0] = wm_mean - if noise == 'rice': + if noise == "rice": test_data += rice.rvs(0.77, scale=sigma * wm_mean, size=test_data.shape) - elif noise == 'rayleigh': + elif noise == "rayleigh": test_data += np.random.rayleigh(scale=sigma * wm_mean, size=test_data.shape) else: - test_data += np.random.normal(0., scale=sigma * wm_mean, size=test_data.shape) + test_data += np.random.normal( + 0.0, scale=sigma * wm_mean, size=test_data.shape + ) return test_data, wmdata, bgdata @@ -96,4 +98,4 @@ def test_qi2(gtruth, sigma): data, _, bgdata = gtruth.get_data(sigma, rice) value, _ = art_qi2(data, bgdata, save_plot=False) rmtree(tmpdir) - assert value > .0 and value < 0.04 + assert value > 0.0 and value < 0.04 diff --git a/mriqc/reports/__init__.py b/mriqc/reports/__init__.py index 1f332377b..508cf00d4 100644 --- a/mriqc/reports/__init__.py +++ b/mriqc/reports/__init__.py @@ -49,31 +49,31 @@ from .group import gen_html as group_html REPORT_TITLES = { - 'bold': [ - ('BOLD average', 'bold-avg'), - ('Standard deviation map', 'std-map'), - ('FMRI summary plot', 'fmri-summary'), - ('Zoomed-in BOLD average', 'zoomed-avg'), - ('Background noise', 'bg-noise'), - ('Calculated brain mask', 'brain-msk'), - ('Approximate spatial normalization', 'normalization'), + "bold": [ + ("BOLD average", "bold-avg"), + ("Standard deviation map", "std-map"), + ("FMRI summary plot", "fmri-summary"), + ("Zoomed-in BOLD average", "zoomed-avg"), + ("Background noise", "bg-noise"), + ("Calculated brain mask", "brain-msk"), + ("Approximate spatial normalization", "normalization"), ], - 'T1w': [ - ('Zoomed-in (brain mask)', 'zoomed-view'), - ('Background noise', 'bg-noise'), - ('Approximate spatial normalization', 'normalization'), - ('Brain mask', 'brain-msk'), - ('Brain tissue segmentation', 'brain-seg'), - ('Artifacts in background', 'bg-arts'), - ('Head outline', 'head-msk'), - ('"Hat" mask', 'hat-msk'), - ('Distribution of the noise in the background', 'qi2-fitting'), + "T1w": [ + ("Zoomed-in (brain mask)", "zoomed-view"), + ("Background noise", "bg-noise"), + ("Approximate spatial normalization", "normalization"), + ("Brain mask", "brain-msk"), + ("Brain tissue segmentation", "brain-seg"), + ("Artifacts in background", "bg-arts"), + ("Head outline", "head-msk"), + ('"Hat" mask', "hat-msk"), + ("Distribution of the noise in the background", "qi2-fitting"), ], } -REPORT_TITLES['T2w'] = deepcopy(REPORT_TITLES['T1w']) +REPORT_TITLES["T2w"] = deepcopy(REPORT_TITLES["T1w"]) __all__ = [ - 'individual_html', - 'group_html', + "individual_html", + "group_html", ] diff --git a/mriqc/reports/group.py b/mriqc/reports/group.py index 9044987ac..ad9d7d9e1 100644 --- a/mriqc/reports/group.py +++ b/mriqc/reports/group.py @@ -25,88 +25,179 @@ def gen_html(csv_file, mod, csv_failed=None, out_file=None): from io import BytesIO as TextIO QCGROUPS = { - 'T1w': [ - (['cjv'], None), - (['cnr'], None), - (['efc'], None), - (['fber'], None), - (['wm2max'], None), - (['snr_csf', 'snr_gm', 'snr_wm'], None), - (['snrd_csf', 'snrd_gm', 'snrd_wm'], None), - (['fwhm_avg', 'fwhm_x', 'fwhm_y', 'fwhm_z'], 'vox'), - (['qi_1', 'qi_2'], None), - (['inu_range', 'inu_med'], None), - (['icvs_csf', 'icvs_gm', 'icvs_wm'], None), - (['rpve_csf', 'rpve_gm', 'rpve_wm'], None), - (['tpm_overlap_csf', 'tpm_overlap_gm', 'tpm_overlap_wm'], None), - (['summary_bg_mean', 'summary_bg_median', 'summary_bg_stdv', 'summary_bg_mad', - 'summary_bg_k', 'summary_bg_p05', 'summary_bg_p95'], None), - (['summary_csf_mean', 'summary_csf_median', 'summary_csf_stdv', 'summary_csf_mad', - 'summary_csf_k', 'summary_csf_p05', 'summary_csf_p95'], None), - (['summary_gm_mean', 'summary_gm_median', 'summary_gm_stdv', 'summary_gm_mad', - 'summary_gm_k', 'summary_gm_p05', 'summary_gm_p95'], None), - (['summary_wm_mean', 'summary_wm_median', 'summary_wm_stdv', 'summary_wm_mad', - 'summary_wm_k', 'summary_wm_p05', 'summary_wm_p95'], None) + "T1w": [ + (["cjv"], None), + (["cnr"], None), + (["efc"], None), + (["fber"], None), + (["wm2max"], None), + (["snr_csf", "snr_gm", "snr_wm"], None), + (["snrd_csf", "snrd_gm", "snrd_wm"], None), + (["fwhm_avg", "fwhm_x", "fwhm_y", "fwhm_z"], "vox"), + (["qi_1", "qi_2"], None), + (["inu_range", "inu_med"], None), + (["icvs_csf", "icvs_gm", "icvs_wm"], None), + (["rpve_csf", "rpve_gm", "rpve_wm"], None), + (["tpm_overlap_csf", "tpm_overlap_gm", "tpm_overlap_wm"], None), + ( + [ + "summary_bg_mean", + "summary_bg_median", + "summary_bg_stdv", + "summary_bg_mad", + "summary_bg_k", + "summary_bg_p05", + "summary_bg_p95", + ], + None, + ), + ( + [ + "summary_csf_mean", + "summary_csf_median", + "summary_csf_stdv", + "summary_csf_mad", + "summary_csf_k", + "summary_csf_p05", + "summary_csf_p95", + ], + None, + ), + ( + [ + "summary_gm_mean", + "summary_gm_median", + "summary_gm_stdv", + "summary_gm_mad", + "summary_gm_k", + "summary_gm_p05", + "summary_gm_p95", + ], + None, + ), + ( + [ + "summary_wm_mean", + "summary_wm_median", + "summary_wm_stdv", + "summary_wm_mad", + "summary_wm_k", + "summary_wm_p05", + "summary_wm_p95", + ], + None, + ), ], - 'T2w': [ - (['cjv'], None), - (['cnr'], None), - (['efc'], None), - (['fber'], None), - (['wm2max'], None), - (['snr_csf', 'snr_gm', 'snr_wm'], None), - (['snrd_csf', 'snrd_gm', 'snrd_wm'], None), - (['fwhm_avg', 'fwhm_x', 'fwhm_y', 'fwhm_z'], 'mm'), - (['qi_1', 'qi_2'], None), - (['inu_range', 'inu_med'], None), - (['icvs_csf', 'icvs_gm', 'icvs_wm'], None), - (['rpve_csf', 'rpve_gm', 'rpve_wm'], None), - (['tpm_overlap_csf', 'tpm_overlap_gm', 'tpm_overlap_wm'], None), - (['summary_bg_mean', 'summary_bg_stdv', 'summary_bg_k', - 'summary_bg_p05', 'summary_bg_p95'], None), - (['summary_csf_mean', 'summary_csf_stdv', 'summary_csf_k', - 'summary_csf_p05', 'summary_csf_p95'], None), - (['summary_gm_mean', 'summary_gm_stdv', 'summary_gm_k', - 'summary_gm_p05', 'summary_gm_p95'], None), - (['summary_wm_mean', 'summary_wm_stdv', 'summary_wm_k', - 'summary_wm_p05', 'summary_wm_p95'], None) + "T2w": [ + (["cjv"], None), + (["cnr"], None), + (["efc"], None), + (["fber"], None), + (["wm2max"], None), + (["snr_csf", "snr_gm", "snr_wm"], None), + (["snrd_csf", "snrd_gm", "snrd_wm"], None), + (["fwhm_avg", "fwhm_x", "fwhm_y", "fwhm_z"], "mm"), + (["qi_1", "qi_2"], None), + (["inu_range", "inu_med"], None), + (["icvs_csf", "icvs_gm", "icvs_wm"], None), + (["rpve_csf", "rpve_gm", "rpve_wm"], None), + (["tpm_overlap_csf", "tpm_overlap_gm", "tpm_overlap_wm"], None), + ( + [ + "summary_bg_mean", + "summary_bg_stdv", + "summary_bg_k", + "summary_bg_p05", + "summary_bg_p95", + ], + None, + ), + ( + [ + "summary_csf_mean", + "summary_csf_stdv", + "summary_csf_k", + "summary_csf_p05", + "summary_csf_p95", + ], + None, + ), + ( + [ + "summary_gm_mean", + "summary_gm_stdv", + "summary_gm_k", + "summary_gm_p05", + "summary_gm_p95", + ], + None, + ), + ( + [ + "summary_wm_mean", + "summary_wm_stdv", + "summary_wm_k", + "summary_wm_p05", + "summary_wm_p95", + ], + None, + ), + ], + "bold": [ + (["efc"], None), + (["fber"], None), + (["fwhm", "fwhm_x", "fwhm_y", "fwhm_z"], "mm"), + (["gsr_%s" % a for a in ["x", "y"]], None), + (["snr"], None), + (["dvars_std", "dvars_vstd"], None), + (["dvars_nstd"], None), + (["fd_mean"], "mm"), + (["fd_num"], "# timepoints"), + (["fd_perc"], "% timepoints"), + (["spikes_num"], "# slices"), + (["dummy_trs"], "# TRs"), + (["gcor"], None), + (["tsnr"], None), + (["aor"], None), + (["aqi"], None), + ( + [ + "summary_bg_mean", + "summary_bg_stdv", + "summary_bg_k", + "summary_bg_p05", + "summary_bg_p95", + ], + None, + ), + ( + [ + "summary_fg_mean", + "summary_fg_stdv", + "summary_fg_k", + "summary_fg_p05", + "summary_fg_p95", + ], + None, + ), ], - 'bold': [ - (['efc'], None), - (['fber'], None), - (['fwhm', 'fwhm_x', 'fwhm_y', 'fwhm_z'], 'mm'), - (['gsr_%s' % a for a in ['x', 'y']], None), - (['snr'], None), - (['dvars_std', 'dvars_vstd'], None), - (['dvars_nstd'], None), - (['fd_mean'], 'mm'), - (['fd_num'], '# timepoints'), - (['fd_perc'], '% timepoints'), - (['spikes_num'], '# slices'), - (['dummy_trs'], '# TRs'), - (['gcor'], None), - (['tsnr'], None), - (['aor'], None), - (['aqi'], None), - (['summary_bg_mean', 'summary_bg_stdv', 'summary_bg_k', - 'summary_bg_p05', 'summary_bg_p95'], None), - (['summary_fg_mean', 'summary_fg_stdv', 'summary_fg_k', - 'summary_fg_p05', 'summary_fg_p95'], None), - ] } - if csv_file.suffix == '.csv': + if csv_file.suffix == ".csv": def_comps = list(BIDS_COMP.keys()) - dataframe = pd.read_csv(csv_file, index_col=False, - dtype={comp: object for comp in def_comps}) + dataframe = pd.read_csv( + csv_file, index_col=False, dtype={comp: object for comp in def_comps} + ) id_labels = list(set(def_comps) & set(dataframe.columns.ravel().tolist())) - dataframe['label'] = dataframe[id_labels].apply(_format_labels, args=(id_labels,), - axis=1) + dataframe["label"] = dataframe[id_labels].apply( + _format_labels, args=(id_labels,), axis=1 + ) else: - dataframe = pd.read_csv(csv_file, index_col=False, sep='\t', - dtype={'bids_name': object}) - dataframe = dataframe.rename(index=str, columns={'bids_name': 'label'}) + dataframe = pd.read_csv( + csv_file, index_col=False, sep="\t", dtype={"bids_name": object} + ) + dataframe = dataframe.rename(index=str, columns={"bids_name": "label"}) nPart = len(dataframe) @@ -127,43 +218,54 @@ def gen_html(csv_file, mod, csv_failed=None, out_file=None): csv_groups = [] datacols = dataframe.columns.ravel().tolist() for group, units in QCGROUPS[mod]: - dfdict = {'iqm': [], 'value': [], 'label': [], 'units': []} + dfdict = {"iqm": [], "value": [], "label": [], "units": []} for iqm in group: if iqm in datacols: values = dataframe[[iqm]].values.ravel().tolist() if values: - dfdict['iqm'] += [iqm] * nPart - dfdict['units'] += [units] * nPart - dfdict['value'] += values - dfdict['label'] += dataframe[['label']].values.ravel().tolist() + dfdict["iqm"] += [iqm] * nPart + dfdict["units"] += [units] * nPart + dfdict["value"] += values + dfdict["label"] += dataframe[["label"]].values.ravel().tolist() # Save only if there are values - if dfdict['value']: + if dfdict["value"]: csv_df = pd.DataFrame(dfdict) csv_str = TextIO() - csv_df[['iqm', 'value', 'label', 'units']].to_csv(csv_str, index=False) + csv_df[["iqm", "value", "label", "units"]].to_csv(csv_str, index=False) csv_groups.append(csv_str.getvalue()) if out_file is None: - out_file = op.abspath('group.html') + out_file = op.abspath("group.html") tpl = GroupTemplate() - tpl.generate_conf({ - 'modality': mod, - 'timestamp': datetime.datetime.now().strftime("%Y-%m-%d, %H:%M"), - 'version': ver, - 'csv_groups': csv_groups, - 'failed': failed, - 'boxplots_js': open(pkgrf('mriqc', op.join('data', 'reports', - 'embed_resources', - 'boxplots.js'))).read(), - 'd3_js': open(pkgrf('mriqc', op.join('data', 'reports', - 'embed_resources', - 'd3.min.js'))).read(), - 'boxplots_css': open(pkgrf('mriqc', op.join('data', 'reports', - 'embed_resources', - 'boxplots.css'))).read() - }, out_file) + tpl.generate_conf( + { + "modality": mod, + "timestamp": datetime.datetime.now().strftime("%Y-%m-%d, %H:%M"), + "version": ver, + "csv_groups": csv_groups, + "failed": failed, + "boxplots_js": open( + pkgrf( + "mriqc", + op.join("data", "reports", "embed_resources", "boxplots.js"), + ) + ).read(), + "d3_js": open( + pkgrf( + "mriqc", op.join("data", "reports", "embed_resources", "d3.min.js") + ) + ).read(), + "boxplots_css": open( + pkgrf( + "mriqc", + op.join("data", "reports", "embed_resources", "boxplots.css"), + ) + ).read(), + }, + out_file, + ) return out_file @@ -174,5 +276,5 @@ def _format_labels(row, id_labels): for col_id, prefix in list(BIDS_COMP.items()): if col_id in id_labels: - crow.append('%s-%s' % (prefix, row[[col_id]].values[0])) - return '_'.join(crow) + crow.append("%s-%s" % (prefix, row[[col_id]].values[0])) + return "_".join(crow) diff --git a/mriqc/reports/individual.py b/mriqc/reports/individual.py index 15ede8652..771aeb419 100644 --- a/mriqc/reports/individual.py +++ b/mriqc/reports/individual.py @@ -14,43 +14,48 @@ def individual_html(in_iqms, in_plots=None, api_id=None): from ..data import IndividualTemplate def _get_details(in_iqms, modality): - in_prov = in_iqms.pop('provenance', {}) - warn_dict = in_prov.pop('warnings', None) - sett_dict = in_prov.pop('settings', None) + in_prov = in_iqms.pop("provenance", {}) + warn_dict = in_prov.pop("warnings", None) + sett_dict = in_prov.pop("settings", None) wf_details = [] - if modality == 'bold': - bold_exclude_index = in_iqms.get('dumb_trs') + if modality == "bold": + bold_exclude_index = in_iqms.get("dumb_trs") if bold_exclude_index is None: - config.loggers.cli.warning('Building bold report: no exclude index was found') + config.loggers.cli.warning( + "Building bold report: no exclude index was found" + ) elif bold_exclude_index > 0: msg = """\ Non-steady state (strong T1 contrast) has been detected in the \ first {} volumes. They were excluded before generating any QC measures and plots.""" wf_details.append(msg.format(bold_exclude_index)) - hmc_fsl = sett_dict.pop('hmc_fsl') + hmc_fsl = sett_dict.pop("hmc_fsl") if hmc_fsl is not None: - msg = 'Framewise Displacement was computed using ' + msg = "Framewise Displacement was computed using " if hmc_fsl: - msg += 'FSL mcflirt' + msg += "FSL mcflirt" else: - msg += 'AFNI 3dvolreg' + msg += "AFNI 3dvolreg" wf_details.append(msg) - fd_thres = sett_dict.pop('fd_thres') + fd_thres = sett_dict.pop("fd_thres") if fd_thres is not None: wf_details.append( - 'Framewise Displacement threshold was defined at %f mm' % fd_thres) - elif modality in ('T1w', 'T2w'): - if warn_dict.pop('small_air_mask', False): + "Framewise Displacement threshold was defined at %f mm" % fd_thres + ) + elif modality in ("T1w", "T2w"): + if warn_dict.pop("small_air_mask", False): wf_details.append( - 'Detected hat mask was too small') + 'Detected hat mask was too small' + ) - if warn_dict.pop('large_rot_frame', False): + if warn_dict.pop("large_rot_frame", False): wf_details.append( 'Detected a zero-filled frame, has the original ' - 'image been rotated?') + "image been rotated?" + ) return in_prov, wf_details, sett_dict @@ -62,50 +67,52 @@ def _get_details(in_iqms, modality): out_file = str(Path(in_iqms.with_suffix(".html").name).resolve()) # Extract and prune metadata - metadata = iqms_dict.pop('bids_meta', None) - mod = metadata.pop('modality', None) + metadata = iqms_dict.pop("bids_meta", None) + mod = metadata.pop("modality", None) prov, wf_details, _ = _get_details(iqms_dict, mod) - file_id = [metadata.pop(k, None) - for k in list(BIDS_COMP.keys())] + file_id = [metadata.pop(k, None) for k in list(BIDS_COMP.keys())] file_id = [comp for comp in file_id if comp is not None] if in_plots is None: in_plots = [] else: - if any(('melodic_reportlet' in k for k in in_plots)): - REPORT_TITLES['bold'].insert(3, ('ICA components', 'ica-comps')) - if any(('plot_spikes' in k for k in in_plots)): - REPORT_TITLES['bold'].insert(3, ('Spikes', 'spikes')) + if any(("melodic_reportlet" in k for k in in_plots)): + REPORT_TITLES["bold"].insert(3, ("ICA components", "ica-comps")) + if any(("plot_spikes" in k for k in in_plots)): + REPORT_TITLES["bold"].insert(3, ("Spikes", "spikes")) - in_plots = [(REPORT_TITLES[mod][i] + (read_report_snippet(v), )) - for i, v in enumerate(in_plots)] + in_plots = [ + (REPORT_TITLES[mod][i] + (read_report_snippet(v),)) + for i, v in enumerate(in_plots) + ] pred_qa = None # metadata.pop('mriqc_pred', None) _config = { - 'modality': mod, - 'dataset': metadata.pop('dataset', None), - 'bids_name': in_iqms.with_suffix("").name, - 'timestamp': datetime.datetime.now().strftime("%Y-%m-%d, %H:%M"), - 'version': config.environment.version, - 'imparams': iqms2html(iqms_dict, 'iqms-table'), - 'svg_files': in_plots, - 'workflow_details': wf_details, - 'webapi_url': prov.pop('webapi_url'), - 'webapi_port': prov.pop('webapi_port'), - 'provenance': iqms2html(prov, 'provenance-table'), - 'md5sum': prov['md5sum'], - 'metadata': iqms2html(metadata, 'metadata-table'), - 'pred_qa': pred_qa + "modality": mod, + "dataset": metadata.pop("dataset", None), + "bids_name": in_iqms.with_suffix("").name, + "timestamp": datetime.datetime.now().strftime("%Y-%m-%d, %H:%M"), + "version": config.environment.version, + "imparams": iqms2html(iqms_dict, "iqms-table"), + "svg_files": in_plots, + "workflow_details": wf_details, + "webapi_url": prov.pop("webapi_url"), + "webapi_port": prov.pop("webapi_port"), + "provenance": iqms2html(prov, "provenance-table"), + "md5sum": prov["md5sum"], + "metadata": iqms2html(metadata, "metadata-table"), + "pred_qa": pred_qa, } - if _config['metadata'] is None: - _config['workflow_details'].append( + if _config["metadata"] is None: + _config["workflow_details"].append( 'File has no metadata ' - '(sidecar JSON file missing or empty)') + "(sidecar JSON file missing or empty)" + ) tpl = IndividualTemplate() tpl.generate_conf(_config, out_file) - config.loggers.cli.info('Generated individual log (%s)', out_file) + config.loggers.cli.info("Generated individual log (%s)", out_file) return out_file diff --git a/mriqc/reports/utils.py b/mriqc/reports/utils.py index 029264dba..5c3d5680e 100644 --- a/mriqc/reports/utils.py +++ b/mriqc/reports/utils.py @@ -1,14 +1,6 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# -# @Author: oesteban -# @Date: 2016-01-05 11:33:39 -# @Email: code@oscaresteban.es -# @Last modified by: oesteban -# @Last Modified time: 2017-05-25 13:41:58 -""" Helpers in report generation""" +"""Helpers in report generation""" def iqms2html(indict, table_id): @@ -20,19 +12,19 @@ def iqms2html(indict, table_id): depth = max([len(col) for col in columns]) result_str = '\n' % table_id - td = '{0}'.format + td = "{0}".format for line in columns: - result_str += '' + result_str += "" ncols = len(line) for i, col in enumerate(line): colspan = 0 - colstring = '' + colstring = "" if (depth - ncols) > 0 and i == ncols - 2: colspan = (depth - ncols) + 1 - colstring = ' colspan=%d' % colspan + colstring = " colspan=%d" % colspan result_str += td(col, colstring) - result_str += '\n' - result_str += '
\n' + result_str += "\n" + result_str += "\n" return result_str @@ -45,7 +37,7 @@ def unfold_columns(indict, prefix=None): data = [] subdict = {} for key in keys: - col = key.split('_', 1) + col = key.split("_", 1) if len(col) == 1: value = indict[col[0]] data.append(prefix + [col[0], value]) @@ -59,11 +51,10 @@ def unfold_columns(indict, prefix=None): sskeys = list(subdict[skey].keys()) if len(sskeys) == 1: value = subdict[skey][sskeys[0]] - newkey = '_'.join([skey] + sskeys) + newkey = "_".join([skey] + sskeys) data.append(prefix + [newkey, value]) else: - data += unfold_columns( - subdict[skey], prefix=prefix + [skey]) + data += unfold_columns(subdict[skey], prefix=prefix + [skey]) return data @@ -74,23 +65,23 @@ def read_report_snippet(in_file): import re from io import open # pylint: disable=W0622 - is_svg = (op.splitext(op.basename(in_file))[1] == '.svg') + is_svg = op.splitext(op.basename(in_file))[1] == ".svg" with open(in_file) as thisfile: if not is_svg: return thisfile.read() svg_tag_line = 0 - content = thisfile.read().split('\n') + content = thisfile.read().split("\n") corrected = [] for i, line in enumerate(content): if "[a-zA-Z0-9]+)(_ses-(?P[a-zA-Z0-9]+))?\ @@ -51,12 +56,12 @@ def reorder_csv(csv_file, out_file=None): dataframe = pd.read_csv(csv_file) cols = dataframe.columns.tolist() # pylint: disable=no-member try: - cols.remove('Unnamed: 0') + cols.remove("Unnamed: 0") except ValueError: # The column does not exist pass - for col in ['scan', 'session', 'subject']: + for col in ["scan", "session", "subject"]: cols.remove(col) cols.insert(0, col) @@ -71,35 +76,36 @@ def rotate_files(fname): import os.path as op name, ext = op.splitext(fname) - if ext == '.gz': + if ext == ".gz": name, ext2 = op.splitext(fname) ext = ext2 + ext if not op.isfile(fname): return - prev = glob.glob('{}.*{}'.format(name, ext)) + prev = glob.glob("{}.*{}".format(name, ext)) prev.insert(0, fname) - prev.append('{0}.{1:d}{2}'.format(name, len(prev) - 1, ext)) + prev.append("{0}.{1:d}{2}".format(name, len(prev) - 1, ext)) for i in reversed(list(range(1, len(prev)))): os.rename(prev[i - 1], prev[i]) -def bids_path(subid, sesid=None, runid=None, prefix=None, out_path=None, ext='json'): +def bids_path(subid, sesid=None, runid=None, prefix=None, out_path=None, ext="json"): import os.path as op - fname = '{}'.format(subid) + + fname = "{}".format(subid) if prefix is not None: - if not prefix.endswith('_'): - prefix += '_' + if not prefix.endswith("_"): + prefix += "_" fname = prefix + fname if sesid is not None: - fname += '_ses-{}'.format(sesid) + fname += "_ses-{}".format(sesid) if runid is not None: - fname += '_run-{}'.format(runid) + fname += "_run-{}".format(runid) if out_path is not None: fname = op.join(out_path, fname) - return op.abspath(fname + '.' + ext) + return op.abspath(fname + "." + ext) def generate_pred(derivatives_dir, output_dir, mod): @@ -108,21 +114,20 @@ def generate_pred(derivatives_dir, output_dir, mod): generates a corresponding prediction CSV table """ - if mod != 'T1w': + if mod != "T1w": return None # If some were found, generate the CSV file and group report - jsonfiles = list(output_dir.glob( - 'sub-*/**/%s/sub-*_%s.json' % (IMTYPES[mod], mod))) + jsonfiles = list(output_dir.glob("sub-*/**/%s/sub-*_%s.json" % (IMTYPES[mod], mod))) if not jsonfiles: return None - headers = list(BIDS_COMP.keys()) + ['mriqc_pred'] + headers = list(BIDS_COMP.keys()) + ["mriqc_pred"] predictions = {k: [] for k in headers} for jsonfile in jsonfiles: - with open(jsonfile, 'r') as jsondata: - data = json.load(jsondata).pop('bids_meta', None) + with open(jsonfile, "r") as jsondata: + data = json.load(jsondata).pop("bids_meta", None) if data is None: continue @@ -130,21 +135,18 @@ def generate_pred(derivatives_dir, output_dir, mod): for k in headers: predictions[k].append(data.pop(k, None)) - dataframe = pd.DataFrame( - predictions).sort_values(by=list(BIDS_COMP.keys())) + dataframe = pd.DataFrame(predictions).sort_values(by=list(BIDS_COMP.keys())) # Drop empty columns - dataframe.dropna(axis='columns', how='all', inplace=True) + dataframe.dropna(axis="columns", how="all", inplace=True) bdits_cols = list(set(BIDS_COMP.keys()) & set(dataframe.columns.ravel())) # Drop duplicates - dataframe.drop_duplicates(bdits_cols, - keep='last', inplace=True) + dataframe.drop_duplicates(bdits_cols, keep="last", inplace=True) - out_csv = Path(output_dir) / ('%s_predicted_qa_csv' % mod) - dataframe[bdits_cols + ['mriqc_pred']].to_csv( - str(out_csv), index=False) + out_csv = Path(output_dir) / ("%s_predicted_qa_csv" % mod) + dataframe[bdits_cols + ["mriqc_pred"]].to_csv(str(out_csv), index=False) return out_csv @@ -154,8 +156,8 @@ def generate_tsv(output_dir, mod): """ # If some were found, generate the CSV file and group report - out_tsv = output_dir / ('group_%s.tsv' % mod) - jsonfiles = list(output_dir.glob('sub-*/**/%s/sub-*_%s.json' % (IMTYPES[mod], mod))) + out_tsv = output_dir / ("group_%s.tsv" % mod) + jsonfiles = list(output_dir.glob("sub-*/**/%s/sub-*_%s.json" % (IMTYPES[mod], mod))) if not jsonfiles: return None, out_tsv @@ -165,21 +167,21 @@ def generate_tsv(output_dir, mod): if dfentry is not None: bids_name = str(Path(jsonfile.name).stem) - dfentry.pop('bids_meta', None) - dfentry.pop('provenance', None) - dfentry['bids_name'] = bids_name + dfentry.pop("bids_meta", None) + dfentry.pop("provenance", None) + dfentry["bids_name"] = bids_name datalist.append(dfentry) dataframe = pd.DataFrame(datalist) cols = dataframe.columns.tolist() # pylint: disable=no-member - dataframe = dataframe.sort_values(by=['bids_name']) + dataframe = dataframe.sort_values(by=["bids_name"]) # Drop duplicates - dataframe.drop_duplicates(['bids_name'], keep='last', inplace=True) + dataframe.drop_duplicates(["bids_name"], keep="last", inplace=True) # Set filename at front - cols.insert(0, cols.pop(cols.index('bids_name'))) - dataframe[cols].to_csv(str(out_tsv), index=False, sep='\t') + cols.insert(0, cols.pop(cols.index("bids_name"))) + dataframe[cols].to_csv(str(out_tsv), index=False, sep="\t") return dataframe, out_tsv @@ -188,7 +190,7 @@ def _read_and_save(in_file): return data if data else None -def _flatten(in_dict, parent_key='', sep='_'): +def _flatten(in_dict, parent_key="", sep="_"): items = [] for k, val in list(in_dict.items()): new_key = parent_key + sep + k if parent_key else k @@ -207,8 +209,8 @@ def _flatten_dict(indict): else: for subk, subval in list(value.items()): if not isinstance(subval, dict): - out_qc['_'.join([k, subk])] = subval + out_qc["_".join([k, subk])] = subval else: for ssubk, ssubval in list(subval.items()): - out_qc['_'.join([k, subk, ssubk])] = ssubval + out_qc["_".join([k, subk, ssubk])] = ssubval return out_qc diff --git a/mriqc/viz/__init__.py b/mriqc/viz/__init__.py index 7f877fb02..e69de29bb 100644 --- a/mriqc/viz/__init__.py +++ b/mriqc/viz/__init__.py @@ -1,4 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: diff --git a/mriqc/viz/misc.py b/mriqc/viz/misc.py index 4be30fc03..4dc949737 100644 --- a/mriqc/viz/misc.py +++ b/mriqc/viz/misc.py @@ -1,13 +1,6 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# -# @Author: oesteban -# @Date: 2016-01-05 11:32:01 -# @Email: code@oscaresteban.es -# @Last modified by: oesteban -""" Helper functions for the figures in the paper """ +"""Helper functions for the figures in the paper.""" import os.path as op import numpy as np import pandas as pd @@ -19,33 +12,42 @@ def plot_qi2(x_grid, ref_pdf, fit_pdf, ref_data, cutoff_idx, out_file=None): fig, ax = plt.subplots() - ax.plot(x_grid, ref_pdf, linewidth=2, alpha=0.5, label='background', color='dodgerblue') + ax.plot( + x_grid, ref_pdf, linewidth=2, alpha=0.5, label="background", color="dodgerblue" + ) refmax = np.percentile(ref_data, 99.95) x_max = x_grid[-1] - ax.hist(ref_data, 40 * max(int(refmax / x_max), 1), - fc='dodgerblue', histtype='stepfilled', - alpha=0.2, normed=True) + ax.hist( + ref_data, + 40 * max(int(refmax / x_max), 1), + fc="dodgerblue", + histtype="stepfilled", + alpha=0.2, + normed=True, + ) fit_pdf[fit_pdf > 1.0] = np.nan - ax.plot(x_grid, fit_pdf, linewidth=2, alpha=0.5, label='chi2', color='darkorange') + ax.plot(x_grid, fit_pdf, linewidth=2, alpha=0.5, label="chi2", color="darkorange") ylims = ax.get_ylim() - ax.axvline(x_grid[-cutoff_idx], ymax=ref_pdf[-cutoff_idx] / ylims[1], color='dodgerblue') + ax.axvline( + x_grid[-cutoff_idx], ymax=ref_pdf[-cutoff_idx] / ylims[1], color="dodgerblue" + ) plt.xlabel('Intensity within "hat" mask') - plt.ylabel('Frequency') + plt.ylabel("Frequency") ax.set_xlim([0, x_max]) plt.legend() if out_file is None: - out_file = op.abspath('qi2_plot.svg') + out_file = op.abspath("qi2_plot.svg") - fig.savefig(out_file, bbox_inches='tight', pad_inches=0, dpi=300) + fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300) return out_file -def plot_batches(fulldata, cols=None, out_file=None, site_labels='left'): - fulldata = fulldata.sort_values(by=['database', 'site']).copy() +def plot_batches(fulldata, cols=None, out_file=None, site_labels="left"): + fulldata = fulldata.sort_values(by=["database", "site"]).copy() sites = fulldata.site.values.ravel().tolist() if cols is None: numdata = fulldata.select_dtypes([np.number]) @@ -54,49 +56,53 @@ def plot_batches(fulldata, cols=None, out_file=None, site_labels='left'): numdata = numdata[cols] colmin = numdata.min() - numdata = (numdata - colmin) + numdata = numdata - colmin colmax = numdata.max() numdata = numdata / colmax fig, ax = plt.subplots(figsize=(20, 10)) - ax.imshow(numdata.values, cmap=plt.cm.viridis, interpolation='nearest', aspect='auto') + ax.imshow( + numdata.values, cmap=plt.cm.viridis, interpolation="nearest", aspect="auto" + ) locations = [] spines = [] - fulldata['index'] = range(len(fulldata)) + fulldata["index"] = range(len(fulldata)) for site in list(set(sites)): - indices = fulldata.loc[fulldata.site == site, 'index'].values.ravel().tolist() + indices = fulldata.loc[fulldata.site == site, "index"].values.ravel().tolist() locations.append(int(np.average(indices))) spines.append(indices[0]) - if site_labels == 'right': + if site_labels == "right": ax.yaxis.tick_right() ax.yaxis.set_label_position("right") - plt.xticks(range(numdata.shape[1]), numdata.columns.ravel().tolist(), rotation='vertical') + plt.xticks( + range(numdata.shape[1]), numdata.columns.ravel().tolist(), rotation="vertical" + ) plt.yticks(locations, list(set(sites))) for line in spines[1:]: - plt.axhline(y=line, color='w', linestyle='-') - ax.spines['right'].set_visible(False) - ax.spines['top'].set_visible(False) + plt.axhline(y=line, color="w", linestyle="-") + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) # ax.spines['left'].set_visible(False) - ax.spines['bottom'].set_visible(False) + ax.spines["bottom"].set_visible(False) ax.grid(False) ticks_font = FontProperties( - family='FreeSans', style='normal', size=14, - weight='normal', stretch='normal') + family="FreeSans", style="normal", size=14, weight="normal", stretch="normal" + ) for label in ax.get_yticklabels(): label.set_fontproperties(ticks_font) ticks_font = FontProperties( - family='FreeSans', style='normal', size=12, - weight='normal', stretch='normal') + family="FreeSans", style="normal", size=12, weight="normal", stretch="normal" + ) for label in ax.get_xticklabels(): label.set_fontproperties(ticks_font) if out_file is not None: - fig.savefig(out_file, bbox_inches='tight', pad_inches=0, dpi=300) + fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300) return fig @@ -106,19 +112,19 @@ def plot_roc_curve(true_y, prob_y, out_file=None): fpr, tpr, _ = roc_curve(true_y, prob_y) fig = plt.figure() - plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve') - plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--') + plt.plot(fpr, tpr, color="darkorange", lw=2, label="ROC curve") + plt.plot([0, 1], [0, 1], color="navy", lw=1, linestyle="--") plt.xlim([-0.025, 1.025]) plt.ylim([-0.025, 1.025]) - plt.xlabel('False Positive Rate') - plt.ylabel('True Positive Rate') - plt.title('RoC Curve') + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title("RoC Curve") if out_file is not None: fig.savefig(out_file) return fig -def fill_matrix(matrix, width, value='n/a'): +def fill_matrix(matrix, width, value="n/a"): if matrix.shape[0] < width: nraters = matrix.shape[1] nas = np.chararray((1, nraters), itemsize=len(value)) @@ -129,8 +135,8 @@ def fill_matrix(matrix, width, value='n/a'): def plot_raters(dataframe, ax=None, width=101, size=0.40): raters = sorted(dataframe.columns.ravel().tolist()) - dataframe['notnan'] = np.any(np.isnan(dataframe[raters]), axis=1).astype(int) - dataframe = dataframe.sort_values(by=['notnan'] + raters, ascending=True) + dataframe["notnan"] = np.any(np.isnan(dataframe[raters]), axis=1).astype(int) + dataframe = dataframe.sort_values(by=["notnan"] + raters, ascending=True) for rater in raters: dataframe[rater] = dataframe[[rater]].astype(str) @@ -144,7 +150,7 @@ def plot_raters(dataframe, ax=None, width=101, size=0.40): nblocks = (matrix.shape[0] // width) + 1 nas = np.chararray((width, 1), itemsize=3) - nas[:] = 'n/a' + nas[:] = "n/a" for i in range(nblocks): if i > 0: matrices.append(nas) @@ -153,54 +159,73 @@ def plot_raters(dataframe, ax=None, width=101, size=0.40): matrices[-1] = fill_matrix(matrices[-1], width) matrix = np.hstack(tuple(matrices)) - palette = {'1.0': 'limegreen', '0.0': 'dimgray', '-1.0': 'tomato', 'n/a': 'w'} + palette = {"1.0": "limegreen", "0.0": "dimgray", "-1.0": "tomato", "n/a": "w"} ax = ax if ax is not None else plt.gca() # ax.patch.set_facecolor('gray') - ax.set_aspect('equal', 'box') + ax.set_aspect("equal", "box") ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) nrows = ((nsamples - 1) // width) + 1 xlims = (-14.0, width) - ylims = (-0.07 * nraters, nrows * nraters + nraters * .07 + (nrows - 1)) + ylims = (-0.07 * nraters, nrows * nraters + nraters * 0.07 + (nrows - 1)) ax.set_xlim(xlims) ax.set_ylim(ylims) - offset = 0.5 * (size / .40) + offset = 0.5 * (size / 0.40) for (x, y), w in np.ndenumerate(matrix): if w not in list(palette.keys()): - w = 'n/a' + w = "n/a" color = palette[w] - rect = plt.Circle([x + offset, y + offset], size, - facecolor=color, edgecolor=color) + rect = plt.Circle( + [x + offset, y + offset], size, facecolor=color, edgecolor=color + ) ax.add_patch(rect) # text_x = ((nsamples - 1) % width) + 6.5 text_x = -8.5 for i, rname in enumerate(raters): - nsamples = sum(dataframe[rname] != 'n/a') - good = 100 * sum(dataframe[rname] == '1.0') / nsamples - bad = 100 * sum(dataframe[rname] == '-1.0') / nsamples + nsamples = sum(dataframe[rname] != "n/a") + good = 100 * sum(dataframe[rname] == "1.0") / nsamples + bad = 100 * sum(dataframe[rname] == "-1.0") / nsamples text_y = 1.5 * i + (nrows - 1) * 2.0 - ax.text(text_x, text_y, '%2.0f%%' % good, - color='limegreen', weight=1000, size=16, - horizontalalignment='right', - verticalalignment='center', - transform=ax.transData) - ax.text(text_x + 3.50, text_y, '%2.0f%%' % max((0.0, 100 - good - bad)), - color='dimgray', weight=1000, size=16, - horizontalalignment='right', - verticalalignment='center', - transform=ax.transData) - ax.text(text_x + 7.0, text_y, '%2.0f%%' % bad, - color='tomato', weight=1000, size=16, - horizontalalignment='right', - verticalalignment='center', - transform=ax.transData) + ax.text( + text_x, + text_y, + "%2.0f%%" % good, + color="limegreen", + weight=1000, + size=16, + horizontalalignment="right", + verticalalignment="center", + transform=ax.transData, + ) + ax.text( + text_x + 3.50, + text_y, + "%2.0f%%" % max((0.0, 100 - good - bad)), + color="dimgray", + weight=1000, + size=16, + horizontalalignment="right", + verticalalignment="center", + transform=ax.transData, + ) + ax.text( + text_x + 7.0, + text_y, + "%2.0f%%" % bad, + color="tomato", + weight=1000, + size=16, + horizontalalignment="right", + verticalalignment="center", + transform=ax.transData, + ) # ax.autoscale_view() ax.invert_yaxis() @@ -209,28 +234,34 @@ def plot_raters(dataframe, ax=None, width=101, size=0.40): # Remove and redefine spines for side in ["top", "right", "bottom"]: # Toggle the spine objects - ax.spines[side].set_color('none') + ax.spines[side].set_color("none") ax.spines[side].set_visible(False) ax.spines["left"].set_linewidth(1.5) - ax.spines["left"].set_color('dimgray') + ax.spines["left"].set_color("dimgray") # ax.spines["left"].set_position(('data', xlims[0])) ax.set_yticks([0.5 * (ylims[0] + ylims[1])]) - ax.tick_params(axis='y', which='major', pad=15) + ax.tick_params(axis="y", which="major", pad=15) ticks_font = FontProperties( - family='FreeSans', style='normal', size=20, - weight='normal', stretch='normal') + family="FreeSans", style="normal", size=20, weight="normal", stretch="normal" + ) for label in ax.get_yticklabels(): label.set_fontproperties(ticks_font) return ax -def raters_variability_plot(mdata, figsize=(22, 22), width=101, out_file=None, - raters=('rater_1', 'rater_2', 'rater_3'), only_overlap=True, - rater_names=('Rater 1', 'Rater 2a', 'Rater 2b')): +def raters_variability_plot( + mdata, + figsize=(22, 22), + width=101, + out_file=None, + raters=("rater_1", "rater_2", "rater_3"), + only_overlap=True, + rater_names=("Rater 1", "Rater 2a", "Rater 2b"), +): if only_overlap: mdata = mdata[np.all(~np.isnan(mdata[raters]), axis=1)] # Swap raters 2 and 3 @@ -247,12 +278,18 @@ def raters_variability_plot(mdata, figsize=(22, 22), width=101, out_file=None, blocks = [(slen - 1) // width + 1 for slen in sites_len] fig = plt.figure(figsize=figsize) - gs = GridSpec(len(sites_list), 1, width_ratios=[1], height_ratios=blocks, hspace=0.05) + gs = GridSpec( + len(sites_list), 1, width_ratios=[1], height_ratios=blocks, hspace=0.05 + ) for s, gsel in zip(sites_list, gs): ax = plt.subplot(gsel) - plot_raters(mdata.loc[mdata.site == s, raters], ax=ax, width=width, - size=.40 if len(raters) == 3 else .80) + plot_raters( + mdata.loc[mdata.site == s, raters], + ax=ax, + width=width, + size=0.40 if len(raters) == 3 else 0.80, + ) ax.set_yticklabels([s]) # ax.add_line(Line2D([0.0, width], [8.0, 8.0], color='k')) @@ -266,7 +303,7 @@ def raters_variability_plot(mdata, figsize=(22, 22), width=101, out_file=None, # fontsize=20, ha='center', va='top', # arrowprops=dict(arrowstyle='-[, widthB=3.0, lengthB=0.2', lw=1.0)) - newax = plt.axes([0.6, 0.65, .25, .16]) + newax = plt.axes([0.6, 0.65, 0.25, 0.16]) newax.grid(False) newax.set_xticklabels([]) newax.set_xticks([]) @@ -279,56 +316,91 @@ def raters_variability_plot(mdata, figsize=(22, 22), width=101, out_file=None, good = 100 * sum(mdata[rater] == 1.0) / nsamples bad = 100 * sum(mdata[rater] == -1.0) / nsamples - text_x = .92 - text_y = .5 - 0.17 * i - newax.text(text_x - .36, text_y, '%2.1f%%' % good, - color='limegreen', weight=1000, size=25, - horizontalalignment='right', - verticalalignment='center', - transform=newax.transAxes) - newax.text(text_x - .18, text_y, '%2.1f%%' % max((0.0, 100 - good - bad)), - color='dimgray', weight=1000, size=25, - horizontalalignment='right', - verticalalignment='center', - transform=newax.transAxes) - newax.text(text_x, text_y, '%2.1f%%' % bad, - color='tomato', weight=1000, size=25, - horizontalalignment='right', - verticalalignment='center', - transform=newax.transAxes) - - newax.text(1 - text_x, text_y, rater_names[i], - color='k', size=25, - horizontalalignment='left', - verticalalignment='center', - transform=newax.transAxes) - - newax.text(0.5, 0.95, 'Imbalance of ratings', - color='k', size=25, - horizontalalignment='center', - verticalalignment='top', - transform=newax.transAxes) - newax.text(0.5, 0.85, '(ABIDE, aggregated)', - color='k', size=25, - horizontalalignment='center', - verticalalignment='top', - transform=newax.transAxes) + text_x = 0.92 + text_y = 0.5 - 0.17 * i + newax.text( + text_x - 0.36, + text_y, + "%2.1f%%" % good, + color="limegreen", + weight=1000, + size=25, + horizontalalignment="right", + verticalalignment="center", + transform=newax.transAxes, + ) + newax.text( + text_x - 0.18, + text_y, + "%2.1f%%" % max((0.0, 100 - good - bad)), + color="dimgray", + weight=1000, + size=25, + horizontalalignment="right", + verticalalignment="center", + transform=newax.transAxes, + ) + newax.text( + text_x, + text_y, + "%2.1f%%" % bad, + color="tomato", + weight=1000, + size=25, + horizontalalignment="right", + verticalalignment="center", + transform=newax.transAxes, + ) + + newax.text( + 1 - text_x, + text_y, + rater_names[i], + color="k", + size=25, + horizontalalignment="left", + verticalalignment="center", + transform=newax.transAxes, + ) + + newax.text( + 0.5, + 0.95, + "Imbalance of ratings", + color="k", + size=25, + horizontalalignment="center", + verticalalignment="top", + transform=newax.transAxes, + ) + newax.text( + 0.5, + 0.85, + "(ABIDE, aggregated)", + color="k", + size=25, + horizontalalignment="center", + verticalalignment="top", + transform=newax.transAxes, + ) if out_file is None: - out_file = 'raters.svg' + out_file = "raters.svg" fname, ext = op.splitext(out_file) - if ext[1:] not in ['pdf', 'svg', 'png']: - ext = '.svg' - out_file = fname + '.svg' + if ext[1:] not in ["pdf", "svg", "png"]: + ext = ".svg" + out_file = fname + ".svg" - fig.savefig(op.abspath(out_file), format=ext[1:], - bbox_inches='tight', pad_inches=0, dpi=300) + fig.savefig( + op.abspath(out_file), format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=300 + ) return fig -def plot_abide_stripplots(inputs, figsize=(15, 2), out_file=None, - rating_label='rater_1', dpi=100): +def plot_abide_stripplots( + inputs, figsize=(15, 2), out_file=None, rating_label="rater_1", dpi=100 +): import seaborn as sn from ..classifier.helper import FEATURE_NORM from ..classifier.data import read_dataset @@ -340,12 +412,13 @@ def plot_abide_stripplots(inputs, figsize=(15, 2), out_file=None, pp_cols = [] for X, Y, sitename in inputs: - sitedata, cols = read_dataset(X, Y, rate_label=rating_label, - binarize=False, site_name=sitename) - sitedata['database'] = [sitename] * len(sitedata) + sitedata, cols = read_dataset( + X, Y, rate_label=rating_label, binarize=False, site_name=sitename + ) + sitedata["database"] = [sitename] * len(sitedata) - if sitename == 'DS030': - sitedata['site'] = [sitename] * len(sitedata) + if sitename == "DS030": + sitedata["site"] = [sitename] * len(sitedata) mdata.append(sitedata) pp_cols.append(cols) @@ -354,28 +427,27 @@ def plot_abide_stripplots(inputs, figsize=(15, 2), out_file=None, pp_cols = pp_cols[0] for col in mdata.columns.ravel().tolist(): - if col.startswith('rater_') and col != rating_label: + if col.startswith("rater_") and col != rating_label: del mdata[col] mdata = mdata.loc[mdata[rating_label].notnull()] - for col in ['size_x', 'size_y', 'size_z', 'spacing_x', 'spacing_y', 'spacing_z']: + for col in ["size_x", "size_y", "size_z", "spacing_x", "spacing_y", "spacing_z"]: del mdata[col] try: pp_cols.remove(col) except ValueError: pass - zscored = BatchRobustScaler( - by='site', columns=FEATURE_NORM).fit_transform(mdata) + zscored = BatchRobustScaler(by="site", columns=FEATURE_NORM).fit_transform(mdata) sites = list(set(mdata.site.values.ravel())) nsites = len(sites) # palette = ['dodgerblue', 'darkorange'] - palette = ['limegreen', 'tomato'] + palette = ["limegreen", "tomato"] if len(set(mdata[[rating_label]].values.ravel().tolist())) == 3: - palette = ['tomato', 'gold', 'limegreen'] + palette = ["tomato", "gold", "limegreen"] # pp_cols = pp_cols[:5] nrows = len(pp_cols) @@ -391,15 +463,51 @@ def plot_abide_stripplots(inputs, figsize=(15, 2), out_file=None, ax_zsc = plt.subplot(gs[i, 3]) # plots - sn.stripplot(x='site', y=colname, data=mdata, hue=rating_label, jitter=0.18, alpha=.6, - split=True, palette=palette, ax=ax_nzs) - sn.stripplot(x='site', y=colname, data=zscored, hue=rating_label, jitter=0.18, alpha=.6, - split=True, palette=palette, ax=ax_zsc) + sn.stripplot( + x="site", + y=colname, + data=mdata, + hue=rating_label, + jitter=0.18, + alpha=0.6, + split=True, + palette=palette, + ax=ax_nzs, + ) + sn.stripplot( + x="site", + y=colname, + data=zscored, + hue=rating_label, + jitter=0.18, + alpha=0.6, + split=True, + palette=palette, + ax=ax_zsc, + ) - sn.stripplot(x='database', y=colname, data=mdata, hue=rating_label, jitter=0.18, alpha=.6, - split=True, palette=palette, ax=axg_nzs) - sn.stripplot(x='database', y=colname, data=zscored, hue=rating_label, jitter=0.18, - alpha=.6, split=True, palette=palette, ax=axg_zsc) + sn.stripplot( + x="database", + y=colname, + data=mdata, + hue=rating_label, + jitter=0.18, + alpha=0.6, + split=True, + palette=palette, + ax=axg_nzs, + ) + sn.stripplot( + x="database", + y=colname, + data=zscored, + hue=rating_label, + jitter=0.18, + alpha=0.6, + split=True, + palette=palette, + ax=axg_zsc, + ) ax_nzs.legend_.remove() ax_zsc.legend_.remove() @@ -417,17 +525,17 @@ def plot_abide_stripplots(inputs, figsize=(15, 2), out_file=None, axg_nzs.set_xticklabels([]) axg_zsc.set_xticklabels([]) - ax_nzs.set_xlabel('', visible=False) - ax_zsc.set_xlabel('', visible=False) - ax_zsc.set_ylabel('', visible=False) + ax_nzs.set_xlabel("", visible=False) + ax_zsc.set_xlabel("", visible=False) + ax_zsc.set_ylabel("", visible=False) ax_zsc.yaxis.tick_right() axg_nzs.set_yticklabels([]) - axg_nzs.set_xlabel('', visible=False) - axg_nzs.set_ylabel('', visible=False) + axg_nzs.set_xlabel("", visible=False) + axg_nzs.set_ylabel("", visible=False) axg_zsc.set_yticklabels([]) - axg_zsc.set_xlabel('', visible=False) - axg_zsc.set_ylabel('', visible=False) + axg_zsc.set_xlabel("", visible=False) + axg_zsc.set_ylabel("", visible=False) for yt in ax_nzs.yaxis.get_major_ticks()[1:-1]: yt.label1.set_visible(False) @@ -435,31 +543,35 @@ def plot_abide_stripplots(inputs, figsize=(15, 2), out_file=None, for yt in axg_nzs.yaxis.get_major_ticks()[1:-1]: yt.label1.set_visible(False) - for yt in zip(ax_zsc.yaxis.get_majorticklabels(), axg_zsc.yaxis.get_majorticklabels()): + for yt in zip( + ax_zsc.yaxis.get_majorticklabels(), axg_zsc.yaxis.get_majorticklabels() + ): yt[0].set_visible(False) yt[1].set_visible(False) if out_file is None: - out_file = 'stripplot.svg' + out_file = "stripplot.svg" fname, ext = op.splitext(out_file) - if ext[1:] not in ['pdf', 'svg', 'png']: - ext = '.svg' - out_file = fname + '.svg' + if ext[1:] not in ["pdf", "svg", "png"]: + ext = ".svg" + out_file = fname + ".svg" - fig.savefig(op.abspath(out_file), format=ext[1:], - bbox_inches='tight', pad_inches=0, dpi=dpi) + fig.savefig( + op.abspath(out_file), format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=dpi + ) return fig def plot_corrmat(in_csv, out_file=None): import seaborn as sn + sn.set(style="whitegrid") - dataframe = pd.read_csv(in_csv, index_col=False, na_values='n/a', na_filter=False) + dataframe = pd.read_csv(in_csv, index_col=False, na_values="n/a", na_filter=False) colnames = dataframe.columns.ravel().tolist() - for col in ['subject_id', 'site', 'modality']: + for col in ["subject_id", "site", "modality"]: try: colnames.remove(col) except ValueError: @@ -467,7 +579,7 @@ def plot_corrmat(in_csv, out_file=None): # Correlation matrix corr = dataframe[colnames].corr() - corr = corr.dropna((0, 1), 'all') + corr = corr.dropna((0, 1), "all") # Generate a mask for the upper triangle mask = np.zeros_like(corr, dtype=np.bool) @@ -477,34 +589,38 @@ def plot_corrmat(in_csv, out_file=None): cmap = sn.diverging_palette(220, 10, as_cmap=True) # Draw the heatmap with the mask and correct aspect ratio - corrplot = sn.clustermap(corr, cmap=cmap, center=0., - method='average', square=True, linewidths=.5) - plt.setp(corrplot.ax_heatmap.yaxis.get_ticklabels(), rotation='horizontal') + corrplot = sn.clustermap( + corr, cmap=cmap, center=0.0, method="average", square=True, linewidths=0.5 + ) + plt.setp(corrplot.ax_heatmap.yaxis.get_ticklabels(), rotation="horizontal") # , mask=mask, square=True, linewidths=.5, cbar_kws={"shrink": .5}) if out_file is None: - out_file = 'corr_matrix.svg' + out_file = "corr_matrix.svg" fname, ext = op.splitext(out_file) - if ext[1:] not in ['pdf', 'svg', 'png']: - ext = '.svg' - out_file = fname + '.svg' + if ext[1:] not in ["pdf", "svg", "png"]: + ext = ".svg" + out_file = fname + ".svg" - corrplot.savefig(out_file, format=ext[1:], bbox_inches='tight', pad_inches=0, dpi=100) + corrplot.savefig( + out_file, format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=100 + ) return corrplot -def plot_histograms(X, Y, rating_label='rater_1', out_file=None): +def plot_histograms(X, Y, rating_label="rater_1", out_file=None): import re import seaborn as sn from ..classifier.data import read_dataset + sn.set(style="whitegrid") mdata, pp_cols = read_dataset(X, Y, rate_label=rating_label) - mdata['rater'] = mdata[[rating_label]].values.ravel() + mdata["rater"] = mdata[[rating_label]].values.ravel() for col in mdata.columns.ravel().tolist(): - if col.startswith('rater_'): + if col.startswith("rater_"): del mdata[col] mdata = mdata.loc[mdata.rater.notnull()] @@ -513,7 +629,7 @@ def plot_histograms(X, Y, rating_label='rater_1', out_file=None): # zscored = zscore_dataset( # mdata, excl_columns=['rater', 'size_x', 'size_y', 'size_z', # 'spacing_x', 'spacing_y', 'spacing_z']) - pat = re.compile(r'^(spacing|summary|size)') + pat = re.compile(r"^(spacing|summary|size)") colnames = [col for col in sorted(pp_cols) if pat.match(col)] nrows = len(colnames) @@ -526,16 +642,36 @@ def plot_histograms(X, Y, rating_label='rater_1', out_file=None): ax_nzs = plt.subplot(gs[i, 0]) ax_zsd = plt.subplot(gs[i, 1]) - sn.distplot(mdata.loc[(mdata.rater == 0), col], norm_hist=False, - label='Accept', ax=ax_nzs, color='dodgerblue') - sn.distplot(mdata.loc[(mdata.rater == 1), col], norm_hist=False, - label='Reject', ax=ax_nzs, color='darkorange') + sn.distplot( + mdata.loc[(mdata.rater == 0), col], + norm_hist=False, + label="Accept", + ax=ax_nzs, + color="dodgerblue", + ) + sn.distplot( + mdata.loc[(mdata.rater == 1), col], + norm_hist=False, + label="Reject", + ax=ax_nzs, + color="darkorange", + ) ax_nzs.legend() - sn.distplot(zscored.loc[(zscored.rater == 0), col], norm_hist=False, - label='Accept', ax=ax_zsd, color='dodgerblue') - sn.distplot(zscored.loc[(zscored.rater == 1), col], norm_hist=False, - label='Reject', ax=ax_zsd, color='darkorange') + sn.distplot( + zscored.loc[(zscored.rater == 0), col], + norm_hist=False, + label="Accept", + ax=ax_zsd, + color="dodgerblue", + ) + sn.distplot( + zscored.loc[(zscored.rater == 1), col], + norm_hist=False, + label="Reject", + ax=ax_zsd, + color="darkorange", + ) alldata = mdata[[col]].values.ravel().tolist() minv = np.percentile(alldata, 0.2) @@ -548,33 +684,34 @@ def plot_histograms(X, Y, rating_label='rater_1', out_file=None): ax_zsd.set_xlim([minv, maxv]) if out_file is None: - out_file = 'histograms.svg' + out_file = "histograms.svg" fname, ext = op.splitext(out_file) - if ext[1:] not in ['pdf', 'svg', 'png']: - ext = '.svg' - out_file = fname + '.svg' + if ext[1:] not in ["pdf", "svg", "png"]: + ext = ".svg" + out_file = fname + ".svg" - fig.savefig(out_file, format=ext[1:], bbox_inches='tight', pad_inches=0, dpi=100) + fig.savefig(out_file, format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=100) return fig -def inter_rater_variability(y1, y2, figsize=(4, 4), normed=True, - raters=None, labels=None, out_file=None): +def inter_rater_variability( + y1, y2, figsize=(4, 4), normed=True, raters=None, labels=None, out_file=None +): plt.rcParams["font.family"] = "sans-serif" plt.rcParams["font.sans-serif"] = "FreeSans" - plt.rcParams['font.size'] = 25 - plt.rcParams['axes.labelsize'] = 20 - plt.rcParams['axes.titlesize'] = 25 - plt.rcParams['xtick.labelsize'] = 15 - plt.rcParams['ytick.labelsize'] = 15 + plt.rcParams["font.size"] = 25 + plt.rcParams["axes.labelsize"] = 20 + plt.rcParams["axes.titlesize"] = 25 + plt.rcParams["xtick.labelsize"] = 15 + plt.rcParams["ytick.labelsize"] = 15 # fig = plt.figure(figsize=(3.5, 3)) if raters is None: - raters = ['Rater 1', 'Rater 2'] + raters = ["Rater 1", "Rater 2"] if labels is None: - labels = ['exclude', 'doubtful', 'accept'] + labels = ["exclude", "doubtful", "accept"] fig, ax = plt.subplots(figsize=figsize) ax.set_aspect("equal") @@ -594,155 +731,192 @@ def inter_rater_variability(y1, y2, figsize=(4, 4), normed=True, ycenters = (ybins[:-1] + ybins[1:]) * 0.5 total = np.sum(hist.reshape(-1)) - celfmt = '%d%%' if normed else '%d' + celfmt = "%d%%" if normed else "%d" for i, x in enumerate(xcenters): for j, y in enumerate(ycenters): val = hist[i, j] if normed: val = 100 * hist[i, j] / total - ax.text(x, y, celfmt % val, - ha="center", va="center", fontweight="bold", - color='w' if hist[i, j] < 15 else 'k') + ax.text( + x, + y, + celfmt % val, + ha="center", + va="center", + fontweight="bold", + color="w" if hist[i, j] < 15 else "k", + ) # plt.colorbar(pad=0.10) plt.grid(False) plt.xticks(xcenters, xlabels) - plt.yticks(ycenters, ylabels, rotation='vertical', va='center') + plt.yticks(ycenters, ylabels, rotation="vertical", va="center") plt.xlabel(raters[0]) plt.ylabel(raters[1]) ax.yaxis.tick_right() ax.xaxis.set_label_position("top") if out_file is not None: - fig.savefig(out_file, bbox_inches='tight', pad_inches=0, dpi=300) + fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300) return fig -def plot_artifact(image_path, figsize=(20, 20), vmax=None, cut_coords=None, display_mode='ortho', - size=None): +def plot_artifact( + image_path, + figsize=(20, 20), + vmax=None, + cut_coords=None, + display_mode="ortho", + size=None, +): import nilearn.plotting as nplt fig = plt.figure(figsize=figsize) nplt_disp = nplt.plot_anat( - image_path, display_mode=display_mode, cut_coords=cut_coords, - vmax=vmax, figure=fig, annotate=False) + image_path, + display_mode=display_mode, + cut_coords=cut_coords, + vmax=vmax, + figure=fig, + annotate=False, + ) if size is None: size = figsize[0] * 6 - bg_color = 'k' - fg_color = 'w' + bg_color = "k" + fg_color = "w" ax = fig.gca() ax.text( - .1, .95, 'L', + 0.1, + 0.95, + "L", transform=ax.transAxes, - horizontalalignment='left', - verticalalignment='top', + horizontalalignment="left", + verticalalignment="top", size=size, bbox=dict(boxstyle="square,pad=0", ec=bg_color, fc=bg_color, alpha=1), - color=fg_color) + color=fg_color, + ) ax.text( - .9, .95, 'R', + 0.9, + 0.95, + "R", transform=ax.transAxes, - horizontalalignment='right', - verticalalignment='top', + horizontalalignment="right", + verticalalignment="top", size=size, bbox=dict(boxstyle="square,pad=0", ec=bg_color, fc=bg_color), - color=fg_color) + color=fg_color, + ) return nplt_disp, ax -def figure1_a(image_path, display_mode='y', vmax=300, cut_coords=None, figsize=(20, 20)): +def figure1_a( + image_path, display_mode="y", vmax=300, cut_coords=None, figsize=(20, 20) +): import matplotlib.patches as patches if cut_coords is None: cut_coords = [15] - disp, ax = plot_artifact(image_path, display_mode=display_mode, vmax=vmax, - cut_coords=cut_coords, figsize=figsize) + disp, ax = plot_artifact( + image_path, + display_mode=display_mode, + vmax=vmax, + cut_coords=cut_coords, + figsize=figsize, + ) ax.add_patch( patches.Arrow( - 0.2, # x - 0.2, # y - 0.1, # dx - 0.6, # dy - width=.25, - color='tomato', - transform=ax.transAxes + 0.2, # x + 0.2, # y + 0.1, # dx + 0.6, # dy + width=0.25, + color="tomato", + transform=ax.transAxes, ) ) ax.add_patch( patches.Arrow( - 0.8, # x - 0.2, # y - -0.1, # dx - 0.6, # dy - width=.25, - color='tomato', - transform=ax.transAxes + 0.8, # x + 0.2, # y + -0.1, # dx + 0.6, # dy + width=0.25, + color="tomato", + transform=ax.transAxes, ) ) return disp -def figure1_b(image_path, display_mode='z', vmax=400, cut_coords=None, figsize=(20, 20)): +def figure1_b( + image_path, display_mode="z", vmax=400, cut_coords=None, figsize=(20, 20) +): import matplotlib.patches as patches if cut_coords is None: cut_coords = [-24] - disp, ax = plot_artifact(image_path, display_mode=display_mode, vmax=vmax, - cut_coords=cut_coords, figsize=figsize) + disp, ax = plot_artifact( + image_path, + display_mode=display_mode, + vmax=vmax, + cut_coords=cut_coords, + figsize=figsize, + ) ax.add_patch( patches.Arrow( - 0.02, # x - 0.55, # y - 0.1, # dx - 0.0, # dy - width=.10, - color='tomato', - transform=ax.transAxes + 0.02, # x + 0.55, # y + 0.1, # dx + 0.0, # dy + width=0.10, + color="tomato", + transform=ax.transAxes, ) ) ax.add_patch( patches.Arrow( - 0.98, # x - 0.55, # y - -0.1, # dx - 0.0, # dy - width=.10, - color='tomato', - transform=ax.transAxes + 0.98, # x + 0.55, # y + -0.1, # dx + 0.0, # dy + width=0.10, + color="tomato", + transform=ax.transAxes, ) ) ax.add_patch( patches.Arrow( - 0.02, # x - 0.80, # y - 0.1, # dx - 0.0, # dy - width=.10, - color='limegreen', - transform=ax.transAxes + 0.02, # x + 0.80, # y + 0.1, # dx + 0.0, # dy + width=0.10, + color="limegreen", + transform=ax.transAxes, ) ) ax.add_patch( patches.Arrow( - 0.98, # x - 0.80, # y - -0.1, # dx - 0.0, # dy - width=.10, - color='limegreen', - transform=ax.transAxes + 0.98, # x + 0.80, # y + -0.1, # dx + 0.0, # dy + width=0.10, + color="limegreen", + transform=ax.transAxes, ) ) return disp @@ -750,8 +924,7 @@ def figure1_b(image_path, display_mode='z', vmax=400, cut_coords=None, figsize=( def figure1(artifact1, artifact2, out_file): from .svg import svg2str, combine_svg - combine_svg([ - svg2str(figure1_b(artifact2)), - svg2str(figure1_a(artifact1)) - ], - axis='vertical').save(out_file) + + combine_svg( + [svg2str(figure1_b(artifact2)), svg2str(figure1_a(artifact1))], axis="vertical" + ).save(out_file) diff --git a/mriqc/viz/svg.py b/mriqc/viz/svg.py index 279ae445b..57a9ec591 100644 --- a/mriqc/viz/svg.py +++ b/mriqc/viz/svg.py @@ -1,13 +1,6 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# -# @Author: oesteban -# @Date: 2016-01-05 11:32:01 -# @Email: code@oscaresteban.es -# @Last modified by: oesteban -""" SVG handling utilities """ +"""SVG handling utilities """ def svg2str(display_object, dpi=300): @@ -15,15 +8,16 @@ def svg2str(display_object, dpi=300): Serializes a nilearn display object as a string """ from io import StringIO + image_buf = StringIO() display_object.frame_axes.figure.savefig( - image_buf, dpi=dpi, format='svg', - facecolor='k', edgecolor='k') + image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k" + ) image_buf.seek(0) return image_buf.getvalue() -def combine_svg(svg_list, axis='vertical'): +def combine_svg(svg_list, axis="vertical"): """ Composes the input svgs into one standalone svg """ @@ -31,13 +25,13 @@ def combine_svg(svg_list, axis='vertical'): import svgutils.transform as svgt # Read all svg files and get roots - svgs = [svgt.fromstring(f.encode('utf-8')) for f in svg_list] + svgs = [svgt.fromstring(f.encode("utf-8")) for f in svg_list] roots = [f.getroot() for f in svgs] # Query the size of each sizes = [(int(f.width[:-2]), int(f.height[:-2])) for f in svgs] - if axis == 'vertical': + if axis == "vertical": # Calculate the scale to fit all widths scales = [1.0] * len(svgs) if not all([width[0] == sizes[0][0] for width in sizes[1:]]): @@ -45,11 +39,12 @@ def combine_svg(svg_list, axis='vertical'): for i, els in enumerate(sizes): scales[i] = ref_size[0] / els[0] - newsizes = [tuple(size) - for size in np.array(sizes) * np.array(scales)[..., np.newaxis]] + newsizes = [ + tuple(size) for size in np.array(sizes) * np.array(scales)[..., np.newaxis] + ] totalsize = [newsizes[0][0], np.sum(newsizes, axis=0)[1]] - elif axis == 'horizontal': + elif axis == "horizontal": # Calculate the scale to fit all heights scales = [1.0] * len(svgs) if not all([height[0] == sizes[0][1] for height in sizes[1:]]): @@ -57,22 +52,23 @@ def combine_svg(svg_list, axis='vertical'): for i, els in enumerate(sizes): scales[i] = ref_size[1] / els[1] - newsizes = [tuple(size) - for size in np.array(sizes) * np.array(scales)[..., np.newaxis]] + newsizes = [ + tuple(size) for size in np.array(sizes) * np.array(scales)[..., np.newaxis] + ] totalsize = [np.sum(newsizes, axis=0)[0], newsizes[0][1]] # Compose the views panel: total size is the width of # any element (used the first here) and the sum of heights fig = svgt.SVGFigure(totalsize[0], totalsize[1]) - if axis == 'vertical': + if axis == "vertical": yoffset = 0 for i, r in enumerate(roots): size = newsizes[i] r.moveto(0, yoffset, scale=scales[i]) yoffset += size[1] fig.append(r) - elif axis == 'horizontal': + elif axis == "horizontal": xoffset = 0 for i, r in enumerate(roots): size = newsizes[i] @@ -88,6 +84,6 @@ def extract_svg(display_object, dpi=300): Removes the preamble of the svg files generated with nilearn """ image_svg = svg2str(display_object, dpi) - start_idx = image_svg.find('') + start_idx = image_svg.find("") return image_svg[start_idx:end_idx] diff --git a/mriqc/viz/utils.py b/mriqc/viz/utils.py index 4a03b8574..64acb28fc 100644 --- a/mriqc/viz/utils.py +++ b/mriqc/viz/utils.py @@ -1,12 +1,5 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# -# @Author: oesteban -# @Date: 2016-01-05 11:32:01 -# @Email: code@oscaresteban.es -# @Last modified by: oesteban """ Visualization utilities """ import math @@ -24,8 +17,16 @@ DINA4_PORTRAIT = (8.27, 11.69) -def plot_slice(dslice, spacing=None, cmap='Greys_r', label=None, - ax=None, vmax=None, vmin=None, annotate=False): +def plot_slice( + dslice, + spacing=None, + cmap="Greys_r", + label=None, + ax=None, + vmax=None, + vmin=None, + annotate=False, +): from matplotlib.cm import get_cmap if isinstance(cmap, (str, bytes)): @@ -44,36 +45,74 @@ def plot_slice(dslice, spacing=None, cmap='Greys_r', label=None, spacing = [1.0, 1.0] phys_sp = np.array(spacing) * dslice.shape - ax.imshow(np.swapaxes(dslice, 0, 1), vmin=vmin, vmax=vmax, cmap=cmap, - interpolation='nearest', origin='lower', - extent=[0, phys_sp[0], 0, phys_sp[1]]) + ax.imshow( + np.swapaxes(dslice, 0, 1), + vmin=vmin, + vmax=vmax, + cmap=cmap, + interpolation="nearest", + origin="lower", + extent=[0, phys_sp[0], 0, phys_sp[1]], + ) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.grid(False) - ax.axis('off') + ax.axis("off") bgcolor = cmap(min(vmin, 0.0)) fgcolor = cmap(vmax) if annotate: - ax.text(.95, .95, 'R', color=fgcolor, transform=ax.transAxes, - horizontalalignment='center', verticalalignment='top', - size=18, bbox=dict(boxstyle="square,pad=0", ec=bgcolor, fc=bgcolor)) - ax.text(.05, .95, 'L', color=fgcolor, transform=ax.transAxes, - horizontalalignment='center', verticalalignment='top', - size=18, bbox=dict(boxstyle="square,pad=0", ec=bgcolor, fc=bgcolor)) + ax.text( + 0.95, + 0.95, + "R", + color=fgcolor, + transform=ax.transAxes, + horizontalalignment="center", + verticalalignment="top", + size=18, + bbox=dict(boxstyle="square,pad=0", ec=bgcolor, fc=bgcolor), + ) + ax.text( + 0.05, + 0.95, + "L", + color=fgcolor, + transform=ax.transAxes, + horizontalalignment="center", + verticalalignment="top", + size=18, + bbox=dict(boxstyle="square,pad=0", ec=bgcolor, fc=bgcolor), + ) if label is not None: - ax.text(.98, .01, label, color=fgcolor, transform=ax.transAxes, - horizontalalignment='right', verticalalignment='bottom', - size=18, bbox=dict(boxstyle="square,pad=0", ec=bgcolor, fc=bgcolor)) + ax.text( + 0.98, + 0.01, + label, + color=fgcolor, + transform=ax.transAxes, + horizontalalignment="right", + verticalalignment="bottom", + size=18, + bbox=dict(boxstyle="square,pad=0", ec=bgcolor, fc=bgcolor), + ) return ax -def plot_slice_tern(dslice, prev=None, post=None, - spacing=None, cmap='Greys_r', label=None, ax=None, - vmax=None, vmin=None): +def plot_slice_tern( + dslice, + prev=None, + post=None, + spacing=None, + cmap="Greys_r", + label=None, + ax=None, + vmax=None, + vmin=None, +): from matplotlib.cm import get_cmap if isinstance(cmap, (str, bytes)): @@ -101,26 +140,36 @@ def plot_slice_tern(dslice, prev=None, post=None, post = np.ones_like(dslice) combined = np.swapaxes(np.vstack((prev, dslice, post)), 0, 1) - ax.imshow(combined, vmin=vmin, vmax=vmax, cmap=cmap, - interpolation='nearest', origin='lower', - extent=[0, phys_sp[1] * 3, 0, phys_sp[0]]) + ax.imshow( + combined, + vmin=vmin, + vmax=vmax, + cmap=cmap, + interpolation="nearest", + origin="lower", + extent=[0, phys_sp[1] * 3, 0, phys_sp[0]], + ) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.grid(False) if label is not None: - ax.text(.5, .05, label, - transform=ax.transAxes, - horizontalalignment='center', - verticalalignment='top', - size=14, - bbox=dict(boxstyle="square,pad=0", ec='k', fc='k'), - color='w') - - -def plot_spikes(in_file, in_fft, spikes_list, cols=3, - labelfmt='t={0:.3f}s (z={1:d})', - out_file=None): + ax.text( + 0.5, + 0.05, + label, + transform=ax.transAxes, + horizontalalignment="center", + verticalalignment="top", + size=14, + bbox=dict(boxstyle="square,pad=0", ec="k", fc="k"), + color="w", + ) + + +def plot_spikes( + in_file, in_fft, spikes_list, cols=3, labelfmt="t={0:.3f}s (z={1:d})", out_file=None +): from mpl_toolkits.axes_grid1 import make_axes_locatable nii = nb.as_closest_canonical(nb.load(in_file)) @@ -159,28 +208,52 @@ def plot_spikes(in_file, in_fft, spikes_list, cols=3, ax2 = divider.new_vertical(size="100%", pad=0.1) fig.add_axes(ax2) - plot_slice_tern(data[..., z, t], prev=prev, post=post, spacing=zooms, - ax=ax2, - label=labelfmt.format(t * tstep, z)) - - plot_slice_tern(fft[..., z, t], prev=pvft, post=psft, vmin=-5, vmax=5, - cmap=get_parula(), ax=ax1) + plot_slice_tern( + data[..., z, t], + prev=prev, + post=post, + spacing=zooms, + ax=ax2, + label=labelfmt.format(t * tstep, z), + ) + + plot_slice_tern( + fft[..., z, t], + prev=pvft, + post=psft, + vmin=-5, + vmax=5, + cmap=get_parula(), + ax=ax1, + ) plt.tight_layout() if out_file is None: fname, ext = op.splitext(op.basename(in_file)) - if ext == '.gz': + if ext == ".gz": fname, _ = op.splitext(fname) - out_file = op.abspath('%s.svg' % fname) + out_file = op.abspath("%s.svg" % fname) - fig.savefig(out_file, format='svg', dpi=300, bbox_inches='tight') + fig.savefig(out_file, format="svg", dpi=300, bbox_inches="tight") return out_file -def plot_mosaic(img, out_file=None, ncols=8, title=None, overlay_mask=None, - bbox_mask_file=None, only_plot_noise=False, annotate=True, - vmin=None, vmax=None, cmap='Greys_r', plot_sagittal=True, - fig=None, zmax=128): +def plot_mosaic( + img, + out_file=None, + ncols=8, + title=None, + overlay_mask=None, + bbox_mask_file=None, + only_plot_noise=False, + annotate=True, + vmin=None, + vmax=None, + cmap="Greys_r", + plot_sagittal=True, + fig=None, + zmax=128, +): if isinstance(img, (str, bytes)): nii = nb.as_closest_canonical(nb.load(img)) @@ -189,7 +262,7 @@ def plot_mosaic(img, out_file=None, ncols=8, title=None, overlay_mask=None, else: img_data = img zooms = [1.0, 1.0, 1.0] - out_file = 'mosaic.svg' + out_file = "mosaic.svg" # Remove extra dimensions img_data = np.squeeze(img_data) @@ -201,8 +274,7 @@ def plot_mosaic(img, out_file=None, ncols=8, title=None, overlay_mask=None, img_data = _bbox(img_data, mask_file) if bbox_mask_file is not None: - bbox_data = nb.as_closest_canonical( - nb.load(bbox_mask_file)).get_data() + bbox_data = nb.as_closest_canonical(nb.load(bbox_mask_file)).get_data() img_data = _bbox(img_data, bbox_data) z_vals = np.array(list(range(0, img_data.shape[2]))) @@ -229,15 +301,13 @@ def plot_mosaic(img, out_file=None, ncols=8, title=None, overlay_mask=None, nrows += 1 if overlay_mask: - overlay_data = nb.as_closest_canonical( - nb.load(overlay_mask)).get_data() + overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_data() # create figures if fig is None: fig = plt.figure(figsize=(22, nrows * 3)) - est_vmin, est_vmax = _get_limits(img_data, - only_plot_noise=only_plot_noise) + est_vmin, est_vmax = _get_limits(img_data, only_plot_noise=only_plot_noise) if not vmin: vmin = est_vmin if not vmax: @@ -249,18 +319,32 @@ def plot_mosaic(img, out_file=None, ncols=8, title=None, overlay_mask=None, if overlay_mask: ax.set_rasterized(True) - plot_slice(img_data[:, :, z_val], vmin=vmin, vmax=vmax, - cmap=cmap, ax=ax, spacing=zooms[:2], - label='%d' % z_val, annotate=annotate) + plot_slice( + img_data[:, :, z_val], + vmin=vmin, + vmax=vmax, + cmap=cmap, + ax=ax, + spacing=zooms[:2], + label="%d" % z_val, + annotate=annotate, + ) if overlay_mask: from matplotlib import cm + msk_cmap = cm.Reds # @UndefinedVariable msk_cmap._init() alphas = np.linspace(0, 0.75, msk_cmap.N + 3) msk_cmap._lut[:, -1] = alphas - plot_slice(overlay_data[:, :, z_val], vmin=0, vmax=1, - cmap=msk_cmap, ax=ax, spacing=zooms[:2]) + plot_slice( + overlay_data[:, :, z_val], + vmin=0, + vmax=1, + cmap=msk_cmap, + ax=ax, + spacing=zooms[:2], + ) naxis += 1 if plot_sagittal: @@ -276,26 +360,32 @@ def plot_mosaic(img, out_file=None, ncols=8, title=None, overlay_mask=None, for x_val in list(range(start, stop, step))[:ncols]: ax = fig.add_subplot(nrows, ncols, naxis) - plot_slice(img_data[x_val, ...], vmin=vmin, vmax=vmax, - cmap=cmap, ax=ax, label='%d' % x_val, - spacing=[zooms[0], zooms[2]]) + plot_slice( + img_data[x_val, ...], + vmin=vmin, + vmax=vmax, + cmap=cmap, + ax=ax, + label="%d" % x_val, + spacing=[zooms[0], zooms[2]], + ) naxis += 1 fig.subplots_adjust( - left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, - hspace=0.05) + left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05 + ) if title: - fig.suptitle(title, fontsize='10') + fig.suptitle(title, fontsize="10") fig.subplots_adjust(wspace=0.002, hspace=0.002) if out_file is None: fname, ext = op.splitext(op.basename(img)) if ext == ".gz": fname, _ = op.splitext(fname) - out_file = op.abspath(fname + '_mosaic.svg') + out_file = op.abspath(fname + "_mosaic.svg") - fig.savefig(out_file, format='svg', dpi=300, bbox_inches='tight') + fig.savefig(out_file, format="svg", dpi=300, bbox_inches="tight") return out_file @@ -328,15 +418,20 @@ def plot_fd(fd_file, fd_radius, mean_fd_dist=None, figsize=DINA4_LANDSCAPE): sns.distplot(mean_fd_dist, ax=ax) ax.set_xlabel("Mean Frame Displacement (over all subjects) [mm]") mean_fd = fd_power.mean() - label = r'$\overline{{\text{{FD}}}}$ = {0:g}'.format(mean_fd) + label = r"$\overline{{\text{{FD}}}}$ = {0:g}".format(mean_fd) plot_vline(mean_fd, label, ax=ax) return fig def plot_dist( - main_file, mask_file, xlabel, distribution=None, xlabel2=None, - figsize=DINA4_LANDSCAPE): + main_file, + mask_file, + xlabel, + distribution=None, + xlabel2=None, + figsize=DINA4_LANDSCAPE, +): data = _get_values_inside_a_mask(main_file, mask_file) fig = plt.Figure(figsize=figsize) @@ -363,8 +458,15 @@ def plot_vline(cur_val, label, ax): vloc = (ylim[0] + ylim[1]) / 2.0 xlim = ax.get_xlim() pad = (xlim[0] + xlim[1]) / 100.0 - ax.text(cur_val - pad, vloc, label, color="blue", rotation=90, - verticalalignment='center', horizontalalignment='right') + ax.text( + cur_val - pad, + vloc, + label, + color="blue", + rotation=90, + verticalalignment="center", + horizontalalignment="right", + ) def _calc_rows_columns(ratio, n_images): @@ -381,15 +483,17 @@ def _calc_rows_columns(ratio, n_images): def _calc_fd(fd_file, fd_radius): from math import pi - lines = open(fd_file, 'r').readlines() + + lines = open(fd_file, "r").readlines() rows = [[float(x) for x in line.split()] for line in lines] cols = np.array([list(col) for col in zip(*rows)]) translations = np.transpose(np.abs(np.diff(cols[0:3, :]))) rotations = np.transpose(np.abs(np.diff(cols[3:6, :]))) - fd_power = np.sum(translations, axis=1) + \ - (fd_radius * pi / 180) * np.sum(rotations, axis=1) + fd_power = np.sum(translations, axis=1) + (fd_radius * pi / 180) * np.sum( + rotations, axis=1 + ) # FD is zero for the first time point fd_power = np.insert(fd_power, 0, 0) @@ -418,14 +522,13 @@ def _get_values_inside_a_mask(main_file, mask_file): return data -def plot_segmentation(anat_file, segmentation, out_file, - **kwargs): +def plot_segmentation(anat_file, segmentation, out_file, **kwargs): from nilearn.plotting import plot_anat - vmax = kwargs.get('vmax') - vmin = kwargs.get('vmin') + vmax = kwargs.get("vmax") + vmin = kwargs.get("vmin") - if kwargs.get('saturate', False): + if kwargs.get("saturate", False): vmax = np.percentile(nb.load(anat_file).get_data().reshape(-1), 70) if vmax is None and vmin is None: @@ -435,14 +538,15 @@ def plot_segmentation(anat_file, segmentation, out_file, disp = plot_anat( anat_file, - display_mode=kwargs.get('display_mode', 'ortho'), - cut_coords=kwargs.get('cut_coords', 8), - title=kwargs.get('title'), - vmax=vmax, vmin=vmin) + display_mode=kwargs.get("display_mode", "ortho"), + cut_coords=kwargs.get("cut_coords", 8), + title=kwargs.get("title"), + vmax=vmax, + vmin=vmin, + ) disp.add_contours( - segmentation, - levels=kwargs.get('levels', [1]), - colors=kwargs.get('colors', 'r')) + segmentation, levels=kwargs.get("levels", [1]), colors=kwargs.get("colors", "r") + ) disp.savefig(out_file) disp.close() disp = None @@ -542,6 +646,7 @@ def get_parula(): [0.9588714286, 0.8949, 0.1132428571], [0.9598238095, 0.9218333333, 0.0948380952], [0.9661, 0.9514428571, 0.0755333333], - [0.9763, 0.9831, 0.0538]] + [0.9763, 0.9831, 0.0538], + ] - return LinearSegmentedColormap.from_list('parula', cm_data) + return LinearSegmentedColormap.from_list("parula", cm_data) diff --git a/mriqc/workflows/utils.py b/mriqc/workflows/utils.py index 4b72d19a2..09cfb1b67 100644 --- a/mriqc/workflows/utils.py +++ b/mriqc/workflows/utils.py @@ -1,12 +1,5 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -# -# @Author: oesteban -# @Date: 2016-01-05 17:15:12 -# @Email: code@oscaresteban.es -# @Last modified by: oesteban """Helper functions for the workflows""" from distutils.version import StrictVersion from builtins import range diff --git a/setup.py b/setup.py index b42469842..ef55c5f56 100644 --- a/setup.py +++ b/setup.py @@ -8,13 +8,14 @@ # Give setuptools a hint to complain if it's too old a version # 30.3.0 allows us to put most metadata in setup.cfg # Should match pyproject.toml -SETUP_REQUIRES = ['setuptools >= 30.3.0'] +SETUP_REQUIRES = ["setuptools >= 30.3.0"] # This enables setuptools to install wheel on-the-fly -SETUP_REQUIRES += ['wheel'] if 'bdist_wheel' in sys.argv else [] +SETUP_REQUIRES += ["wheel"] if "bdist_wheel" in sys.argv else [] -if __name__ == '__main__': - setup(name='mriqc', - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - setup_requires=SETUP_REQUIRES, - ) +if __name__ == "__main__": + setup( + name="mriqc", + version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), + setup_requires=SETUP_REQUIRES, + )