Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HQQ FSDP #17

Merged
merged 8 commits into from Mar 1, 2024
Merged

HQQ FSDP #17

merged 8 commits into from Mar 1, 2024

Conversation

KeremTurgutlu
Copy link
Contributor

@KeremTurgutlu KeremTurgutlu commented Feb 29, 2024

  • Store quantized weights same as compute dtype for FSDP training.
  • Added tests.
  • Tested 4-bit FSDP llama-7b training and compared it against BNB, shows similar loss plot. HQQ requires 50% less memory compared to BNB! Used alpaca dataset with context length: 512.
Screenshot 2024-02-29 at 9 14 27 AM

Related issue: #14

Note: 1-bit and 8-bit HQQLinear is failing in the test_hqq_linear test. Otherwise they work with Quantizer.quantize and Quantizer.dequantize

@jph00
Copy link

jph00 commented Feb 29, 2024

@mobicham - by way of background, this PR is need to allow HQQ to be used with FSDP for multi-GPU qlora training. We'll be releasing a blog post from Answer.AI soon showing folks how to use this functionality (we've also done it for bitsandbytes BTW). Let us know if you have any questions.

@mobicham
Copy link
Collaborator

mobicham commented Feb 29, 2024

Thank you very much for your contribution @KeremTurgutlu @jph00 ! Using multi-gpu to train quantized models is super valuable!

The PR looks mostly fine, I just have a few things if you don't mind. Fortunately, these are easy fixes. In fact, I already included them here (via gist since the PR is coming from your repo): https://gist.github.com/mobicham/af0b7676c587ff36c0607affc00795eb

Bugs

(B1): Error when loading quantized models via from_quantized - FIXED

The code breaks when trying to load a quantized model via model = HQQModelForCausalLM.from_quantized(model_id). That's because when loading a quantized model, the HQQLinear layer is initialized with linear_layer=None. So I moved the bias copy to the initialize(self) function.
https://gist.github.com/mobicham/af0b7676c587ff36c0607affc00795eb#file-quantize_fsdp_fix-py-L282

(B2): Error when predicting with models loaded via from_quantized - FIXED

The code breaks when trying to do prediction with a saved quantized model because unpack_view_dtype is not defined in the meta-data. I put a check in load_state_dict(), if unpack_view_dtype is not defined in meta, then add it.
https://gist.github.com/mobicham/af0b7676c587ff36c0607affc00795eb#file-quantize_fsdp_fix-py-L340

(B3): view_as_float as an argument of dequantize creates some issues - FIXED

I moved view_as_float to meta because this might create a couple of issues with some settings.
I also added default values for scale/zero when loading an (older) quantized model so it doesn't break.

(B4): torch.compile issues - OPEN

torch.compile fails with view_as_float=True. This means that both model=torch.compile(model) and the HQQLinear.set_backend(HQQBackend.PYTORCH_BACKPROP_COMPILE) backend break.
Luckily, this doesn't happen when it's set to False.
I am not sure why that happens. I tried moving the view call inside @torch.jit.ignore and it still breaks.
Using torch.compile could also make training faster. I will try to take a look at it again.

Suggestions

(S1): Set default as view_as_float=False

Setting view_as_float=False by default in the code and adding it as a quant param in hqq_base_quant_config()
This way, it is using the original int bitpacking by default and it's consistent with the format used by other packages like oobabooga as well as the already published models on Hugging Face. Also, there's the torch.compile issue.
So now you can simply pass it as an argument like this:

quant_config = BaseQuantizeConfig(nbits=4, group_size=64, offload_meta=True, view_as_float=True);
model.quantize_model(quant_config=quant_config, compute_dtype=torch.float16)

(S2): Slower forward pass with view_as_float

There's a slight overhead for the forward pass with the float view (~0.10 sec for Llama2-7B and ~0.25 sec for Llama2-70B on a Titan RTX) + there's the issue of torch.compile breaking. It is possible to revert back to int bitpacking after training for faster inference. This can be done via the patching functions:

