<a href="https://colab.research.google.com/github/chrislouis86/Kaggle-/blob/main/Gemma_7B_pirate_upload_to_Kaggle_and_Hugging_Face.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This colab: [bit.ly/fine-tuned-gemma-upload](https://bit.ly/fine-tuned-gemma-upload)
##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Installation

Install Keras and KerasNLP with the Gemma model.

In [2]:
!pip install jax[tpu] --user -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -U keras-nlp
!pip install -U keras>=3
!pip install -U tensorflow-cpu # for tf.data only
!pip install huggingface_hub
!pip install kaggle

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


**How to this run demo:**
* On Kaggle, you can use a TPUv3-8 which have 8 cores with 16GB of memory each.
* On GCP, you can [provision a TPUv3-8 backend for Colab](https://docs.google.com/document/d/13DjWGxqbqAyoEcJjtmenmaMya9lxTsFn0jJP-UUH1jc/edit?usp=sharing).
* On GCP, you can [provision an multi-GPU 8xV100 backend for Colab](https://docs.google.com/document/d/1MNINJCk6vp0gCr4XqvmaTiwN8GKJH9_SFceXNvKjRw0/edit?usp=sharing).

In [3]:
!pip install keras_nlp



In [4]:
!pip install --upgrade --force-reinstall tensorflow-text

Collecting tensorflow-text
  Using cached tensorflow_text-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.0 kB)
Collecting tensorflow<2.20,>=2.19.0 (from tensorflow-text)
  Using cached tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow<2.20,>=2.19.0->tensorflow-text)
  Using cached absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow<2.20,>=2.19.0->tensorflow-text)
  Using cached astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow<2.20,>=2.19.0->tensorflow-text)
  Using cached flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow<2.20,>=2.19.0->tensorflow-text)
  Using cached gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow<2.20,>=2.19.0->tensorflow-text)
  Usin

In [None]:
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

# for saving smaller weights
keras.config.set_floatx("bfloat16")

# check what accelerators are available
import jax
jax.devices()

In [None]:
# formatting utility
from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

# Chat helper to drive a turn-by-turn dialog with Gemma
class ChatState():

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      return "".join([*self.history])

  def get_full_prompt(self):
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")
    self.add_to_history_as_model(result)
    return result


In [None]:
# Access to Gemma checkpoints on Kaggle

# * On Colab, set up you Kaggle credentials as Colab secrets and use this:

# from google.colab import userdata
# os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
# os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

# * On Kaggle, set up you Kaggle credentials as Kaggle secrets and use this:

# from kaggle_secrets import UserSecretsClient
# secret_label = "your-secret-label"
# os.environ["KAGGLE_USERNAME"] = UserSecretsClient().get_secret(KAGGLE_USERNAME)
# os.environ["KAGGLE_KEY"] = UserSecretsClient().get_secret(KAGGLE_KEY)

# * If using Jupyter locally, you can set up the environment variables
#   KAGGLE_USERNAME and KAGGLE_KEY in your local environment before launching Jupyter.

# * To access models on HuggingFace, create an accesss token on https://huggingface.co/settings/tokens
#   Then load it into the environment variable HF_TOKEN using one of the three methods above.

# Load the Gemma 7B-instruct model

The model is loaded with ModelParallel partitioning. The code below specifies how the weights of the model are split across multiple accelerators, so as to fit in memory.

