In [2]:
import os


os.environ['CUDA_VISIBLE_DEVICES'] = '1'

device = 'cuda'

In [3]:
from fms.models import get_model
import torch

import axolotl.models.fms_extras.models.paged_llama


torch.set_default_dtype(torch.half)

model = get_model(
    f"paged_llama",
    model_path='/models/Meta-Llama-3-8B-Instruct',
    checkpoint_sharding=None,
    device_type=device,
    source='hf',
    distributed_strategy=None,
    group=None,
)

model

  from .autonotebook import tqdm as notebook_tqdm


PagedLLaMA(
  (headless_model): PagedLLaMAHeadless(
    (shared): WordEmbedding(
      (emb): Embedding(128256, 4096)
      (head): Linear(in_features=4096, out_features=128256, bias=False)
    )
    (layers): ModuleList(
      (0-31): 32 x PagedLLaMABlock(
        (ln): LayerNormParameterized()
        (ff_ln): LayerNormParameterized()
        (attn): PagedMultiHeadAttention(
          (dense): Linear(in_features=4096, out_features=4096, bias=False)
          (qkv_fused): Linear(in_features=4096, out_features=6144, bias=False)
        )
        (ff_sub_layer): GatedLinearUnit(
          (wg1_fused): Linear(in_features=4096, out_features=28672, bias=False)
          (a): SiLU()
          (w2): Linear(in_features=14336, out_features=4096, bias=False)
        )
      )
    )
    (dec_norm): LayerNormParameterized()
  )
  (head): WordEmbedding(
    (emb): Embedding(128256, 4096)
    (head): Linear(in_features=4096, out_features=128256, bias=False)
  )
)

In [4]:
import torch
from fms.utils import generation, tokenizers


tokenizer = tokenizers.get_tokenizer('/models/Meta-Llama-3-8B-Instruct')
model.eval()
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7283b9512440>

In [6]:
from axolotl.models.fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel


speculator = MLPSpeculatorPreTrainedModel.from_pretrained(
    'ibm-fms/llama3-8b-accelerator', device_map=device
).speculator
speculator

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.95s/it]


MLPSpeculator(
  (emb): ModuleList(
    (0-3): 4 x Embedding(128256, 3072)
  )
  (proj): ModuleList(
    (0): Linear(in_features=4096, out_features=3072, bias=False)
    (1-3): 3 x Linear(in_features=3072, out_features=3072, bias=False)
  )
  (head): ModuleList(
    (0-3): 4 x Linear(in_features=3072, out_features=128256, bias=False)
  )
  (ln): ModuleList(
    (0-3): 4 x LayerNormParameterized()
  )
  (activation): GELU(approximate='none')
)

In [7]:
from axolotl.models.fms_extras.utils.cache.paged import PagedKVCacheManager


kv_cache_manager = PagedKVCacheManager(
    model.config.nlayers,
    model.config.nheads,
    model.config.emb_dim,
    kv_heads=model.config.kvheads,
    tensor_parallel_size=1,
    dtype=torch.get_default_dtype(),
    device=device,
)
kv_cache_manager

<axolotl.models.fms_extras.utils.cache.paged.PagedKVCacheManager at 0x7283a7aebb20>

In [8]:
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"

prompt1 = template.format(
    "Provide a list of instructions for preparing chicken soup."
)
prompt1

'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nProvide a list of instructions for preparing chicken soup.\n\n### Response:'

In [9]:
def ids_for_prompt(prompt):
    tokens = tokenizer.tokenize(prompt)
    ids = tokenizer.convert_tokens_to_ids(tokens)
    ids = [tokenizer.bos_token_id] + ids
    ids = torch.tensor(ids, dtype=torch.long, device=device)
    return ids

prompt1 = ids_for_prompt(prompt1)
prompt1

tensor([128000,  39314,    374,    459,   7754,    430,  16964,    264,   3465,
            13,   9842,    264,   2077,    430,  36001,  45695,    279,   1715,
           382,  14711,  30151,    512,  61524,    264,   1160,    315,  11470,
           369,  20646,  16553,  19724,    382,  14711,   6075,     25],
       device='cuda:0')

In [10]:
ids = [prompt1]

In [11]:
cudagraphs = True
max_seq_len = (
    model.config.max_expected_seq_len
    if hasattr(model.config, "max_expected_seq_len")
    else model.config.max_pos
)
max_seq_len


8192

In [13]:
from axolotl.models.fms_extras.utils.generation import paged_generate, speculative_generate


In [14]:
result, n_steps, ttft, generated_token_time_out = paged_generate(
    model,
    ids,
    kv_cache_manager,
    max_new_tokens=100,
    max_seq_len=max_seq_len,
    do_sample=False,
    decode_model=None,
    cudagraphs=cudagraphs,
)

result, n_steps, ttft, generated_token_time_out

(tensor([[128000,  39314,    374,    459,   7754,    430,  16964,    264,   3465,
              13,   9842,    264,   2077,    430,  36001,  45695,    279,   1715,
             382,  14711,  30151,    512,  61524,    264,   1160,    315,  11470,
             369,  20646,  16553,  19724,    382,  14711,   6075,     25,   4815,
            8586,    374,    264,   1160,    315,  11470,    369,  20646,  16553,
           19724,   1473,    334,   8468,    220,     16,     25,  50095,  52275,
           57277,      9,    220,     16,  31123,  17685,   1752,     11,   6930,
            1752,  16553,  17659,    477,  60611,    198,      9,    220,     19,
           26446,  16553,  45993,    198,      9,    220,     16,   3544,  38427,
              11,  38525,    198,      9,    220,     18,  85388,  31735,     11,
           94927,    198,      9,    220,     17,  62517,     11,  83612,    323,
           38525,    198,      9,    220,     17,  70121,  55972,     82,     11,
           38525

In [15]:
result, n_steps, ttft, generated_token_time_out, n_accepts = speculative_generate(
    model,
    ids,
    speculator,
    kv_cache_manager,
    new_tokens=100,
    max_seq_len=max_seq_len,
    decode_model=None,
    # todo: we can only reduce-overhead for now when batch size is 1
    flattening=True,
    cudagraphs=False,
    threshes=[4,3,2,2],
)

result, n_steps, ttft, generated_token_time_out, n_accepts

([tensor([128000,  39314,    374,    459,   7754,    430,  16964,    264,   3465,
              13,   9842,    264,   2077,    430,  36001,  45695,    279,   1715,
             382,  14711,  30151,    512,  61524,    264,   1160,    315,  11470,
             369,  20646,  16553,  19724,    382,  14711,   6075,     25,   4815,
            8586,    374,    264,   1160,    315,  11470,    369,  20646,  16553,
           19724,   1473,    334,   8468,    220,     16,     25,  50095,  52275,
           57277,      9,    220,     16,  31123,  17685,   1752,     11,   6930,
            1752,  16553,  17659,    477,  60611,    198,      9,    220,     19,
           26446,  16553,  45993,    198,      9,    220,     16,   3544,  38427,
              11,  38525,    198,      9,    220,     18,  85388,  31735,     11,
           94927,    198,      9,    220,     17,  62517,     11,  83612,    323,
           38525,    198,      9,    220,     17,  70121,  55972,     82,     11,
           38525