In [1]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer
import jax
import jax.numpy as jnp

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-130m'

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
model, params = Mamba.from_pretrained(pretrained_model_name)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  return self.fget.__get__(instance, owner)()


In [2]:
from jax.lib import xla_bridge

print(xla_bridge.get_backend().platform)
print('JAX Devices:', '\n'.join([d.device_kind for d in jax.devices()]))

gpu
JAX Devices: NVIDIA GeForce RTX 4090


In [3]:
def generate(model,
            params, 
            tokenizer,
            prompt: str,
            n_tokens_to_gen: int = 50,
            sample: bool = True,
            top_k: int = 40,
            rng = jax.random.PRNGKey(177013)):
    
    # Encode prompt to tokens
    input_ids = tokenizer(prompt, return_tensors='jax').input_ids

    for token_n in range(n_tokens_to_gen):
        # Get the logits of the last predicted token
        next_token_logits = model.apply(params, input_ids)[:, -1]

        # Apply softmax to convert logits to probabilities
        probs = jax.nn.softmax(next_token_logits)

        # Apply top-k filtering
        if top_k is not None:
            (values, indices) = jax.lax.top_k(probs, k=top_k)
            mask = probs < values[..., -1, None]
            probs = probs.at[mask].set(0)
            probs = probs / jnp.sum(probs, axis=1, keepdims=True)

        if sample:
            # Sample the next token indices
            next_indices = jax.random.categorical(rng, probs, num_samples=1)
        else:
            # Pick the most likely next token
            next_indices = jnp.argmax(probs, axis=-1, keepdims=True)

        # Append next token ID to the sequence
        input_ids = jnp.concatenate([input_ids, next_indices], axis=1)
    
    # Decode generated tokens to text
    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]

    return output_completions

In [None]:
print(generate(model, params, tokenizer, 'Mamba is the', sample=False))