Skip to content

Commit

Permalink
int: Rename original combiner_registry to `combiner_config_registry…
Browse files Browse the repository at this point in the history
…`, update decorator name (#3516)
  • Loading branch information
ksbrar authored and arnavgarg1 committed Aug 11, 2023
1 parent eb57534 commit 20f34e9
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 27 deletions.
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("comparator")
@register_combiner_config("comparator")
@ludwig_dataclass
class ComparatorCombinerConfig(BaseCombinerConfig):
"""Parameters for comparator combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("concat")
@register_combiner_config("concat")
@ludwig_dataclass
class ConcatCombinerConfig(BaseCombinerConfig):
"""Parameters for concat combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/project_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("project_aggregate")
@register_combiner_config("project_aggregate")
@ludwig_dataclass
class ProjectAggregateCombinerConfig(BaseCombinerConfig):
type: str = schema_utils.ProtectedString(
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.sequence_concat import MAIN_SEQUENCE_FEATURE_DESCRIPTION
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.encoders.base import BaseEncoderConfig
from ludwig.schema.encoders.utils import EncoderDataclassField
from ludwig.schema.metadata import COMBINER_METADATA
Expand All @@ -19,7 +19,7 @@


@DeveloperAPI
@register_combiner("sequence")
@register_combiner_config("sequence")
@ludwig_dataclass
class SequenceCombinerConfig(BaseCombinerConfig):
"""Parameters for sequence combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/sequence_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass

Expand All @@ -19,7 +19,7 @@


@DeveloperAPI
@register_combiner("sequence_concat")
@register_combiner_config("sequence_concat")
@ludwig_dataclass
class SequenceConcatCombinerConfig(BaseCombinerConfig):
"""Parameters for sequence concat combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.common_transformer_options import CommonTransformerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("tabtransformer")
@register_combiner_config("tabtransformer")
@ludwig_dataclass
class TabTransformerCombinerConfig(BaseCombinerConfig, CommonTransformerConfig):
"""Parameters for tab transformer combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ludwig.api_annotations import DeveloperAPI
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("tabnet")
@register_combiner_config("tabnet")
@ludwig_dataclass
class TabNetCombinerConfig(BaseCombinerConfig):
"""Parameters for tabnet combiner."""
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/combiners/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.common_transformer_options import CommonTransformerConfig
from ludwig.schema.combiners.utils import register_combiner
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.metadata import COMBINER_METADATA
from ludwig.schema.utils import ludwig_dataclass


@DeveloperAPI
@register_combiner("transformer")
@register_combiner_config("transformer")
@ludwig_dataclass
class TransformerCombinerConfig(BaseCombinerConfig, CommonTransformerConfig):
"""Parameters for transformer combiner."""
Expand Down
18 changes: 9 additions & 9 deletions ludwig/schema/combiners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@
DEFAULT_VALUE = "concat"
DESCRIPTION = "Select the combiner type."

combiner_registry = Registry[Type[BaseCombinerConfig]]()
combiner_config_registry = Registry[Type[BaseCombinerConfig]]()


@DeveloperAPI
def register_combiner(name: str):
def register_combiner_config(name: str):
def wrap(cls: Type[BaseCombinerConfig]):
combiner_registry[name] = cls
combiner_config_registry[name] = cls
return cls

return wrap


@DeveloperAPI
def get_combiner_registry():
return combiner_registry
return combiner_config_registry


@DeveloperAPI
def get_combiner_jsonschema():
"""Returns a JSON schema structured to only require a `type` key and then conditionally apply a corresponding
combiner's field constraints."""
combiner_types = sorted(list(combiner_registry.keys()))
combiner_types = sorted(list(combiner_config_registry.keys()))
parameter_metadata = convert_metadata_to_json(
ParameterMetadata.from_dict(
{
Expand Down Expand Up @@ -72,17 +72,17 @@ def get_combiner_descriptions():
Returns:
dict: A dictionary of combiner descriptions.
"""
return {k: convert_metadata_to_json(v[TYPE]) for k, v in COMBINER_METADATA.items() if k in combiner_registry}
return {k: convert_metadata_to_json(v[TYPE]) for k, v in COMBINER_METADATA.items() if k in combiner_config_registry}


@DeveloperAPI
def get_combiner_conds() -> List[Dict[str, Any]]:
"""Returns a list of if-then JSON clauses for each combiner type in `combiner_registry` and its properties'
constraints."""
combiner_types = sorted(list(combiner_registry.keys()))
combiner_types = sorted(list(combiner_config_registry.keys()))
conds = []
for combiner_type in combiner_types:
combiner_cls = combiner_registry[combiner_type]
combiner_cls = combiner_config_registry[combiner_type]
schema_cls = combiner_cls
combiner_schema = schema_utils.unload_jsonschema_from_marshmallow_class(schema_cls)
combiner_props = combiner_schema["properties"]
Expand All @@ -97,7 +97,7 @@ def __init__(self):
# For registration of all combiners
import ludwig.combiners.combiners # noqa

super().__init__(registry=combiner_registry, default_value=DEFAULT_VALUE, description=DESCRIPTION)
super().__init__(registry=combiner_config_registry, default_value=DEFAULT_VALUE, description=DESCRIPTION)

def get_schema_from_registry(self, key: str) -> Type[schema_utils.BaseMarshmallowConfig]:
return self.registry[key]
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_custom_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ludwig.modules.metric_modules import LossMetric, register_metric
from ludwig.schema import utils as schema_utils
from ludwig.schema.combiners.base import BaseCombinerConfig
from ludwig.schema.combiners.utils import register_combiner as register_combiner_schema
from ludwig.schema.combiners.utils import register_combiner_config
from ludwig.schema.decoders.base import BaseDecoderConfig
from ludwig.schema.decoders.utils import register_decoder_config
from ludwig.schema.encoders.base import BaseEncoderConfig
Expand Down Expand Up @@ -55,7 +55,7 @@ class CustomLossConfig(BaseLossConfig):
type: str = "custom_loss"


@register_combiner_schema("custom_combiner")
@register_combiner_config("custom_combiner")
@dataclass
class CustomTestCombinerConfig(BaseCombinerConfig):
type: str = "custom_combiner"
Expand Down

0 comments on commit 20f34e9

Please sign in to comment.