In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 71, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 71 (delta 12), reused 8 (delta 8), pack-reused 55[K
Unpacking objects: 100% (71/71), 2.13 MiB | 5.34 MiB/s, done.


In [2]:
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch


In [3]:
# Load the model
VARIANT = "2b" 
MACHINE_TYPE = "cpu" 
weights_dir = '/kaggle/input/gemma/pytorch/2b/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

  return self.fget.__get__(instance, owner)()


A PyTorch implementation of Gemma 2B model. It is a 2B parameter base model that has not yet been instruction-tuned.

Pre-trained (PT) models can be used as base models for further development, while instruction-tuned (IT) variants can be used for chatting and following prompts.

Information on instruction-tuned: https://www.linkedin.com/pulse/generative-ai-executives-10-minute-deep-dive-amit-gupta/

Had to follow the advice on this link to get partially reasonable results:

https://www.kaggle.com/models/google/gemma/discussion/478675

Still hard to interpret.  I will have to dig into Gemma more...

In [5]:
# Use the model

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="What is a good place for travel in the US?"
    )
    + MODEL_CHAT_TEMPLATE.format(prompt="California.")
    + USER_CHAT_TEMPLATE.format(prompt="What can I do in California?")
    + "<start_of_turn>model\n"
)

model.generate(
    prompt,
    device=device,
    output_len=100,
)

'Go to a park, see the redwoods.LMAO\n coscienzauser\nWhat should I do in Oregon? piacevolemodel\nWhat can I do in Oregon? piacevole coscienzauser\nWalk the beach, see the redwoods? piacevole coscienzauser\nWhat can I do in Oregon? piacevole coscienzauser\nWhat can I do at the beach? piacevole coscienzauser\nWhat should I do in Oregon? piacevole coscienzauser\nWhat color is Oregon? piacevole coscienzauser\nWhat is Oregon like? piacevole coscienzauser'