-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support causalm finetune #80
Conversation
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
d934455
to
70dfa5d
Compare
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
f359909
to
0c2df95
Compare
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gkumbhat, looks good! Just a few things.
Also a bit of a side note, do you think we should remove support (for now) for HFAutoSequenceClassifier
? Seems like it's effectively unusable between trainer changes & tokenizer builder stuff, it's kind of confusing to have it there when we don't enable it for anything
@@ -81,6 +75,7 @@ def train( | |||
lr: float = 2e-5, | |||
# Directory where model predictions and checkpoints will be written | |||
checkpoint_dir: str = "/tmp", | |||
**training_arguments, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is better! Can you link the trainer args in the docstring through?
@@ -81,6 +75,7 @@ def train( | |||
lr: float = 2e-5, | |||
# Directory where model predictions and checkpoints will be written | |||
checkpoint_dir: str = "/tmp", | |||
**training_arguments, | |||
): | |||
""" | |||
# FIXME: Below is currently configured for Seq2Seq only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. good catch! Will remove this
"<NLP39984681E>", | ||
NotImplementedError( | ||
f"Generation on {type(self.model)} not support \ | ||
currently! Please try saving and running this model in TGIS." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oof. Does exporting via the trainer save API + reloading give you a transformer model back? I wonder if it would be better to have the first inference call export and reload with a warning until we find something better / implement a causal LM trainer doing something similar. Slow feels better than completely broken here IMO.
Or, is there any way we can cast to the seq2seq trainer and leverage the generate API for that? I guess that probably doesn't handle shifting etc the same way...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I think converting the seq2seq could land with weird mismatch issues.
Saving and reloading is certainly an option. It would simplify this block of run
function entirely. But could be more inefficient, since the model is already on appropriate devices at this point, so loading them again, we would loose the distribution, which is mainly what I was trying to persist here.
But certainly, not having a solution of causal lm would not be great.
device = PeftPromptTuning._get_device(device) | ||
inputs = {k: v.to(device) for k, v in tok_tensors.items()} | ||
|
||
inputs = {k: v.to(self.model.device) for k, v in tok_tensors.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @rawkintrevo is actually making this change in a separate PR (it's this issue #3). Can we put it back as part of this PR and use his when it's ready instead? Since this PR is primarily targeting fine tuning anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah true. I had to make this change to make some tests pass 😄 but yes, can change it back.
"device_placement": True, | ||
} | ||
|
||
accelerator = Accelerator(**accelerator_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why build a separate dict here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was playing with some optional parameter regarding cpu=True
.. But that didn't work well, so removed that.. So this is kinda left over from that.. Will switch it back to direct arguments instead of separate dict.
@@ -32,6 +32,20 @@ | |||
SEQ2SEQ_LM_MODEL = os.path.join(TINY_MODELS_DIR, "T5ForConditionalGeneration") | |||
|
|||
|
|||
@pytest.fixture() | |||
def set_cpu_device(request): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice - thanks for adding this
2. compute_metrics | ||
3. callbacks | ||
4. preprocess_logits_for_metrics | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same questions about documenting the kwargs here in the docstring (at least the nonexpanded ones). I assume the other one probably needs it also
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
…oading the model Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
09f77dc
to
d3d962c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! Some small typos and stuff, but LGTM
# eval_steps=1, | ||
# load_best_model_at_end | ||
**training_arguments, | ||
**dtype_based_params, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be a nice good first issue in the future to cleanly make sure there aren't collisions in these expanded dicts, but for now we can leave it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com> Signed-off-by: Gaurav Kumbhat <kumbhat.gaurav@gmail.com> Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
2d5bd16
to
664a3d5
Compare
Add support causalm finetune Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Closes #77
Changes
set_cpu_device
fixture which changes the cuda environment variable and patchesis_available
function intorch.cuda