Skip to content

Commit

Permalink
Small edits
Browse files Browse the repository at this point in the history
  • Loading branch information
asdataminer committed Jun 15, 2023
1 parent 1662827 commit 4860265
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ludwig/schema/decoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
NUMBER,
SET,
TIMESERIES,
VECTOR
VECTOR,
)
from ludwig.schema import common_fields
from ludwig.schema import utils as schema_utils
Expand Down
4 changes: 2 additions & 2 deletions ludwig/schema/encoders/text_encoders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import MODEL_ECD, MODEL_GBM, TEXT
from ludwig.constants import MODEL_ECD, MODEL_GBM, MODEL_RWD, TEXT
from ludwig.error import ConfigValidationError
from ludwig.schema import utils as schema_utils
from ludwig.schema.encoders.sequence_encoders import SequenceEncoderConfig
Expand Down Expand Up @@ -599,7 +599,7 @@ def module_name():


@DeveloperAPI
@register_encoder_config("bert", TEXT)
@register_encoder_config("bert", TEXT, model_types=[MODEL_RWD])
@ludwig_dataclass
class BERTConfig(HFEncoderConfig):
"""This dataclass configures the schema used for an BERT encoder."""
Expand Down
2 changes: 1 addition & 1 deletion ludwig/schema/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class RWDTextInputFeatureConfig(TextInputFeatureConfig):
encoder: BaseEncoderConfig = EncoderDataclassField(
MODEL_RWD,
feature_type=TEXT,
default="parallel_cnn",
default="bert",
)


Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_rlhf_reward_model_trainer(tmpdir):
"transcript_column": transcript_column,
}
}
config["model_type"] = "reward_model"
config["model_type"] = "rwd"

# Train Ludwig model with the dataset
ludwig_model = LudwigModel(config, backend=backend)
Expand Down

0 comments on commit 4860265

Please sign in to comment.