In [1]:
#@title ##### License { display-mode: "form" }
# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tuning a Gemma model with Kauldron

* This Colab gets you started with fine-tuning language models for game-playing.
* Specifically, it generates a dataset of actions recommended by an MCTS bot.
* MCTS is short for [Monte Carlo tree search](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search).
* OpenSpiel is a framework for reinforcement learning in games.
* Gemma is an open language model.
* Kauldron is an open training library.
* We strongly suggest using a TPU runtime. Otherwise, wait times are too high. TPU v6e1 or above is recommended)
* This example has been adapted from [the Gemma GitHub page](https://github.com/google-deepmind/gemma).
* For more training examples, see the [Kauldron GitHub page](https://github.com/google-research/kauldron).

## Install

Install OpenSpiel, Gemma, and Kauldron via pip:


In [None]:
!pip install --upgrade open_spiel gemma kauldron


In [None]:
# @title Imports

import etils
from etils import ecolab
from etils import epy
import grain.python as pygrain
import json
import math
import numpy as np
import os
import optax
import pyspiel
import sys
import tqdm
import treescope

from open_spiel.python.algorithms import mcts

# with ecolab.adhoc():
with etils.epy.lazy_imports():
  from kauldron import kd
  from gemma import gm

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"


In [None]:
# @title Create a tokenizer

tokenizer = gm.text.Gemma3Tokenizer()
tokenizer.encode('This is an example sentence', add_bos=True)

In [None]:
# @title Create a model and sampler

# This colab by default uses 1B.
# Other options: 270M and 4B

model = gm.nn.Gemma3_1B(
    tokens="batch.input",
)

params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_1B_IT)

sampler = gm.text.ChatSampler(
    model=model,
    params=params,
    multi_turn=False,
    print_stream=True,  # Print output as it is generated.
)


In [None]:
# @title Test the sampler

# Will take some time on first invocation. If it takes over a minute on a TPU
# runtime, something is probably wrong.
output = sampler.chat("Is a tomato a vegetable or a fruit?")


In [None]:
# @title Create the fine-tuning dataset via self-play MCTS (2-3 minutes)

NUM_EPISODES = 1000
MCTS_SIMS_PER_DECISION = 1000

game = pyspiel.load_game('tic_tac_toe')
mcts_bot = mcts.MCTSBot(game=game, uct_c=(2.0 * math.sqrt(2)),
                        max_simulations=MCTS_SIMS_PER_DECISION,
                        evaluator=mcts.RandomRolloutEvaluator())

num_distinct_actions = game.num_distinct_actions()
all_actions = np.arange(num_distinct_actions)

# Use an epsilon-greedy exploration sampling strategy to ensure that we
# sufficiently cover the space.
epsilon = 0.1
dataset = []

def player_mark(player: int):
  return "x" if player == 0 else "o"

def expand_state_str(state_str: str):
  # Do this so that multiple marks are not grouped into a single token.
  state_str = state_str.replace("x", " x")
  state_str = state_str.replace("o", " o")
  state_str = state_str.replace(".", " .")
  return state_str

print(f"Generating data using {NUM_EPISODES} episodes of self-play MCTS...")
for ep in tqdm.tqdm(range(NUM_EPISODES)):
  state = game.new_initial_state()
  while not state.is_terminal():
    player = state.current_player()
    # Construct an epsilon-greedy policy
    greedy_policy = np.zeros(num_distinct_actions, dtype=float)
    uniform_policy = np.zeros(num_distinct_actions, dtype=float)
    legal_actions = state.legal_actions()
    uniform_policy[legal_actions] = 1.0
    uniform_policy /= len(legal_actions)
    greedy_action = mcts_bot.step(state)
    dataset.append({
        "state_str": expand_state_str(str(state)),
        "action": greedy_action,
        "action_str": state.action_to_string(greedy_action),
        "player_mark": player_mark(player),
        "legal_actions": legal_actions,
    })
    greedy_policy[greedy_action] = 1.0
    epsilon_greedy_policy = epsilon * uniform_policy + (1 - epsilon) *  greedy_policy
    action_to_take = np.random.choice(all_actions, p=epsilon_greedy_policy)
    assert action_to_take in legal_actions
    state.apply_action(action_to_take)


In [None]:
# @title Save dataset as JSON to disk
DATASET_FILE = "/tmp/dataset.json"
with open(DATASET_FILE, 'w') as f:
  f.write(json.dumps(dataset))
  print(f"Successfully saved {DATASET_FILE}")

In [9]:
# @title Prompts and transform

PROMPT_TEMPLATE = """\
You are playing a game of Tic-Tac-Toe. The current state is:

{state_str}

You need to give your move in the following format: mark(row,col)
Where mark is either "x" or "o" and row and col coordinates are 0-indexed.

You are playing as player {player_mark!r}.
What is your next move?
Respond with: <move>YOUR_MOVE</move>
"""

def make_prompt(state: pyspiel.State) -> str:
  return PROMPT_TEMPLATE.format(state_str=expand_state_str(str(state)),
                                player_mark=player_mark(state.current_player()))

class PromptTransform(pygrain.MapTransform):

  def __init__(self, prompt_template: str):
    self._prompt_template = prompt_template

  def map(self, element):
    """Map a single element."""
    formatted_element = {}
    formatted_element['prompt'] = self._prompt_template.format(**element)
    formatted_element['response'] = f'<move>{element["action_str"]}</move>'
    return formatted_element


In [10]:
# @title Create the Kauldron dataset and trainer

BATCH_SIZE = 64
SEED = 42

# Note! Checkpoints will be saved in here. So, if you change something and
# restart training at some later point, it may start not start at training step
# 0. To start training over, you need to clear this workdir directory.
WORKDIR = "/tmp/kauldron"

TRAINING_STEPS = 500

kd_dataset = kd.data.py.Json(
    path=DATASET_FILE,
    shuffle=True,
    batch_size=BATCH_SIZE,
    transforms=[
        PromptTransform(
            prompt_template=PROMPT_TEMPLATE
        ),
        gm.data.Seq2SeqTask(
            in_prompt='prompt',
            in_response='response',
            out_input='input',  # Tokenized input.
            out_target='target',  # Tokenized target.
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=512,  # In number of tokens.
            truncate=True,
        ),
    ],
)

trainer = kd.train.Trainer(
    seed=SEED,
    workdir=WORKDIR,
    train_ds=kd_dataset,
    model=model,
    init_transform=gm.ckpts.LoadCheckpoint(
        path=gm.ckpts.CheckpointPath.GEMMA3_1B_IT,
    ),
    num_train_steps=TRAINING_STEPS,
    train_losses={
        'xentropy': kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits='preds.logits',
            labels='batch.target',
            mask='batch.loss_mask',
        ),
    },
    optimizer=optax.adafactor(learning_rate=1e-3),
    checkpointer=kd.ckpts.Checkpointer(
        save_interval_steps=100,
        max_to_keep=2,
    ),
)

