Skip to content

Commit

Permalink
HQQ v.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Nov 20, 2023
1 parent a310f10 commit 0e7b6ca
Show file tree
Hide file tree
Showing 13 changed files with 962 additions and 0 deletions.
Empty file added code/hqq/__init__.py
Empty file.
Empty file added code/hqq/models/__init__.py
Empty file.
187 changes: 187 additions & 0 deletions code/hqq/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
#####################################################

import torch
import gc, os
from tqdm import tqdm
from abc import abstractmethod

from huggingface_hub import snapshot_download
from ..quantize.core import HQQLinear

def cleanup():
torch.cuda.empty_cache()
gc.collect()

def fix_path(path):
if(len(path)==0): return path
return path + '/' if (path[-1]!='/') else path

#Base patching class. Patching defines how nn.Linear and other layers are replaced via a patching function.
class BasePatch():
#Override these OR override the main patch_model() function
############################################
#This method iterates through layers of the model that are NOT nn.Linear and processes them via new_nodule = patch_fct(module, params)
@classmethod
def patch_nonlinearlayers(cls, model, patch_fct, verbose=True):
pass

#This method iterates through layers of the model that are nn.Linear and processes them via new_nodule = patch_fct(module, params)
@classmethod
def patch_linearlayers(cls, base_model, patch_fct, patch_params, verbose=True):
pass
############################################
#These tags are used to specfiy parameters of the patching in patch_linearlayers()
@classmethod
def get_linear_tags(cls):
return []

#Autmatically name modules. This is very important to save/load the weights
@classmethod
def autoname_modules(cls, model):
for name, module in model.named_modules():
module.name = name

#Freeze all layers
@classmethod
def freeze_model(cls, model):
for param in model.parameters():
param.requires_grad = False
try:
for param in model.model.parameters():
param.requires_grad = False
except:
pass

#Main patching function
@classmethod
def patch_model(cls, model, patch_nonlinear_fct, patch_linear_fct, patch_params, verbose=True):
model.eval()
cls.freeze_model(model)
cls.patch_nonlinearlayers(model, patch_nonlinear_fct, verbose=verbose)
cls.patch_linearlayers(model, patch_linear_fct, patch_params, verbose=verbose)
cls.autoname_modules(model)
cleanup()


class BaseHQQModel:
#Override these
############################################
#This method creates and empty model based on the specfied architecture
@abstractmethod
def create_model(self):
pass

#This method saves the model architecture only without inculding the weights (for example to a config.json)
@abstractmethod
def cache_model(cls, model, save_dir):
pass
############################################

@classmethod
def get_config_file(cls, save_dir):
return fix_path(save_dir) + 'config.json'

@classmethod
def get_weight_file(cls, save_dir):
return fix_path(save_dir) + 'qmodel.pt'

@classmethod
def get_ignore_layers(cls, model):
return []

@classmethod
def save_weights(cls, weights, save_dir):
torch.save(weights, cls.get_weight_file(save_dir))

@classmethod
def load_weights(cls, save_dir):
return torch.load(cls.get_weight_file(save_dir))

@classmethod
def quantize_model(cls, model, quant_config):
#Use the same quantization config for all linear layers. Use None to skip quantizing a specfic layer.
patch_params = dict([(k, quant_config) for k in cls.get_linear_tags()])

#We replace the nn.Linear layers with HQQLinear
def _patch_linear(linear_layer, quant_config):
return HQQLinear(linear_layer, quant_config) if (quant_config is not None) else linear_layer

cls.patch_model(model, lambda l: l.half().cuda(), _patch_linear, patch_params)

@classmethod
def save_quantized(cls, model, save_dir, verbose=False):
#Save config
cls.cache_model(model, save_dir)

#Save weights
weights = {}
ignore_keys = cls.get_ignore_layers(model)
for name, module in model.named_modules():
if(name in ignore_keys): continue
try:
state_dict = module.state_dict()
if(len(state_dict)>0):
weights[name] = dict(state_dict)
except Exception as error:
if(verbose):
print('Skipping', name)

cls.save_weights(weights, save_dir)

@classmethod
def try_snapshot_download(cls, save_dir_or_hub, cache_dir=''):
save_dir = fix_path(cache_dir) + save_dir_or_hub

if(os.path.exists(save_dir)==False):
save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir)
save_dir = fix_path(save_dir)

#Check
if(os.path.exists(cls.get_weight_file(save_dir))==False):
raise Exception('Weight file missing. Check your cache directory.')
if(os.path.exists(cls.get_config_file(save_dir))==False):
raise Exception('Config file missing. Check your cache directory.')

return save_dir

@classmethod
def from_quantized(cls, save_dir_or_hub, cache_dir=''):
#Get directory path
save_dir = cls.try_snapshot_download(save_dir_or_hub, cache_dir)

#Load model from config
model = cls.create_model(save_dir)

#Name the layers
cls.autoname_modules(model)

#Load weights
try:
weights = cls.load_weights(save_dir)
except Exception as error:
print("Failed to load the weights", error)
return

#load_state_dict() doesn't work with modules initialized with init_empty_weights(), so we need to do this manually
@torch.no_grad()
def _load_module(module, params=None):
if(module.name not in weights):
return module.half().cuda()

state_dict = weights[module.name]
if(('W_q' in state_dict) and ('meta' in state_dict)):
module = HQQLinear(linear_layer=None, quant_config=None)
module.load_state_dict(state_dict)
else:
for key in state_dict:
setattr(module, key, torch.nn.Parameter(state_dict[key]))

return module

cls.patch_model(model, _load_module, _load_module, dict([(k, None) for k in cls.get_linear_tags()]))

return model



65 changes: 65 additions & 0 deletions code/hqq/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from .base import *

from tqdm import tqdm
from accelerate import init_empty_weights
import transformers

#Patch LLama functions
class LLamaPatch(BasePatch):
#These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers
@classmethod
def get_linear_tags(cls):
return ['self_attn.q_proj',
'self_attn.k_proj',
'self_attn.v_proj',
'self_attn.o_proj',
'mlp.gate_proj' ,
'mlp.up_proj' ,
'mlp.down_proj' ]

@classmethod
def patch_nonlinearlayers(cls, model, patch_fct, verbose=True):
base_model = model.model
model.lm_head = patch_fct(model.lm_head)
base_model.embed_tokens = patch_fct(base_model.embed_tokens)
base_model.norm = patch_fct(base_model.norm)

layers = base_model.layers
for i in tqdm(range(len(base_model.layers)), disable=not verbose):
layers[i].self_attn.rotary_emb = patch_fct(layers[i].self_attn.rotary_emb)
layers[i].mlp.act_fn = patch_fct(layers[i].mlp.act_fn)
layers[i].input_layernorm = patch_fct(layers[i].input_layernorm)
layers[i].post_attention_layernorm = patch_fct(layers[i].post_attention_layernorm)

@classmethod
def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True):
base_model = model.model
layers = base_model.layers
for i in tqdm(range(len(layers)), disable=not verbose):
layers[i].self_attn.q_proj = patch_fct(layers[i].self_attn.q_proj, patch_params['self_attn.q_proj'])
layers[i].self_attn.k_proj = patch_fct(layers[i].self_attn.k_proj, patch_params['self_attn.k_proj'])
layers[i].self_attn.v_proj = patch_fct(layers[i].self_attn.v_proj, patch_params['self_attn.v_proj'])
layers[i].self_attn.o_proj = patch_fct(layers[i].self_attn.o_proj, patch_params['self_attn.o_proj'])
layers[i].mlp.gate_proj = patch_fct(layers[i].mlp.gate_proj, patch_params['mlp.gate_proj'])
layers[i].mlp.up_proj = patch_fct(layers[i].mlp.up_proj, patch_params['mlp.up_proj'])
layers[i].mlp.down_proj = patch_fct(layers[i].mlp.down_proj, patch_params['mlp.down_proj'])


class LlamaHQQ(LLamaPatch, BaseHQQModel):
#layers to ignore when saving the weights
@classmethod
def get_ignore_layers(cls, model):
return ['', 'model', 'model.layers'] + ['model.layers.' + str(i) for i in range(len(model.model.layers))]

#Save model architecture
@classmethod
def cache_model(cls, model, save_dir):
model.config.save_pretrained(save_dir)

#Create empty model
@classmethod
def create_model(cls, save_dir):
config = transformers.AutoConfig.from_pretrained(cls.get_config_file(save_dir))
with init_empty_weights():
model = transformers.LlamaForCausalLM(config)
return model
65 changes: 65 additions & 0 deletions code/hqq/models/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from .base import *

from tqdm import tqdm
import timm, json, os

#Patch ViT functions
class VitPatch(BasePatch):
#These tags are used to specify the parameters of each layer type. For example, if you want to give different quantization parameters to different layers
@classmethod
def get_linear_tags(cls):
return ['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']

@classmethod
def freeze_model(cls, model):
for param in model.parameters():
param.requires_grad = False

@classmethod
def patch_nonlinearlayers(cls, model, patch_fct, verbose=True):
model.patch_embed.proj = patch_fct(model.patch_embed.proj)
model.patch_embed.norm = patch_fct(model.patch_embed.norm)
model.norm_pre = patch_fct(model.norm_pre)
model.norm = patch_fct(model.norm)
model.head = patch_fct(model.head)
model.cls_token.data = patch_fct(model.cls_token.data)
model.pos_embed.data = patch_fct(model.pos_embed.data)

for i in tqdm(range(len(model.blocks)), disable=not verbose):
model.blocks[i].norm1 = patch_fct(model.blocks[i].norm1)
model.blocks[i].norm2 = patch_fct(model.blocks[i].norm2)

@classmethod
def patch_linearlayers(cls, model, patch_fct, patch_params, verbose=True):
for i in tqdm(range(len(model.blocks))):
model.blocks[i].attn.qkv = patch_fct(model.blocks[i].attn.qkv, patch_params['attn.qkv'])
model.blocks[i].attn.proj = patch_fct(model.blocks[i].attn.proj, patch_params['attn.proj'])
model.blocks[i].mlp.fc1 = patch_fct(model.blocks[i].mlp.fc1, patch_params['mlp.fc1'])
model.blocks[i].mlp.fc2 = patch_fct(model.blocks[i].mlp.fc2, patch_params['mlp.fc2'])


class ViTHQQ(VitPatch, BaseHQQModel):
#layers to ignore when saving the weights
@classmethod
def get_ignore_layers(cls, model):
return ['', 'model', 'model.blocks'] + ['model.blocks.' + str(i) for i in range(len(model.blocks))]

#Save model architecture
@classmethod
def cache_model(cls, model, save_dir):
try:
os.makedirs(save_dir, exist_ok=True)
except Exception as error:
print(error)

with open(cls.get_config_file(save_dir), "w") as file:
json.dump(model.default_cfg, file)

#Create empty model
@classmethod
def create_model(cls, save_dir):
with open(cls.get_config_file(save_dir), "r") as file:
config = json.load(file)

model = timm.create_model(config['architecture'] + '.' + config['tag'], pretrained=True)
return model
Empty file added code/hqq/quantize/__init__.py
Empty file.
Loading

0 comments on commit 0e7b6ca

Please sign in to comment.