-
Notifications
You must be signed in to change notification settings - Fork 1
/
__init__.py
93 lines (78 loc) · 2.92 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import tensorflow as tf
from chia import components
from chia.components.base_models.keras import (
keras_basemodel,
keras_dataaugmentation,
keras_featureextractor,
keras_learningrateschedule,
keras_preprocessor,
keras_trainer,
)
def _get_input_img_np(sample):
return sample.get_resource("input_img_np")
class KerasOptimizerFactory(components.Factory):
name_to_class_mapping = {
"adam": tf.keras.optimizers.Adam,
"sgd": tf.keras.optimizers.SGD,
}
default_section = "keras_optimizer"
i_know_that_var_args_are_not_supported = True
class KerasBaseModelContainer:
def __init__(self, config, classifier, observers=()):
self.learning_rate_schedule = (
keras_learningrateschedule.KerasLearningRateScheduleFactory.create(
config["learning_rate_schedule"], observers=observers
)
)
self.optimizer = KerasOptimizerFactory.create(
config["optimizer"], observers=observers
)
try:
augmentation_config = config["augmentation"]
except KeyError:
augmentation_config = dict()
self.augmentation = keras_dataaugmentation.KerasDataAugmentationFactory.create(
augmentation_config,
observers=observers,
)
try:
preprocessor_config = config["preprocessor"]
except KeyError:
preprocessor_config = dict()
self.preprocessor = keras_preprocessor.KerasPreprocessorFactory.create(
preprocessor_config,
observers=observers,
augmentation=self.augmentation,
)
self.feature_extractor = (
keras_featureextractor.KerasFeatureExtractorFactory.create(
config["feature_extractor"],
observers=observers,
preprocessor=self.preprocessor,
)
)
self.feature_extractor = (
keras_featureextractor.KerasFeatureExtractorFactory.create(
config["feature_extractor"],
observers=observers,
preprocessor=self.preprocessor,
)
)
self.trainer = keras_trainer.KerasTrainerFactory.create(
config["trainer"],
observers=observers,
feature_extractor=self.feature_extractor,
preprocessor=self.preprocessor,
classifier=classifier,
learning_rate_schedule=self.learning_rate_schedule,
optimizer=self.optimizer,
)
self.base_model = keras_basemodel.KerasBaseModelFactory.create(
dict(), # This is because a field "trainer" is present in config -> conflict
observers=observers,
batch_size=config["trainer"]["batch_size"],
classifier=classifier,
feature_extractor=self.feature_extractor,
trainer=self.trainer,
preprocessor=self.preprocessor,
)