In [1]:
from torch import cuda, bfloat16
import transformers 
import torch

  from .autonotebook import tqdm as notebook_tqdm


#### Load the model

In [2]:
model_id = 'meta-llama/Llama-2-70b-chat-hf'

#Empty cuda cache
torch.cuda.empty_cache()
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_flash_sdp(False)

device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'


# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library 35 gb ram 
bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=bfloat16
)

# begin initializing HF items, need auth token for these
hf_auth = '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'


model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map='auto',
    token=hf_auth
)
print(f"Model loaded on {device}")
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-chat-hf", token=hf_auth)

Loading checkpoint shards: 100%|██████████| 15/15 [04:28<00:00, 17.93s/it]
You shouldn't move a model when it is dispatched on multiple devices.


ValueError: `.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.

##### Run inference

In [None]:
generate_text = transformers.pipeline(
    model=model, tokenizer=tokenizer,
    return_full_text=True,  # langchain expects the full text
    task='text-generation',
    # we pass model parameters here too
    temperature=0.1,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    max_new_tokens=512,  # mex number of tokens to generate in the output
    repetition_penalty=1.1  # without this output begins repeating
)

In [None]:
res = generate_text("What is the difference between Lasso regression and Ridge regression?")
print(res[0]["generated_text"])