Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pack/Unpack to Different Dtypes for FSDP #14

Closed
KeremTurgutlu opened this issue Feb 23, 2024 · 13 comments
Closed

Pack/Unpack to Different Dtypes for FSDP #14

KeremTurgutlu opened this issue Feb 23, 2024 · 13 comments

Comments

@KeremTurgutlu
Copy link
Contributor

KeremTurgutlu commented Feb 23, 2024

Thanks for this great package!

I've noticed that the existing packing/unpacking only works with certain dtypes. FSDP requires all the params to be float dtype for sharding, so are there any plans to extend them to different dtypes?

@mobicham
Copy link
Collaborator

Hi @KeremTurgutlu, thank you for your question!
Can you please provide an example of the desired behavior vs. what is not working right now ?
Thanks!

@KeremTurgutlu
Copy link
Contributor Author

KeremTurgutlu commented Feb 23, 2024

You can find the related PR for bnb here: TimDettmers/bitsandbytes#970

The primary blocker for FSDP QLoRA finetuning is the quantized storage type of uint8. FSDP can only shard float data types.

Also, BnB packing/unpack is dtype agnostic I think it is due to the way they implemented packing/unpacking logic in their kernels, for example:

import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit
import torch

W = torch.randn(128,128)

W.dtype

param = Params4bit(W, quant_storage=torch.uint8)
w = param.data.contiguous().cuda(0)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=param.blocksize, compress_statistics=param.compress_statistics,
                                                    quant_type=param.quant_type, quant_storage=param.quant_storage)
w_dq_uint8 = bnb.functional.dequantize_4bit(w, quant_state)

param = Params4bit(W, quant_storage=torch.bfloat16)
w = param.data.contiguous().cuda(0)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=param.blocksize, compress_statistics=param.compress_statistics,
                                                    quant_type=param.quant_type, quant_storage=param.quant_storage)
w_dq_bf16 = bnb.functional.dequantize_4bit(w, quant_state)

assert torch.equal(w_dq_uint8,w_dq_bf16)

I am not an expert in quantization but this might be the difference: https://github.com/TimDettmers/bitsandbytes/blob/e820409c095ea7cbb5ce156992307b84352cbf90/csrc/kernels.cu#L827

@mobicham
Copy link
Collaborator

Understood! I think we can use int32 instead of uint8 for packing, then cast to float32. I did a few checks regarding the range of the values and it looks like it could work.

Since 3-bit already uses int32, we can quickly test it:

https://github.com/mobiusml/hqq/blob/master/hqq/core/bitpack.py#L68 replace these with:

@staticmethod
def pack_3bit_32(W_q_in):
	W_q = torch.zeros([int(10*np.ceil(W_q_in.shape[0]/10.)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32)
	W_q[:len(W_q_in)] = W_q_in
	_step = int(len(W_q)/10)
	W_q = (W_q[:_step] << 27) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 21) | (W_q[_step*3:_step*4] << 18) | (W_q[_step*4:_step*5] << 15) | (W_q[_step*5:_step*6] << 12) | (W_q[_step*6:_step*7] << 9) | (W_q[7*_step:_step*8] << 6) | (W_q[_step*8:_step*9] << 3) | (W_q[_step*9:]) 
	return W_q.to(torch.float32) #Now the stored quantized weights are float32

@staticmethod
def unpack_3bit_32(W_q):
	W_q                  = W_q.to(torch.int32)
	_step                = W_q.shape[0]
	tmp                  = torch.empty([10*_step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
	tmp[:_step]          = ((W_q & 0b00111000000000000000000000000000) >> 27)
	tmp[1*_step:2*_step] = ((W_q & 0b00000111000000000000000000000000) >> 24)
	tmp[2*_step:3*_step] = ((W_q & 0b00000000111000000000000000000000) >> 21)
	tmp[3*_step:4*_step] = ((W_q & 0b00000000000111000000000000000000) >> 18)
	tmp[4*_step:5*_step] = ((W_q & 0b00000000000000111000000000000000) >> 15)
	tmp[5*_step:6*_step] = ((W_q & 0b00000000000000000111000000000000) >> 12)
	tmp[6*_step:7*_step] = ((W_q & 0b00000000000000000000111000000000) >> 9)
	tmp[7*_step:8*_step] = ((W_q & 0b00000000000000000000000111000000) >> 6)
	tmp[8*_step:9*_step] = ((W_q & 0b00000000000000000000000000111000) >> 3)
	tmp[9*_step:]        = ((W_q & 0b00000000000000000000000000000111))
	return tmp

Example:

W_q = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [7, 0, 1, 2], [7, 0, 5, 3], [1, 0, 1, 0]]) # 3-bit range [0,7]
W_q_packed = pack_3bit_32(W_q)
Out[48]: tensor([[8.3657e+07, 2.1810e+08, 3.7254e+08, 5.2507e+08]])
W_q_unpacked = unpack_3bit_32(W_q_packed)
assert torch.mean(1.*(W_q_unpacked[:len(W_q)]==W_q)  ) == 1. # Works!

So then we would need to implement bitpacking with int32 instead of uint8 for 8/4/2/1 bits + add their corresponding CUDA kernels

@KeremTurgutlu
Copy link
Contributor Author

KeremTurgutlu commented Feb 23, 2024

Great, thanks! I will give it a try and try to update the CUDA kernels as well. torch.float16 and torch.bfloat16 will still remain a problem though if we try to cast int32, but probably FSDP will be fine to keep mixed weights. User will need to make sure they don't cast the quantized weight but only cast say LoRA weights, or perhaps handle it with torch.distributed.fsdp.MixedPrecision(...,_module_classes_to_ignore=[HQQLinear])

@mobicham
Copy link
Collaborator

Yes, once the quantized weights are packed to float32 they should not be touched.

Let me know if the trick for the 3-bit case works with FSDP. Meanwhile, I will run more tests to make sure the casting doesn't create any issues.

If it works, I can add new bitpacking with the same logic + their CUDA kernels. The bitpacking logic with int32 will be quite different but not too difficult to add.

@KeremTurgutlu
Copy link
Contributor Author

KeremTurgutlu commented Feb 23, 2024

import numpy as np
def pack_4bit_32(W_q):
    _step = int(len(W_q)/8)
    W_q = (W_q[:_step] << 28) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 20) | (W_q[_step*3:_step*4] << 16) | (W_q[_step*4:_step*5] << 12) | (W_q[_step*5:_step*6] << 8) | (W_q[_step*6:_step*7] << 4) | (W_q[_step*7:]) 
    return W_q

def unpack_4bit_32_cat(W_q):
    return torch.cat([((W_q & 0b11110000000000000000000000000000) >> 28),
                      ((W_q & 0b00001111000000000000000000000000) >> 24),
                      ((W_q & 0b00000000111100000000000000000000) >> 20),
                      ((W_q & 0b00000000000011110000000000000000) >> 16),
                      ((W_q & 0b00000000000000001111000000000000) >> 12),
                      ((W_q & 0b00000000000000000000111100000000) >> 8),
                      ((W_q & 0b00000000000000000000000011110000) >> 4),
                      ((W_q & 0b00000000000000000000000000001111))], axis=0)


#A bit faster than _cat version
def unpack_4bit_32(W_q):
    _step                = W_q.shape[0]
    tmp                  = torch.empty([8*_step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
    tmp[:_step]          = ((W_q & 0b11110000000000000000000000000000) >> 28)
    tmp[1*_step:2*_step] = ((W_q & 0b00001111000000000000000000000000) >> 24)
    tmp[2*_step:3*_step] = ((W_q & 0b00000000111100000000000000000000) >> 20)
    tmp[3*_step:4*_step] = ((W_q & 0b00000000000011110000000000000000) >> 16)
    tmp[4*_step:5*_step] = ((W_q & 0b00000000000000001111000000000000) >> 12)
    tmp[5*_step:6*_step] = ((W_q & 0b00000000000000000000111100000000) >> 8)
    tmp[6*_step:7*_step] = ((W_q & 0b00000000000000000000000011110000) >> 4)
    tmp[7*_step:8*_step] = (W_q & 0b00000000000000000000000000001111)    
    return tmp
    
for i in range(100):
    x = torch.randint(0,2**4,(32,)); #random 4-bit quantized 1D weights
    pack_4bit_32(x)
    assert torch.equal(unpack_4bit_32_cat(pack_4bit_32(x)), x)
    
for i in range(100):
    x = torch.randint(0,2**4,(32,32)); #random 4-bit quantized 2D weights
    assert torch.equal(unpack_4bit_32(pack_4bit_32(x)), x)

Using int32 does seem to work! I will test FSDP training with 3bit HQQ Lora now.

Edit: Actually casting back and forth breaks it.

# Assertion error
for i in range(100):
    print(i)
    x = torch.randint(0,2**4,(32,)); #random 4-bit quantized 1D weights
    pack_4bit_32(x)
    assert torch.equal(unpack_4bit_32_cat(pack_4bit_32(x).float().to(torch.int32)), x)

The problem is that the first bit is used for sign. Maybe instead of packing 8 groups of 4 bit values, we can do 7 groups, it will cause 28/32=87.5% bit utilization, in other words 12.5% extra memory usage, but won't have issue with back and forth casting.

@mobicham
Copy link
Collaborator

So actually it only works with rows up to 8, the 2 rows after are not properly decoded. I will play with some toy examples and see if there's another way

@mobicham
Copy link
Collaborator

The problem seems the casting from float32 > int32

In [105]: torch.tensor([365215118], dtype=torch.int32).to(torch.float32).to(torch.int32)
Out[105]: tensor([365215104], dtype=torch.int32)

@KeremTurgutlu
Copy link
Contributor Author

KeremTurgutlu commented Feb 23, 2024

This might be helpful? https://discuss.pytorch.org/t/bitwise-operation-on-float-tensor/170863

torch.tensor([365215118], dtype=torch.int32).view(torch.float32).view(torch.int32)```

We can probably use `view()` without needing to change it `pack_4bit_u8(x).view(torch.bfloat16), pack_4bit_u8(x).view(torch.float16)`

@mobicham
Copy link
Collaborator

Actually that was very helpful, now it works:

import torch
import numpy as np

#Float32 cast
def pack_3bit_32(W_q_in):
	W_q = torch.zeros([int(10*np.ceil(W_q_in.shape[0]/10.)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.float32).view(torch.int32)
	W_q[:len(W_q_in)] = W_q_in
	_step = int(len(W_q)/10)
	W_q = (W_q[:_step] << 27) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 21) | (W_q[_step*3:_step*4] << 18) | (W_q[_step*4:_step*5] << 15) | (W_q[_step*5:_step*6] << 12) | (W_q[_step*6:_step*7] << 9) | (W_q[7*_step:_step*8] << 6) | (W_q[_step*8:_step*9] << 3) | (W_q[_step*9:]) 
	return W_q.view(torch.float32)


def unpack_3bit_32(W_q):
	W_q                  = (W_q).view(torch.int32)
	_step                = W_q.shape[0]
	tmp                  = torch.empty([10*_step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
	tmp[:_step]          = ((W_q & 0b00111000000000000000000000000000) >> 27)
	tmp[1*_step:2*_step] = ((W_q & 0b00000111000000000000000000000000) >> 24)
	tmp[2*_step:3*_step] = ((W_q & 0b00000000111000000000000000000000) >> 21)
	tmp[3*_step:4*_step] = ((W_q & 0b00000000000111000000000000000000) >> 18)
	tmp[4*_step:5*_step] = ((W_q & 0b00000000000000111000000000000000) >> 15)
	tmp[5*_step:6*_step] = ((W_q & 0b00000000000000000111000000000000) >> 12)
	tmp[6*_step:7*_step] = ((W_q & 0b00000000000000000000111000000000) >> 9)
	tmp[7*_step:8*_step] = ((W_q & 0b00000000000000000000000111000000) >> 6)
	tmp[8*_step:9*_step] = ((W_q & 0b00000000000000000000000000111000) >> 3)
	tmp[9*_step:]        = ((W_q & 0b00000000000000000000000000000111))
	return tmp
#######################################################################################################
W_q  = torch.randint(low=0, high=(2**3), size=(4096, 4096))

W_q_packed   = pack_3bit_32(W_q)
W_q_unpacked = unpack_3bit_32(W_q_packed)

assert torch.mean(1.*(W_q_unpacked[:len(W_q)]==W_q)  ) == 1. # Works!

I am not sure how that would work on the CUDA side

@KeremTurgutlu
Copy link
Contributor Author

KeremTurgutlu commented Feb 23, 2024

Probably we can keep the existing packing logic and CUDA kernels in uint8 and just use torch.view based on an argument that user specifies, such as quant_storage and during quantization and dequantization it will view back and forth. This will make sure parameters will be in the desired dtype I suppose.

@mobicham
Copy link
Collaborator

Yeah do W_q.view(torch.int32) on the Pytorch side and the rest on CUDA.

Only works with float32, didn't work with float16/int16

@mobicham
Copy link
Collaborator

By the way, for your test with 3-bit, you can still use the existing CUDA kernels for 3-bit, you just need to replace this:
https://github.com/mobiusml/hqq/blob/master/hqq/core/quantize.py#L446
with:
W_est = self.dequantize_Wq_aten(W_q.view(torch.int32), meta)

Hope it works !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants