In [1]:
import time

import torch
from transformers import AutoTokenizer

# from mamba2 import Mamba2LMHeadModel

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.models.simple_mamba2 import Mamba2LMHeadModel

In [3]:
# model = Mamba2LMHeadModel.from_pretrained("state-spaces/mamba2-1.3b", device=device)
model = Mamba2LMHeadModel.from_pretrained("state-spaces/mamba2-130m", device=device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token_id = tokenizer.eos_token_id



In [5]:
def generate(prompt: str, seed: int = 0, show_perf: bool = True):
    """Generate streaming completion"""
    torch.manual_seed(seed)

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)[0]
    print(prompt, end="")

    start = time.process_time()
    n_generated = 0
    for i, (token_id, _hidden_state) in enumerate(model.generate(input_ids, **generation_config)):
        token = tokenizer.decode([token_id])
        if i == 0:
            now = time.process_time()
            prompt_eval_elapsed, start = now - start, now
        else:
            n_generated += 1
        print(token, end="", flush=True)
    if show_perf:
        elapsed = time.process_time() - start
        print('\n\n---')
        print(f'Prompt eval | tokens: {input_ids.shape[0]} | elapsed: {prompt_eval_elapsed:.2f}s | tok/s: {input_ids.shape[0] / prompt_eval_elapsed:.2f}')
        print(f'Generation | tokens: {n_generated} | elapsed: {elapsed:.2f}s | tok/s: {n_generated / elapsed:.2f}')

In [15]:
generation_config = dict(
    max_new_length=10,
    # temperature=1.0,
    temperature=0.1,
    top_k=30,
    top_p=1.0,
)

In [19]:
generate("What is the twin city of Lyon? It is ")


What is the twin city of Lyon? It is 
the city of Lyon, France. It is

---
Prompt eval | tokens: 11 | elapsed: 1.14s | tok/s: 9.67
Generation | tokens: 9 | elapsed: 0.51s | tok/s: 17.74


In [21]:
generate("The meaning of life is")


The meaning of life is a complex and multifaceted concept. It is

---
Prompt eval | tokens: 5 | elapsed: 3.85s | tok/s: 1.30
Generation | tokens: 9 | elapsed: 1.09s | tok/s: 8.23


In [17]:
generate("CUDA is Nvidia's biggest most")

CUDA is Nvidia's biggest most popular GPU, and it's a big reason why

---
Prompt eval | tokens: 8 | elapsed: 0.94s | tok/s: 8.51
Generation | tokens: 9 | elapsed: 0.43s | tok/s: 21.00


In [16]:
generate("1 2 3 4 ")

1 2 3 4 
1 2 3 4 
1 2 3

---
Prompt eval | tokens: 5 | elapsed: 0.78s | tok/s: 6.37
Generation | tokens: 9 | elapsed: 0.71s | tok/s: 12.61
