Skip to content

Llama-inference-jax: Accelerated inference with Llama Models in JAX for high-speed, pure JAX implementation.

License

Notifications You must be signed in to change notification settings

erfanzar/Llama-Inference-JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Llama-inference-jax

Accelerated inference with Llama Models in JAX for high-speed, pure JAX implementation.

Note

This project will only support Llama Models (at least for now), and focuses on local machines and it's just an example to show people how they can implement their own model using pure jax so if you are more likely trying to use this project for any other purposes I suggest you check out EasyDeL.

Overview

Llama-inference-jax is a library designed to perform accelerated inference using Llama Models in JAX, providing high-speed and pure JAX implementation. Llama Models are known for their efficiency and accuracy in various machine learning tasks, and integrating them with JAX allows for seamless deployment on accelerators like GPUs and TPUs.

Features

  • Accelerated inference with Llama Models.
  • Pure JAX implementation for high-speed execution.
  • Seamless deployment on GPUs and TPUs.
  • Custom Pallas Kernels.
  • Parameter Quantization.
  • Standalone weights.
  • Flash Attention Support on CPU/GPU/TPU.
  • PyTrees and JAX compatible Blocks for Model.

Usage

Converting Your Own Llama Model to LiJAX as easy as possible
from lijax.covertors import convert_llama_model
import pickle as pkl

lijax_model = convert_llama_model(
    pre_trained_model_name_or_path="meta-llama/Meta-Llama-3-8B-Instruct",
    extra_loading_options_for_model=dict(),  # Kwargs to hf model loading
    quantize_mlp=True,
    quantize_embed=True,
    quantize_lm_head=True,
    quantize_self_attn=True
)
lijax_model.shard()

print(lijax_model)

# Saving Model 

pkl.dump(lijax_model, open("lijax_llama_3_8b", "wb"))

# Loading Saved Model 

_new_lijax_model = pkl.load(open("lijax_llama_3_8b", "rb"))
_new_lijax_model.shard()  # sharding model is optional across available GPUs,TPUs

Generation Process

import jax.numpy
from transformers import AutoTokenizer
from lijax.model import llama_generate
from lijax.covertors import convert_llama_model

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
lijax_model = convert_llama_model("meta-llama/Meta-Llama-3-8B-Instruct")
lijax_model.shard()
generated_ids = None
printed_length = 0
for token in llama_generate(
        block=lijax_model,
        input_ids=tokenizer.apply_chat_template(
            [
                {"role": "user", "content": "hi"}
            ],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="np"
        ),
        use_flash_attention=False,
        # runtime_kernel="pallas",
        runtime_kernel="normal",
        max_length=2048,
        max_new_tokens=32,
        eos_token_id=tokenizer.eos_token_id,
        temperature=1.6,
        # do_sample=True,
        top_k=20,
        top_p=0.95,
):
    generated_ids = jax.numpy.concatenate([generated_ids, token], -1) if generated_ids is not None else token
    stream = tokenizer.decode(generated_ids[0].tolist(), skip_special_tokens=False)
    print(stream[printed_length:], end="")
    printed_length = len(stream)

License

This project is licensed under the MIT License - see the LICENSE file for details.

About

Llama-inference-jax: Accelerated inference with Llama Models in JAX for high-speed, pure JAX implementation.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages