##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This notebook provides a practical guide to working with Gemma 2 for Japan, the latest variant of Google's open language models.

The model itself is available on both Kaggle and Hugging Face.

As the model is currently only available in a 2B size, no special hardware is required.

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Gemma_2_for_Japan_using_Transformers_and_PyTorch.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

### Downloading prerequisites
The **KaggleHub** library will be used to retrieve the model from Kaggle, and the **Transformers** framework from Hugging Face will be used with **PyTorch** for inference.

Note: the `%%capture` magic keyword suppresses output of a given cell.

In [1]:
%%capture
!pip install kagglehub --upgrade
!pip install transformers --upgrade
!pip install torch --upgrade

### Authenticating with Kaggle

Next, Kaggle will need to be authenticated with.

Note: the call to `login()` can be skipped if running in a Kaggle notebook.

In [2]:
import kagglehub

# skip if in a Kaggle notebook
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


### Downloading the model

This next step ensures that you have the necessary model files readily available, ready for loading and inference.

`kagglehub.model_download()` fetches the model from Kaggle.

In [3]:
# Download latest version
path = kagglehub.model_download("google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it")

print("Path to model files:", path)

Downloading 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/model.safetensors.index.json...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/tokenizer_config.json...
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/tokenizer.model...




  0%|          | 0.00/1.06k [00:00<?, ?B/s][A[A
100%|██████████| 1.06k/1.06k [00:00<00:00, 166kB/s]


100%|██████████| 23.7k/23.7k [00:00<00:00, 3.01MB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/model-00001-of-00002.safetensors...




  0%|          | 0.00/4.65G [00:00<?, ?B/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/model-00002-of-00002.safetensors...


[A


  0%|          | 0.00/230M [00:00<?, ?B/s][A[A[A

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/special_tokens_map.json...






100%|██████████| 555/555 [00:00<00:00, 39.6kB/s]

Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/generation_config.json...





Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1/download/config.json...






100%|██████████| 168/168 [00:00<00:00, 103kB/s]





100%|██████████| 805/805 [00:00<00:00, 496kB/s]

  0%|          | 1.00M/4.65G [00:01<1:46:34, 780kB/s][A

 25%|██▍       | 1.00M/4.04M [00:01<00:04, 771kB/s][A[A


  0%|          | 1.00M/230M [00:01<05:13, 765kB/s][A[A[A


  1%|          | 2.00M/230M [00:01<02:56, 1.35MB/s][A[A[A
  0%|          | 2.00M/4.65G [00:01<1:02:06, 1.34MB/s][A

100%|██████████| 4.04M/4.04M [00:01<00:00, 2.31MB/s]



  2%|▏         | 4.00M/230M [00:02<01:12, 3.24MB/s][A[A[A
  0%|          | 5.00M/4.65G [00:02<20:28, 4.06MB/s]  [A


  2%|▏         | 5.00M/230M [00:02<01:02, 3.77MB/s][A[A[A
  0%|          | 7.00M/4.65G [00:02<14:29, 5.73MB/s][A


  4%|▍         | 9.00M/230M [00:02<00:26, 8.59MB/s][A[A[A
  0%|          | 10.0M/4.65G [00:02<09:29, 8.74MB/s][A
  0%|          | 12.0M/4.65G [00:02<08:03, 10.3MB/s][A
  0%|          | 14.0M/4.65G [00:02<06:59, 11.9MB/s][A


  5%|▍         | 11.0M/230M [00:02<00:28, 8.03MB/s][A[A[A
  0%| 

Path to model files: /root/.cache/kagglehub/models/google/gemma-2-2b-jpn-it/transformers/gemma-2-2b-jpn-it/1





### Importing PyTorch

For faster inference, it's good to use an accelerator such as a TPU or GPU - and also very important to make sure that PyTorch is set up to use it.

The next cell imports PyTorch, then checks whether CUDA is available for use with a GPU.

In [4]:
import torch
torch.cuda.is_available()

True

Now you can use Transformers. Here you set up the tokenizer and model, and make sure that the model is on the correct device.

In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
    path,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    local_files_only=True
).to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Here are some helper functions - `format_gemma_instruction()` is very simple (and maybe doesn't even need to be a function!), wrapping the user's instruction in Gemma format.

`unwrap_gemma_response()` takes care of stripping the response you get from the model, removing tokens as well as the prompt.

For more information on Gemma formatting, see https://ai.google.dev/gemma/docs/formatting

In [9]:
def format_gemma_instruction(instruction: str) -> str:
  return f"<start_of_turn>user {instruction}<end_of_turn><start_of_turn>model"

def unwrap_gemma_response(query: str, response: str) -> str:
  end_sequence = '<end_of_turn>'

  start_idx = 0
  query_idx = response.find(query)

  if query_idx >= 0:
    start_idx = query_idx + len(query)

  trim = response[start_idx:]

  end_idx = len(trim) - 1
  endseq_idx = trim.find(end_sequence)

  if endseq_idx >= 0:
    end_idx = endseq_idx

  return trim[:end_idx].strip()

First, you will send a prompt in Japanese to the model, asking it to write us a poem in Japanese.

In [10]:
# Our prompt will be, "Write me a poem about machine learning." in Japanese.
input_text = "マシーンラーニングについての詩を書いてください。"

input_formatted = format_gemma_instruction(input_text)

input_ids = tokenizer(input_formatted, return_tensors="pt").to(device)

outputs = model.generate(**input_ids, max_new_tokens=1024)
formatted = unwrap_gemma_response(input_formatted, tokenizer.decode(outputs[0]))

print(formatted)

The 'max_batch_size' argument of HybridCache is deprecated and will be removed in v4.46. Use the more precisely named 'batch_size' argument instead.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


機械学習の波、広がる未来の光、
データの海、複雑な知識を導く。
複雑なパターン、隠された法則を見つける、
予測と改善、未来を形作る。

ニューラルネットワーク、複雑な脳の像、
学習と進化、無限の可能性を秘める。
教師あり学習、教師なし学習、
データの力、人間の知恵を融合する。

機械学習の力、無限の可能性を秘める、
複雑な問題を解き明かす、未来を拓く。
AIの進化、人類の未来を左右する、
新たな時代を迎えよう。


Next, you can feed that poem back into the model, asking it to translate it into English.

In [16]:
translation_input_text = "Translate the following poem from Japanese to English. \n\n" + formatted
translation_input_formatted = format_gemma_instruction(translation_input_text)

translation_input_ids = tokenizer(translation_input_formatted, return_tensors="pt").to(device)

translation_output = model.generate(**translation_input_ids, max_new_tokens=2048)
translation_formatted = unwrap_gemma_response(translation_input_formatted, tokenizer.decode(translation_output[0]))

print(translation_formatted)

Here's the translation of the Japanese poem:

**The wave of machine learning, spreading the light of the future,
A vast ocean of data, guiding complex knowledge.
Finding complex patterns, uncovering hidden laws,
Prediction and improvement, shaping the future.

A neural network, a complex image of the brain,
Learning and evolution, holding infinite possibilities.
Supervised and unsupervised learning,
The power of data, fusing human wisdom.

The power of machine learning, holding infinite possibilities,
Unraveling complex problems, opening up the future.
The evolution of AI, shaping the future of humanity,
We are ushering in a new era.**



**Explanation of the Poem's Themes:**

* **Machine Learning's Impact:** The poem highlights the growing influence of machine learning, emphasizing its ability to shape the future.
* **Data as the Foundation:**  The poem emphasizes the role of data as the foundation for machine learning, suggesting that it is the key to unlocking its potential.
* **Com

And you're all done!