Skip to content

Commit

Permalink
Add dicts to initializer field types for combiners (#1476)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksbrar committed Nov 13, 2021
1 parent 8279d22 commit 4a97138
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 13 deletions.
20 changes: 8 additions & 12 deletions ludwig/combiners/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ class ConcatCombinerConfig:
num_fc_layers: int = schema.NonNegativeInteger(default=0)
fc_size: int = schema.PositiveInteger(default=256)
use_bias: bool = True
weights_initializer: str = schema.InitializerOptions(
default='xavier_uniform')
bias_initializer: str = schema.InitializerOptions(default='zeros')
weights_initializer: Union[str, Dict] = schema.InitializerOrDict(default='xavier_uniform')
bias_initializer: Union[str, Dict] = schema.InitializerOrDict(default='zeros')
norm: Optional[str] = schema.StringOptions(['batch', 'layer'])
norm_params: Optional[dict] = schema.Dict()
activation: str = 'relu'
Expand Down Expand Up @@ -583,9 +582,8 @@ class TransformerCombinerConfig:
num_fc_layers: int = schema.NonNegativeInteger(default=0)
fc_size: int = schema.PositiveInteger(default=256)
use_bias: bool = True
weights_initializer: str = schema.InitializerOptions(
default='xavier_uniform')
bias_initializer: str = schema.InitializerOptions(default='zeros')
weights_initializer: Union[str, Dict] = schema.InitializerOrDict(default='xavier_uniform')
bias_initializer: Union[str, Dict] = schema.InitializerOrDict(default='zeros')
norm: Optional[str] = schema.StringOptions(['batch', 'layer'])
norm_params: Optional[dict] = schema.Dict()
fc_activation: str = 'relu'
Expand Down Expand Up @@ -722,9 +720,8 @@ class TabTransformerCombinerConfig:
num_fc_layers: int = schema.NonNegativeInteger(default=0)
fc_size: int = schema.PositiveInteger(default=256)
use_bias: bool = True
weights_initializer: str = schema.InitializerOptions(
default='xavier_uniform')
bias_initializer: str = schema.InitializerOptions(default='zeros')
weights_initializer: Union[str, Dict] = schema.InitializerOrDict(default='xavier_uniform')
bias_initializer: Union[str, Dict] = schema.InitializerOrDict(default='zeros')
norm: Optional[str] = schema.StringOptions(['batch', 'layer'])
norm_params: Optional[dict] = schema.Dict()
fc_activation: str = 'relu'
Expand Down Expand Up @@ -963,9 +960,8 @@ class ComparatorCombinerConfig:
num_fc_layers: int = schema.NonNegativeInteger(default=1)
fc_size: int = schema.PositiveInteger(default=256)
use_bias: bool = True
weights_initializer: str = schema.InitializerOptions(
default='xavier_uniform')
bias_initializer: str = schema.InitializerOptions(default='zeros')
weights_initializer: Union[str, Dict] = schema.InitializerOrDict(default='xavier_uniform')
bias_initializer: Union[str, Dict] = schema.InitializerOrDict(default='zeros')
norm: Optional[str] = schema.StringOptions(['batch', 'layer'])
norm_params: Optional[dict] = schema.Dict()
activation: str = 'relu'
Expand Down
45 changes: 45 additions & 0 deletions ludwig/utils/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import marshmallow_dataclass
from marshmallow import fields, validate, ValidationError
from torch.nn import init

from ludwig.utils.torch_utils import initializer_registry
from ludwig.modules.reduction_modules import reduce_mode_registry
Expand Down Expand Up @@ -90,6 +91,12 @@ def Embed():
_embed_options = ['add']


def InitializerOrDict(default='xavier_uniform'):
return field(metadata={
'marshmallow_field': InitializerOptionsOrCustomDictField(allow_none=False)
}, default=default)


class EmbedInputFeatureNameField(fields.Field):
def _deserialize(self, value, attr, data, **kwargs):
if value is None:
Expand All @@ -116,6 +123,44 @@ def _jsonschema_type_mapping(self):
]
}

class InitializerOptionsOrCustomDictField(fields.Field):
def _deserialize(self, value, attr, data, **kwargs):
initializers = list(initializer_registry.keys())
if isinstance(value, str):
if value not in initializers:
raise ValidationError(
f"Expected one of: {initializers}, found: {value}"
)
return value

if isinstance(value, dict):
if 'type' not in value:
raise ValidationError(
f"Dict must contain 'type'"
)
if value['type'] not in initializers:
raise ValidationError(
f"Dict expected key 'type' to be one of: {initializers}, found: {value}"
)
return value

raise ValidationError('Field should be str or dict')

def _jsonschema_type_mapping(self):
initializers = list(initializer_registry.keys())
return {
'oneOf': [
{'type': 'string', 'enum': initializers},
{
"type": "object",
"properties": {
"type": { "type": "string", 'enum': initializers },
},
"required": ["type"],
"additionalProperties": True,
},
]
}

def load_config(cls, **kwargs):
schema = marshmallow_dataclass.class_schema(cls)()
Expand Down
18 changes: 17 additions & 1 deletion tests/ludwig/utils/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,30 @@ def test_config_bad_combiner_types_enums():
config['combiner']['weights_initializer'] = 'fail'
with pytest.raises(ValidationError, match=r"'fail' is not of*"):
validate_config(config)

config['combiner']['weights_initializer'] = {}
with pytest.raises(ValidationError, match=r"Failed validating 'type'"):
validate_config(config)
config['combiner']['weights_initializer'] = {'type':'fail'}
with pytest.raises(ValidationError, match=r"'fail' is not one of*"):
validate_config(config)
config['combiner']['weights_initializer'] = {'type':'normal', 'stddev': 0}
validate_config(config)

# Test bias initializer:
del config['combiner']['weights_initializer']
config['combiner']['bias_initializer'] = 'kaiming_uniform'
validate_config(config)
config['combiner']['bias_initializer'] = 'fail'
with pytest.raises(ValidationError, match=r"'fail' is not of*"):
validate_config(config)
config['combiner']['bias_initializer'] = {}
with pytest.raises(ValidationError, match=r"Failed validating 'type'"):
validate_config(config)
config['combiner']['bias_initializer'] = {'type':'fail'}
with pytest.raises(ValidationError, match=r"'fail' is not one of*"):
validate_config(config)
config['combiner']['bias_initializer'] = {'type':'zeros', 'stddev': 0}
validate_config(config)

# Test norm:
del config['combiner']['bias_initializer']
Expand Down

0 comments on commit 4a97138

Please sign in to comment.