Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions keras-llama-32.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
title: “Llama 3.2 in Keras”
thumbnail: /blog/keras_llama_32/thumbnail.gif
thumbnail: /blog/assets/keras_llama_32/thumbnail.jpg
authors:
- user: martin-gorner
---
Expand All @@ -16,7 +16,7 @@ This is going to be the shortest blog post ever.

Yes, Keras Llama3 can be loaded from any standard (i.e. safetensors) Hugging Face checkpoint, including the 3.2 checkpoints. If a conversion is required, it happens on the fly. Try this:

```Python
```py
!pip install keras_hub

from keras_hub import models.Llama3CausalLM
Expand All @@ -34,7 +34,7 @@ OK, OK, I'm being told that if I want to publish a blog post, I have to fill the

Keras is the time-tested modeling library for JAX, PyTorch and TensorFlow. You might have noticed this line in the [demo Colab](https://colab.research.google.com/drive/1cnAUQbDfM8lErQ8MD2x9Mo5sfKIqIxEh):

```Python
```py
import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch", or "tensorflow"
```
Expand All @@ -48,14 +48,14 @@ Keras is a modeling library and [keras-hub](https://keras.io/keras_hub/) is its
## LLMs in Keras come "batteries included"

I mean, "tokenizer included". `model.generate()` just works on strings:
```
```py
model.generate("Hi there!")
> "Hi there! I'm looking for information on how to ...
```

Same thing for training. You can train on a set of strings directly:

```
```py
model.fit(strings) # list or dataset of input strings
```

Expand All @@ -81,7 +81,7 @@ For convenience, the [demo Colab](https://colab.research.google.com/drive/1cnAUQ

If you don't like "batteries included" and want to get to the underlying tokenizer and model, they are easily accessible:

```
```py
# tokenizer
model.preprocessor.tokenizer

Expand All @@ -97,7 +97,7 @@ tokenizer = keras_hub.models.Llama3Tokenizer.from_preset("hf://meta-llama/Llama-

The Tokenizer just transforms text into integer vectors. Here "Hello" translates into a single token:

```
```py
tokenizer("Hello")
> Array([9906], dtype=int32)
```
Expand All @@ -107,7 +107,7 @@ The Preprocessor is a catch-all concept for doing all the data transformations a
* padding the token sequences and generating a mask
* generating "expected outputs" for training and fine-tuning. For CausalLM tasks this is the input string shifted by one.

```
```py
tokens = model.preprocessor("Hello")

tokens[0] # 128000 and 128009 are the start and end text tokens
Expand Down Expand Up @@ -138,7 +138,7 @@ For a complete example, see the [demo Colab](https://colab.research.google.com/d

Once you are happy with your fine-tuned model, upload it directly from Keras using:

```
```py
model.save_to_preset("./pirate-llama")
# Use your own repo here
keras_hub.upload_preset(
Expand All @@ -156,7 +156,7 @@ Some of you are wondering, why use Keras when one can already work with LLMs on

Let's pick an 8B parameters model to demonstrate: meta-llama/Llama-3.1-8B-Instruct ([demo Colab here](https://colab.research.google.com/drive/1WzErEM04rieeCMY6s_wGyTjWcuhAF-3D)). Without quantization, this model is too large for any single accelerator. With Keras, you can load it sharded across multiple accelerators, GPU or TPU. If you are uncertain about the "correct" weight shardings, most models provide sensible defaults. Here, call `keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)`:

```
```py
devices = keras.distribution.list_devices() # 8 TPU cores: let's do a 2x4 mesh
device_mesh = keras.distribution.DeviceMesh((2, 4), ["batch", "model"], devices)
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) # defaults
Expand All @@ -169,7 +169,7 @@ model = keras_hub.models.Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.1-8

And if you don't trust the default layout map provided by the model, you can define your own. In this example running on a "small" TPU setup with only 8 cores, the following layout map is a bit faster than the default (54s/epoch rather than 62s/epoch):

```
```py
layout_map = keras.distribution.LayoutMap(device_mesh)

layout_map["token_embedding/embeddings"] = ("model", None)
Expand Down