In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "microsoft/NatureLM-8x7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

A new version of the following files was downloaded from https://huggingface.co/microsoft/NatureLM-8x7B:
- science_tokens.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/NatureLM-8x7B:
- nlm_tokenizer.py
- science_tokens.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


In [3]:
# As the model may not fit in single GPU, you may need to use model parallel.
# Or load the model in 8-bit (see 8x7b_inst_model.ipynb for more details)


def create_device_map(n_parallel=2, start_device=0):
    n_layers = 32
    layer_per_rank = n_layers // n_parallel
    device_map = {}
    device_map["model.embed_tokens.weight"] = 0 + start_device
    for i in range(n_layers):
        device_idx = i // layer_per_rank + start_device
        device_map[f"model.layers.{i}"] = device_idx

    device_map["model.norm.weight"] = (n_layers - 1) // layer_per_rank + start_device
    device_map["lm_head.weight"] = (n_layers - 1) // layer_per_rank + start_device

    return device_map

In [None]:
with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(model_name)

model = load_checkpoint_and_dispatch(
    model,
    model_name,
    device_map=create_device_map(n_parallel=2, start_device=0),
    no_split_module_classes=["MixtralDecoderLayer"],
    dtype=torch.bfloat16,
    offload_folder=None,
    offload_state_dict=True,
)

In [None]:
tokens = tokenizer("<mol>C", return_tensors="pt")

output = model.generate(
    input_ids=tokens.input_ids, max_new_tokens=100, do_sample=True, temperature=0.7
)

print(tokenizer.decode(output[0]).replace("<m>", ""))