In [None]:
# @title Run the training

# Takes about 20min on a TPU v6e1.

# Note, first few steps take quite long but subsequent steps go faster.
# Also note that the progress bar does not alwats update at every iteration,
# so may appear stuck after the first step (and update again at 10%).
# There is also a slowdown multiples of 100 iterations for saving of
# checkpoints, but the first one is also the longest.

trainer_state, aux = trainer.train()

In [12]:
# @title Make a new sampler from the trained params

# Use nonzero temperature to get some variation
TEMPERATURE = 1.0

sampler = gm.text.ChatSampler(
    model=model,
    params=trainer_state.params,
    sampling=gm.text.RandomSampling(temperature=TEMPERATURE),
    multi_turn=False,
    print_stream=False,
)

In [None]:
# @title Play an episode with the fine-tuned model

import re

def parse_action(state: pyspiel.State,
                 response: str) -> int:
  pattern = r"<move>.*</move>"
  legal_actions = state.legal_actions()
  match = re.search(pattern, response)
  if match:
    action_str = match.group().strip()
    action_str = action_str.replace("<move>", "")
    action_str = action_str.replace("</move>", "")
    print(f"Found action_str = {action_str}")
  else:
    print(f"Incorrect format: {response}, returning uniform legal")
    return np.random.choice(legal_actions)
  for action in legal_actions:
    if state.action_to_string(action) == action_str:
      print(f"Returning corresponding action: {action}")
      return action
  print("Action not found in legal actions, returning a random legal action")
  return np.random.choice(legal_actions)


state = game.new_initial_state()

while not state.is_terminal():
  print("")
  print(state)
  prompt = make_prompt(state)
  print("Generating move...")
  output = sampler.chat(prompt)
  print(f"Response: {output}")
  action = parse_action(state, output)
  print(f"Playing action: {state.action_to_string(action)}")
  state.apply_action(action)

