# Sharding

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sharding.ipynb)

Sharding for Gemma models. This example run inference on Gemma 27B, on a TPU v2, using 8 devices.

In [None]:
!pip install -q gemma

In [None]:
# Common imports
import jax
import jax.numpy as jnp

# Gemma imports
from gemma import gm
from kauldron import kd

For this colab, make sure to be connected to the TPU kernel by selecting `Change runtime type` > `v2-8 TPU` to access multiple accelerators. Jax should display multiple devices.

In [None]:
jax.device_count()

8

Load the model, the params and the tokenizer. Here we load the 27B model.

In [None]:
tokenizer = gm.text.Gemma2Tokenizer()

model = gm.nn.Gemma2_27B()

When restoring the weights, you can pass `sharding=` parameter to the `gm.ckpts.load_params`. Here we use a naive `kd.sharding.FSDPSharding` heuristic.

In [None]:
params = gm.ckpts.load_params(
    gm.ckpts.CheckpointPath.GEMMA2_27B_IT,
    sharding=kd.sharding.FSDPSharding(),
)

## Single token

Here's an example of predicting a single token, directly calling the model.

In [None]:
# Encode the prompt
prompt = tokenizer.encode('My name is', add_bos=True)  # /!\ Don't forget to add the BOS token
prompt = jnp.asarray(prompt)

# Here, we replicate the prompt across devices, usually, the prompts will
# be batched and padded, then sharded using `kd.sharding.FIRST_DIM`
prompt = kd.sharding.with_sharding_constraint(prompt, kd.sharding.REPLICATED)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Sample a token from the predicted logits
next_token = jax.random.categorical(
    jax.random.key(1),
    out.logits
)
tokenizer.decode(next_token)

' Mary'

You can also display the next token probability.

In [None]:
tokenizer.plot_logits(out.logits)

## Multiple tokens

In practice, Gemma provide a `gm.text.Sampler` to perform efficient sampling (with kv-caching, early stopping,...).

In [None]:
sampler = gm.text.Sampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.sample('My name is', max_new_tokens=30)

" Sarah, and I'm a 25-year-old woman living in a bustling city. I work as a graphic designer, a job"

## Training

To use sharding during training, simply set the `sharding=` attribute of the trainer, like:

```python
trainer = kd.train.Trainer(
    ...,
    sharding=kd.sharding.ShardingStrategy(
        params=kd.sharding.FSDPSharding(),
    ),
    ...,
)
```

See a full example at: https://github.com/google-deepmind/gemma/tree/main/examples/sharding.py