 ### Loading large models on multi GPU with accelerate (up to 30b on 2 GPU)

In [1]:
import torch
import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from baukit import nethook
import re
import time
import os
import gc

weights_dir = "/share/projects/engine/weights/"

# all weights must be sharded in .bin
MODEL_PATH = os.path.join(weights_dir, "llama-hf/llama-30b")


## for weights > 30b
if (int(re.match(".+(\d+)(b|B)", MODEL_PATH).group(1)) > 30):
    assert(torch.cuda.device_count() > 2), torch.cuda.device_count()
else:
    assert(torch.cuda.is_available())

In [2]:
def check_dev(n):
    t = torch.cuda.get_device_properties(n).total_memory
    r = torch.cuda.memory_reserved(n)
    a = torch.cuda.memory_allocated(n)
    f = r-a  # free
    print(f"{a} / {t} used for dev {n}, reserved {r}")
    
def check_devs():
    for i in range(torch.cuda.device_count()):
        check_dev(i)

check_devs()

0 / 51041271808 used for dev 0, reserved 0
0 / 51041271808 used for dev 1, reserved 0


## Loading sequential

In [3]:
NO_SPLIT_CLASSES = ["LlamaDecoderLayer"]
torch_dtype=torch.float16
torch.cuda.empty_cache()

# empty init from config
config = AutoConfig.from_pretrained(MODEL_PATH)
with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config, torch_dtype=torch_dtype)

In [4]:
model # used to get no split blocks

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 6656, padding_idx=0)
    (layers): ModuleList(
      (0-59): 60 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=6656, out_features=6656, bias=False)
          (k_proj): Linear(in_features=6656, out_features=6656, bias=False)
          (v_proj): Linear(in_features=6656, out_features=6656, bias=False)
          (o_proj): Linear(in_features=6656, out_features=6656, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=6656, out_features=17920, bias=False)
          (down_proj): Linear(in_features=17920, out_features=6656, bias=False)
          (up_proj): Linear(in_features=6656, out_features=17920, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [5]:
# loading, must tie weights first
model.tie_weights()
model = load_checkpoint_and_dispatch(
   model, MODEL_PATH, device_map='auto', no_split_module_classes = NO_SPLIT_CLASSES 
)
model.hf_device_map

{'model.embed_tokens': 0,
 'model.layers.0': 0,
 'model.layers.1': 0,
 'model.layers.2': 0,
 'model.layers.3': 0,
 'model.layers.4': 0,
 'model.layers.5': 0,
 'model.layers.6': 0,
 'model.layers.7': 0,
 'model.layers.8': 0,
 'model.layers.9': 0,
 'model.layers.10': 0,
 'model.layers.11': 0,
 'model.layers.12': 0,
 'model.layers.13': 0,
 'model.layers.14': 0,
 'model.layers.15': 0,
 'model.layers.16': 0,
 'model.layers.17': 0,
 'model.layers.18': 0,
 'model.layers.19': 0,
 'model.layers.20': 0,
 'model.layers.21': 0,
 'model.layers.22': 0,
 'model.layers.23': 0,
 'model.layers.24': 0,
 'model.layers.25': 0,
 'model.layers.26': 0,
 'model.layers.27': 0,
 'model.layers.28': 0,
 'model.layers.29': 0,
 'model.layers.30': 1,
 'model.layers.31': 1,
 'model.layers.32': 1,
 'model.layers.33': 1,
 'model.layers.34': 1,
 'model.layers.35': 1,
 'model.layers.36': 1,
 'model.layers.37': 1,
 'model.layers.38': 1,
 'model.layers.39': 1,
 'model.layers.40': 1,
 'model.layers.41': 1,
 'model.layers.42'

In [6]:
check_devs()

32639052800 / 51041271808 used for dev 0, reserved 32830914560
32639066112 / 51041271808 used for dev 1, reserved 32830914560


In [7]:
# quick inference test
# Note: accelerate won't appreciate model.cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer("The inside of a kiwi is the color", return_tensors="pt")
inputs = inputs.to(0)

time.process_time_ns()
output = model.generate(inputs["input_ids"])
print(f"generation time: {time.process_time_ns()}ns")

tokenizer.decode(output[0].tolist())

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


generation time: 48557750483ns


' The inside of a kiwi is the color of a sunset.\nThe inside of a'

In [8]:
# # this doesn't work, need restart kernel to empty mem
# del model
# gc.collect()
# torch.cuda.empty_cache()
# check_devs()

## ModelLoader Class

In [7]:
class ModelLoader:
    '''
    TODO: need to check for non-accelerate models
    '''
    def __init__(self, 
                 MODEL_NAME,
                 MODEL_PATH,
                 NO_SPLIT_CLASSES,
                 torch_dtype=torch.float16) -> None:
        
        self.MODEL_NAME = MODEL_NAME
        
        # empty init
        config = AutoConfig.from_pretrained(MODEL_PATH)
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config, torch_dtype=torch_dtype)
            
        # load weights
        # must tie weights before loading
        model.tie_weights()
        
        self.model = load_checkpoint_and_dispatch(
            model, MODEL_PATH, device_map='auto', 
            no_split_module_classes = NO_SPLIT_CLASSES 
        )
       
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)


        # check load
        print(f"hf_device_map ==> \n{self.model.hf_device_map}")

        nethook.set_requires_grad(False, self.model)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        print(f"{self.MODEL_NAME} ==> devices: {set(self.model.hf_device_map.values())}, memory: {self.model.get_memory_footprint()}" )

        self.layer_names = [
            n
            for n, m in self.model.named_modules()
            if (re.match(r"\w+\.(h|layers)\.\d+$", n))
        ]
        self.num_layers = len(self.layer_names)

In [8]:
model = ModelLoader("llama_7b", MODEL_PATH, ["LlamaDecoderLayer"])

hf_device_map ==> 
{'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0, 'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 1, 'model.layers.19': 1, 'model.layers.20': 1, 'model.layers.21': 1, 'model.layers.22': 1, 'model.layers.23': 1, 'model.layers.24': 1, 'model.layers.25': 1, 'model.layers.26': 1, 'model.layers.27': 1, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.norm': 1, 'lm_head': 1}
llama_7b ==> devices: {0, 1}, memory: 13543948288


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Inference

In [11]:
# quick inference

inputs = model.tokenizer("The inside of a kiwi is the color", return_tensors="pt")
inputs = inputs.to(0)

output = model.model.generate(inputs["input_ids"])

model.tokenizer.decode(output[0].tolist())

' The inside of a kiwi is the color of a pearl.\nThe kiwi'