Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic data sets to SliceGPT pass #1145

Merged
merged 1 commit into from
May 9, 2024
Merged

Conversation

shaahji
Copy link
Contributor

@shaahji shaahji commented May 9, 2024

Add support for generic data sets to SliceGPT pass

Implementation of SliceGPT supported only a handful of specific datasets. Widen the support for any generic dataset via the data_config configuration.

Checklist before requesting a review

  • Add unit tests for this change.
  • Make sure all tests can pass.
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? If yes, give a description of this change to be included in the release notes.
  • Is this PR including examples changes? If yes, please remember to update example documentation in a follow-up PR.

(Optional) Issue link

dataloader = data_config.to_data_container().create_dataloader(data_root)
dataset = [
{
"input_ids": data[0]["input_ids"].squeeze(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the squeeze to remove the batch dimension so that you can apply the config.calibration_batch_size later? Does this mean the data config must use batch size 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the batch size has to be 1. However, even with batch size of 1, the generated output needs to be squeezed to drop the extra dimension.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! that makes sense. the default_dataloader from data config always inserts a batch dimension even if it is 1.

You could try this dataloader

def no_auto_batch_dataloader(dataset, **kwargs):
if you don't want the dataloader to batch the data.

Implementation of SliceGPT supported only a handful of specific
datasets. Widen the support for any generic dataset via the
data_config configuration.
]

torch.manual_seed(config.seed)
sampler = SubsetRandomSampler(torch.randperm(len(dataset))[: config.calibration_nsamples])
Copy link
Contributor

@jambayk jambayk May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks okay for now. But we can probably add an option in the dataloader section of dataconfig in the future for random sampling so that this extra work of rewrapping the data with a dataloader is not needed.

Will also remove the potential confusion between batch_size, max_samples in the data config and calibration_batch_size, calibration_nsamples here.

@shaahji shaahji merged commit 2b361b5 into main May 9, 2024
35 checks passed
@shaahji shaahji deleted the shaahji/slicegpt branch May 9, 2024 22:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants