Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HQQ FSDP #17

Merged
merged 8 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
160 changes: 160 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
65 changes: 49 additions & 16 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ class Quantizer:
'2bit_u8':BitPack.unpack_2bit_u8,
'1bit_u8':BitPack.unpack_1bit_u8}

unpack_view_dtype = {'8bit_u8':torch.uint8,
'4bit_u8':torch.uint8,
'3bit_32':torch.int32,
'2bit_u8':torch.uint8,
'1bit_u8':torch.uint8}

@classmethod
def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True):
def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False):
assert nbits in Quantizer.SUPPORTED_BITS, "nbits=" + str(nbits) + " not supported."
assert axis in [0, 1], "axis should be either 0 or 1"
if(group_size is not None):
Expand Down Expand Up @@ -70,10 +76,13 @@ def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=Fa

#Store meta-data (we invert the scale for dequantization)
meta = {'nbits':nbits, 'group_size':group_size, 'shape':shape, 'scale':1./scale, 'zero':zero, 'axis':axis, 'packing':Quantizer.bit_to_packing[nbits]}

meta['unpack_view_dtype'] = Quantizer.unpack_view_dtype[meta['packing']]

#Pack bits
meta['view_as_float'] = view_as_float
if(bitpack):
W_q = Quantizer.pack[meta['packing']](W_q)
if view_as_float: W_q = W_q.view(torch.float32 if compute_dtype is None else compute_dtype) # store quantized weights as compute_dtype
else:
W_q = W_q.to(tensor.dtype)
meta['packing'] = None
Expand All @@ -89,6 +98,7 @@ def quantize(cls, tensor, nbits=4, channel_wise=True, group_size=64, optimize=Fa
def dequantize(cls, W_q, meta):
compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16
if(meta['packing']):
if meta['view_as_float']: W_q = W_q.view(meta['unpack_view_dtype'])
W_r = Quantizer.unpack[meta['packing']](W_q).to(compute_dtype)
if((meta['group_size'] is not None) and (meta['nbits']==3)):
W_r = W_r[:meta['group_size']] if (meta['axis']==0) else W_r[:,:meta['group_size']]
Expand Down Expand Up @@ -248,7 +258,7 @@ def backward(ctx, grad_output):
class HQQLinear(torch.nn.Module):
backend = HQQBackend.PYTORCH #Default

def __init__(self, linear_layer, quant_config, del_orig=True, compute_dtype=torch.float16, device_n=0):
def __init__(self, linear_layer, quant_config, del_orig=True, compute_dtype=torch.float16, device_n=0, initialize=True):
super().__init__()
self.ready = False
self.in_gpu = False
Expand All @@ -257,17 +267,23 @@ def __init__(self, linear_layer, quant_config, del_orig=True, compute_dtype=torc
self.device_n = device_n
self.compute_dtype = compute_dtype
self.quant_config = copy.deepcopy(quant_config)
self.del_orig = del_orig
self.offload_meta = self.quant_config.pop('offload_meta') if (self.quant_config is not None) else None

self.set_backend(HQQLinear.backend) #Default backend

if(linear_layer is not None):
self.bias = None if (linear_layer.bias==None) else linear_layer.bias.to(self.compute_dtype).cuda()
self.quantize(linear_layer.weight.data, **self.quant_config)
self.linear_layer = linear_layer

if(del_orig): del linear_layer
torch.cuda.empty_cache()
if(initialize): self.initialize()

def initialize(self):
if(self.linear_layer is not None):
self.quantize(self.linear_layer.weight.data, **self.quant_config)
self.bias = None if (self.linear_layer.bias==None) else self.linear_layer.bias.to(self.compute_dtype).cuda()

if(self.del_orig): del self.linear_layer
torch.cuda.empty_cache()

#Set backends
@classmethod
def set_backend(cls, backend: HQQBackend):
Expand Down Expand Up @@ -312,13 +328,30 @@ def to(self, *args, **kwargs):
def half(self, *args, **kwargs):
return self

def state_dict(self):
def state_dict(self, *args, **kwargs):
return {'W_q':self.W_q, 'meta':self.meta, 'bias':self.bias}

def load_state_dict(self, state_dict):
self.W_q = state_dict['W_q']
self.meta = state_dict['meta']
self.bias = state_dict['bias'] if ('bias' in state_dict) else None

#Float view settings
if('unpack_view_dtype' not in self.meta):
self.meta['unpack_view_dtype'] = Quantizer.unpack_view_dtype[self.meta['packing']]

if('view_as_float' not in self.meta):
self.meta['view_as_float'] = False

if('meta_scale' in self.meta):
if('view_as_float' not in self.meta['meta_scale']):
self.meta['meta_scale']['view_as_float'] = False

if('meta_zero' in self.meta):
if('view_as_float' not in self.meta['meta_zero']):
self.meta['meta_zero']['view_as_float'] = False

#Check GPU
self.in_gpu = self.W_q.device.type == 'cuda'
if(self.in_gpu):
if('zero' in self.meta):
Expand All @@ -340,15 +373,15 @@ def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params
self.in_features, self.out_features = W.t().shape

#Quantize
W_q , meta = Quantizer.quantize(W, **weight_quant_params)
W_q , meta = Quantizer.quantize(W, compute_dtype=self.compute_dtype, **weight_quant_params)
meta.update({'quant_scale':quant_scale, 'quant_zero':quant_zero})

if(meta['quant_zero']):
meta['zero_q'], meta['meta_zero'] = Quantizer.quantize(meta['zero'], **zero_quant_params); del meta['zero']
meta['zero_q'], meta['meta_zero'] = Quantizer.quantize(meta['zero'], view_as_float=False, **zero_quant_params); del meta['zero']
meta['meta_zero']['compute_dtype'] = self.compute_dtype

if(meta['quant_scale']):
meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale']
meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], view_as_float=False, **scale_quant_params); del meta['scale']
meta['meta_scale']['compute_dtype'] = self.compute_dtype

self.W_q = W_q
Expand Down Expand Up @@ -414,6 +447,7 @@ def forward_pytorch_compile(self, x):
#Requires building the aten backend
@torch.jit.ignore
def dequantize_Wq_aten(self, W_q, meta):
if meta['view_as_float']: W_q = W_q.view(meta['unpack_view_dtype'])
return hqq_aten.dequantize(W_q, meta['scale'], meta['zero'], meta['shape'], meta['group_size'] if (meta['group_size']) else -1, meta['nbits'], meta['axis'], meta['packing'])

def dequantize_aten(self):
Expand Down Expand Up @@ -444,7 +478,7 @@ def dequantize_aten(self):
meta['zero'] = self.dequantize_Wq_aten(meta['zero_q'], meta['meta_zero']); del_keys.append('zero')
else:
meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero')

W_est = self.dequantize_Wq_aten(W_q, meta)

#Cleanup
Expand Down Expand Up @@ -503,11 +537,11 @@ def forward_aten_backprop(self, x):
# return hqq_aten.forward_with_quant(*args)


def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False):
def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False, view_as_float=False):
assert nbits in Quantizer.SUPPORTED_BITS, "nbits value not supported. Check Quantizer.SUPPORTED_BITS."
if(group_size is not None):
assert is_divisible(group_size, 8), "Invalid group_size param: the value should be a multiple of 8."
weight_quant_params = {'nbits':nbits,'channel_wise':True, 'group_size':group_size, 'optimize':True, 'round_zero':True if nbits==4 else False}
weight_quant_params = {'nbits':nbits,'channel_wise':True, 'group_size':group_size, 'optimize':True, 'round_zero':True if nbits==4 else False, 'view_as_float':view_as_float}

if(offload_meta):
if((quant_scale!=quant_zero)):
Expand All @@ -521,7 +555,6 @@ def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=F
scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None
zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None


return {'weight_quant_params':weight_quant_params, 'scale_quant_params':scale_quant_params, 'zero_quant_params':zero_quant_params, 'offload_meta':offload_meta}

#Alias: follow similar Auto-GPTQ naming
Expand Down