Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Rename, reorganize schema module #1963

Merged
merged 35 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b7ce8c4
Remove generated files.
ksbrar Apr 24, 2022
6424655
Remove extraction script.
ksbrar Apr 24, 2022
74a4d32
Adjust marshmallow_schema_utils.
ksbrar Apr 24, 2022
3e166d4
change manifest.
ksbrar Apr 24, 2022
bfb2e61
adjust combiners, rename unload method.
ksbrar Apr 24, 2022
5ee4eca
adjust trainer
ksbrar Apr 24, 2022
664b4c7
convert optimizers.
ksbrar Apr 25, 2022
765c783
update FloatRange allow_none
ksbrar Apr 25, 2022
11a2857
Add descriptions to optimizer, gradientclipping dataclass fields.
ksbrar Apr 25, 2022
4e8a232
update marshmallow tests.
ksbrar Apr 25, 2022
2e57cb2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2022
b266763
fix
ksbrar Apr 25, 2022
39dfdc0
Merge branch 'refactor_remove_marshmallow_extraction' of github.com:k…
ksbrar Apr 25, 2022
f7601e7
remove old method refs.
ksbrar Apr 25, 2022
f3df813
fix
ksbrar Apr 25, 2022
a3406f4
replace default description None with TODO
ksbrar Apr 25, 2022
a927865
fix test
ksbrar Apr 25, 2022
1affa5b
additionalProperties fix
ksbrar Apr 25, 2022
7742720
Rename marshamllow_schema_utils file -> utils
ksbrar Apr 26, 2022
d8863b1
rename ludwig.marshmallow folder -> ludwig.schema
ksbrar Apr 26, 2022
26008fd
Move ludwig/utils/schema file to new ludwig/schema folder.
ksbrar Apr 26, 2022
b3e93dd
style: clean up some import aliases.
ksbrar Apr 26, 2022
dd5e36c
Merge in latest.
ksbrar Apr 26, 2022
1800687
integration test fix.
ksbrar Apr 26, 2022
9279905
add header comments.
ksbrar Apr 26, 2022
31b1b58
add header comments.
ksbrar Apr 26, 2022
77ea23a
add header comments.
ksbrar Apr 26, 2022
09e06e0
Rename import ... utils -> import ... utils as schema_utils.
ksbrar Apr 27, 2022
9fb1f48
fix replace-all, move schema.py into schema/__init__.py
ksbrar Apr 27, 2022
d9eb45d
Move trainer and combiner schemas over. (still need to fix tests, may…
ksbrar Apr 28, 2022
f6dd4e3
tmp - fix tests, move sequence_encoder_registry, fix some imports.
ksbrar Apr 28, 2022
4c0a2fc
more import fixes.
ksbrar Apr 28, 2022
b6f48af
fix import.
ksbrar Apr 28, 2022
3197e31
Merge in latest.
ksbrar May 9, 2022
de7769e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
set_disable_progressbar,
TRAIN_SET_METADATA_FILE_NAME,
)
from ludwig.marshmallow.marshmallow_schema_utils import load_config_with_kwargs
from ludwig.models.ecd import ECD
from ludwig.models.inference import InferenceModule
from ludwig.models.predictor import (
Expand All @@ -72,6 +71,8 @@
)
from ludwig.models.trainer import Trainer
from ludwig.modules.metric_modules import get_best_function
from ludwig.schema.schema import validate_config
from ludwig.schema.utils import load_config_with_kwargs
from ludwig.utils import metric_utils
from ludwig.utils.data_utils import (
figure_data_format,
Expand All @@ -85,7 +86,6 @@
from ludwig.utils.fs_utils import makedirs, open_file, path_exists, upload_output_directory
from ludwig.utils.misc_utils import get_file_names, get_output_directory
from ludwig.utils.print_utils import print_boxed
from ludwig.utils.schema import validate_config

logger = logging.getLogger(__name__)

Expand Down
128 changes: 63 additions & 65 deletions ludwig/combiners/combiners.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ludwig.features.base_feature import InputFeature, OutputFeature
from ludwig.features.feature_registries import input_type_registry, output_type_registry
from ludwig.features.feature_utils import LudwigFeatureDict
from ludwig.marshmallow.marshmallow_schema_utils import load_config_with_kwargs
from ludwig.schema.utils import load_config_with_kwargs
from ludwig.utils import output_feature_utils
from ludwig.utils.algorithms_utils import topological_sort_feature_dependencies
from ludwig.utils.data_utils import clear_data_cache
Expand Down
66 changes: 33 additions & 33 deletions ludwig/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import ludwig.marshmallow.marshmallow_schema_utils as schema
from ludwig.constants import COMBINED, LOSS, TEST, TRAINING, VALIDATION
from ludwig.data.dataset.base import Dataset
from ludwig.globals import (
Expand All @@ -56,6 +55,7 @@
GradientClippingDataclassField,
OptimizerDataclassField,
)
from ludwig.schema import utils
from ludwig.utils import time_utils
from ludwig.utils.checkpoint_utils import Checkpoint, CheckpointManager
from ludwig.utils.defaults import default_random_seed
Expand Down Expand Up @@ -103,32 +103,32 @@ def __exit__(self, exc_type, exc_val, exc_tb):


def get_trainer_jsonschema():
return schema.unload_jsonschema_from_marshmallow_class(TrainerConfig)
return utils.unload_jsonschema_from_marshmallow_class(TrainerConfig)


@dataclass
class TrainerConfig(schema.BaseMarshmallowConfig):
class TrainerConfig(utils.BaseMarshmallowConfig):
"""TrainerConfig is a dataclass that configures most of the hyperparameters used for model training."""

optimizer: BaseOptimizerConfig = OptimizerDataclassField(
default={"type": "adam"}, description="Parameter values for selected torch optimizer."
)

epochs: int = schema.PositiveInteger(
epochs: int = utils.PositiveInteger(
default=100, description="Number of epochs the algorithm is intended to be run over."
)

regularization_lambda: float = schema.FloatRange(
regularization_lambda: float = utils.FloatRange(
default=0.0, min=0, description="Strength of the $L2$ regularization."
)

regularization_type: Optional[str] = schema.RegularizerOptions(default="l2", description="Type of regularization.")
regularization_type: Optional[str] = utils.RegularizerOptions(default="l2", description="Type of regularization.")

should_shuffle: bool = schema.Boolean(
should_shuffle: bool = utils.Boolean(
default=True, description="Whether to shuffle batches during training when true."
)

learning_rate: float = schema.NumericOrStringOptionsField(
learning_rate: float = utils.NumericOrStringOptionsField(
default=0.001,
min=0.0,
max=1.0,
Expand All @@ -142,7 +142,7 @@ class TrainerConfig(schema.BaseMarshmallowConfig):
),
)

batch_size: Union[int, str] = schema.IntegerOrStringOptionsField(
batch_size: Union[int, str] = utils.IntegerOrStringOptionsField(
default=128,
options=["auto"],
default_numeric=128,
Expand All @@ -152,7 +152,7 @@ class TrainerConfig(schema.BaseMarshmallowConfig):
description="Size of batch to pass to the model for training.",
)

eval_batch_size: Union[None, int, str] = schema.IntegerOrStringOptionsField(
eval_batch_size: Union[None, int, str] = utils.IntegerOrStringOptionsField(
default=None,
options=["auto"],
default_numeric=None,
Expand All @@ -162,7 +162,7 @@ class TrainerConfig(schema.BaseMarshmallowConfig):
description="Size of batch to pass to the model for evaluation.",
)

early_stop: int = schema.IntegerRange(
early_stop: int = utils.IntegerRange(
default=5,
min=-1,
description=(
Expand All @@ -171,27 +171,27 @@ class TrainerConfig(schema.BaseMarshmallowConfig):
),
)

steps_per_checkpoint: int = schema.NonNegativeInteger(
steps_per_checkpoint: int = utils.NonNegativeInteger(
default=0,
description=(
"How often the model is checkpointed. Also dictates maximum evaluation frequency. If 0 the model is "
"checkpointed after every epoch."
),
)

checkpoints_per_epoch: int = schema.NonNegativeInteger(
checkpoints_per_epoch: int = utils.NonNegativeInteger(
default=0,
description=(
"Number of checkpoints per epoch. For example, 2 -> checkpoints are written every half of an epoch. Note "
"that it is invalid to specify both non-zero `steps_per_checkpoint` and non-zero `checkpoints_per_epoch`."
),
)

evaluate_training_set: bool = schema.Boolean(
evaluate_training_set: bool = utils.Boolean(
default=True, description="Whether to include the entire training set during evaluation."
)

reduce_learning_rate_on_plateau: float = schema.FloatRange(
reduce_learning_rate_on_plateau: float = utils.FloatRange(
default=0.0,
min=0.0,
max=1.0,
Expand All @@ -201,65 +201,65 @@ class TrainerConfig(schema.BaseMarshmallowConfig):
),
)

reduce_learning_rate_on_plateau_patience: int = schema.NonNegativeInteger(
reduce_learning_rate_on_plateau_patience: int = utils.NonNegativeInteger(
default=5, description="How many epochs have to pass before the learning rate reduces."
)

reduce_learning_rate_on_plateau_rate: float = schema.FloatRange(
reduce_learning_rate_on_plateau_rate: float = utils.FloatRange(
default=0.5, min=0.0, max=1.0, description="Rate at which we reduce the learning rate."
)

reduce_learning_rate_eval_metric: str = schema.String(default=LOSS, description="TODO: Document parameters.")
reduce_learning_rate_eval_metric: str = utils.String(default=LOSS, description="TODO: Document parameters.")

reduce_learning_rate_eval_split: str = schema.String(default=TRAINING, description="TODO: Document parameters.")
reduce_learning_rate_eval_split: str = utils.String(default=TRAINING, description="TODO: Document parameters.")

increase_batch_size_on_plateau: int = schema.NonNegativeInteger(
increase_batch_size_on_plateau: int = utils.NonNegativeInteger(
default=0, description="Number to increase the batch size by on a plateau."
)

increase_batch_size_on_plateau_patience: int = schema.NonNegativeInteger(
increase_batch_size_on_plateau_patience: int = utils.NonNegativeInteger(
default=5, description="How many epochs to wait for before increasing the batch size."
)

increase_batch_size_on_plateau_rate: float = schema.NonNegativeFloat(
increase_batch_size_on_plateau_rate: float = utils.NonNegativeFloat(
default=2.0, description="Rate at which the batch size increases."
)

increase_batch_size_on_plateau_max: int = schema.PositiveInteger(
increase_batch_size_on_plateau_max: int = utils.PositiveInteger(
default=512, description="Maximum size of the batch."
)

increase_batch_size_eval_metric: str = schema.String(default=LOSS, description="TODO: Document parameters.")
increase_batch_size_eval_metric: str = utils.String(default=LOSS, description="TODO: Document parameters.")

increase_batch_size_eval_split: str = schema.String(default=TRAINING, description="TODO: Document parameters.")
increase_batch_size_eval_split: str = utils.String(default=TRAINING, description="TODO: Document parameters.")

decay: bool = schema.Boolean(default=False, description="Turn on exponential decay of the learning rate.")
decay: bool = utils.Boolean(default=False, description="Turn on exponential decay of the learning rate.")

decay_steps: int = schema.PositiveInteger(default=10000, description="TODO: Document parameters.")
decay_steps: int = utils.PositiveInteger(default=10000, description="TODO: Document parameters.")

decay_rate: float = schema.FloatRange(default=0.96, min=0.0, max=1.0, description="TODO: Document parameters.")
decay_rate: float = utils.FloatRange(default=0.96, min=0.0, max=1.0, description="TODO: Document parameters.")

staircase: bool = schema.Boolean(default=False, description="Decays the learning rate at discrete intervals.")
staircase: bool = utils.Boolean(default=False, description="Decays the learning rate at discrete intervals.")

gradient_clipping: Optional[GradientClippingConfig] = GradientClippingDataclassField(
description="Parameter values for gradient clipping."
)

# TODO(#1673): Need some more logic here for validating against output features
validation_field: str = schema.String(
validation_field: str = utils.String(
default=COMBINED,
description="First output feature, by default it is set as the same field of the first output feature.",
)

validation_metric: str = schema.String(
validation_metric: str = utils.String(
default=LOSS, description="Metric used on `validation_field`, set by default to accuracy."
)

learning_rate_warmup_epochs: float = schema.NonNegativeFloat(
learning_rate_warmup_epochs: float = utils.NonNegativeFloat(
default=1.0, description="Number of epochs to warmup the learning rate for."
)

learning_rate_scaling: str = schema.StringOptions(
learning_rate_scaling: str = utils.StringOptions(
["constant", "sqrt", "linear"],
default="linear",
description=(
Expand Down
2 changes: 1 addition & 1 deletion ludwig/modules/optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from marshmallow import fields, ValidationError
from marshmallow_dataclass import dataclass

from ludwig.marshmallow.marshmallow_schema_utils import (
from ludwig.schema.utils import (
BaseMarshmallowConfig,
Boolean,
create_cond,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion ludwig/utils/schema.py → ludwig/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from ludwig.decoders.registry import get_decoder_classes
from ludwig.encoders.registry import get_encoder_classes
from ludwig.features.feature_registries import input_type_registry, output_type_registry
from ludwig.marshmallow.marshmallow_schema_utils import create_cond
from ludwig.models.trainer import get_trainer_jsonschema
from ludwig.schema.utils import create_cond


def get_schema():
Expand Down
File renamed without changes.
4 changes: 3 additions & 1 deletion tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ def filter(stats):
assert k1 == k2
for (name1, metric1), (name2, metric2) in zip(v1.items(), v2.items()):
assert name1 == name2
assert np.isclose(metric1, metric2, rtol=1e-04), f"metric {name1}: {metric1} != {metric2}"
assert np.isclose(
metric1, metric2, rtol=1e-04, atol=1e-5
), f"metric {name1}: {metric1} != {metric2}"

return model
finally:
Expand Down
2 changes: 1 addition & 1 deletion tests/ludwig/combiners/test_combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TransformerCombiner,
TransformerCombinerConfig,
)
from ludwig.marshmallow.marshmallow_schema_utils import load_config
from ludwig.schema.utils import load_config

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down
36 changes: 18 additions & 18 deletions tests/ludwig/marshmallow/test_fields_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from marshmallow.exceptions import ValidationError as MarshmallowValidationError
from marshmallow_dataclass import dataclass

import ludwig.marshmallow.marshmallow_schema_utils as lusutils
from ludwig.schema import utils


def get_marshmallow_from_dataclass_field(dfield):
Expand All @@ -19,14 +19,14 @@ def test_StringOptions():
# Test case of default conflicting with allowed options:
test_options = ["one"]
with pytest.raises(MarshmallowValidationError):
lusutils.StringOptions(test_options, default=None, nullable=False)
utils.StringOptions(test_options, default=None, nullable=False)

# Test creating a schema with simple option, null not allowed:
test_options = ["one"]

@dataclass
class CustomTestSchema(lusutils.BaseMarshmallowConfig):
foo: str = lusutils.StringOptions(test_options, "one", nullable=False)
class CustomTestSchema(utils.BaseMarshmallowConfig):
foo: str = utils.StringOptions(test_options, "one", nullable=False)

with pytest.raises(MarshmallowValidationError):
CustomTestSchema.Schema().load({"foo": None})
Expand All @@ -37,14 +37,14 @@ class CustomTestSchema(lusutils.BaseMarshmallowConfig):

def test_Embed():
# Test metadata matches expected defaults after field creation (null allowed):
default_embed = get_marshmallow_from_dataclass_field(lusutils.Embed())
default_embed = get_marshmallow_from_dataclass_field(utils.Embed())
assert default_embed.default is None
assert default_embed.allow_none is True

# Test simple schema creation:
@dataclass
class CustomTestSchema(lusutils.BaseMarshmallowConfig):
foo: Union[None, str, int] = lusutils.Embed()
class CustomTestSchema(utils.BaseMarshmallowConfig):
foo: Union[None, str, int] = utils.Embed()

# Test null/empty loading cases:
assert CustomTestSchema.Schema().load({}).foo is None
Expand All @@ -63,20 +63,20 @@ class CustomTestSchema(lusutils.BaseMarshmallowConfig):

def test_InitializerOrDict():
# Test metadata matches expected defaults after field creation (null allowed):
default_initializerordict = get_marshmallow_from_dataclass_field(lusutils.InitializerOrDict())
default_initializerordict = get_marshmallow_from_dataclass_field(utils.InitializerOrDict())
assert default_initializerordict.default == "xavier_uniform"

initializerordict = get_marshmallow_from_dataclass_field(lusutils.InitializerOrDict("zeros"))
initializerordict = get_marshmallow_from_dataclass_field(utils.InitializerOrDict("zeros"))
assert initializerordict.default == "zeros"

# Test default value validation:
with pytest.raises(MarshmallowValidationError):
lusutils.InitializerOrDict("test")
utils.InitializerOrDict("test")

# Test simple schema creation:
@dataclass
class CustomTestSchema(lusutils.BaseMarshmallowConfig):
foo: Union[None, str, Dict] = lusutils.InitializerOrDict()
class CustomTestSchema(utils.BaseMarshmallowConfig):
foo: Union[None, str, Dict] = utils.InitializerOrDict()

# Test invalid non-dict loads:
with pytest.raises(MarshmallowValidationError):
Expand All @@ -103,17 +103,17 @@ class CustomTestSchema(lusutils.BaseMarshmallowConfig):

def test_FloatRangeTupleDataclassField():
# Test metadata matches expected defaults after field creation (null not allowed):
default_floatrange_tuple = get_marshmallow_from_dataclass_field(lusutils.FloatRangeTupleDataclassField())
default_floatrange_tuple = get_marshmallow_from_dataclass_field(utils.FloatRangeTupleDataclassField())
assert default_floatrange_tuple.default == (0.9, 0.999)

# Test dimensional mismatch:
with pytest.raises(MarshmallowValidationError):
lusutils.FloatRangeTupleDataclassField(N=3, default=(1, 1))
utils.FloatRangeTupleDataclassField(N=3, default=(1, 1))

# Test default schema creation:
@dataclass
class CustomTestSchema(lusutils.BaseMarshmallowConfig):
foo: Tuple[float, float] = lusutils.FloatRangeTupleDataclassField()
class CustomTestSchema(utils.BaseMarshmallowConfig):
foo: Tuple[float, float] = utils.FloatRangeTupleDataclassField()

# Test empty load:
assert CustomTestSchema.Schema().load({}).foo == (0.9, 0.999)
Expand All @@ -128,8 +128,8 @@ class CustomTestSchema(lusutils.BaseMarshmallowConfig):

# Test non-default schema (N=3, other custom metadata):
@dataclass
class CustomTestSchema(lusutils.BaseMarshmallowConfig):
foo: Tuple[float, float] = lusutils.FloatRangeTupleDataclassField(N=3, default=(1, 1, 1), min=-10, max=10)
class CustomTestSchema(utils.BaseMarshmallowConfig):
foo: Tuple[float, float] = utils.FloatRangeTupleDataclassField(N=3, default=(1, 1, 1), min=-10, max=10)

assert CustomTestSchema.Schema().load({}).foo == (1, 1, 1)
assert CustomTestSchema.Schema().load({"foo": [2, 2, 2]}).foo == (2, 2, 2)
Expand Down
Loading