-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for generic data sets to SliceGPT pass #1145
Conversation
dataloader = data_config.to_data_container().create_dataloader(data_root) | ||
dataset = [ | ||
{ | ||
"input_ids": data[0]["input_ids"].squeeze(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the squeeze to remove the batch dimension so that you can apply the config.calibration_batch_size
later? Does this mean the data config must use batch size 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the batch size has to be 1. However, even with batch size of 1, the generated output needs to be squeezed to drop the extra dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! that makes sense. the default_dataloader from data config always inserts a batch dimension even if it is 1.
You could try this dataloader
Olive/olive/data/component/dataloader.py
Line 22 in 40845a3
def no_auto_batch_dataloader(dataset, **kwargs): |
Implementation of SliceGPT supported only a handful of specific datasets. Widen the support for any generic dataset via the data_config configuration.
] | ||
|
||
torch.manual_seed(config.seed) | ||
sampler = SubsetRandomSampler(torch.randperm(len(dataset))[: config.calibration_nsamples]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks okay for now. But we can probably add an option in the dataloader section of dataconfig in the future for random sampling so that this extra work of rewrapping the data with a dataloader is not needed.
Will also remove the potential confusion between batch_size
, max_samples
in the data config and calibration_batch_size
, calibration_nsamples
here.
Add support for generic data sets to SliceGPT pass
Implementation of SliceGPT supported only a handful of specific datasets. Widen the support for any generic dataset via the data_config configuration.
Checklist before requesting a review
lintrunner -a
(Optional) Issue link