def patch_linear_int_bitpacking(layer, patch_params):
	#or if you are using HQQ's peft, it should use layer.linear_layer instead
	if(hasattr(layer, 'W_q')):
		layer.W_q = layer.W_q.view(layer.meta['unpack_view_dtype'])
		layer.meta['view_as_float'] = False
	return layer

model.base_class.patch_linearlayers(model, patch_linear_int_bitpacking, dict([(linear_tag, None) for linear_tag in model.base_class.linear_tags]))

I am not sure where we could put this, maybe in hqq/models/base.py. But no worries, this is not a blocker, we can take a look at it later.

(S3): Scale/Zero still treated as int

I see that view_as_float is set to False for the scale/zero point:
-Does the training work fine when the scale/zero point are quantized?
-Does the training work fine with meta_offloading=True?

If the scale/zero work fine with FSDP even-though they are int-packed when they are quantized, it means that, if we make W_q as a standard tensor instead of nn.Parameter it should work the same way and no need to do the float view. However, I think this would break a few things with the current implementation. I am just curious if it's done that way, maybe it will make training faster and use even less memory?

(S4): Tests

Regarding the tests, I am not sure why test_hqq_linear() would fail for 8-bit but not the rest. Luckily, there are other ways to check that the float view is not causing any issues:

This test checks the view dtype and compares with the original bitpacking logic. This one passes for all settings, which means everything is fine:

def test_floatview_bitpacking(self):
    shapes = [[32, 32], [256, 256], [512, 512], [1024, 1024], [2048, 2048], [4096, 4096], [8192, 8192], [8192, 4096], [8192, 128], [32, 4096]]
    for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:
        for nbits in [8, 4, 3, 2, 1]:
            for shape in shapes:
                bit_pack      = Quantizer.bit_to_packing[nbits]
                view_dtype    = Quantizer.unpack_view_dtype[bit_pack]
                W             = torch.randint(0, 2**nbits, shape, device='cuda').to(view_dtype)
                W_packed_orig = Quantizer.pack[bit_pack](W)
                W_packed_view = W_packed_orig.clone().view(compute_dtype)
                assert W_packed_view.dtype == compute_dtype
                assert torch.abs(W_packed_orig - W_packed_view.view(view_dtype)).max() < 1e-5

The second test I would like to suggest makes sure the forward pass produces the same results as the original bitpacking logic:

def test_hqq_linear_forward(self):
    HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
    context_size = 4096
    batch_size   = 1
    for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:
        for nbits in [4,3,2]:
            for quant_zero in [False, True]:
                for quant_scale in [False, True]:
                    for offload_meta in [True, False]:

                    	#Maybe loop over group_sizes as well [8, 16, 32, 64, 128, 256]
                        quant_config_int   = BaseQuantizeConfig(nbits=nbits, group_size=64, quant_zero=quant_zero, quant_scale=quant_scale, offload_meta=offload_meta, view_as_float=False)
                        quant_config_float = BaseQuantizeConfig(nbits=nbits, group_size=64, quant_zero=quant_zero, quant_scale=quant_scale, offload_meta=offload_meta, view_as_float=True)

                        hqq_linear_int     = HQQLinear(self.m, quant_config_int, compute_dtype=compute_dtype,   del_orig=False)
                        hqq_linear_float   = HQQLinear(self.m, quant_config_float, compute_dtype=compute_dtype, del_orig=False)

                        x = torch.randn([batch_size, context_size, self.m.weight.shape[1]], device='cuda').to(compute_dtype)
                        with torch.no_grad():
                            y_int   = hqq_linear_int.forward(x)
                            y_float = hqq_linear_float.forward(x)

                        assert torch.allclose(y_int, y_float, rtol=1e-5)

I tried with a couple of combinations and it worked fine.

I also run some lm-evaluation-harness benchmarks and the numbers look ok.

(S5): Please provide an example

Can you please provide a full script example in the examples/fsdp folder so it's easy for people to get started? Something like this:
https://github.com/mobiusml/hqq/blob/master/examples/lora/train_hqq_lora_example.py

