<a href="https://colab.research.google.com/github/cscpsb/AIML_1/blob/main/LLM%20Post%20Training/Gemma_Distributed_Fine_tuning_on_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Distributed tuning with Gemma using Keras

### Kaggle credentials

In [None]:
import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

## Installation

Install Keras and KerasNLP with the Gemma model.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install tensorflow-cpu~=2.16.0 keras-nlp==0.8.2 tensorflow-hub==0.16.1  tensorflow-text==2.16.1 keras==3.0.5

Collecting tensorflow-cpu~=2.16.0
  Downloading tensorflow_cpu-2.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting keras-nlp==0.8.2
  Downloading keras_nlp-0.8.2-py3-none-any.whl.metadata (7.0 kB)
Collecting tensorflow-hub==0.16.1
  Downloading tensorflow_hub-0.16.1-py2.py3-none-any.whl.metadata (1.3 kB)
[31mERROR: Could not find a version that satisfies the requirement tensorflow-text==2.16.1 (from versions: 2.18.1, 2.19.0rc0, 2.19.0)[0m[31m
[0m[31mERROR: No matching distribution found for tensorflow-text==2.16.1[0m[31m
[0m

In [None]:
!pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
!pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
!pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
!pip install -q -U keras

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m947.9/947.9 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.9/251.9 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25h

### Set up Keras JAX backend

In [None]:
import jax

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]

In [None]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

## Load model

In [None]:
import keras
import keras_nlp

### mixed precision training (if applicable)



In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

### Model Parallelism

To load the model with the weights and tensors distributed across TPUs, create a new `DeviceMesh`.

In [None]:
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

`LayoutMap` from the distribution API specifies how the weights and tensors should be sharded or replicated, using the string keys

In [None]:
model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (None, model_dim)

# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

`ModelParallel` allows you to shard model weights or activation tensors across all devcies on the `DeviceMesh`.

In [None]:
model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
# Download the smaller Gemma 2B model with the correct preset name.
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("vault_gemma_1b_en")

Downloading from https://www.kaggle.com/api/v1/models/keras/vaultgemma/keras/vault_gemma_1b_en/2/download/config.json...


100%|██████████| 784/784 [00:00<00:00, 2.29MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/vaultgemma/keras/vault_gemma_1b_en/2/download/task.json...


100%|██████████| 2.91k/2.91k [00:00<00:00, 8.21MB/s]


XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 1.10G. That was not possible. There are 536.89M free.; (0x0x0_HBM0)

Now verify that the model has been partitioned correctly. Let's take `decoder_block_1` as an example.

In [None]:
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (

## Inference before finetuning

In [None]:
gemma_lm.generate("Best comedy movies: ", max_length=64)

'Best comedy movies: \n\n1. **The Hangover (2009)**\n2. **Bridesmaids (2011)**\n3. **Anchorman (2007)**\n4. **Superbad (2007)**\n5. **The Room (2'

The model generates a list of great comedy movies from the 90s to watch. Now we finetune the Gemma model to change the output style.

## Finetune with IMDB

In [None]:
import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()

Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete3PYPCR/imdb_reviews-train.tfrecord…

Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete3PYPCR/imdb_reviews-test.tfrecord*…

Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete3PYPCR/imdb_reviews-unsupervised.t…

Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.


b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."

In [None]:
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

### LoRA training

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

In [None]:
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)

[1m2000/2000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m326s[0m 142ms/step - loss: 14.7736 - sparse_categorical_accuracy: 0.4807


<keras.src.callbacks.history.History at 0x7c25d831b280>

## Inference after finetuning

In [None]:
gemma_lm.generate("Best comedy movies: ", max_length=256)

'Best comedy movies: \n\n1. **The Hangover (2009)**\n2. **Bridesmaids (2011)**\n3. **Anchorman (2007)**\n4. **Old School (2003)**\n5. **Superbad (2007)**\n6. **The Room (2003)**\n\nWhat is the common thread between the movies on this list?\n\nThey are all comedies.'

After finetuning, the model has learned the style of movie reviews and is now generating output in that style in the context of 90s comedy movies.