This approach is currently not implemented by default and can improve generation quality. Source: https://twitter.com/Tim_Dettmers/status/1695352747694919931

In [1]:
!pip install trl transformers accelerate git+https://github.com/huggingface/peft.git -Uqqq
!pip install datasets bitsandbytes einops -Uqqq

In [3]:
import torch
import peft
import json
import shutil
from peft.utils import _get_submodules
import os
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig
import gc
import copy

Taken from: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930

In [4]:
def dequantize_model(model, tokenizer, save_to='./dequantized_model', dtype=torch.float16, device="cuda"):
    """
    'model': the peftmodel you loaded with qlora.
    'tokenizer': the model's corresponding hf's tokenizer.
    'to': directory to save the dequantized model
    'dtype': dtype that the model was trained using
    'device': device to load the model to
    """

    if save_to != None:
        # Delete the model object if it exists
        if os.path.exists(save_to):
            shutil.rmtree(save_to)

        os.makedirs(to, exist_ok=True)

    cls = bnb.nn.Linear4bit

    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, cls):
                print(f"Dequantizing `{name}`...")
                quant_state = copy.deepcopy(module.weight.quant_state)

                quant_state[2] = dtype

                weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)

                new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None, dtype=dtype)
                new_module.weight = torch.nn.Parameter(weights)
                new_module.to(device=device, dtype=dtype)

                parent, target, target_name = _get_submodules(model, name)
                setattr(parent, target_name, new_module)

        # a hack, setting this to avoid hf's saving error because hf
        # itself does not support saving a model that is registered to be loaded in 4bit.
        model.is_loaded_in_4bit = False

        if save_to != None:
            print("Saving dequantized model...")
            model.save_pretrained(save_to)
            tokenizer.save_pretrained(save_to)
            config_data = json.loads(open(os.path.join(save_to, 'config.json'), 'r').read())
            config_data.pop("quantization_config", None)
            config_data.pop("pretraining_tp", None)
            with open(os.path.join(to, 'config.json'), 'w') as config:
                config.write(json.dumps(config_data, indent=2))
        
        return model

In [5]:
model_path = "abhishek/llama-2-7b-hf-small-shards"
adapter_path = "nihiluis/finadv100"

quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )

try:
    print(f"Starting to load the model {model_path} into memory")

    model = LlamaForCausalLM.from_pretrained(
        model_path,
        load_in_4bit=True,
        torch_dtype=torch.float16,
        quantization_config=quantization_config,
        device_map={"": 0}
    )
    print(model)
    tok = LlamaTokenizer.from_pretrained(model_path)
    model = dequantize_model(model, tok, save_to=None)
    print(model)
    model = PeftModel.from_pretrained(model = model, model_id = adapter_path)
    print(model)
    model = model.merge_and_unload()
    print(model)
    
    print(f"Successfully loaded the model {model_path} into memory")

except Exception as e:
    print(f"An error occurred: {e}")

    # Delete the model object if it exists
    if 'model' in locals():
        del model

    # Clear the GPU cache
    torch.cuda.empty_cache()

    # Run the garbage collection
    gc.collect()

    print("Model, GPU cache, and garbage have been cleared.")

Starting to load the model abhishek/llama-2-7b-hf-small-shards into memory


Downloading (…)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/10 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00010.bin:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00010.bin:   0%|          | 0.00/2.88G [00:00<?, ?B/s]

Downloading (…)l-00003-of-00010.bin:   0%|          | 0.00/2.99G [00:00<?, ?B/s]

Downloading (…)l-00004-of-00010.bin:   0%|          | 0.00/2.86G [00:00<?, ?B/s]

Downloading (…)l-00005-of-00010.bin:   0%|          | 0.00/2.88G [00:00<?, ?B/s]

Downloading (…)l-00006-of-00010.bin:   0%|          | 0.00/2.97G [00:00<?, ?B/s]

Downloading (…)l-00007-of-00010.bin:   0%|          | 0.00/2.88G [00:00<?, ?B/s]

Downloading (…)l-00008-of-00010.bin:   0%|          | 0.00/2.99G [00:00<?, ?B/s]

Downloading (…)l-00009-of-00010.bin:   0%|          | 0.00/2.86G [00:00<?, ?B/s]

Downloading (…)l-00010-of-00010.bin:   0%|          | 0.00/705M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/174 [00:00<?, ?B/s]



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )


Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

Dequantizing `model.layers.0.self_attn.q_proj`...
Dequantizing `model.layers.0.self_attn.k_proj`...
Dequantizing `model.layers.0.self_attn.v_proj`...
Dequantizing `model.layers.0.self_attn.o_proj`...
Dequantizing `model.layers.0.mlp.gate_proj`...
Dequantizing `model.layers.0.mlp.up_proj`...
Dequantizing `model.layers.0.mlp.down_proj`...
Dequantizing `model.layers.1.self_attn.q_proj`...
Dequantizing `model.layers.1.self_attn.k_proj`...
Dequantizing `model.layers.1.self_attn.v_proj`...
Dequantizing `model.layers.1.self_attn.o_proj`...
Dequantizing `model.layers.1.mlp.gate_proj`...
Dequantizing `model.layers.1.mlp.up_proj`...
Dequantizing `model.layers.1.mlp.down_proj`...
Dequantizing `model.layers.2.self_attn.q_proj`...
Dequantizing `model.layers.2.self_attn.k_proj`...
Dequantizing `model.layers.2.self_attn.v_proj`...
Dequantizing `model.layers.2.self_attn.o_proj`...
Dequantizing `model.layers.2.mlp.gate_proj`...
Dequantizing `model.layers.2.mlp.up_proj`...
Dequantizing `model.layers.2.m