In [10]:
import argparse
import time
import json

import torch
import torch.nn.functional as F

from einops import rearrange

from transformers import AutoTokenizer, AutoModelForCausalLM

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

import numpy as np
from transformers import pipeline, AutoTokenizer
from transformers import GPTNeoForCausalLM
from mamba_ssm import MambaLMHeadModel
from tqdm.notebook import tqdm


def get_model_and_tokenizer(model_name):
    device = "cuda"
    dtype = torch.float16

    print(f"Loading model {model_name}")
    is_mamba = model_name.startswith("state-spaces/mamba-")
    if is_mamba:
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
        model = MambaLMHeadModel.from_pretrained(model_name, device=device, dtype=dtype)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map={"": device}, torch_dtype=dtype
        )
    model.eval()
    print(
        f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    )

    return model, tokenizer


def generate_text(
    model,
    tokenizer,
    model_name,
    prompt=None,
    promptlen=100,
    genlen=100,
    temperature=1.0,
    topk=1,
    topp=1.0,
    minp=0.0,
    repetition_penalty=1.0,
    batch=1,
):
    is_mamba = model_name.startswith("state-spaces/mamba-")

    repeats = 3
    device = "cuda"
    dtype = torch.float16

    torch.random.manual_seed(0)
    if prompt is None:
        input_ids = torch.randint(
            1, 1000, (batch, promptlen), dtype=torch.long, device="cuda"
        )
        attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
    else:
        tokens = tokenizer(prompt, return_tensors="pt")
        input_ids = tokens.input_ids.to(device=device)
        attn_mask = tokens.attention_mask.to(device=device)
    max_length = input_ids.shape[1] + genlen

    if is_mamba:
        fn = lambda: model.generate(
            input_ids=input_ids,
            max_length=max_length,
            cg=True,
            return_dict_in_generate=True,
            output_scores=True,
            enable_timing=False,
            temperature=temperature,
            top_k=topk,
            top_p=topp,
            min_p=minp,
            repetition_penalty=repetition_penalty,
        )
    else:
        fn = lambda: model.generate(
            input_ids=input_ids,
            attention_mask=attn_mask,
            max_length=max_length,
            return_dict_in_generate=True,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            temperature=temperature,
            top_k=topk,
            top_p=topp,
            repetition_penalty=repetition_penalty,
        )
    out = fn()
    if prompt is not None:
        res = tokenizer.batch_decode(out.sequences.tolist())

        return res

    # torch.cuda.synchronize()
    # start = time.time()
    # for _ in range(repeats):
    #     fn()
    # torch.cuda.synchronize()
    # print(
    #     f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}"
    # )
    # print(
    #     f"{model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms"
    # )

In [11]:
gptneo = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")
gptneo.eval()
mamba = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")
mamba.eval()

tokenizer_gptneo = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
tokenizer_mamba = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [6]:
prompt = "Tri is"

model_name = "state-spaces/mamba-130m"
#  model_name = "EleutherAI/gpt-neo-125m"

model, tokenizer = get_model_and_tokenizer(model_name=model_name)


n = 5

for i in tqdm(range(n)):

    text = generate_text(
        model=model, tokenizer=tokenizer, model_name=model_name, prompt=prompt, batch=1
    )

    print(text)

Loading model state-spaces/mamba-130m
Number of parameters: 129135360


  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [7]:
generator = pipeline("text-generation", model=gptneo, tokenizer=tokenizer_gptneo)

for i in tqdm(range(10_000)):
    generator(prompt, do_sample=True, min_length=20)

  0%|          | 0/10000 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


KeyboardInterrupt: 