-
Notifications
You must be signed in to change notification settings - Fork 377
/
learner_config.py
1409 lines (1233 loc) · 56.4 KB
/
learner_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from os.path import join, isdir
from enum import Enum
import random
import uuid
import logging
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Sequence, Tuple, Union)
from typing_extensions import Literal
from pydantic import (PositiveFloat, PositiveInt as PosInt, constr, confloat,
conint)
from pydantic.utils import sequence_like
import albumentations as A
import torch
from torch import (nn, optim)
from torch.optim.lr_scheduler import CyclicLR, MultiStepLR, _LRScheduler
from torch.utils.data import Dataset, ConcatDataset, Subset
from rastervision.pipeline.config import (Config, register_config, ConfigError,
Field, validator, root_validator)
from rastervision.pipeline.file_system import (list_paths, download_if_needed,
unzip, file_exists,
get_local_path, sync_from_dir)
from rastervision.core.data import (ClassConfig, Scene, DatasetConfig as
SceneDatasetConfig)
from rastervision.pytorch_learner.utils import (
color_to_triple, validate_albumentation_transform, MinMaxNormalize,
deserialize_albumentation_transform, get_hubconf_dir_from_cfg,
torch_hub_load_local, torch_hub_load_github, torch_hub_load_uri)
if TYPE_CHECKING:
from rastervision.pytorch_learner.learner import Learner
log = logging.getLogger(__name__)
default_augmentors = ['RandomRotate90', 'HorizontalFlip', 'VerticalFlip']
augmentors = [
'Blur', 'RandomRotate90', 'HorizontalFlip', 'VerticalFlip', 'GaussianBlur',
'GaussNoise', 'RGBShift', 'ToGray'
]
# types
Proportion = confloat(ge=0, le=1)
NonEmptyStr = constr(strip_whitespace=True, min_length=1)
NonNegInt = conint(ge=0)
RGBTuple = Tuple[int, int, int]
ChannelInds = Sequence[NonNegInt]
class Backbone(Enum):
alexnet = 'alexnet'
densenet121 = 'densenet121'
densenet169 = 'densenet169'
densenet201 = 'densenet201'
densenet161 = 'densenet161'
googlenet = 'googlenet'
inception_v3 = 'inception_v3'
mnasnet0_5 = 'mnasnet0_5'
mnasnet0_75 = 'mnasnet0_75'
mnasnet1_0 = 'mnasnet1_0'
mnasnet1_3 = 'mnasnet1_3'
mobilenet_v2 = 'mobilenet_v2'
resnet18 = 'resnet18'
resnet34 = 'resnet34'
resnet50 = 'resnet50'
resnet101 = 'resnet101'
resnet152 = 'resnet152'
resnext50_32x4d = 'resnext50_32x4d'
resnext101_32x8d = 'resnext101_32x8d'
wide_resnet50_2 = 'wide_resnet50_2'
wide_resnet101_2 = 'wide_resnet101_2'
shufflenet_v2_x0_5 = 'shufflenet_v2_x0_5'
shufflenet_v2_x1_0 = 'shufflenet_v2_x1_0'
shufflenet_v2_x1_5 = 'shufflenet_v2_x1_5'
shufflenet_v2_x2_0 = 'shufflenet_v2_x2_0'
squeezenet1_0 = 'squeezenet1_0'
squeezenet1_1 = 'squeezenet1_1'
vgg11 = 'vgg11'
vgg11_bn = 'vgg11_bn'
vgg13 = 'vgg13'
vgg13_bn = 'vgg13_bn'
vgg16 = 'vgg16'
vgg16_bn = 'vgg16_bn'
vgg19_bn = 'vgg19_bn'
vgg19 = 'vgg19'
@staticmethod
def int_to_str(x):
mapping = {
1: 'alexnet',
2: 'densenet121',
3: 'densenet169',
4: 'densenet201',
5: 'densenet161',
6: 'googlenet',
7: 'inception_v3',
8: 'mnasnet0_5',
9: 'mnasnet0_75',
10: 'mnasnet1_0',
11: 'mnasnet1_3',
12: 'mobilenet_v2',
13: 'resnet18',
14: 'resnet34',
15: 'resnet50',
16: 'resnet101',
17: 'resnet152',
18: 'resnext50_32x4d',
19: 'resnext101_32x8d',
20: 'wide_resnet50_2',
21: 'wide_resnet101_2',
22: 'shufflenet_v2_x0_5',
23: 'shufflenet_v2_x1_0',
24: 'shufflenet_v2_x1_5',
25: 'shufflenet_v2_x2_0',
26: 'squeezenet1_0',
27: 'squeezenet1_1',
28: 'vgg11',
29: 'vgg11_bn',
30: 'vgg13',
31: 'vgg13_bn',
32: 'vgg16',
33: 'vgg16_bn',
34: 'vgg19_bn',
35: 'vgg19'
}
return mapping[x]
@register_config('external-module')
class ExternalModuleConfig(Config):
"""Config describing an object to be loaded via Torch Hub."""
uri: Optional[NonEmptyStr] = Field(
None,
description=('Local uri of a zip file, or local uri of a directory,'
'or remote uri of zip file.'))
github_repo: Optional[constr(
strip_whitespace=True, regex=r'.+/.+')] = Field(
None, description='<repo-owner>/<repo-name>[:tag]')
name: Optional[NonEmptyStr] = Field(
None,
description=
'Name of the folder in which to extract/copy the definition files.')
entrypoint: NonEmptyStr = Field(
...,
description=('Name of a callable present in hubconf.py. '
'See docs for torch.hub for details.'))
entrypoint_args: list = Field(
[],
description='Args to pass to the entrypoint. Must be serializable.')
entrypoint_kwargs: dict = Field(
{},
description=
'Keyword args to pass to the entrypoint. Must be serializable.')
force_reload: bool = Field(
False, description='Force reload of module definition.')
@root_validator(skip_on_failure=True)
def check_either_uri_or_repo(cls, values: dict) -> dict:
has_uri = values.get('uri') is not None
has_repo = values.get('github_repo') is not None
if has_uri == has_repo:
raise ConfigError(
'Must specify one (and only one) of github_repo and uri.')
return values
def build(self, save_dir: str, hubconf_dir: Optional[str] = None) -> Any:
"""Load an external module via torch.hub.
Note: Loading a PyTorch module is the typical use case, but there are
no type restrictions on the object loaded through torch.hub.
Args:
save_dir (str, optional): The module def will be saved here.
hubconf_dir (str, optional): Path to existing definition.
If provided, the definition will not be fetched from the
external source but instead from this dir. Defaults to None.
Returns:
Any: The module loaded via torch.hub.
"""
if hubconf_dir is not None:
log.info(f'Using existing module definition at: {hubconf_dir}')
module = torch_hub_load_local(
hubconf_dir=hubconf_dir,
entrypoint=self.entrypoint,
*self.entrypoint_args,
**self.entrypoint_kwargs)
return module
hubconf_dir = get_hubconf_dir_from_cfg(self, parent=save_dir)
if self.github_repo is not None:
log.info(f'Fetching module definition from: {self.github_repo}')
module = torch_hub_load_github(
repo=self.github_repo,
hubconf_dir=hubconf_dir,
entrypoint=self.entrypoint,
*self.entrypoint_args,
**self.entrypoint_kwargs)
else:
log.info(f'Fetching module definition from: {self.uri}')
module = torch_hub_load_uri(
uri=self.uri,
hubconf_dir=hubconf_dir,
entrypoint=self.entrypoint,
*self.entrypoint_args,
**self.entrypoint_kwargs)
return module
def model_config_upgrader(cfg_dict, version):
if version == 0:
cfg_dict['backbone'] = Backbone.int_to_str(cfg_dict['backbone'])
return cfg_dict
@register_config('model', upgrader=model_config_upgrader)
class ModelConfig(Config):
"""Config related to models."""
backbone: Backbone = Field(
Backbone.resnet18,
description='The torchvision.models backbone to use.')
pretrained: bool = Field(
True,
description=(
'If True, use ImageNet weights. If False, use random initialization.'
))
init_weights: Optional[str] = Field(
None,
description=('URI of PyTorch model weights used to initialize model. '
'If set, this supercedes the pretrained option.'))
load_strict: bool = Field(
True,
description=(
'If True, the keys in the state dict referenced by init_weights '
'must match exactly. Setting this to False can be useful if you '
'just want to load the backbone of a model.'))
external_def: Optional[ExternalModuleConfig] = Field(
None,
description='If specified, the model will be built from the '
'definition from this external source, using Torch Hub.')
extra_args: dict = Field(
{},
description='Other implementation-specific args that might be useful '
'for constructing the default model. This is ignored if using an '
'external model.')
def get_backbone_str(self):
return self.backbone.name
def build(self,
num_classes: int,
in_channels: int,
save_dir: Optional[str] = None,
hubconf_dir: Optional[str] = None,
**kwargs) -> nn.Module:
"""Build and return a model based on the config.
Args:
num_classes (int): Number of classes.
in_channels (int, optional): Number of channels in the images that
will be fed into the model. Defaults to 3.
save_dir (Optional[str], optional): Used for building external_def
if specified. Defaults to None.
hubconf_dir (Optional[str], optional): Used for building
external_def if specified. Defaults to None.
Returns:
nn.Module: a PyTorch nn.Module.
"""
if self.external_def is not None:
return self.build_external_model(
save_dir=save_dir, hubconf_dir=hubconf_dir)
return self.build_default_model(num_classes, in_channels, **kwargs)
def build_default_model(self, num_classes: int, in_channels: int,
**kwargs) -> nn.Module:
"""Build and return the default model.
Args:
num_classes (int): Number of classes.
in_channels (int, optional): Number of channels in the images that
will be fed into the model. Defaults to 3.
Returns:
nn.Module: a PyTorch nn.Module.
"""
raise NotImplementedError()
def build_external_model(self,
save_dir: str,
hubconf_dir: Optional[str] = None) -> nn.Module:
"""Build and return an external model.
Args:
save_dir (str): The module def will be saved here.
hubconf_dir (Optional[str], optional): Path to existing definition.
Defaults to None.
Returns:
nn.Module: a PyTorch nn.Module.
"""
return self.external_def.build(save_dir, hubconf_dir=hubconf_dir)
def solver_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version < 4:
# 'ignore_last_class' replaced by 'ignore_class_index' in version 4
ignore_last_class = cfg_dict.get('ignore_last_class')
if ignore_last_class is not None:
if ignore_last_class is not False:
cfg_dict['ignore_class_index'] = -1
del cfg_dict['ignore_last_class']
return cfg_dict
@register_config('solver', upgrader=solver_config_upgrader)
class SolverConfig(Config):
"""Config related to solver aka optimizer."""
lr: PositiveFloat = Field(1e-4, description='Learning rate.')
num_epochs: PosInt = Field(
10,
description=
'Number of epochs (ie. sweeps through the whole training set).')
test_num_epochs: PosInt = Field(
2, description='Number of epochs to use in test mode.')
test_batch_sz: PosInt = Field(
4, description='Batch size to use in test mode.')
overfit_num_steps: PosInt = Field(
1, description='Number of optimizer steps to use in overfit mode.')
sync_interval: PosInt = Field(
1, description='The interval in epochs for each sync to the cloud.')
batch_sz: PosInt = Field(32, description='Batch size.')
one_cycle: bool = Field(
True,
description=
('If True, use triangular LR scheduler with a single cycle across all '
'epochs with start and end LR being lr/10 and the peak being lr.'))
multi_stage: List = Field(
[], description=('List of epoch indices at which to divide LR by 10.'))
class_loss_weights: Optional[Sequence[float]] = Field(
None, description=('Class weights for weighted loss.'))
ignore_class_index: Optional[int] = Field(
None,
description='If specified, this index is ignored when computing the '
'loss. See pytorch documentation for nn.CrossEntropyLoss for more '
'details. This can also be negative, in which case it is treated as a '
'negative slice index i.e. -1 = last index, -2 = second-last index, '
'and so on.')
external_loss_def: Optional[ExternalModuleConfig] = Field(
None,
description='If specified, the loss will be built from the definition '
'from this external source, using Torch Hub.')
@root_validator(skip_on_failure=True)
def check_no_loss_opts_if_external(cls, values: dict) -> dict:
has_external_loss_def = values.get('external_loss_def') is not None
has_ignore_class_index = values.get('ignore_class_index') is not None
has_class_loss_weights = values.get('class_loss_weights') is not None
if has_external_loss_def:
if has_ignore_class_index:
raise ConfigError('ignore_class_index is not supported '
'with external_loss_def.')
if has_class_loss_weights:
raise ConfigError('class_loss_weights is not supported '
'with external_loss_def.')
return values
def build_loss(self,
num_classes: int,
save_dir: Optional[str] = None,
hubconf_dir: Optional[str] = None) -> Callable:
"""Build and return a loss function based on the config.
Args:
num_classes (int): Number of classes.
save_dir (Optional[str], optional): Used for building
external_loss_def if specified. Defaults to None.
hubconf_dir (Optional[str], optional): Used for building
external_loss_def if specified. Defaults to None.
Returns:
Callable: Loss function.
"""
if self.external_loss_def is not None:
return self.external_loss_def.build(
save_dir=save_dir, hubconf_dir=hubconf_dir)
args = {}
loss_weights = self.class_loss_weights
if loss_weights is not None:
loss_weights = torch.tensor(loss_weights).float()
args['weight'] = loss_weights
ignore_class_index = self.ignore_class_index
if ignore_class_index is not None:
if ignore_class_index >= 0:
args['ignore_index'] = ignore_class_index
else:
args['ignore_index'] = num_classes + ignore_class_index
loss = nn.CrossEntropyLoss(**args)
return loss
def build_optimizer(self, model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.Adam(model.parameters(), lr=self.lr, **kwargs)
def build_step_scheduler(self,
optimizer: optim.Optimizer,
train_ds_sz: int,
last_epoch: int = -1,
**kwargs) -> Optional[_LRScheduler]:
"""Returns an LR scheduler that changes the LR each step.
This is used to implement the "one cycle" schedule popularized by
fastai.
"""
scheduler = None
if self.one_cycle and self.num_epochs > 1:
steps_per_epoch = max(1, train_ds_sz // self.batch_sz)
total_steps = self.num_epochs * steps_per_epoch
step_size_up = (self.num_epochs // 2) * steps_per_epoch
step_size_down = total_steps - step_size_up
# Note that we don't pass in last_epoch here. See note below.
scheduler = CyclicLR(
optimizer,
base_lr=self.lr / 10,
max_lr=self.lr,
step_size_up=step_size_up,
step_size_down=step_size_down,
cycle_momentum=kwargs.pop('cycle_momentum', False),
**kwargs)
# Note: We need this loop because trying to resume the scheduler by
# just passing last_epoch does not work. See:
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822/2 # noqa
num_past_epochs = last_epoch + 1
for _ in range(num_past_epochs * steps_per_epoch):
scheduler.step()
return scheduler
def build_epoch_scheduler(self,
optimizer: optim.Optimizer,
last_epoch: int = -1,
**kwargs) -> Optional[_LRScheduler]:
"""Returns an LR scheduler tha changes the LR each epoch.
This is used to divide the LR by 10 at certain epochs.
"""
scheduler = None
if self.multi_stage:
# Note that we don't pass in last_epoch here. See note below.
scheduler = MultiStepLR(
optimizer,
milestones=self.multi_stage,
gamma=kwargs.pop('gamma', 0.1),
**kwargs)
# Note: We need this loop because trying to resume the scheduler by
# just passing last_epoch does not work. See:
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822/2 # noqa
num_past_epochs = last_epoch + 1
for _ in range(num_past_epochs):
scheduler.step()
return scheduler
def get_default_channel_display_groups(
nb_img_channels: int) -> Dict[str, ChannelInds]:
"""Returns the default channel_display_groups object.
See PlotOptions.channel_display_groups.
Displays at most the first 3 channels as RGB.
Args:
nb_img_channels: number of channels in the image that this is for
"""
num_display_channels = min(3, nb_img_channels)
return {'Input': list(range(num_display_channels))}
def validate_channel_display_groups(groups: Optional[Union[Dict[
str, ChannelInds], Sequence[ChannelInds]]]):
"""Validate channel display groups object.
See PlotOptions.channel_display_groups.
"""
if groups is None:
return None
elif len(groups) == 0:
raise ConfigError(
f'channel_display_groups cannot be empty. Set to None instead.')
elif not isinstance(groups, dict):
# if in list/tuple form, convert to dict s.t.
# [(0, 1, 2), (4, 3, 5)] --> {
# "Channels [0, 1, 2]": [0, 1, 2],
# "Channels [4, 3, 5]": [4, 3, 5]
# }
groups = {f'Channels: {[*chs]}': list(chs) for chs in groups}
else:
groups = {k: list(v) for k, v in groups.items()}
if isinstance(groups, dict):
for k, _v in groups.items():
if not (0 < len(_v) <= 3):
raise ConfigError(f'channel_display_groups[{k}]: '
'len(group) must be 1, 2, or 3')
return groups
@register_config('plot_options')
class PlotOptions(Config):
"""Config related to plotting."""
transform: Optional[dict] = Field(
A.to_dict(MinMaxNormalize()),
description='An Albumentations transform serialized as a dict that '
'will be applied to each image before it is plotted. Mainly useful '
'for undoing any data transformation that you do not want included in '
'the plot, such as normalization. The default value will shift and scale the '
'image so the values range from 0.0 to 1.0 which is the expected range for '
'the plotting function. This default is useful for cases where the values after '
'normalization are close to zero which makes the plot difficult to see.'
)
channel_display_groups: Optional[Union[Dict[str, ChannelInds], Sequence[
ChannelInds]]] = Field(
None,
description=
('Groups of image channels to display together as a subplot '
'when plotting the data and predictions. '
'Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a '
'dict containing title-to-group mappings '
'(e.g. {"RGB": [0, 1, 2], "IR": [3]}), '
'where each group is a list or tuple of channel indices and '
'title is a string that will be used as the title of the subplot '
'for that group.'))
# validators
_tf = validator(
'transform', allow_reuse=True)(validate_albumentation_transform)
def update(self, **kwargs) -> None:
super().update()
img_channels: Optional[int] = kwargs.get('img_channels')
if self.channel_display_groups is None and img_channels is not None:
self.channel_display_groups = get_default_channel_display_groups(
img_channels)
@validator('channel_display_groups')
def validate_channel_display_groups(
cls, v: Optional[Union[Dict[str, Sequence[NonNegInt]], Sequence[
Sequence[NonNegInt]]]]
) -> Optional[Dict[str, List[NonNegInt]]]:
return validate_channel_display_groups(v)
def ensure_class_colors(
class_names: List[str],
class_colors: Optional[List[Union[str, RGBTuple]]] = None):
"""Ensure that class_colors is valid.
If class_names is empty, fill with random colors.
Args:
class_names: see DataConfig.class_names
class_colors: see DataConfig.class_colors
"""
if class_colors is not None:
if len(class_names) != len(class_colors):
raise ConfigError(f'len(class_names) ({len(class_names)}) != '
f'len(class_colors) ({len(class_colors)})\n'
f'class_names: {class_names}\n'
f'class_colors: {class_colors}')
elif len(class_names) > 0:
class_colors = [color_to_triple() for _ in class_names]
return class_colors
def data_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version < 2:
cfg_dict['type_hint'] = 'image_data'
elif version < 3:
cfg_dict['img_channels'] = cfg_dict.get('img_channels')
return cfg_dict
@register_config('data', upgrader=data_config_upgrader)
class DataConfig(Config):
"""Config related to dataset for training and testing."""
class_names: List[str] = Field([], description='Names of classes.')
class_colors: Optional[List[Union[str, RGBTuple]]] = Field(
None,
description=('Colors used to display classes. '
'Can be color 3-tuples in list form.'))
img_channels: Optional[PosInt] = Field(
None, description='The number of channels of the training images.')
img_sz: PosInt = Field(
256,
description=
('Length of a side of each image in pixels. This is the size to transform '
'it to during training, not the size in the raw dataset.'))
train_sz: Optional[int] = Field(
None,
description=
('If set, the number of training images to use. If fewer images exist, '
'then an exception will be raised.'))
train_sz_rel: Optional[float] = Field(
None, description='If set, the proportion of training images to use.')
num_workers: int = Field(
4,
description='Number of workers to use when DataLoader makes batches.')
augmentors: List[str] = Field(
default_augmentors,
description='Names of albumentations augmentors to use for training '
f'batches. Choices include: {augmentors}. Alternatively, a custom '
'transform can be provided via the aug_transform option.')
base_transform: Optional[dict] = Field(
None,
description='An Albumentations transform serialized as a dict that '
'will be applied to all datasets: training, validation, and test. '
'This transformation is in addition to the resizing due to img_sz. '
'This is useful for, for example, applying the same normalization to '
'all datasets.')
aug_transform: Optional[dict] = Field(
None,
description='An Albumentations transform serialized as a dict that '
'will be applied as data augmentation to the training dataset. This '
'transform is applied before base_transform. If provided, the '
'augmentors option is ignored.')
plot_options: Optional[PlotOptions] = Field(
PlotOptions(), description='Options to control plotting.')
preview_batch_limit: Optional[int] = Field(
None,
description=
('Optional limit on the number of items in the preview plots produced '
'during training.'))
@property
def num_classes(self):
return len(self.class_names)
# validators
_base_tf = validator(
'base_transform', allow_reuse=True)(validate_albumentation_transform)
_aug_tf = validator(
'aug_transform', allow_reuse=True)(validate_albumentation_transform)
@root_validator(skip_on_failure=True)
def ensure_class_colors(cls, values: dict) -> dict:
class_names = values.get('class_names')
class_colors = values.get('class_colors')
values['class_colors'] = ensure_class_colors(class_names, class_colors)
return values
@validator('augmentors', each_item=True)
def validate_augmentors(cls, v: str) -> str:
if v not in augmentors:
raise ConfigError(f'Unsupported augmentor "{v}"')
return v
@root_validator(skip_on_failure=True)
def validate_plot_options(cls, values: dict) -> dict:
plot_options: Optional[PlotOptions] = values.get('plot_options')
if plot_options is None:
return None
img_channels: Optional[PosInt] = values.get('img_channels')
if img_channels is not None:
plot_options.update(img_channels=img_channels)
return values
def make_datasets(self) -> Tuple[Dataset, Dataset, Dataset]:
raise NotImplementedError()
def get_custom_albumentations_transforms(self) -> List[dict]:
"""Returns all custom transforms found in this config.
This should return all serialized albumentations transforms with
a 'lambda_transforms_path' field contained in this
config or in any of its members no matter how deeply neseted.
The pupose is to make it easier to adjust their paths all at once while
saving to or loading from a bundle.
"""
transforms_all = [
self.base_transform, self.aug_transform,
self.plot_options.transform
]
transforms_with_lambdas = [
tf for tf in transforms_all if (tf is not None) and (
tf.get('lambda_transforms_path') is not None)
]
return transforms_with_lambdas
def get_bbox_params(self) -> Optional[A.BboxParams]:
"""Returns BboxParams used by albumentations for data augmentation."""
return None
def get_data_transforms(self) -> Tuple[A.BasicTransform, A.BasicTransform]:
"""Get albumentations transform objects for data augmentation.
Returns:
1st tuple arg: a transform that doesn't do any data augmentation
2nd tuple arg: a transform with data augmentation
"""
bbox_params = self.get_bbox_params()
base_tfs = [A.Resize(self.img_sz, self.img_sz)]
if self.base_transform is not None:
base_tfs.append(
deserialize_albumentation_transform(self.base_transform))
base_transform = A.Compose(base_tfs, bbox_params=bbox_params)
if self.aug_transform is not None:
aug_transform = deserialize_albumentation_transform(
self.aug_transform)
aug_transform = A.Compose(
[base_transform, aug_transform], bbox_params=bbox_params)
return base_transform, aug_transform
augmentors_dict = {
'Blur': A.Blur(),
'RandomRotate90': A.RandomRotate90(),
'HorizontalFlip': A.HorizontalFlip(),
'VerticalFlip': A.VerticalFlip(),
'GaussianBlur': A.GaussianBlur(),
'GaussNoise': A.GaussNoise(),
'RGBShift': A.RGBShift(),
'ToGray': A.ToGray()
}
aug_transforms = [base_transform]
for augmentor in self.augmentors:
try:
aug_transforms.append(augmentors_dict[augmentor])
except KeyError as k:
log.warning(
f'{k} is an unknown augmentor. Continuing without {k}. '
f'Known augmentors are: {list(augmentors_dict.keys())}')
aug_transform = A.Compose(aug_transforms, bbox_params=bbox_params)
return base_transform, aug_transform
def build(self,
tmp_dir: str,
overfit_mode: bool = False,
test_mode: bool = False) -> Tuple[Dataset, Dataset, Dataset]:
"""Build and return train, val, and test datasets."""
raise NotImplementedError()
def random_subset_dataset(self,
ds: Dataset,
size: Optional[int] = None,
fraction: Optional[Proportion] = None) -> Subset:
if size is None and fraction is None:
return ds
if size is not None and fraction is not None:
raise ValueError('Specify either size or fraction but not both.')
if fraction is not None:
size = int(len(ds) * fraction)
random.seed(1234)
inds = list(range(len(ds)))
random.shuffle(inds)
ds = Subset(ds, inds[:size])
return ds
@register_config('image_data')
class ImageDataConfig(DataConfig):
"""Config related to dataset for training and testing."""
data_format: Optional[str] = Field(
None, description='Name of dataset format.')
uri: Optional[Union[str, List[str]]] = Field(
None,
description='One of the following:\n'
'(1) a URI of a directory containing "train", "valid", and '
'(optinally) "test" subdirectories;\n'
'(2) a URI of a zip file containing (1);\n'
'(3) a list of (2);\n'
'(4) a URI of a directory containing zip files containing (1).')
group_uris: Optional[List[Union[str, List[str]]]] = Field(
None,
description=
'This can be set instead of uri in order to specify groups of chips. '
'Each element in the list is expected to be an object of the same '
'form accepted by the uri field. The purpose of separating chips into '
'groups is to be able to use the group_train_sz field.')
group_train_sz: Optional[Union[int, List[int]]] = Field(
None,
description='If group_uris is set, this can be used to specify the '
'number of chips to use per group. Only applies to training chips. '
'This can either be a single value that will be used for all groups '
'or a list of values (one for each group).')
group_train_sz_rel: Optional[Union[Proportion, List[Proportion]]] = Field(
None,
description='Relative version of group_train_sz. Must be a float '
'in [0, 1]. If group_uris is set, this can be used to specify the '
'proportion of the total chips in each group to use per group. '
'Only applies to training chips. This can either be a single value '
'that will be used for all groups or a list of values '
'(one for each group).')
@root_validator(skip_on_failure=True)
def validate_group_uris(cls, values: dict) -> dict:
group_train_sz = values.get('group_train_sz')
group_train_sz_rel = values.get('group_train_sz_rel')
group_uris = values.get('group_uris')
has_group_train_sz = group_train_sz is not None
has_group_train_sz_rel = group_train_sz_rel is not None
has_group_uris = group_uris is not None
if has_group_train_sz and has_group_train_sz_rel:
raise ConfigError('Only one of group_train_sz and '
'group_train_sz_rel should be specified.')
if has_group_train_sz and not has_group_uris:
raise ConfigError('group_train_sz specified without group_uris.')
if has_group_train_sz_rel and not has_group_uris:
raise ConfigError(
'group_train_sz_rel specified without group_uris.')
if has_group_train_sz and sequence_like(group_train_sz):
if len(group_train_sz) != len(group_uris):
raise ConfigError('len(group_train_sz) != len(group_uris).')
if has_group_train_sz_rel and sequence_like(group_train_sz_rel):
if len(group_train_sz_rel) != len(group_uris):
raise ConfigError(
'len(group_train_sz_rel) != len(group_uris).')
return values
def make_datasets(self,
train_dirs: Iterable[str],
val_dirs: Iterable[str],
test_dirs: Iterable[str],
train_tf: Optional[A.BasicTransform] = None,
val_tf: Optional[A.BasicTransform] = None,
test_tf: Optional[A.BasicTransform] = None
) -> Tuple[Dataset, Dataset, Dataset]:
"""Make training, validation, and test datasets.
Args:
train_dirs (str): Directories where training data is located.
val_dirs (str): Directories where validation data is located.
test_dirs (str): Directories where test data is located.
train_tf (Optional[A.BasicTransform], optional): Transform for the
training dataset. Defaults to None.
val_tf (Optional[A.BasicTransform], optional): Transform for the
validation dataset. Defaults to None.
test_tf (Optional[A.BasicTransform], optional): Transform for the
test dataset. Defaults to None.
Returns:
Tuple[Dataset, Dataset, Dataset]: PyTorch-compatiable training,
validation, and test datasets.
"""
train_ds_list = [self.dir_to_dataset(d, train_tf) for d in train_dirs]
val_ds_list = [self.dir_to_dataset(d, val_tf) for d in val_dirs]
test_ds_list = [self.dir_to_dataset(d, test_tf) for d in test_dirs]
for ds_list in [train_ds_list, val_ds_list, test_ds_list]:
if len(ds_list) == 0:
ds_list.append([])
train_ds = ConcatDataset(train_ds_list)
val_ds = ConcatDataset(val_ds_list)
test_ds = ConcatDataset(test_ds_list)
return train_ds, val_ds, test_ds
def dir_to_dataset(self, data_dir: str,
transform: A.BasicTransform) -> Dataset:
raise NotImplementedError()
def build(self,
tmp_dir: str,
overfit_mode: bool = False,
test_mode: bool = False) -> Tuple[Dataset, Dataset, Dataset]:
if self.group_uris is None:
return self.get_datasets_from_uri(
self.uri,
tmp_dir=tmp_dir,
overfit_mode=overfit_mode,
test_mode=test_mode)
if self.uri is not None:
log.warn('Both DataConfig.uri and DataConfig.group_uris '
'specified. Only DataConfig.group_uris will be used.')
train_ds, valid_ds, test_ds = self.get_datasets_from_group_uris(
self.group_uris,
tmp_dir=tmp_dir,
overfit_mode=overfit_mode,
test_mode=test_mode)
if self.train_sz is not None or self.train_sz_rel is not None:
train_ds = self.random_subset_dataset(
train_ds, size=self.train_sz, fraction=self.train_sz_rel)
return train_ds, valid_ds, test_ds
def get_datasets_from_uri(
self,
uri: Union[str, List[str]],
tmp_dir: str,
overfit_mode: bool = False,
test_mode: bool = False) -> Tuple[Dataset, Dataset, Dataset]:
"""Get image train, validation, & test datasets from a single zip file.
Args:
uri (Union[str, List[str]]): Uri of a zip file containing the
images.
Returns:
Tuple[Dataset, Dataset, Dataset]: Training, validation, and test
dataSets.
"""
data_dirs = self.get_data_dirs(uri, unzip_dir=tmp_dir)
train_dirs = [join(d, 'train') for d in data_dirs if isdir(d)]
val_dirs = [join(d, 'valid') for d in data_dirs if isdir(d)]
test_dirs = [join(d, 'test') for d in data_dirs if isdir(d)]
train_dirs = [d for d in train_dirs if isdir(d)]
val_dirs = [d for d in val_dirs if isdir(d)]
test_dirs = [d for d in test_dirs if isdir(d)]
base_transform, aug_transform = self.get_data_transforms()
train_tf = (aug_transform if not overfit_mode else base_transform)
val_tf, test_tf = base_transform, base_transform
train_ds, val_ds, test_ds = self.make_datasets(
train_dirs=train_dirs,
val_dirs=val_dirs,
test_dirs=test_dirs,
train_tf=train_tf,
val_tf=val_tf,
test_tf=test_tf)
return train_ds, val_ds, test_ds
def get_datasets_from_group_uris(
self,
uris: Union[str, List[str]],
tmp_dir: str,
group_train_sz: Optional[int] = None,
group_train_sz_rel: Optional[float] = None,
overfit_mode: bool = False,
test_mode: bool = False,
) -> Tuple[Dataset, Dataset, Dataset]:
train_ds_lst, valid_ds_lst, test_ds_lst = [], [], []
group_sizes = None
if group_train_sz is not None:
group_sizes = group_train_sz
elif group_train_sz_rel is not None:
group_sizes = group_train_sz_rel
if not sequence_like(group_sizes):
group_sizes = [group_sizes] * len(uris)
for uri, size in zip(uris, group_sizes):
train_ds, valid_ds, test_ds = self.get_datasets_from_uri(
uri,
tmp_dir=tmp_dir,
overfit_mode=overfit_mode,
test_mode=test_mode)
if size is not None:
if isinstance(size, float):
train_ds = self.random_subset_dataset(
train_ds, fraction=size)
else:
train_ds = self.random_subset_dataset(train_ds, size=size)
train_ds_lst.append(train_ds)
valid_ds_lst.append(valid_ds)
test_ds_lst.append(test_ds)
train_ds, valid_ds, test_ds = (ConcatDataset(train_ds_lst),
ConcatDataset(valid_ds_lst),
ConcatDataset(test_ds_lst))
return train_ds, valid_ds, test_ds
def get_data_dirs(self, uri: Union[str, List[str]],
unzip_dir: str) -> List[str]:
"""Extract data dirs from uri.
Data dirs are directories containing "train", "valid", and
(optinally) "test" subdirectories.
Args:
uri (Union[str, List[str]]): a URI or a list of URIs of one of the
following:
(1) a URI of a directory containing "train", "valid", and
(optinally) "test" subdirectories
(2) a URI of a zip file containing (1)
(3) a list of (2)
(4) a URI of a directory containing zip files
containing (1)
Returns:
paths to directories that each contain contents of one zip file
"""
def is_data_dir(uri: str) -> bool: