In [1]:
import sys
sys.path.append('../../')

In [2]:
import jax
import jax.numpy as jnp
from einshard import einshard
from transformers import AutoTokenizer
from transformerx.models.llama.default import \
    get_tokenize_fn, load_jx_config, load_jx_params
from transformerx.models.llama.modeling import forward_fn, LlamaInputs

In [3]:
model = 'microsoft/Phi-3-mini-4k-instruct'
config = load_jx_config(model)
params = load_jx_params(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
params = jax.tree_util.tree_map(
    lambda e: einshard(e, '... O -> ... O*'), params)

In [5]:
tokenize_fn = get_tokenize_fn(model, max_length=4096, padding_side='left')
detokenize_fn = AutoTokenizer.from_pretrained(model).batch_decode

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
prompt = '<|user|>\n' + 'Hello, who are you?' + '<|end|>\n<|assistant|>'
inputs = LlamaInputs(**tokenize_fn(prompt))

while True:
    logits = forward_fn(params, inputs, config).logits[:, -1]
    new_token = jnp.argmax(logits, axis=-1, keepdims=True)
    new_shift = new_token.shape[1]
    inputs = LlamaInputs(
        input_ids=jnp.hstack((
            inputs.input_ids[:, new_shift:],
            new_token)),
        attention_mask=jnp.hstack((
            inputs.attention_mask[:, new_shift:],
            jnp.ones_like(new_token))),
        position_ids=jnp.hstack((
            inputs.position_ids[:, new_shift:],
            inputs.position_ids[:, -new_shift:] + new_shift)))
    print(detokenize_fn(inputs.input_ids, skip_special_tokens=True)[0])
    if new_token[0][0] == tokenize_fn('<|end|>')['input_ids'][0, -1]:
        break

Hello, who are you? I
Hello, who are you? I am
Hello, who are you? I am Ph
Hello, who are you? I am Phi
Hello, who are you? I am Phi,
Hello, who are you? I am Phi, an
Hello, who are you? I am Phi, an A
Hello, who are you? I am Phi, an AI
Hello, who are you? I am Phi, an AI developed
Hello, who are you? I am Phi, an AI developed to
Hello, who are you? I am Phi, an AI developed to provide
Hello, who are you? I am Phi, an AI developed to provide information
Hello, who are you? I am Phi, an AI developed to provide information,
Hello, who are you? I am Phi, an AI developed to provide information, answer
Hello, who are you? I am Phi, an AI developed to provide information, answer questions
Hello, who are you? I am Phi, an AI developed to provide information, answer questions,
Hello, who are you? I am Phi, an AI developed to provide information, answer questions, and
Hello, who are you? I am Phi, an AI developed to provide information, answer questions, and assist
Hello, who are you? I am Phi