-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: transform onehot
encoder outputs to float32 tensor
#3242
Conversation
def test_onehot_category_encoder(): | ||
config = { | ||
"defaults": {"category": {"encoder": {"type": "onehot"}}}, | ||
"input_features": [ | ||
{"name": "MSSubClass", "type": "category"}, | ||
{"name": "MSZoning", "type": "category"}, | ||
{"name": "Street", "type": "category"}, | ||
{"name": "Neighborhood", "type": "category"}, | ||
], | ||
"model_type": "ecd", | ||
"output_features": [{"name": "SalePrice", "type": "number"}], | ||
"trainer": {"train_steps": 1}, | ||
"combiner": {"type": "concat", "num_fc_layers": 2}, | ||
} | ||
ModelConfig.from_dict(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abidwael does just doing .from_dict()
repro the error? I would image this would require at least 1-2 steps of training for the error in your PR comment to be reproed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, it doesn't. Will remove this test as it will be captured in #2991
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would probably add a test here as well since these will run on every commit, whereas the tests in #2991 will run on merge to master? or is that not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will run on every commit
When we specify a
onehot
encoder for category features and have dense layers down the line, we get the following errorThis PR makes sure to convert the outputs of the
onehot
encoder tofloat32