-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multiclass support #29
base: main
Are you sure you want to change the base?
Conversation
… into add-multiclass-support
@@ -1,5 +1,11 @@ | |||
# Version History | |||
|
|||
## 1.1.0 Feb 26, 2024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This date should be consistent with the release date, so I mark it with a comment as a reminder :)
Summary of issues we have discussed today on the call:
|
Thanks for the review :)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I really like how it looks now! 🙂
I agree that the loss and optimizer feature is not part of the multiclass support. I have just created the related issue - #43.
@@ -75,7 +78,7 @@ def __init__( | |||
self._params.update(additional_params) | |||
|
|||
self.device = device | |||
self.collate_fn = BertClassifierWithPooling.collate_fn_pooled_tokens | |||
self.collate_fn = BertBaseWithPooling.collate_fn_pooled_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.collate_fn_pooled_tokens
would also work and use the method overwritten in a derived class (if done so).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this is is a static method, it seems natural to me to use with the class and not an instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may be right in general, but this is a specific case, where the collate_fn_pooled_tokens
method is called in an instance method of the same class.
If you create a new class that derives from the BertBaseWithPooling
class and overwrite the collate_fn_pooled_tokens
method, you would probably expect the overwritten method to be used in the __init__
method and that's why I think it would be safer to call it as self.collate_fn_pooled_tokens
.
However, in this very specific case, this scenario may be unlikely, so I just want to present you my point of view, but I leave the decision to you.
I would also love to see the same basic tests for the Regressor classes. |
from pathlib import Path | ||
from shutil import rmtree | ||
|
||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy
is not used
from belt_nlp.bert_regressor_truncated import BertRegressorTruncated | ||
MODEL_PARAMS = {"batch_size": 1, "learning_rate": 5e-5, "epochs": 1, "device": "cpu"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a blank line in between.
|
||
x_test = ["nice"] * 99 + ["bad"] * 1 | ||
|
||
model.fit(x_train, y_train) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, y_train
expects list[bool]
, so the type hint should be updated.
assert scores.shape == torch.Size([2, 1]) | ||
|
||
|
||
def test_regression_order(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that for both classifiers and regressors prediction_order
is the appropriate name.
predictions
is used in the docstring anyway ;)
No description provided.