In [1]:
import os
import time

import torch
from tensorizer import TensorDeserializer, TensorSerializer, stream_io
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.llama import LlamaConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM

from vllm.model_executor.parallel_utils.parallel_state import \
    initialize_model_parallel


  from .autonotebook import tqdm as notebook_tqdm
2024-01-10 20:30:49,498	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Create a mapping between a model reference to the corresponding `vllm` module. This is *not* an exhaustive mapping, and just a demonstration of tensorizing a `vllm` model.

In [2]:
mistral = ["mistralai/Mistral-7B-v0.1", 
           "mistralai/Mistral-7B-Instruct-v0.1",
           MistralForCausalLM]

llama = ["meta-llama/Llama-2-13b-hf", 
"meta-llama/Llama-2-70b-hf", 
"openlm-research/open_llama_13b", 
"lmsys/vicuna-13b-v1.3", 
"young-geng/koala",
LlamaForCausalLM]

gptj = ["EleutherAI/gpt-j-6b", "nomic-ai/gpt4all-j", GPTJForCausalLM]

gptneox = ["EleutherAI/gpt-neox-20b", 
"EleutherAI/pythia-12b", 
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", 
"databricks/dolly-v2-12b", 
"stabilityai/stablelm-tuned-alpha-7b", 
GPTNeoXForCausalLM]

modelref_to_module = {}
for lists in [mistral, llama, gptj, gptneox]:
    module = lists[-1]
    for ref in lists[0:-1]:
        modelref_to_module.update({ref:module})

In [3]:
!pip install accelerate --quiet

MODEL_REF = "EleutherAI/gpt-j-6b"
BUCKET = "tensorized-ssteel"
dtype = None

if MODEL_REF not in modelref_to_module.keys():
    raise KeyError(f"{MODEL_REF} not in supported model list given.")

MODEL_NAME = MODEL_REF.split("/")[1]
S3_URI = f"s3://{BUCKET}/{MODEL_NAME}-vllm.tensors"

MODEL_PATH = f"/tmp/{MODEL_NAME}"

if dtype:
    torch.set_default_dtype(dtype)

def make_model_contiguous(model):
    # Ensure tensors are saved in memory contiguously
    for param in model.parameters():
        param.data = param.data.contiguous()

def serialize():
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_REF, device_map="auto", torch_dtype="auto"
    )
    
    make_model_contiguous(model)
    model.save_pretrained(MODEL_PATH)
    config = AutoConfig.from_pretrained(MODEL_REF)
    model = modelref_to_module[MODEL_REF](config)
    model.load_weights(MODEL_PATH)

    stream = stream_io.open_stream(S3_URI, "wb")
    serializer = TensorSerializer(stream)
    print(f"Writing serialized tensors for model {MODEL_REF} using module {model} to {S3_URI}.") 
    serializer.write_module(model)
    serializer.close()
    print("Serialization complete. It is recommended you restart the kernel if deserializing.")


def deserialize():
    config = AutoConfig.from_pretrained(MODEL_REF)

    with no_init_or_tensor():
        model = modelref_to_module[MODEL_REF](config)
        if dtype: 
            model.to(dtype)

    before_mem = get_mem_usage()
    # Lazy load the tensors from S3 into the model.
    start = time.time()
    stream = stream_io.open_stream(S3_URI, "rb")
    deserializer = TensorDeserializer(stream, plaid_mode=True)
    deserializer.load_into_module(model)
    end = time.time()

    # Brag about how fast we are.
    total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
    duration = end - start
    per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
    after_mem = get_mem_usage()
    deserializer.close()
    print(f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s")
    print(f"Memory usage before: {before_mem}")
    print(f"Memory usage after: {after_mem}")

    return model


[0m

In [4]:
print(f"{MODEL_NAME}\n{MODEL_PATH}\n{S3_URI}")

gpt-j-6b
/tmp/gpt-j-6b
s3://tensorized-ssteel/gpt-j-6b-vllm.tensors


In [5]:
os.environ["MASTER_ADDR"] = "0.0.0.0"
os.environ["MASTER_PORT"] = "8080"

torch.distributed.init_process_group(world_size=1, rank=0)
initialize_model_parallel()


In [6]:
print("Serializing...")
serialize()

Serializing...
Writing serialized tensors for model EleutherAI/gpt-j-6b using module GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): VocabParallelEmbedding()
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (qkv_proj): QKVParallelLinear()
          (out_proj): RowParallelLinear()
          (rotary_emb): RotaryEmbedding()
          (attn): PagedAttention()
        )
        (mlp): GPTJMLP(
          (fc_in): ColumnParallelLinear()
          (fc_out): RowParallelLinear()
          (act): NewGELU()
        )
      )
    )
    (ln_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): ParallelLMHead()
  (sampler): Sampler()
) to s3://tensorized-ssteel/gpt-j-6b-vllm.tensors.
Serialization complete. It is recommended you restart the kernel if deserializing.


In [6]:
print("Deserializing...")
model = deserialize()

Deserializing...
Deserialized 24.2 GB in 292.81s, 82.7 MB/s
Memory usage before: CPU: (maxrss: 4,081MiB F: 95,826MiB) GPU: (U: 867MiB F: 47,809MiB T: 48,676MiB) TORCH: (R: 2MiB/2MiB, A: 0MiB/1MiB)
Memory usage after: CPU: (maxrss: 6,114MiB F: 95,813MiB) GPU: (U: 23,971MiB F: 24,705MiB T: 48,676MiB) TORCH: (R: 23,086MiB/23,086MiB, A: 23,083MiB/23,084MiB)


Tensorizing a model in bfloat16 forces you to specify `dtype="bfloat16"` in the api_server.py CLI, as it gets defaults from Llama 2's HF config, which goes for float16