If you can review the gist I shared and check it works fine with your training code, you can simply do another commit with the changes and we can merge. We can take a look at the torch.compile issue + repacking to int later.

Looking forward to your answers and excited to see the blogpost on Answer.AI!

@KeremTurgutlu
Copy link
Contributor Author

KeremTurgutlu commented Mar 1, 2024

@mobicham thanks a lot for the detailed response and suggestions!

I incorporated the changes and also added new tests. FSDP training runs fine as before. Fixing torch.compile would definitely be great, I agree.

(S3): Scale/Zero still treated as int

FSDP only requires parameters and buffers to be a float type, so tensors in meta can be kept in torch.uint8. If we make W_q a non-parameter then it won't get sharded by FSDP when using FULL_SHARD strategy, which means losing all the memory savings of large model training - the main reason why we use FSDP. So the current approach looks good to me.

(S5): Please provide an example

Training script can be found here: https://github.com/AnswerDotAI/fsdp_qlora/blob/scaling_experiments/train.py. We are still actively working on it but should be finalized by the time we have the blog post ready.

I also tested weights before and after training and made sure only the desired params are updated.

FSDP model saving issue with HQQLinear

HQQLinear's custom state_dict() method is causing issues when saving the model with FSDP. So I added a custom model saving logic which will only save the updated lora weights: https://github.com/AnswerDotAI/fsdp_qlora/blob/6b2e126de531c369392e28916e432dc064159842/train.py#L1021-L1040

@mobicham mobicham merged commit d18c324 into mobiusml:master Mar 1, 2024
@mobicham
Copy link
Collaborator

mobicham commented Mar 1, 2024

Thanks for the update @KeremTurgutlu ! I run more tests and it looks fine to me, so already merged!
I also created a new release https://github.com/mobiusml/hqq/releases/tag/0.1.5, updated the doc and made it available via pip.

Regarding saving FSDP models, LoRa weights should be saved separately anyway, that's how we do it as well https://github.com/mobiusml/hqq/blob/master/hqq/core/peft.py#L316 .

We can create a separate issue for torch.compile.

@KeremTurgutlu
Copy link
Contributor Author

SGTM!

@jph00
Copy link

jph00 commented Mar 6, 2024

@mobicham FYI there's a draft post here that we'll be announcing tomorrow, that includes a discussion of the great results we're seeing with HQQ. Let us know if you see any issues with it: https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html

@mobicham
Copy link
Collaborator

mobicham commented Mar 6, 2024

@jph00 Thank you very much for sharing the draft! It is a very nice read!

The only suggestion on our side would be using the official website for Mobius Labs https://www.mobiuslabs.com/ , the one mentioned in the article (https://mobiusml.github.io/) is very old!

Looking forward to seeing and sharing the final version!

@appoose
Copy link
Contributor

appoose commented Mar 6, 2024

@jph00 This reads really nice. I hope that future significant models will be trained using less powerful GPUs in remote parts of the world, potentially helping to level the playing field in AI. Upon this being live, would it be acceptable to link to and feature this on our blog, https://mobiusml.github.io/blog/, with a note expressing our delight at having contributed to this work?

@jph00
Copy link

jph00 commented Mar 6, 2024 via email

@warner-benjamin
Copy link
Contributor

@mobicham Our HQQ+FSDP changes have been merged in main, so you'll want to update the HQQ ReadMe link from the scaling_experiments branch https://github.com/AnswerDotAI/fsdp_qlora/blob/scaling_experiments/train.py to https://github.com/AnswerDotAI/fsdp_qlora

@jph00
Copy link

jph00 commented Mar 8, 2024

@appoose FYI this is launched now. https://twitter.com/jeremyphoward/status/1765868543235805232

@appoose
Copy link
Contributor

appoose commented Mar 8, 2024

@jph00 It is now linked to our blog at https://mobiusml.github.io/blog/ . Looking forward for future collaborations!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants