In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/gemma/pytorch/2b/2/config.json
/kaggle/input/gemma/pytorch/2b/2/gemma-2b.ckpt
/kaggle/input/gemma/pytorch/2b/2/tokenizer.model


In [2]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 71, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 71 (delta 12), reused 8 (delta 8), pack-reused 55[K
Unpacking objects: 100% (71/71), 2.13 MiB | 5.19 MiB/s, done.


In [6]:
!ls /kaggle/working/gemma_pytorch

CONTRIBUTING.md  README.md  gemma	      scripts	tokenizer
LICENSE		 docker     requirements.txt  setup.py


In [7]:
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch


In [3]:
# Load the model
VARIANT = "2b" 
MACHINE_TYPE = "cpu" 
weights_dir = '/kaggle/input/gemma/pytorch/2b/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

  return self.fget.__get__(instance, owner)()


A PyTorch implementation of Gemma 2B model. It is a 2B (two billion) parameter base model that has not yet been instruction-tuned.

Pre-trained (PT) models can be used as base models for further development, while instruction-tuned (IT) variants can be used for chatting and following prompts.

Information on instruction-tuned: https://www.linkedin.com/pulse/generative-ai-executives-10-minute-deep-dive-amit-gupta/

The IT variant is available on the same kaggle model page:  https://www.kaggle.com/models/google/gemma/frameworks/pyTorch/variations/2b-it

It summarizes the differences:  "Pre-trained (PT) models can be used as base models for further development, while instruction-tuned (IT) variants can be used for chatting and following prompts."

Note that I had to add the model for the IT variant to this notebook (only the 2b version was originally added) through the Notebook pane, on the right, with the 'Add Input' button.

In [19]:
# Load the model
VARIANT = "2b-it" 
# Need to set this to cuda, not gpu or cpu while using the gpu t4 on kaggle.
# Much faster results (as expected) when I did so.
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/2b-it/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

  return self.fget.__get__(instance, owner)()


In [20]:
!ls /kaggle/input/gemma/pytorch/2b-it/2

config.json  gemma-2b-it.ckpt  tokenizer.model


Had to follow the advice on this link to get partially reasonable results:

https://www.kaggle.com/models/google/gemma/discussion/478675

Still hard to interpret.  I will have to dig into Gemma more...

In [22]:
# Use the model

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="Give me a workout for beginners"
    )
    + "<start_of_turn>model\n"
)

model.generate(
    prompt,
    device=device,
    output_len=100,
)

'**Beginner Workout**\n\n**Warm-up (5 minutes)**\n\n* Light cardio, such as brisk walking or cycling\n* Dynamic stretching, such as arm circles, torso twists, and leg swings\n\n**Core workout (15 minutes)**\n\n* Crunches: 3 sets of 10-12 repetitions\n* Plank: 3 sets of 30-60 seconds\n* Squats: 3 sets of 10-12 repetitions\n* Side'

Results are bad to OK.  Not sure that they are entirely correct. Will need to dig into the model documentation more to understand how this might be tuned, and what the separator (or filler) words are. Sometimes the output is primarily these seemingly random filler words. Maybe the system can't think of anything else to say?

Update: By switching to the 2b-it variant I now seem to get much better results. This is not surprising as the IT variant is designed for this sort of query pattern. The odd thing is that the 2b version presents the same queries, but it works poorly with them. Worth investigating that some more!