-
Notifications
You must be signed in to change notification settings - Fork 62
/
core.py
504 lines (420 loc) · 22.6 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
from __future__ import absolute_import, division, print_function
import numpy as np
import time
import glob
import sys
from os import makedirs, remove
from os.path import exists, join as pjoin
import nibabel
import pickle
import amico.scheme
from amico.preproc import debiasRician
import amico.lut
import amico.models
from amico.lut import is_valid
from amico.progressbar import ProgressBar
from dipy.core.gradients import gradient_table
import dipy.reconst.dti as dti
from amico.util import LOG, NOTE, WARNING, ERROR
def setup( lmax = 12, ndirs = 32761 ) :
"""General setup/initialization of the AMICO framework.
Parameters
----------
lmax : int
Maximum SH order to use for the rotation phase (default : 12)
ndirs : int
Number of directions on the half of the sphere representing the possible orientations of the response functions (default : 32761)
"""
if not is_valid(ndirs):
ERROR( 'Unsupported value for ndirs.\nNote: supported values for ndirs are [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 32761 (default)]' )
amico.lut.precompute_rotation_matrices( lmax, ndirs )
class Evaluation :
"""Class to hold all the information (data and parameters) when performing an
evaluation with the AMICO framework.
"""
def __init__( self, study_path, subject, output_path=None ) :
"""Setup the data structure with default values.
Parameters
----------
study_path : string
The path to the folder containing all the subjects from one study
subject : string
The path (relative to previous folder) to the subject folder
output_path : string
Optionally sets a custom full path for the output. Leave as None
for default behaviour - output in study_path/subject/AMICO/<MODEL>
"""
self.htable = None
self.niiDWI = None # set by "load_data" method
self.niiDWI_img = None
self.scheme = None
self.niiMASK = None
self.niiMASK_img = None
self.model = None # set by "set_model" method
self.KERNELS = None # set by "load_kernels" method
self.RESULTS = None # set by "fit" method
self.mean_b0s = None # set by "load_data" method
# store all the parameters of an evaluation with AMICO
self.CONFIG = {}
self.set_config('study_path', study_path)
self.set_config('subject', subject)
self.set_config('DATA_path', pjoin( study_path, subject ))
self.set_config('OUTPUT_path', output_path)
self.set_config('peaks_filename', None)
self.set_config('doNormalizeSignal', True)
self.set_config('doKeepb0Intact', False) # does change b0 images in the predicted signal
self.set_config('doComputeNRMSE', False)
self.set_config('doSaveCorrectedDWI', False)
self.set_config('doMergeB0', False) # Merge b0 volumes
self.set_config('doDebiasSignal', False) # Flag to remove Rician bias
self.set_config('DWI-SNR', None) # SNR of DWI image: SNR = b0/sigma
def set_config( self, key, value ) :
self.CONFIG[ key ] = value
def get_config( self, key ) :
return self.CONFIG.get( key )
def load_data( self, dwi_filename = 'DWI.nii',
scheme_filename = 'DWI.scheme', mask_filename = None, b0_thr = 0 ) :
"""Load the diffusion signal and its corresponding acquisition scheme.
Parameters
----------
dwi_filename : string
The file name of the DWI data, relative to the subject folder (default : 'DWI.nii')
scheme_filename : string
The file name of the corresponding acquisition scheme (default : 'DWI.scheme')
mask_filename : string
The file name of the (optional) binary mask (default : None)
b0_thr : float
The threshold below which a b-value is considered a b0 (default : 0)
"""
# Loading data, acquisition scheme and mask (optional)
LOG( '\n-> Loading data:' )
tic = time.time()
print('\t* DWI signal')
self.set_config('dwi_filename', dwi_filename)
self.niiDWI = nibabel.load( pjoin( self.get_config('DATA_path'), dwi_filename) )
self.niiDWI_img = self.niiDWI.get_data().astype(np.float32)
hdr = self.niiDWI.header if nibabel.__version__ >= '2.0.0' else self.niiDWI.get_header()
self.set_config('dim', self.niiDWI_img.shape[:3])
self.set_config('pixdim', tuple( hdr.get_zooms()[:3] ))
print('\t\t- dim = %d x %d x %d x %d' % self.niiDWI_img.shape)
print('\t\t- pixdim = %.3f x %.3f x %.3f' % self.get_config('pixdim'))
# Scale signal intensities (if necessary)
if ( np.isfinite(hdr['scl_slope']) and np.isfinite(hdr['scl_inter']) and hdr['scl_slope'] != 0 and
( hdr['scl_slope'] != 1 or hdr['scl_inter'] != 0 ) ):
print('\t\t- rescaling data ', end='')
self.niiDWI_img = self.niiDWI_img * hdr['scl_slope'] + hdr['scl_inter']
print('[OK]')
print('\t* Acquisition scheme')
self.set_config('scheme_filename', scheme_filename)
self.set_config('b0_thr', b0_thr)
self.scheme = amico.scheme.Scheme( pjoin( self.get_config('DATA_path'), scheme_filename), b0_thr )
print('\t\t- %d samples, %d shells' % ( self.scheme.nS, len(self.scheme.shells) ))
print('\t\t- %d @ b=0' % ( self.scheme.b0_count ), end=' ')
for i in range(len(self.scheme.shells)) :
print(', %d @ b=%.1f' % ( len(self.scheme.shells[i]['idx']), self.scheme.shells[i]['b'] ), end=' ')
print()
if self.scheme.nS != self.niiDWI_img.shape[3] :
ERROR( 'Scheme does not match with DWI data' )
print('\t* Binary mask')
if mask_filename is not None :
self.niiMASK = nibabel.load( pjoin( self.get_config('DATA_path'), mask_filename) )
self.niiMASK_img = self.niiMASK.get_data().astype(np.uint8)
niiMASK_hdr = self.niiMASK.header if nibabel.__version__ >= '2.0.0' else self.niiMASK.get_header()
print('\t\t- dim = %d x %d x %d' % self.niiMASK_img.shape[:3])
print('\t\t- pixdim = %.3f x %.3f x %.3f' % niiMASK_hdr.get_zooms()[:3])
if self.niiMASK.ndim != 3 :
ERROR( 'The provided MASK if 4D, but a 3D dataset is expected' )
if self.get_config('dim') != self.niiMASK_img.shape[:3] :
ERROR( 'MASK geometry does not match with DWI data' )
else :
self.niiMASK = None
self.niiMASK_img = np.ones( self.get_config('dim') )
print('\t\t- not specified')
print('\t\t- voxels = %d' % np.count_nonzero(self.niiMASK_img))
LOG( ' [ %.1f seconds ]' % ( time.time() - tic ) )
# Preprocessing
LOG( '\n-> Preprocessing:' )
tic = time.time()
if self.get_config('doDebiasSignal') :
print('\t* Debiasing signal... ', end='')
sys.stdout.flush()
if self.get_config('DWI-SNR') == None:
ERROR( "Set noise variance for debiasing (eg. ae.set_config('RicianNoiseSigma', sigma))" )
self.niiDWI_img = debiasRician(self.niiDWI_img,self.get_config('DWI-SNR'),self.niiMASK_img,self.scheme)
print(' [OK]')
if self.get_config('doNormalizeSignal') :
print('\t* Normalizing to b0... ', end='')
sys.stdout.flush()
if self.scheme.b0_count > 0 :
self.mean_b0s = np.mean( self.niiDWI_img[:,:,:,self.scheme.b0_idx], axis=3 )
else:
ERROR( 'No b0 volume to normalize signal with' )
norm_factor = self.mean_b0s.copy()
idx = self.mean_b0s <= 0
norm_factor[ idx ] = 1
norm_factor = 1 / norm_factor
norm_factor[ idx ] = 0
for i in range(self.scheme.nS) :
self.niiDWI_img[:,:,:,i] *= norm_factor
print('[ min=%.2f, mean=%.2f, max=%.2f ]' % ( self.niiDWI_img.min(), self.niiDWI_img.mean(), self.niiDWI_img.max() ))
if self.get_config('doMergeB0') :
print('\t* Merging multiple b0 volume(s)')
mean = np.expand_dims( np.mean( self.niiDWI_img[:,:,:,self.scheme.b0_idx], axis=3 ), axis=3 )
self.niiDWI_img = np.concatenate( (mean, self.niiDWI_img[:,:,:,self.scheme.dwi_idx]), axis=3 )
else :
print('\t* Keeping all b0 volume(s)')
LOG( ' [ %.1f seconds ]' % ( time.time() - tic ) )
def set_model( self, model_name ) :
"""Set the model to use to describe the signal contributions in each voxel.
Parameters
----------
model_name : string
The name of the model (must match a class name in "amico.models" module)
"""
# Call the specific model constructor
if hasattr(amico.models, model_name ) :
self.model = getattr(amico.models,model_name)()
else :
ERROR( 'Model "%s" not recognized' % model_name )
self.set_config('ATOMS_path', pjoin( self.get_config('study_path'), 'kernels', self.model.id ))
# setup default parameters for fitting the model (can be changed later on)
self.set_solver()
def set_solver( self, **params ) :
"""Setup the specific parameters of the solver to fit the model.
Dispatch to the proper function, depending on the model; a model shoudl provide a "set_solver" function to set these parameters.
"""
if self.model is None :
ERROR( 'Model not set; call "set_model()" method first' )
self.set_config('solver_params', self.model.set_solver( **params ))
def generate_kernels( self, regenerate = False, lmax = 12, ndirs = 32761 ) :
"""Generate the high-resolution response functions for each compartment.
Dispatch to the proper function, depending on the model.
Parameters
----------
regenerate : boolean
Regenerate kernels if they already exist (default : False)
lmax : int
Maximum SH order to use for the rotation procedure (default : 12)
ndirs : int
Number of directions on the half of the sphere representing the possible orientations of the response functions (default : 32761)
"""
if self.scheme is None :
ERROR( 'Scheme not loaded; call "load_data()" first' )
if self.model is None :
ERROR( 'Model not set; call "set_model()" method first' )
if not is_valid(ndirs):
ERROR( 'Unsupported value for ndirs.\nNote: Supported values for ndirs are [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 32761 (default)]' )
# store some values for later use
self.set_config('lmax', lmax)
self.set_config('ndirs', ndirs)
self.model.scheme = self.scheme
LOG( '\n-> Creating LUT for "%s" model:' % self.model.name )
# check if kernels were already generated
tmp = glob.glob( pjoin(self.get_config('ATOMS_path'),'A_*.npy') )
if len(tmp)>0 and not regenerate :
LOG( ' [ LUT already computed. USe option "regenerate=True" to force regeneration ]' )
return
# create folder or delete existing files (if any)
if not exists( self.get_config('ATOMS_path') ) :
makedirs( self.get_config('ATOMS_path') )
else :
for f in glob.glob( pjoin(self.get_config('ATOMS_path'),'*') ) :
remove( f )
# auxiliary data structures
aux = amico.lut.load_precomputed_rotation_matrices( lmax, ndirs )
idx_IN, idx_OUT = amico.lut.aux_structures_generate( self.scheme, lmax )
# Dispatch to the right handler for each model
tic = time.time()
self.model.generate( self.get_config('ATOMS_path'), aux, idx_IN, idx_OUT, ndirs )
LOG( ' [ %.1f seconds ]' % ( time.time() - tic ) )
def load_kernels( self ) :
"""Load rotated kernels and project to the specific gradient scheme of this subject.
Dispatch to the proper function, depending on the model.
"""
if self.model is None :
ERROR( 'Model not set; call "set_model()" method first' )
if self.scheme is None :
ERROR( 'Scheme not loaded; call "load_data()" first' )
tic = time.time()
LOG( '\n-> Resampling LUT for subject "%s":' % self.get_config('subject') )
# auxiliary data structures
idx_OUT, Ylm_OUT = amico.lut.aux_structures_resample( self.scheme, self.get_config('lmax') )
# hash table
self.htable = amico.lut.load_precomputed_hash_table( self.get_config('ndirs') )
# Dispatch to the right handler for each model
self.KERNELS = self.model.resample( self.get_config('ATOMS_path'), idx_OUT, Ylm_OUT, self.get_config('doMergeB0'), self.get_config('ndirs') )
LOG( ' [ %.1f seconds ]' % ( time.time() - tic ) )
def fit( self ) :
"""Fit the model to the data iterating over all voxels (in the mask) one after the other.
Call the appropriate fit() method of the actual model used.
"""
if self.niiDWI is None :
ERROR( 'Data not loaded; call "load_data()" first' )
if self.model is None :
ERROR( 'Model not set; call "set_model()" first' )
if self.KERNELS is None :
ERROR( 'Response functions not generated; call "generate_kernels()" and "load_kernels()" first' )
if self.KERNELS['model'] != self.model.id :
ERROR( 'Response functions were not created with the same model' )
self.set_config('fit_time', None)
totVoxels = np.count_nonzero(self.niiMASK_img)
LOG( '\n-> Fitting "%s" model to %d voxels:' % ( self.model.name, totVoxels ) )
# setup fitting directions
peaks_filename = self.get_config('peaks_filename')
if peaks_filename is None :
DIRs = np.zeros( [self.get_config('dim')[0], self.get_config('dim')[1], self.get_config('dim')[2], 3], dtype=np.float32 )
nDIR = 1
if self.get_config('doMergeB0'):
gtab = gradient_table( np.hstack((0,self.scheme.b[self.scheme.dwi_idx])), np.vstack((np.zeros((1,3)),self.scheme.raw[self.scheme.dwi_idx,:3])) )
else:
gtab = gradient_table( self.scheme.b, self.scheme.raw[:,:3] )
DTI = dti.TensorModel( gtab )
else :
niiPEAKS = nibabel.load( pjoin( self.get_config('DATA_path'), peaks_filename) )
DIRs = niiPEAKS.get_data().astype(np.float32)
nDIR = np.floor( DIRs.shape[3]/3 )
print('\t* peaks dim = %d x %d x %d x %d' % DIRs.shape[:4])
if DIRs.shape[:3] != self.niiMASK_img.shape[:3] :
ERROR( 'PEAKS geometry does not match with DWI data' )
# setup other output files
MAPs = np.zeros( [self.get_config('dim')[0], self.get_config('dim')[1],
self.get_config('dim')[2], len(self.model.maps_name)], dtype=np.float32 )
if self.get_config('doComputeNRMSE') :
NRMSE = np.zeros( [self.get_config('dim')[0],
self.get_config('dim')[1], self.get_config('dim')[2]], dtype=np.float32 )
if self.get_config('doSaveCorrectedDWI') :
DWI_corrected = np.zeros(self.niiDWI.shape, dtype=np.float32)
# fit the model to the data
# =========================
t = time.time()
progress = ProgressBar( n=totVoxels, prefix=" ", erase=False )
for iz in range(self.niiMASK_img.shape[2]) :
for iy in range(self.niiMASK_img.shape[1]) :
for ix in range(self.niiMASK_img.shape[0]) :
if self.niiMASK_img[ix,iy,iz]==0 :
continue
# prepare the signal
y = self.niiDWI_img[ix,iy,iz,:].astype(np.float64)
y[ y < 0 ] = 0 # [NOTE] this should not happen!
# fitting directions
if peaks_filename is None :
dirs = DTI.fit( y ).directions[0]
else :
dirs = DIRs[ix,iy,iz,:]
# dispatch to the right handler for each model
MAPs[ix,iy,iz,:], DIRs[ix,iy,iz,:], x, A = self.model.fit( y, dirs.reshape(-1,3), self.KERNELS, self.get_config('solver_params'), self.htable )
# compute fitting error
if self.get_config('doComputeNRMSE') :
y_est = np.dot( A, x )
den = np.sum(y**2)
NRMSE[ix,iy,iz] = np.sqrt( np.sum((y-y_est)**2) / den ) if den > 1e-16 else 0
if self.get_config('doSaveCorrectedDWI') :
if self.model.name == 'Free-Water' :
n_iso = len(self.model.d_isos)
# keep only FW components of the estimate
x[0:x.shape[0]-n_iso] = 0
# y_fw_corrected below is the predicted signal by the anisotropic part (no iso part)
y_fw_part = np.dot( A, x )
# y is the original signal
y_fw_corrected = y - y_fw_part
y_fw_corrected[ y_fw_corrected < 0 ] = 0 # [NOTE] this should not happen!
if self.get_config('doNormalizeSignal') and self.scheme.b0_count > 0 :
y_fw_corrected = y_fw_corrected * self.mean_b0s[ix,iy,iz]
if self.get_config('doKeepb0Intact') and self.scheme.b0_count > 0 :
# put original b0 data back in.
y_fw_corrected[self.scheme.b0_idx] = y[self.scheme.b0_idx]*self.mean_b0s[ix,iy,iz]
DWI_corrected[ix,iy,iz,:] = y_fw_corrected
progress.update()
self.set_config('fit_time', time.time()-t)
LOG( ' [ %s ]' % ( time.strftime("%Hh %Mm %Ss", time.gmtime(self.get_config('fit_time')) ) ) )
# store results
self.RESULTS = {}
self.RESULTS['DIRs'] = DIRs
self.RESULTS['MAPs'] = MAPs
if self.get_config('doComputeNRMSE') :
self.RESULTS['NRMSE'] = NRMSE
if self.get_config('doSaveCorrectedDWI') :
self.RESULTS['DWI_corrected'] = DWI_corrected
def save_results( self, path_suffix = None ) :
"""Save the output (directions, maps etc).
Parameters
----------
path_suffix : string
Text to be appended to the output path (default : None)
"""
if self.RESULTS is None :
ERROR( 'Model not fitted to the data; call "fit()" first' )
if self.get_config('OUTPUT_path') is None:
RESULTS_path = pjoin( 'AMICO', self.model.id )
if path_suffix :
RESULTS_path = RESULTS_path +'_'+ path_suffix
self.RESULTS['RESULTS_path'] = RESULTS_path
LOG( '\n-> Saving output to "%s/*":' % RESULTS_path )
# delete previous output
RESULTS_path = pjoin( self.get_config('DATA_path'), RESULTS_path )
else:
RESULTS_path = self.get_config('OUTPUT_path')
if path_suffix :
RESULTS_path = RESULTS_path +'_'+ path_suffix
self.RESULTS['RESULTS_path'] = RESULTS_path
LOG( '\n-> Saving output to "%s/*":' % RESULTS_path )
if not exists( RESULTS_path ) :
makedirs( RESULTS_path )
else :
for f in glob.glob( pjoin(RESULTS_path,'*') ) :
remove( f )
# configuration
print('\t- configuration', end=' ')
with open( pjoin(RESULTS_path,'config.pickle'), 'wb+' ) as fid :
pickle.dump( self.CONFIG, fid, protocol=2 )
print(' [OK]')
# estimated orientations
print('\t- FIT_dir.nii.gz', end=' ')
niiMAP_img = self.RESULTS['DIRs']
affine = self.niiDWI.affine if nibabel.__version__ >= '2.0.0' else self.niiDWI.get_affine()
niiMAP = nibabel.Nifti1Image( niiMAP_img, affine )
niiMAP_hdr = niiMAP.header if nibabel.__version__ >= '2.0.0' else niiMAP.get_header()
niiMAP_hdr['cal_min'] = -1
niiMAP_hdr['cal_max'] = 1
niiMAP_hdr['scl_slope'] = 1
niiMAP_hdr['scl_inter'] = 0
nibabel.save( niiMAP, pjoin(RESULTS_path, 'FIT_dir.nii.gz') )
print(' [OK]')
# fitting error
if self.get_config('doComputeNRMSE') :
print('\t- FIT_nrmse.nii.gz', end=' ')
niiMAP_img = self.RESULTS['NRMSE']
niiMAP = nibabel.Nifti1Image( niiMAP_img, affine )
niiMAP_hdr = niiMAP.header if nibabel.__version__ >= '2.0.0' else niiMAP.get_header()
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = 1
niiMAP_hdr['scl_slope'] = 1
niiMAP_hdr['scl_inter'] = 0
nibabel.save( niiMAP, pjoin(RESULTS_path, 'FIT_nrmse.nii.gz') )
print(' [OK]')
if self.get_config('doSaveCorrectedDWI') :
if self.model.name == 'Free-Water' :
print('\t- dwi_fw_corrected.nii.gz', end=' ')
niiMAP_img = self.RESULTS['DWI_corrected']
niiMAP = nibabel.Nifti1Image( niiMAP_img, affine )
niiMAP_hdr = niiMAP.header if nibabel.__version__ >= '2.0.0' else niiMAP.get_header()
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = 1
nibabel.save( niiMAP, pjoin(RESULTS_path, 'dwi_fw_corrected.nii.gz') )
print(' [OK]')
else :
WARNING( '"doSaveCorrectedDWI" option not supported for "%s" model' % self.model.name )
# voxelwise maps
for i in range( len(self.model.maps_name) ) :
print('\t- FIT_%s.nii.gz' % self.model.maps_name[i], end=' ')
niiMAP_img = self.RESULTS['MAPs'][:,:,:,i]
niiMAP = nibabel.Nifti1Image( niiMAP_img, affine )
niiMAP_hdr = niiMAP.header if nibabel.__version__ >= '2.0.0' else niiMAP.get_header()
niiMAP_hdr['descrip'] = self.model.maps_descr[i]
niiMAP_hdr['cal_min'] = niiMAP_img.min()
niiMAP_hdr['cal_max'] = niiMAP_img.max()
niiMAP_hdr['scl_slope'] = 1
niiMAP_hdr['scl_inter'] = 0
nibabel.save( niiMAP, pjoin(RESULTS_path, 'FIT_%s.nii.gz' % self.model.maps_name[i] ) )
print(' [OK]')
LOG( ' [ DONE ]' )