Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support causalm finetune #80

Merged
merged 17 commits into from
Aug 2, 2023

Conversation

gkumbhat
Copy link
Collaborator

@gkumbhat gkumbhat commented Jul 13, 2023

Closes #77

Changes

  • Add support for causal-lm fine-tuning
  • Move trainer and training argument configuration to respective resource folder to allow easy selection and configuration in tuning module.
  • Fix unit test issue where having cuda device would make accelerate put calculations on GPU while rest of the data is on CPU thus raising error. This was fixed by adding set_cpu_device fixture which changes the cuda environment variable and patches is_available function in torch.cuda

Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
@gkumbhat gkumbhat force-pushed the add_support_causalm_finetune branch from d934455 to 70dfa5d Compare July 14, 2023 21:43
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
@gkumbhat gkumbhat force-pushed the add_support_causalm_finetune branch from f359909 to 0c2df95 Compare July 30, 2023 19:19
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
@gkumbhat gkumbhat marked this pull request as ready for review July 31, 2023 14:07
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gkumbhat, looks good! Just a few things.

Also a bit of a side note, do you think we should remove support (for now) for HFAutoSequenceClassifier ? Seems like it's effectively unusable between trainer changes & tokenizer builder stuff, it's kind of confusing to have it there when we don't enable it for anything

@@ -81,6 +75,7 @@ def train(
lr: float = 2e-5,
# Directory where model predictions and checkpoints will be written
checkpoint_dir: str = "/tmp",
**training_arguments,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better! Can you link the trainer args in the docstring through?

@@ -81,6 +75,7 @@ def train(
lr: float = 2e-5,
# Directory where model predictions and checkpoints will be written
checkpoint_dir: str = "/tmp",
**training_arguments,
):
"""
# FIXME: Below is currently configured for Seq2Seq only
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. good catch! Will remove this

"<NLP39984681E>",
NotImplementedError(
f"Generation on {type(self.model)} not support \
currently! Please try saving and running this model in TGIS."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof. Does exporting via the trainer save API + reloading give you a transformer model back? I wonder if it would be better to have the first inference call export and reload with a warning until we find something better / implement a causal LM trainer doing something similar. Slow feels better than completely broken here IMO.

Or, is there any way we can cast to the seq2seq trainer and leverage the generate API for that? I guess that probably doesn't handle shifting etc the same way...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think converting the seq2seq could land with weird mismatch issues.

Saving and reloading is certainly an option. It would simplify this block of run function entirely. But could be more inefficient, since the model is already on appropriate devices at this point, so loading them again, we would loose the distribution, which is mainly what I was trying to persist here.

But certainly, not having a solution of causal lm would not be great.

device = PeftPromptTuning._get_device(device)
inputs = {k: v.to(device) for k, v in tok_tensors.items()}

inputs = {k: v.to(self.model.device) for k, v in tok_tensors.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @rawkintrevo is actually making this change in a separate PR (it's this issue #3). Can we put it back as part of this PR and use his when it's ready instead? Since this PR is primarily targeting fine tuning anyway

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah true. I had to make this change to make some tests pass 😄 but yes, can change it back.

"device_placement": True,
}

accelerator = Accelerator(**accelerator_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why build a separate dict here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was playing with some optional parameter regarding cpu=True.. But that didn't work well, so removed that.. So this is kinda left over from that.. Will switch it back to direct arguments instead of separate dict.

@@ -32,6 +32,20 @@
SEQ2SEQ_LM_MODEL = os.path.join(TINY_MODELS_DIR, "T5ForConditionalGeneration")


@pytest.fixture()
def set_cpu_device(request):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - thanks for adding this

2. compute_metrics
3. callbacks
4. preprocess_logits_for_metrics
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same questions about documenting the kwargs here in the docstring (at least the nonexpanded ones). I assume the other one probably needs it also

Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
…oading the model

Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
@gkumbhat gkumbhat force-pushed the add_support_causalm_finetune branch from 09f77dc to d3d962c Compare August 1, 2023 22:27
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Some small typos and stuff, but LGTM

caikit_nlp/modules/text_generation/fine_tuning.py Outdated Show resolved Hide resolved
# eval_steps=1,
# load_best_model_at_end
**training_arguments,
**dtype_based_params,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a nice good first issue in the future to cleanly make sure there aren't collisions in these expanded dicts, but for now we can leave it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

caikit_nlp/resources/pretrained_model/base.py Outdated Show resolved Hide resolved
gkumbhat and others added 2 commits August 2, 2023 17:54
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com>
Signed-off-by: Gaurav Kumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
@gkumbhat gkumbhat force-pushed the add_support_causalm_finetune branch from 2d5bd16 to 664a3d5 Compare August 2, 2023 22:54
@gkumbhat gkumbhat merged commit b5d29aa into caikit:main Aug 2, 2023
4 checks passed
@gkumbhat gkumbhat deleted the add_support_causalm_finetune branch August 2, 2023 23:05
gkumbhat added a commit to gkumbhat/caikit-nlp that referenced this pull request Aug 24, 2023
Add support causalm finetune
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for causal lm to fine-tuning
2 participants