Detailed documentation in the new [Keras 3 distribution API guide](https://keras.io/guides/distribution/).

In [None]:
# Create a device mesh with shape (1, 1) to parition weights
# across all 8 TPUs cores, the use the default Gemma layout map.
devices = keras.distribution.list_devices()
device_mesh = keras.distribution.DeviceMesh((1, 1), ["batch", "model"], devices)

# The default layout map is provided by the model
#layout_map = keras_nlp.models.GemmaBackbone.get_layout_map(device_mesh) # default map

# Here it is written out explcitly for readability
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = ("model", None)
# Regex to match against the query, key and value matrices in attention layers
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output.kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear.kernel"] = ("model", None)

# Pass the layout_map to ModelParallel
model_parallel = keras.distribution.ModelParallel(devices=devices, layout_map=layout_map) # Added layout_map argument

# Make this the defaul model parallel layout
keras.distribution.set_distribution(model_parallel)

# load the model
gemma = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_7b_en")
# or reload your variation (please change the URL)
#gemma = keras_nlp.models.GemmaCausalLM.from_preset("kaggle://mgornergoogle/gemma/keras/gemma1.1_instruct_pirate_7b")

 !pip install kaggle

# Fine-tune the model

Actually, we will just load a checkpoint fine-tuned previously. See [bit.ly/gemma-pirate-demo](https://bit.ly/gemma-pirate-demo) for the actual fine-tuning code.

In [None]:
# Load from previously trained LoRA checkpoint in GCS
!gsutil cp gs://gemma-pirate/pirate.gemma1.1_7bi.lora.h5 .
gemma.backbone.enable_lora(rank=8)
gemma.backbone.load_lora_weights("pirate.gemma1.1_7bi.lora.h5")

# Save the model as a KerasNLP preset

In [None]:
gemma.save_to_preset('./gemma-pirate')

In [None]:
!ls -al ./gemma-pirate

# Upload the fine-tuned model to Kaggle

In [None]:
#saving to GCS manually
#!gsutil -m cp -r ./gemma-pirate gs://gemma-pirate/saved-preset/

In [None]:
kaggle_username = "chrismorgan86"
model_name = "gemma-pirate"
variation_name = "gemma1.1_instruct_pirate_7b"

# original model URI (looked up on Kaggle)
# https://www.kaggle.com/models/keras/gemma/keras/gemma_1.1_instruct_7b_en

uri = f"kaggle://{kaggle_username}/{model_name}/keras/{variation_name}"
print(uri)

In [None]:
keras_nlp.upload_preset(uri, './gemma-pirate')

# Upload the model to HuggingFace

In [None]:
hf_username = "ChrisMorgan86"
model_variant_name = "gemma_pirate_instruct_7b-keras"

# original model URI (looked up on HuggingFace)
# https://huggingface.co/google/gemma-1.1-7b-it-keras

uri = f"hf://{hf_username}/{model_variant_name}"
print(uri)

# Authenticate with HuggingFace
import huggingface_hub
huggingface_hub.login(token=os.environ["HF_TOKEN"])

In [None]:
keras_nlp.upload_preset(uri, './gemma-pirate', )

# Let's chat as pirates. Arrrr!

In [None]:
# delete the model to free up memory
del gemma
# reload the model from the uploaded variant on Kaggle
gemma = keras_nlp.models.GemmaCausalLM.from_preset("kaggle://chrismorgan86/gemma-pirate/keras/gemma1.1_instruct_pirate_7b")

In [None]:
chat = ChatState(gemma)
message = "Hello there"
display_chat(message, chat.send_message(message))

In [None]:
message = "Prime numbers, for sure!?"
display_chat(message, chat.send_message(message))

In [None]:
message = "Give my Python code computing them primes up to 1000!"
display_chat(message, chat.send_message(message))

In [None]:
print('Primes between 1 and 1000:\n')
for i in range(2, 1001):
    is_prime = True
    for j in range(2, i):
        if i % j == 0:
            is_prime = False
            break
    if is_prime:
        print(i)


# What's next

In this tutorial, you learned how to chat with the Gemma 7B model and fine-tune it to speak like a pirate, using Keras on JAX. You also learned how to load and train the large model in a distributed manner, on powerful TPUs, uising model parallelism.

Here are a few suggestions for what else to learn, about Keras and JAX:
* [Distributed training with Keras 3](https://keras.io/guides/distribution/).
* [Writing a custom training loop for a Keras model in JAX](https://keras.io/guides/writing_a_custom_training_loop_in_jax/).

And a couple of more basic Gemma tutorials:

* [Get started with Keras Gemma](https://ai.google.dev/gemma/docs/get_started).
* [Finetune the Gemma model on GPU](https://ai.google.dev/gemma/docs/lora_tuning).