-
Notifications
You must be signed in to change notification settings - Fork 16
/
augmentation.py
2205 lines (1919 loc) · 97.6 KB
/
augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
The :mod:`fatf.utils.data.augmentation` module implements data set augmenters.
"""
# Author: Alex Hepburn <ah13558@bristol.ac.uk>
# Kacper Sokol <k.sokol@bristol.ac.uk>
# Rafael Poyiadzi <rp13102@bristol.ac.uk>
# License: new BSD
# pylint: disable=too-many-lines
from numbers import Number
from typing import Callable, List, Optional, Tuple, Union
from typing import Set # pylint: disable=unused-import
import abc
import logging
import warnings
import scipy.stats
import scipy.spatial
import numpy as np
from fatf.exceptions import IncompatibleModelError, IncorrectShapeError
import fatf.utils.array.tools as fuat
import fatf.utils.array.validation as fuav
import fatf.utils.distances as fud
import fatf.utils.validation as fuv
__all__ = ['Augmentation',
'NormalSampling',
'TruncatedNormalSampling',
'Mixup',
'NormalClassDiscovery',
'DecisionBoundarySphere',
'LocalSphere'] # yapf: disable
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Index = Union[int, str]
def _validate_input(dataset: np.ndarray,
ground_truth: Optional[np.ndarray] = None,
categorical_indices: Optional[List[Index]] = None,
int_to_float: bool = True) -> bool:
"""
Validates the input parameters of an arbitrary augmentation class.
For the description of the input parameters and exceptions raised by this
function, please see the documentation of the
:class:`fatf.utils.data.augmentation.Augmentation` class.
Returns
-------
is_valid : boolean
``True`` if input is valid, ``False`` otherwise.
"""
is_valid = False
if not fuav.is_2d_array(dataset):
raise IncorrectShapeError('The input dataset must be a '
'2-dimensional numpy array.')
if not fuav.is_base_array(dataset):
raise TypeError('The input dataset must be of a base type.')
if ground_truth is not None:
if not fuav.is_1d_array(ground_truth):
raise IncorrectShapeError('The ground_truth array must be '
'1-dimensional. (Or None if it is not '
'required.)')
if not fuav.is_base_array(ground_truth):
raise TypeError('The ground_truth array must be of a base type.')
if ground_truth.shape[0] != dataset.shape[0]:
raise IncorrectShapeError('The number of labels in the '
'ground_truth array is not equal to the '
'number of data points in the dataset '
'array.')
if categorical_indices is not None:
if isinstance(categorical_indices, list):
invalid_indices = fuat.get_invalid_indices(
dataset, np.asarray(categorical_indices))
if invalid_indices.size:
raise IndexError('The following indices are invalid for the '
'input dataset: {}.'.format(invalid_indices))
else:
raise TypeError('The categorical_indices parameter must be a '
'Python list or None.')
if not isinstance(int_to_float, bool):
raise TypeError('The int_to_float parameter has to be a boolean.')
is_valid = True
return is_valid
class Augmentation(abc.ABC):
"""
An abstract class for implementing data augmentation methods.
An abstract class that all augmentation classes should inherit from. It
contains abstract ``__init__`` and ``sample`` methods and an input
validator -- ``_validate_sample_input`` -- for the ``sample`` method. The
validation of the input parameters to the initialisation method is done via
the ``fatf.utils.data.augmentation._validate_input`` function.
.. note::
The ``_validate_sample_input`` method should be called in all
implementations of the ``sample`` method in the children classes to
ensure that all the input parameters of this method are valid.
Parameters
----------
dataset : numpy.ndarray
A 2-dimensional numpy array with a dataset to be used for sampling.
ground_truth : numpy.ndarray, optional (default=None)
A 1-dimensional numpy array with labels for the supplied dataset.
categorical_indices : List[column indices], optional (default=None)
A list of column indices that should be treat as categorical features.
If ``None`` is given this will be inferred from the data array:
string-based columns will be treated as categorical features and
numerical columns will be treated as numerical features.
int_to_float : boolean
If ``True``, all of the integer dtype columns in the ``dataset`` will
be generalised to ``numpy.float64`` type. Otherwise, integer type
columns will remain integer and floating point type columns will remain
floating point.
Warns
-----
UserWarning
If some of the string-based columns in the input data array were not
indicated to be categorical features by the user (via the
``categorical_indices`` parameter) the user is warned that they will be
added to the list of categorical features.
Raises
------
IncorrectShapeError
The input ``dataset`` is not a 2-dimensional numpy array. The
``ground_truth`` array is not a 1-dimensional numpy array. The number
of ground truth annotation is different than the number of rows in the
data array.
IndexError
Some of the column indices given in the ``categorical_indices``
parameter are not valid for the input ``dataset``.
TypeError
The ``categorical_indices`` parameter is neither a list nor ``None``.
The ``dataset`` or the ``ground_truth`` array (if not ``None``) are not
of base (numerical and/or string) type. The ``int_to_float`` parameter
is not a boolean.
Attributes
----------
dataset : numpy.ndarray
A 2-dimensional numpy array with a dataset to be used for sampling.
data_points_number : integer
The number of data points in the ``dataset``.
is_structured : boolean
``True`` if the ``dataset`` is a structured numpy array, ``False``
otherwise.
ground_truth : Union[numpy.ndarray, None]
A 1-dimensional numpy array with labels for the supplied dataset.
categorical_indices : List[column indices]
A list of column indices that should be treat as categorical features.
numerical_indices : List[column indices]
A list of column indices that should be treat as numerical features.
features_number : integer
The number of features (columns) in the input ``dataset``.
sample_dtype : Union[numpy.dtype, List[Tuple[string, numpy.dtype]]
A dtype with numerical dtypes (in case of a structured data array)
generalised to support the assignment of sampled values. For example,
if the dtype of a numerical feature is ``int`` and the sampling
generates ``float`` this dtype will generalise the type of that column
to ``float``.
"""
# pylint: disable=too-few-public-methods,too-many-instance-attributes
def __init__(self,
dataset: np.ndarray,
ground_truth: Optional[np.ndarray] = None,
categorical_indices: Optional[np.ndarray] = None,
int_to_float: bool = True) -> None:
"""
Constructs an ``Augmentation`` abstract class.
"""
# pylint: disable=too-many-locals
assert _validate_input(
dataset,
ground_truth=ground_truth,
categorical_indices=categorical_indices,
int_to_float=int_to_float), 'Invalid input.'
self.dataset = dataset
self.data_points_number = dataset.shape[0]
self.is_structured = fuav.is_structured_array(dataset)
self.ground_truth = ground_truth
# Sort out column indices
indices = fuat.indices_by_type(dataset)
num_indices = set(indices[0])
cat_indices = set(indices[1])
all_indices = num_indices.union(cat_indices)
if categorical_indices is None:
categorical_indices = cat_indices
numerical_indices = num_indices
else:
if cat_indices.difference(categorical_indices):
msg = ('Some of the string-based columns in the input dataset '
'were not selected as categorical features via the '
'categorical_indices parameter. String-based columns '
'cannot be treated as numerical features, therefore '
'they will be also treated as categorical features '
'(in addition to the ones selected with the '
'categorical_indices parameter).')
warnings.warn(msg, UserWarning)
categorical_indices = cat_indices.union(categorical_indices)
numerical_indices = all_indices.difference(categorical_indices)
self.categorical_indices = sorted(list(categorical_indices))
self.numerical_indices = sorted(list(numerical_indices))
self.features_number = len(all_indices)
# Sort out the dtype of the sampled array.
ntype = np.dtype(np.float64) if int_to_float else np.dtype(np.int64)
if self.is_structured:
sample_dtype = []
for column_name in self.dataset.dtype.names:
if column_name in self.numerical_indices:
new_dtype = fuat.generalise_dtype(
self.dataset.dtype[column_name], ntype)
sample_dtype.append((column_name, new_dtype))
elif column_name in self.categorical_indices:
sample_dtype.append((column_name,
self.dataset.dtype[column_name]))
else:
assert False, 'Unknown column name.' # pragma: nocover
else:
if fuav.is_numerical_array(self.dataset):
sample_dtype = fuat.generalise_dtype(self.dataset.dtype, ntype)
else:
sample_dtype = self.dataset.dtype
self.sample_dtype = sample_dtype
@abc.abstractmethod
def sample(self,
data_row: Optional[Union[np.ndarray, np.void]] = None,
samples_number: int = 50) -> np.ndarray:
"""
Samples a given number of data points based on the initialisation data.
This is an abstract method that must be implemented for each child
object. This method should provide two modes of operation:
- if ``data_row`` is ``None``, the sample should be from the
distribution of the whole dataset that was used to initialise this
class; and
- if ``data_row`` is a numpy array with a data point, the sample should
be from the vicinity of this data point.
Parameters
----------
data_row : Union[numpy.ndarray, numpy.void], optional (default=None)
A data point. If given, the sample will be generated around that
point.
samples_number : integer, optional (default=50)
The number of samples to be generated.
Raises
------
NotImplementedError
This is an abstract method and has not been implemented.
Returns
-------
samples : numpy.ndarray
Sampled data.
"""
assert self._validate_sample_input( # pragma: nocover
data_row, samples_number), 'Invalid sample method input.'
raise NotImplementedError( # pragma: nocover
'sample method needs to be overwritten.')
def _validate_sample_input(self,
data_row: Union[None, np.ndarray, np.void],
samples_number: int) -> bool:
"""
Validates input parameters of the ``sample`` method.
This function checks the validity of ``data_row`` and
``samples_number`` parameters.
Raises
------
IncorrectShapeError
The ``data_row`` is not a 1-dimensional numpy array-like object.
The number of features (columns) in the ``data_row`` is different
to the number of features in the data array used to initialise this
object.
TypeError
The dtype of the ``data_row`` is different than the dtype of the
data array used to initialise this object. The ``samples_number``
parameter is not an integer.
ValueError
The ``samples_number`` parameter is not a positive integer.
Returns
-------
is_valid : boolean
``True`` if input parameters are valid, ``False`` otherwise.
"""
is_valid = False
if data_row is not None:
if not fuav.is_1d_like(data_row):
raise IncorrectShapeError('The data_row must either be a '
'1-dimensional numpy array or numpy '
'void object for structured rows.')
are_similar = fuav.are_similar_dtype_arrays(
self.dataset, np.array([data_row]), strict_comparison=True)
if not are_similar:
raise TypeError('The dtype of the data_row is different to '
'the dtype of the data array used to '
'initialise this class.')
# If the dataset is structured and the data_row has a different
# number of features this will be caught by the above dtype check.
# For classic numpy arrays this has to be done separately.
if not self.is_structured:
if data_row.shape[0] != self.dataset.shape[1]:
raise IncorrectShapeError('The data_row must contain the '
'same number of features as the '
'dataset used to initialise '
'this class.')
if isinstance(samples_number, int):
if samples_number < 1:
raise ValueError('The samples_number parameter must be a '
'positive integer.')
else:
raise TypeError('The samples_number parameter must be an integer.')
is_valid = True
return is_valid
class NormalSampling(Augmentation):
"""
Sampling data from a normal distribution.
This class allows to sample data according to a normal distribution. The
sampling can be performed either around a particular data point (by
supplying the ``data_row`` parameter to the ``sample`` method) or around
the mean of the whole ``dataset`` (if ``data_row`` is not given when
calling the ``sample`` method). In both cases, the standard deviation
of each numerical feature calculated for the whole dataset is used. For
categorical features, the values are sampled with replacement with the
probability for each unique value calculated based on the frequency of its
appearance in the dataset.
For additional parameters, attributes, warnings and exceptions raised by
this class please see the documentation of its parent class:
:class:`fatf.utils.data.augmentation.Augmentation`.
Attributes
----------
numerical_sampling_values : Dictionary[column index, Tuple[number, number]]
Dictionary mapping numerical column feature indices to tuples of two
numbers: column's *mean* and its *standard deviation*.
categorical_sampling_values : Dictionary[column index, \
Tuple[numpy.ndarray, numpy.ndarray]]
Dictionary mapping categorical column feature indices to tuples
consisting of two 1-dimensional numpy arrays: one with unique values
for that column and the other one with their normalised (summing up to
1) frequencies.
"""
# pylint: disable=too-few-public-methods
def __init__(self,
dataset: np.ndarray,
categorical_indices: Optional[List[Index]] = None,
int_to_float: bool = True) -> None:
"""
Constructs a ``NormalSampling`` data augmentation class.
"""
# pylint: disable=too-many-locals,too-many-branches
super().__init__(
dataset,
categorical_indices=categorical_indices,
int_to_float=int_to_float)
# Get sampling parameters for numerical features.
numerical_sampling_values = dict()
if self.numerical_indices:
if self.is_structured:
num_features_array = fuat.as_unstructured(
self.dataset[self.numerical_indices])
else:
num_features_array = self.dataset[:, self.numerical_indices]
num_features_mean = num_features_array.mean(axis=0)
num_features_std = num_features_array.std(axis=0)
for i, index in enumerate(self.numerical_indices):
numerical_sampling_values[index] = (num_features_mean[i],
num_features_std[i])
self.numerical_sampling_values = numerical_sampling_values
# Get sampling parameters for categorical features.
categorical_sampling_values = dict()
for column_name in self.categorical_indices:
if self.is_structured:
feature_column = self.dataset[column_name]
else:
feature_column = self.dataset[:, column_name]
feature_values, values_counts = np.unique(
feature_column, return_counts=True)
values_frequencies = values_counts / values_counts.sum()
categorical_sampling_values[column_name] = (feature_values,
values_frequencies)
self.categorical_sampling_values = categorical_sampling_values
def sample(self,
data_row: Optional[Union[np.ndarray, np.void]] = None,
samples_number: int = 50) -> np.ndarray:
"""
Samples new data from a normal distribution.
If ``data_row`` parameter is given, the sample will be centered around
that data point. Otherwise, when the ``data_row`` parameter is
``None``, the sample will be generated around the mean of the dataset
used to initialise this class.
Numerical features are sampled around their corresponding values in the
``data_row`` parameter or the mean of that feature in the dataset using
the standard deviation calculated from the dataset. Categorical
features are sampled by choosing with replacement all the possible
values of that feature with the probability of sampling each value
corresponding to this value's frequency in the dataset. (This means
that any particular value of a categorical feature in a ``data_row`` is
ignored.)
For the documentation of parameters, warnings and errors please see the
description of the
:func:`~fatf.utils.data.augmentation.Augmentation.sample` method in the
parent :class:`fatf.utils.data.augmentation.Augmentation` class.
"""
assert self._validate_sample_input(data_row,
samples_number), 'Invalid input.'
# Create an array to hold the samples.
if self.is_structured:
shape = (samples_number, ) # type: Tuple[int, ...]
else:
shape = (samples_number, self.features_number)
samples = np.zeros(shape, dtype=self.sample_dtype)
# Sample categorical features.
for index in self.categorical_indices:
sample_values = np.random.choice(
self.categorical_sampling_values[index][0],
size=samples_number,
replace=True,
p=self.categorical_sampling_values[index][1])
if self.is_structured:
samples[index] = sample_values
else:
samples[:, index] = sample_values
# Sample numerical features.
for index in self.numerical_indices:
# Fetch mean ans standard deviation
sampling_parameters = self.numerical_sampling_values[index]
std = sampling_parameters[1]
# If a data row is given sample around that value, otherwise
# sample around data mean.
if data_row is None:
mean = sampling_parameters[0]
else:
mean = data_row[index]
sample_values = np.random.normal(0, 1, samples_number) * std + mean
if self.is_structured:
samples[index] = sample_values
else:
samples[:, index] = sample_values
return samples
class TruncatedNormalSampling(Augmentation):
"""
Sampling data from a truncated normal distribution.
.. versionadded:: 0.0.2
This class allows to sample data according to the
`truncated normal distribution`_. The sampling can be performed either
around a particular data point (by supplying the ``data_row`` parameter to
the ``sample`` method) or around the mean of the whole ``dataset`` (if
``data_row`` is not given when calling the ``sample`` method). In both
cases, the standard deviation of each numerical feature calculated for the
whole ``dataset`` is used. The minimum and maximum of each numerical
feature are also used as the bounds for the truncated normal distribution.
For categorical features, the values are sampled with replacement with the
probability for each unique value calculated based on the frequency of
their appearance in the dataset.
For additional parameters, attributes, warnings and exceptions raised by
this class please see the documentation of its parent class:
:class:`fatf.utils.data.augmentation.Augmentation`.
.. _`truncated normal distribution`: https://en.wikipedia.org/wiki/
Truncated_normal_distribution
Attributes
----------
numerical_sampling_values : Dictionary[column index, \
Tuple[number, number, number, number]]
Dictionary mapping numerical column feature indices to tuples of four
numbers: column's *mean*, *standard deviation*, its *minimum* and
*maximum* value.
categorical_sampling_values : Dictionary[column index, \
Tuple[numpy.ndarray, numpy.ndarray]]
Dictionary mapping categorical column feature indices to tuples
consisting of two 1-dimensional numpy arrays: one with unique values
for that column and the other one with their normalised (summing up to
1) frequencies.
"""
# pylint: disable=too-few-public-methods
def __init__(self,
dataset: np.ndarray,
categorical_indices: Optional[List[Index]] = None,
int_to_float: bool = True) -> None:
"""
Constructs a ``TruncatedNormalSampling`` data augmentation class.
"""
# pylint: disable=too-many-locals
super().__init__(
dataset=dataset,
categorical_indices=categorical_indices,
int_to_float=int_to_float)
# Get sampling parameters for numerical features.
numerical_sampling_values = dict()
if self.numerical_indices:
if self.is_structured:
num_features_array = fuat.as_unstructured(
self.dataset[self.numerical_indices])
else:
num_features_array = self.dataset[:, self.numerical_indices]
num_features_mean = num_features_array.mean(axis=0)
num_features_std = num_features_array.std(axis=0)
num_features_min = num_features_array.min(axis=0)
num_features_max = num_features_array.max(axis=0)
for i, index in enumerate(self.numerical_indices):
numerical_sampling_values[index] = (num_features_mean[i],
num_features_std[i],
num_features_min[i],
num_features_max[i])
self.numerical_sampling_values = numerical_sampling_values
# Get sampling parameters for categorical features.
categorical_sampling_values = dict()
for column_name in self.categorical_indices:
if self.is_structured:
feature_column = self.dataset[column_name]
else:
feature_column = self.dataset[:, column_name]
feature_values, values_counts = np.unique(
feature_column, return_counts=True)
values_frequencies = values_counts / values_counts.sum()
categorical_sampling_values[column_name] = (feature_values,
values_frequencies)
self.categorical_sampling_values = categorical_sampling_values
def sample(self,
data_row: Optional[Union[np.ndarray, np.void]] = None,
samples_number: int = 50) -> np.ndarray:
"""
Samples new data from a truncated normal distribution.
If ``data_row`` parameter is given, the sample will be centered around
that data point. Otherwise, when the ``data_row`` parameter is
``None``, the sample will be generated around the mean of the dataset
used to initialise this class.
Numerical features are sampled around their corresponding values in the
``data_row`` parameter or the mean of that feature in the dataset using
the standard deviation, minimum and maximum values calculated from the
dataset. Categorical features are sampled by choosing with replacement
all the possible values of that feature with the probability of
sampling each value corresponding to this value's frequency in the
dataset. (This means that any particular value of a categorical feature
in a ``data_row`` is ignored.)
For the documentation of parameters, warnings and errors please see the
description of the
:func:`fatf.utils.data.augmentation.Augmentation.sample` method in the
parent :class:`fatf.utils.data.augmentation.Augmentation` class.
"""
assert self._validate_sample_input(data_row,
samples_number), 'Invalid input.'
# Create an array to hold the samples.
if self.is_structured:
shape = (samples_number, ) # type: Tuple[int, ...]
else:
shape = (samples_number, self.features_number)
samples = np.zeros(shape, dtype=self.sample_dtype)
# Sample categorical features.
for index in self.categorical_indices:
sample_values = np.random.choice(
self.categorical_sampling_values[index][0],
size=samples_number,
replace=True,
p=self.categorical_sampling_values[index][1])
if self.is_structured:
samples[index] = sample_values
else:
samples[:, index] = sample_values
# Sample numerical features.
for index in self.numerical_indices:
sampling_parameters = self.numerical_sampling_values[index]
mean, std, minimum, maximum = sampling_parameters
if data_row is not None:
mean = data_row[index]
sample_values = scipy.stats.truncnorm.rvs(
(minimum - mean) / std, (maximum - mean) / std,
loc=mean,
scale=std,
size=samples_number)
if self.is_structured:
samples[index] = sample_values
else:
samples[:, index] = sample_values
return samples
def _validate_input_mixup(
beta_parameters: Union[None, Tuple[float, float]]) -> bool:
"""
Validates :class:``.Mixup`` class-specific input parameters.
Parameters
----------
beta_parameters : Union[Tuple[number, number], None]
Either ``None`` (for the default values) or a pair of numbers that will
be used as beta distribution parameters.
Raises
------
TypeError
The ``beta_parameters`` parameter is neither ``None`` nor a tuple. One
of the values in the ``beta_parameters`` tuple is not a number.
ValueError
The ``beta_parameters`` tuple is not a pair (2-tuple). One of the
numbers in the ``beta_parameters`` tuple is not positive.
Returns
-------
is_valid : boolean
``True`` if input is valid, ``False`` otherwise.
"""
is_valid = False
# Check beta parameters
if beta_parameters is None:
pass
elif isinstance(beta_parameters, tuple):
if len(beta_parameters) != 2:
raise ValueError('The beta_parameters parameter has to be a '
'2-tuple (a pair) of numbers.')
for index, name in enumerate(['first', 'second']):
if isinstance(beta_parameters[index], Number):
if beta_parameters[index] <= 0:
raise ValueError('The {} beta parameter cannot be a '
'negative number.'.format(name))
else:
raise TypeError('The {} beta parameter has to be a '
'numerical type.'.format(name))
else:
raise TypeError('The beta_parameters parameter has to be a tuple '
'with two numbers or None to use the default '
'parameters value.')
is_valid = True
return is_valid
class Mixup(Augmentation):
"""
Sampling data with the Mixup method.
This object implements the Mixup method introduced by [ZHANG2018MIXUP]_.
For a specific data point it select points at random from the ``dataset``
(making sure that the sample is stratified when the ``ground_truth``
parameter is given), then it draws samples from a Beta distribution and it
forms new data points (samples) according to the convex combination of the
original data pint and the randomly sampled dataset points.
.. note::
Sampling from the ``dataset`` mean is not yet implemented.
For additional parameters, attributes, warnings and exceptions raised by
this class please see the documentation of its parent class:
:class:`fatf.utils.data.augmentation.Augmentation` and the function that
validates the input parameters
``fatf.utils.data.augmentation._validate_input_mixup``.
.. [ZHANG2018MIXUP] Zhang, H., Cisse, M., Dauphin, Y. N. and Lopez-Paz, D.,
2018. mixup: Beyond Empirical Risk Minimization. International
Conference on Learning Representations (ICLR 2018).
Parameters
----------
beta_parameters : Tuple[number, number]], optional (default=None)
A pair of numerical parameters used with the Beta distribution. If
``None``, the beta parameters will be set to ``(2, 5)``.
Raises
------
TypeError
The ``beta_parameters`` parameter is neither ``None`` nor a tuple. One
of the values in the ``beta_parameters`` tuple is not a number.
ValueError
The ``beta_parameters`` tuple is not a pair (2-tuple). One of the
numbers in the ``beta_parameters`` tuple is not positive.
Attributes
----------
threshold : number
A threshold used for mixing the random sample from the ``dataset`` with
the instance used to generate a sample. The threshold value is 0.5.
beta_parameters : Tuple[number, number]
A pair of numbers used with the Beta distribution sampling.
ground_truth_unique : np.ndarray
A sorted array holding all the unique values of the ground truth.
ground_truth_frequencies : np.ndarray
An array holding frequencies of all the unique values in the ground
truth array. The order of the frequencies correspond with the order of
the unique values. The frequencies are normalised and they sum up to 1.
indices_per_label : List[np.ndarray]
A list of arrays holding (``dataset``) row indices corresponding to
each of the unique ground truth values. The order of this list
corresponds with the order of the unique values.
ground_truth_probabilities : np.ndarray
A numpy array of [number of dataset instances, number of unique ground
truth values] shape that holds one-hot encoding (pseudo-probabilities)
of the ground truth labels. The column ordering of this array
corresponds with the order of the unique values.
"""
# pylint: disable=too-few-public-methods
def __init__(self,
dataset: np.ndarray,
ground_truth: Optional[np.ndarray] = None,
categorical_indices: Optional[np.ndarray] = None,
beta_parameters: Optional[Tuple[float, float]] = None,
int_to_float: bool = True) -> None:
"""
Constructs a ``Mixup`` data augmentation class.
"""
# pylint: disable=too-many-arguments
super().__init__(
dataset,
ground_truth=ground_truth,
categorical_indices=categorical_indices,
int_to_float=int_to_float)
assert _validate_input_mixup(beta_parameters), 'Invalid Mixup input.'
self.threshold = 0.50
# Get the distribution of the ground truth and collect row indices per
# label
if ground_truth is None:
ground_truth_unique = None
ground_truth_frequencies = None
indices_per_label = None
ground_truth_probabilities = None
else:
ground_truth_unique, counts = np.unique(
self.ground_truth, return_counts=True)
ground_truth_frequencies = counts / counts.sum()
indices_per_label = [
np.where(self.ground_truth == label)[0]
for label in ground_truth_unique
]
# Get pseudo-probabilities per instance, i.e. 1 indicates the label
ground_truth_probabilities = np.zeros(
(self.data_points_number, ground_truth_unique.shape[0]),
dtype=np.int8) # np.int8 suffices since these are 0s and 1s
for i, indices in enumerate(indices_per_label):
ground_truth_probabilities[indices, i] = 1
self.ground_truth_unique = ground_truth_unique
self.ground_truth_frequencies = ground_truth_frequencies
self.indices_per_label = indices_per_label
self.ground_truth_probabilities = ground_truth_probabilities
# Check beta parameters
if beta_parameters is None:
beta_parameters = (2, 5)
self.beta_parameters = beta_parameters
def _validate_sample_input_mixup(
self, data_row_target: Union[float, str, None],
with_replacement: bool, return_probabilities: bool) -> bool:
"""
Validates ``sample`` method input parameters for the ``Mixup`` class.
This function checks the validity of ``data_row_target``,
``with_replacement`` and ``return_probabilities`` parameters.
Parameters
----------
data_row_target : Union[number, string, None]
Either ``None`` or a label (class) of the data row to sample new
data around.
with_replacement : boolean
A boolean parameter that indicates whether the ``dataset`` row
indices should be sampled with replacements (``True``) or not
(``False``).
return_probabilities : boolean
A boolean parameter that indicates whether the sampled target array
should a class probability matrix (``True``) or a 1-dimensional
array with the labels (``False``).
Warns
-----
UserWarning
The user is warned when the ``data_row_target`` parameter is given
but the ``Mixup`` class was initialised without the ground truth
for the ``dataset``, therefore sampling target values is not
possible and the ``data_row_target`` parameter will be ignored.
Raises
------
TypeError
The ``return_probabilities`` or ``with_replacement`` parameters are
not booleans. The ``data_row_target`` parameter is neither a number
not a string.
ValueError
The ``data_row_target`` parameter has a value that does not appear
in the ground truth vector used to initialise this class.
Returns
-------
is_valid : boolean
``True`` if input parameters are valid, ``False`` otherwise.
"""
is_valid = False
if data_row_target is None:
pass
elif isinstance(data_row_target, (Number, str)):
if self.ground_truth_unique is None:
msg = ('This Mixup class has not been initialised with a '
'ground truth vector. The value of the data_row_target '
'parameter will be ignored, therefore target values '
'samples will not be returned.')
warnings.warn(msg, UserWarning)
else:
if data_row_target not in self.ground_truth_unique:
raise ValueError('The value of the data_row_target '
'parameter is not present in the ground '
'truth labels used to initialise this '
'class. The data row target value is not '
'recognised.')
else:
raise TypeError('The data_row_target parameter should either be '
'None or a string/number indicating the target '
'class.')
if not isinstance(with_replacement, bool):
raise TypeError('with_replacement parameter has to be boolean.')
if not isinstance(return_probabilities, bool):
raise TypeError('return_probabilities parameter has to be '
'boolean.')
is_valid = True
return is_valid
def _get_stratified_indices(self, samples_number: int,
with_replacement: bool) -> np.ndarray:
"""
Selects random row indices from the ``dataset``.
Selects ``samples_number`` number of row indices at random either with
replacements or not (depending on the value of the ``with_replacement``
parameter). The indices selection is stratified according to the ground
truth distribution if ground truth vector was given when this class
was initialised. Otherwise, the indices are generated at random.
Parameters
----------
samples_number : integer
The number of data points to be sampled.
with_replacement : boolean
A boolean parameter that indicates whether the ``dataset`` row
indices should be sampled with replacements (``True``) or not
(``False``).
Warns
-----
UserWarning
The user is warned that the random row indices will not be
stratified according to the ground truth distribution if ground
truth vector was not given when this class was initialised.
Returns
-------
random_indices : numpy.ndarray
A 1-dimensional numpy array of shape [samples_number, ] that holds
randomly selected row indices from the ``dataset``.
"""
assert isinstance(samples_number, int), 'Has to be an integer.'
assert samples_number > 0, 'Has to be positive.'
#
assert isinstance(with_replacement, bool), 'Has to be boolean.'
if self.ground_truth_frequencies is None:
msg = ('Since the ground truth vector was not provided while '
'initialising the Mixup class it is not possible to get a '
'stratified sample of data points. Instead, Mixup will '
'choose data points at random, which is equivalent to '
'assuming that the class distribution is balanced.')
warnings.warn(msg, UserWarning)
random_indices = np.random.choice(
self.data_points_number,
samples_number,
replace=with_replacement)
else:
# Get sample quantities per class -- stratified
samples_per_label = [
int(freq * samples_number)
for freq in self.ground_truth_frequencies
]
# Due to integer casting there may be a sub- or under-sampling
# happening. This gets corrected for below.
samples_per_label_len = len(samples_per_label)
diff = samples_number - sum(samples_per_label)
diff_val = 1 if diff >= 0 else -1
for _ in range(diff):
random_index = np.random.randint(0, samples_per_label_len)
samples_per_label[random_index] += diff_val
assert samples_number == sum(samples_per_label), 'Wrong quantity.'
# Get a sample representative of the original label distribution
random_indices = []
for i, label_sample_quantity in enumerate(samples_per_label):
random_indices_label = np.random.choice(
self.indices_per_label[i], # type: ignore
label_sample_quantity,
replace=with_replacement)
random_indices.append(random_indices_label)
random_indices = np.concatenate(random_indices)
return random_indices
def _get_sample_targets(self, data_row_target: Union[float, str],
return_probabilities: bool,
random_draws_lambda: np.ndarray,
random_draws_lambda_1: np.ndarray,
random_indices: np.ndarray) -> np.ndarray:
"""
Samples target values for the sampled data instance.
The target values can either be represented as a class probability
matrix (``return_probabilities`` set to ``True``) or an array with a
single label per instance selected based on the highest probability
(``return_probabilities`` set to ``False``).
Parameters
----------
data_row_target : Union[number, string]
A label (class) of the data row to sample new data around.