/
dcmmeta.py
1813 lines (1561 loc) · 68.9 KB
/
dcmmeta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
DcmMeta header extension and NiftiWrapper for working with extended Niftis.
"""
from __future__ import print_function
import sys, re, json, warnings
from copy import deepcopy
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict
import numpy as np
import nibabel as nb
from nibabel.nifti1 import Nifti1Extension
from nibabel.spatialimages import HeaderDataError
with warnings.catch_warnings():
warnings.simplefilter('ignore')
from nibabel.nicom.dicomwrappers import wrapper_from_data
from .utils import iteritems, unicode_str, PY2
dcm_meta_ecode = 0
_meta_version = 0.6
_req_base_keys_map= {0.5 : set(('dcmmeta_affine',
'dcmmeta_slice_dim',
'dcmmeta_shape',
'dcmmeta_version',
'global',
)
),
0.6 : set(('dcmmeta_affine',
'dcmmeta_reorient_transform',
'dcmmeta_slice_dim',
'dcmmeta_shape',
'dcmmeta_version',
'global',
)
),
}
'''Minimum required keys in the base dictionaty to be considered valid'''
def is_constant(sequence, period=None):
'''Returns true if all elements in (each period of) the sequence are equal.
Parameters
----------
sequence : sequence
The sequence of elements to check.
period : int
If not None then each subsequence of that length is checked.
'''
if period is None:
return all(val == sequence[0] for val in sequence)
else:
if period <= 1:
raise ValueError('The period must be greater than one')
seq_len = len(sequence)
if seq_len % period != 0:
raise ValueError('The sequence length is not evenly divisible by '
'the period length.')
for period_idx in range(seq_len // period):
start_idx = period_idx * period
end_idx = start_idx + period
if not all(val == sequence[start_idx]
for val in sequence[start_idx:end_idx]):
return False
return True
def is_repeating(sequence, period):
'''Returns true if the elements in the sequence repeat with the given
period.
Parameters
----------
sequence : sequence
The sequence of elements to check.
period : int
The period over which the elements should repeat.
'''
seq_len = len(sequence)
if period <= 1 or period >= seq_len:
raise ValueError('The period must be greater than one and less than '
'the length of the sequence')
if seq_len % period != 0:
raise ValueError('The sequence length is not evenly divisible by the '
'period length.')
for period_idx in range(1, seq_len // period):
start_idx = period_idx * period
end_idx = start_idx + period
if sequence[start_idx:end_idx] != sequence[:period]:
return False
return True
class InvalidExtensionError(Exception):
def __init__(self, msg):
'''Exception denoting than a DcmMetaExtension is invalid.'''
self.msg = msg
def __str__(self):
return 'The extension is not valid: %s' % self.msg
class DcmMetaExtension(Nifti1Extension):
'''Nifti extension for storing a summary of the meta data from the source
DICOM files.
'''
@property
def reorient_transform(self):
'''The transformation due to reorientation of the data array. Can be
used to update directional DICOM meta data (after converting to RAS if
needed) into the same space as the affine.'''
if self.version < 0.6:
return None
if self._content['dcmmeta_reorient_transform'] is None:
return None
return np.array(self._content['dcmmeta_reorient_transform'])
@reorient_transform.setter
def reorient_transform(self, value):
if not value is None and value.shape != (4, 4):
raise ValueError("The reorient_transform must be none or (4,4) "
"array")
if value is None:
self._content['dcmmeta_reorient_transform'] = None
else:
self._content['dcmmeta_reorient_transform'] = value.tolist()
@property
def affine(self):
'''The affine associated with the meta data. If this differs from the
image affine, the per-slice meta data will not be used. '''
return np.array(self._content['dcmmeta_affine'])
@affine.setter
def affine(self, value):
if value.shape != (4, 4):
raise ValueError("Invalid shape for affine")
self._content['dcmmeta_affine'] = value.tolist()
@property
def slice_dim(self):
'''The index of the slice dimension associated with the per-slice
meta data.'''
return self._content['dcmmeta_slice_dim']
@slice_dim.setter
def slice_dim(self, value):
if not value is None and not (0 <= value < 3):
raise ValueError("The slice dimension must be between zero and "
"two")
self._content['dcmmeta_slice_dim'] = value
@property
def shape(self):
'''The shape of the data associated with the meta data. Defines the
number of values for the meta data classifications.'''
return tuple(self._content['dcmmeta_shape'])
@shape.setter
def shape(self, value):
if not (3 <= len(value) < 6):
raise ValueError("The shape must have a length between three and "
"six")
self._content['dcmmeta_shape'][:] = value
@property
def version(self):
'''The version of the meta data extension.'''
return self._content['dcmmeta_version']
@version.setter
def version(self, value):
'''Set the version of the meta data extension.'''
self._content['dcmmeta_version'] = value
@property
def slice_normal(self):
'''The slice normal associated with the per-slice meta data.'''
slice_dim = self.slice_dim
if slice_dim is None:
return None
return np.array(self.affine[slice_dim][:3])
@property
def n_slices(self):
'''The number of slices associated with the per-slice meta data.'''
slice_dim = self.slice_dim
if slice_dim is None:
return None
return self.shape[slice_dim]
classifications = (('global', 'const'),
('global', 'slices'),
('time', 'samples'),
('time', 'slices'),
('vector', 'samples'),
('vector', 'slices'),
)
'''The classifications used to separate meta data based on if and how the
values repeat. Each class is a tuple with a base class and a sub class.'''
def get_valid_classes(self):
'''Return the meta data classifications that are valid for this
extension.
Returns
-------
valid_classes : tuple
The classifications that are valid for this extension (based on its
shape).
'''
shape = self.shape
n_dims = len(shape)
if n_dims == 3:
return self.classifications[:2]
elif n_dims == 4:
return self.classifications[:4]
elif n_dims == 5:
if shape[3] != 1:
return self.classifications
else:
return self.classifications[:2] + self.classifications[-2:]
else:
raise ValueError("There must be 3 to 5 dimensions.")
def get_multiplicity(self, classification):
'''Get the number of meta data values for all meta data of the provided
classification.
Parameters
----------
classification : tuple
The meta data classification.
Returns
-------
multiplicity : int
The number of values for any meta data of the provided
`classification`.
'''
if not classification in self.get_valid_classes():
raise ValueError("Invalid classification: %s" % classification)
base, sub = classification
shape = self.shape
n_vals = 1
if sub == 'slices':
n_vals = self.n_slices
if n_vals is None:
return 0
if base == 'vector':
n_vals *= shape[3]
elif base == 'global':
for dim_size in shape[3:]:
n_vals *= dim_size
elif sub == 'samples':
if base == 'time':
n_vals = shape[3]
if len(shape) == 5:
n_vals *= shape[4]
elif base == 'vector':
n_vals = shape[4]
return n_vals
def check_valid(self):
'''Check if the extension is valid.
Raises
------
InvalidExtensionError
The extension is missing required meta data or classifications, or
some element(s) have the wrong number of values for their
classification.
'''
#Check for the required base keys in the json data
if not _req_base_keys_map[self.version] <= set(self._content):
raise InvalidExtensionError('Missing one or more required keys')
#Check the orientation/shape/version
if self.affine.shape != (4, 4):
raise InvalidExtensionError('Affine has incorrect shape')
slice_dim = self.slice_dim
if slice_dim is not None:
if not (0 <= slice_dim < 3):
raise InvalidExtensionError('Slice dimension is not valid')
if not (3 <= len(self.shape) < 6):
raise InvalidExtensionError('Shape is not valid')
#Check all required meta dictionaries, make sure values have correct
#multiplicity
valid_classes = self.get_valid_classes()
for classes in valid_classes:
if not classes[0] in self._content:
raise InvalidExtensionError('Missing required base '
'classification %s' % classes[0])
if not classes[1] in self._content[classes[0]]:
raise InvalidExtensionError(('Missing required sub '
'classification %s in base '
'classification %s') % classes)
cls_meta = self.get_class_dict(classes)
cls_mult = self.get_multiplicity(classes)
if cls_mult == 0 and len(cls_meta) != 0:
raise InvalidExtensionError('Slice dim is None but per-slice '
'meta data is present')
elif cls_mult > 1:
for key, vals in iteritems(cls_meta):
n_vals = len(vals)
if n_vals != cls_mult:
msg = (('Incorrect number of values for key %s with '
'classification %s, expected %d found %d') %
(key, classes, cls_mult, n_vals)
)
raise InvalidExtensionError(msg)
#Check that all keys are uniquely classified
for classes in valid_classes:
for other_classes in valid_classes:
if classes == other_classes:
continue
intersect = (set(self.get_class_dict(classes)) &
set(self.get_class_dict(other_classes))
)
if len(intersect) != 0:
raise InvalidExtensionError("One or more keys have "
"multiple classifications")
def get_keys(self):
'''Get a list of all the meta data keys that are available.'''
keys = []
for base_class, sub_class in self.get_valid_classes():
keys += self._content[base_class][sub_class].keys()
return keys
def get_classification(self, key):
'''Get the classification for the given `key`.
Parameters
----------
key : str
The meta data key.
Returns
-------
classification : tuple or None
The classification tuple for the provided key or None if the key is
not found.
'''
for base_class, sub_class in self.get_valid_classes():
if key in self._content[base_class][sub_class]:
return (base_class, sub_class)
return None
def get_class_dict(self, classification):
'''Get the dictionary for the given classification.
Parameters
----------
classification : tuple
The meta data classification.
Returns
-------
meta_dict : dict
The dictionary for the provided classification.
'''
base, sub = classification
return self._content[base][sub]
def get_values(self, key):
'''Get all values for the provided key.
Parameters
----------
key : str
The meta data key.
Returns
-------
values
The value or values for the given key. The number of values
returned depends on the classification (see 'get_multiplicity').
'''
classification = self.get_classification(key)
if classification is None:
return None
return self.get_class_dict(classification)[key]
def get_values_and_class(self, key):
'''Get the values and the classification for the provided key.
Parameters
----------
key : str
The meta data key.
Returns
-------
vals_and_class : tuple
None for both the value and classification if the key is not found.
'''
classification = self.get_classification(key)
if classification is None:
return (None, None)
return (self.get_class_dict(classification)[key], classification)
def filter_meta(self, filter_func):
'''Filter the meta data.
Parameters
----------
filter_func : callable
Must take a key and values as parameters and return True if they
should be filtered out.
'''
for classes in self.get_valid_classes():
filtered = []
curr_dict = self.get_class_dict(classes)
for key, values in iteritems(curr_dict):
if filter_func(key, values):
filtered.append(key)
for key in filtered:
del curr_dict[key]
def clear_slice_meta(self):
'''Clear all meta data that is per slice.'''
for base_class, sub_class in self.get_valid_classes():
if sub_class == 'slices':
self.get_class_dict((base_class, sub_class)).clear()
def get_subset(self, dim, idx):
'''Get a DcmMetaExtension containing a subset of the meta data.
Parameters
----------
dim : int
The dimension we are taking the subset along.
idx : int
The position on the dimension `dim` for the subset.
Returns
-------
result : DcmMetaExtension
A new DcmMetaExtension corresponding to the subset.
'''
if not 0 <= dim < 5:
raise ValueError("The argument 'dim' must be in the range [0, 5).")
shape = self.shape
valid_classes = self.get_valid_classes()
#Make an empty extension for the result
result_shape = list(shape)
result_shape[dim] = 1
while result_shape[-1] == 1 and len(result_shape) > 3:
result_shape = result_shape[:-1]
result = self.make_empty(result_shape,
self.affine,
self.reorient_transform,
self.slice_dim
)
for src_class in valid_classes:
#Constants remain constant
if src_class == ('global', 'const'):
for key, val in iteritems(self.get_class_dict(src_class)):
result.get_class_dict(src_class)[key] = deepcopy(val)
continue
if dim == self.slice_dim:
if src_class[1] != 'slices':
for key, vals in iteritems(self.get_class_dict(src_class)):
result.get_class_dict(src_class)[key] = deepcopy(vals)
else:
result._copy_slice(self, src_class, idx)
elif dim < 3:
for key, vals in iteritems(self.get_class_dict(src_class)):
result.get_class_dict(src_class)[key] = deepcopy(vals)
elif dim == 3:
result._copy_sample(self, src_class, 'time', idx)
else:
result._copy_sample(self, src_class, 'vector', idx)
return result
def to_json(self):
'''Return the extension encoded as a JSON string.'''
self.check_valid()
return json.dumps(self._content, indent=4)
@classmethod
def from_json(klass, json_str):
'''Create an extension from the JSON string representation.'''
result = klass(dcm_meta_ecode, json_str)
result.check_valid()
return result
@classmethod
def make_empty(klass, shape, affine, reorient_transform=None,
slice_dim=None):
'''Make an empty DcmMetaExtension.
Parameters
----------
shape : tuple
The shape of the data associated with this extension.
affine : array
The RAS affine for the data associated with this extension.
reorient_transform : array
The transformation matrix representing any reorientation of the
data array.
slice_dim : int
The index of the slice dimension for the data associated with this
extension
Returns
-------
result : DcmMetaExtension
An empty DcmMetaExtension with the required values set to the
given arguments.
'''
result = klass(dcm_meta_ecode, '{}')
result._content['global'] = OrderedDict()
result._content['global']['const'] = OrderedDict()
result._content['global']['slices'] = OrderedDict()
if len(shape) > 3 and shape[3] != 1:
result._content['time'] = OrderedDict()
result._content['time']['samples'] = OrderedDict()
result._content['time']['slices'] = OrderedDict()
if len(shape) > 4:
result._content['vector'] = OrderedDict()
result._content['vector']['samples'] = OrderedDict()
result._content['vector']['slices'] = OrderedDict()
result._content['dcmmeta_shape'] = []
result.shape = shape
result.affine = affine
result.reorient_transform = reorient_transform
result.slice_dim = slice_dim
result.version = _meta_version
return result
@classmethod
def from_runtime_repr(klass, runtime_repr):
'''Create an extension from the Python runtime representation (nested
dictionaries).
'''
result = klass(dcm_meta_ecode, '{}')
result._content = runtime_repr
result.check_valid()
return result
@classmethod
def from_sequence(klass, seq, dim, affine=None, slice_dim=None):
'''Create an extension from a sequence of extensions.
Parameters
----------
seq : sequence
The sequence of DcmMetaExtension objects.
dim : int
The dimension to merge the extensions along.
affine : array
The affine to use in the resulting extension. If None, the affine
from the first extension in `seq` will be used.
slice_dim : int
The slice dimension to use in the resulting extension. If None, the
slice dimension from the first extension in `seq` will be used.
Returns
-------
result : DcmMetaExtension
The result of merging the extensions in `seq` along the dimension
`dim`.
'''
if not 0 <= dim < 5:
raise ValueError("The argument 'dim' must be in the range [0, 5).")
n_inputs = len(seq)
first_input = seq[0]
input_shape = first_input.shape
if len(input_shape) > dim and input_shape[dim] != 1:
raise ValueError("The dim must be singular or not exist for the "
"inputs.")
output_shape = list(input_shape)
while len(output_shape) <= dim:
output_shape.append(1)
output_shape[dim] = n_inputs
if affine is None:
affine = first_input.affine
if slice_dim is None:
slice_dim = first_input.slice_dim
result = klass.make_empty(output_shape,
affine,
None,
slice_dim)
#Need to initialize the result with the first extension in 'seq'
result_slc_norm = result.slice_normal
first_slc_norm = first_input.slice_normal
use_slices = (not result_slc_norm is None and
not first_slc_norm is None and
np.allclose(result_slc_norm, first_slc_norm))
for classes in first_input.get_valid_classes():
if classes[1] == 'slices' and not use_slices:
continue
result._content[classes[0]][classes[1]] = \
deepcopy(first_input.get_class_dict(classes))
#Adjust the shape to what the extension actually contains
shape = list(result.shape)
shape[dim] = 1
result.shape = shape
#Initialize reorient transform
reorient_transform = first_input.reorient_transform
#Add the other extensions, updating the shape as we go
for input_ext in seq[1:]:
#If the affines or reorient_transforms don't match, we set the
#reorient_transform to None as we can not reliably use it to update
#directional meta data
if ((reorient_transform is None or
input_ext.reorient_transform is None) or
not (np.allclose(input_ext.affine, affine) or
np.allclose(input_ext.reorient_transform,
reorient_transform)
)
):
reorient_transform = None
result._insert(dim, input_ext)
shape[dim] += 1
result.shape = shape
#Set the reorient transform
result.reorient_transform = reorient_transform
#Try simplifying any keys in global slices
for key in list(result.get_class_dict(('global', 'slices'))):
result._simplify(key)
return result
def __str__(self):
return self._mangle(self._content)
def __eq__(self, other):
if not np.allclose(self.affine, other.affine):
return False
if self.shape != other.shape:
return False
if self.slice_dim != other.slice_dim:
return False
if self.version != other.version:
return False
for classes in self.get_valid_classes():
if (dict(self.get_class_dict(classes)) !=
dict(other.get_class_dict(classes))):
return False
return True
def _unmangle(self, value):
'''Go from extension data to runtime representation.'''
if not isinstance(value, unicode_str):
value = value.decode('utf-8')
#Its not possible to preserve order while loading with python 2.6
kwargs = {}
if sys.version_info >= (2, 7):
kwargs['object_pairs_hook'] = OrderedDict
return json.loads(value, **kwargs)
def _mangle(self, value):
'''Go from runtime representation to extension data.'''
res = json.dumps(value, indent=4)
# Python 2 leaves some trailing white-space in the JSON output while
# python 3 does not. We strip it so output is binary identical across
# versions
if PY2:
res = re.sub('[ \t]+$', '', res, 0, re.M)
return res.encode('utf-8')
_const_tests = {('global', 'slices') : (('global', 'const'),
('vector', 'samples'),
('time', 'samples')
),
('vector', 'slices') : (('global', 'const'),
('time', 'samples')
),
('time', 'slices') : (('global', 'const'),
),
('time', 'samples') : (('global', 'const'),
('vector', 'samples'),
),
('vector', 'samples') : (('global', 'const'),)
}
'''Classification mapping showing possible reductions in multiplicity for
values that are constant with some period.'''
def _get_const_period(self, src_cls, dest_cls):
'''Get the period over which we test for const-ness with for the
given classification change.'''
if dest_cls == ('global', 'const'):
return None
elif src_cls == ('global', 'slices'):
return int(self.get_multiplicity(src_cls) // self.get_multiplicity(dest_cls))
elif src_cls == ('vector', 'slices'): #implies dest_cls == ('time', 'samples'):
return self.n_slices
elif src_cls == ('time', 'samples'): #implies dest_cls == ('vector', 'samples')
return self.shape[3]
assert False #Should take one of the above branches
_repeat_tests = {('global', 'slices') : (('time', 'slices'),
('vector', 'slices')
),
('vector', 'slices') : (('time', 'slices'),),
}
'''Classification mapping showing possible reductions in multiplicity for
values that are repeating with some period.'''
def _simplify(self, key):
'''Try to simplify (reduce the multiplicity) of a single meta data
element by changing its classification. Return True if the
classification is changed, otherwise False.
Looks for values that are constant or repeating with some pattern.
Constant elements with a value of None will be deleted.
'''
values, curr_class = self.get_values_and_class(key)
#If the class is global const then just delete it if the value is None
if curr_class == ('global', 'const'):
if values is None:
del self.get_class_dict(curr_class)[key]
return True
return False
#Test if the values are constant with some period
dests = self._const_tests[curr_class]
for dest_cls in dests:
if dest_cls[0] in self._content:
period = self._get_const_period(curr_class, dest_cls)
#If the period is one, the two classifications have the
#same multiplicity so we are dealing with a degenerate
#case (i.e. single slice data). Just change the
#classification to the "simpler" one in this case
if period == 1 or is_constant(values, period):
if period is None:
self.get_class_dict(dest_cls)[key] = \
values[0]
else:
self.get_class_dict(dest_cls)[key] = \
values[::period]
break
else: #Otherwise test if values are repeating with some period
if curr_class in self._repeat_tests:
for dest_cls in self._repeat_tests[curr_class]:
if dest_cls[0] in self._content:
dest_mult = self.get_multiplicity(dest_cls)
if is_repeating(values, dest_mult):
self.get_class_dict(dest_cls)[key] = \
values[:dest_mult]
break
else: #Can't simplify
return False
else:
return False
del self.get_class_dict(curr_class)[key]
return True
_preserving_changes = {None : (('global', 'const'),
('vector', 'samples'),
('time', 'samples'),
('time', 'slices'),
('vector', 'slices'),
('global', 'slices'),
),
('global', 'const') : (('vector', 'samples'),
('time', 'samples'),
('time', 'slices'),
('vector', 'slices'),
('global', 'slices'),
),
('vector', 'samples') : (('time', 'samples'),
('global', 'slices'),
),
('time', 'samples') : (('global', 'slices'),
),
('time', 'slices') : (('vector', 'slices'),
('global', 'slices'),
),
('vector', 'slices') : (('global', 'slices'),
),
('global', 'slices') : tuple(),
}
'''Classification mapping showing allowed changes when increasing the
multiplicity.'''
def _get_changed_class(self, key, new_class, slice_dim=None):
'''Get an array of values corresponding to a single meta data
element with its classification changed by increasing its
multiplicity. This will preserve all the meta data and allow easier
merging of values with different classifications.'''
values, curr_class = self.get_values_and_class(key)
if curr_class == new_class:
return values
if not new_class in self._preserving_changes[curr_class]:
raise ValueError("Classification change would lose data.")
if curr_class is None:
curr_mult = 1
per_slice = False
else:
curr_mult = self.get_multiplicity(curr_class)
per_slice = curr_class[1] == 'slices'
if new_class in self.get_valid_classes():
new_mult = self.get_multiplicity(new_class)
#Only way we get 0 for mult is if slice dim is undefined
if new_mult == 0:
new_mult = self.shape[slice_dim]
else:
new_mult = 1
mult_fact = int(new_mult // curr_mult)
if curr_mult == 1:
values = [values]
if per_slice:
result = values * mult_fact
else:
result = []
for value in values:
result.extend([deepcopy(value)] * mult_fact)
if new_class == ('global', 'const'):
result = result[0]
return result
def _change_class(self, key, new_class):
'''Change the classification of the meta data element in place. See
_get_changed_class.'''
values, curr_class = self.get_values_and_class(key)
if curr_class == new_class:
return
self.get_class_dict(new_class)[key] = self._get_changed_class(key,
new_class)
if not curr_class is None:
del self.get_class_dict(curr_class)[key]
def _copy_slice(self, other, src_class, idx):
'''Get a copy of the meta data from the 'other' instance with
classification 'src_class', corresponding to the slice with index
'idx'.'''
if src_class[0] == 'global':
for classes in (('time', 'samples'),
('vector', 'samples'),
('global', 'const')):
if classes in self.get_valid_classes():
dest_class = classes
break
elif src_class[0] == 'vector':
for classes in (('time', 'samples'),
('global', 'const')):
if classes in self.get_valid_classes():
dest_class = classes
break
else:
dest_class = ('global', 'const')
src_dict = other.get_class_dict(src_class)
dest_dict = self.get_class_dict(dest_class)
dest_mult = self.get_multiplicity(dest_class)
stride = other.n_slices
for key, vals in iteritems(src_dict):
subset_vals = vals[idx::stride]
if len(subset_vals) < dest_mult:
full_vals = []
for val_idx in range(dest_mult // len(subset_vals)):
full_vals += deepcopy(subset_vals)
subset_vals = full_vals
if len(subset_vals) == 1:
subset_vals = subset_vals[0]
dest_dict[key] = deepcopy(subset_vals)
self._simplify(key)
def _global_slice_subset(self, key, sample_base, idx):
'''Get a subset of the meta data values with the classificaion
('global', 'slices') corresponding to a single sample along the
time or vector dimension (as specified by 'sample_base' and 'idx').
'''
n_slices = self.n_slices
shape = self.shape
src_dict = self.get_class_dict(('global', 'slices'))
if sample_base == 'vector':
slices_per_vec = n_slices * shape[3]
start_idx = idx * slices_per_vec
end_idx = start_idx + slices_per_vec
return src_dict[key][start_idx:end_idx]
else:
if not ('vector', 'samples') in self.get_valid_classes():
start_idx = idx * n_slices
end_idx = start_idx + n_slices
return src_dict[key][start_idx:end_idx]
else:
result = []
slices_per_vec = n_slices * shape[3]
for vec_idx in range(shape[4]):
start_idx = (vec_idx * slices_per_vec) + (idx * n_slices)
end_idx = start_idx + n_slices
result.extend(src_dict[key][start_idx:end_idx])
return result
def _copy_sample(self, other, src_class, sample_base, idx):
'''Get a copy of meta data from 'other' instance with classification
'src_class', corresponding to one sample along the time or vector
dimension.'''
assert src_class != ('global', 'const')
src_dict = other.get_class_dict(src_class)
if src_class[1] == 'samples':
#If we are indexing on the same dim as the src_class we need to
#change the classification
if src_class[0] == sample_base:
#Time samples may become vector samples, otherwise const
best_dest = None
for dest_cls in (('vector', 'samples'),
('global', 'const')):
if (dest_cls != src_class and
dest_cls in self.get_valid_classes()
):
best_dest = dest_cls
break
dest_mult = self.get_multiplicity(dest_cls)
if dest_mult == 1:
for key, vals in iteritems(src_dict):
self.get_class_dict(dest_cls)[key] = \
deepcopy(vals[idx])
else: #We must be doing time samples -> vector samples
stride = other.shape[3]
for key, vals in iteritems(src_dict):
self.get_class_dict(dest_cls)[key] = \
deepcopy(vals[idx::stride])
for key in src_dict.keys():
self._simplify(key)
else: #Otherwise classification does not change
#The multiplicity will change for time samples if splitting
#vector dimension
if src_class == ('time', 'samples'):
dest_mult = self.get_multiplicity(src_class)
start_idx = idx * dest_mult
end_idx = start_idx + dest_mult
for key, vals in iteritems(src_dict):
self.get_class_dict(src_class)[key] = \
deepcopy(vals[start_idx:end_idx])
self._simplify(key)
else: #Otherwise multiplicity is unchanged
for key, vals in iteritems(src_dict):
self.get_class_dict(src_class)[key] = deepcopy(vals)
else: #The src_class is per slice