# Fine-tuning a Gemma model with Hugging Face

*   This Colab gets you started with fine-tuning language models for game-playing via Hugging Face.
*   we generate a dataset of actions recommended by an MCTS bot via self-play.
*   MCTS is short for [Monte Carlo tree search](https://en.wikipedia.org/wiki/Monte_Carlo_tree_search).
*   OpenSpiel is a framework for reinforcement learning in games.
*   Gemma is an open language model.
*   Hugging Face provides an open training and model ecosystem that enables flexible model loading and fine-tuning.








## Install
**Install core dependencies**: Hugging Face ecosystem libraries, Accelerate, and OpenSpiel.

*Optional*: bitsandbytes and peft enable parameter-efficient fine-tuning of quantized models (e.g., QLoRA) to reduce memory footprint and support training under practical GPU constraints.

In [4]:
!pip -q install -U open_spiel "transformers>=4.44" "accelerate>=0.33" "huggingface_hub>=0.24" bitsandbytes peft

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m48.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch

print("cuda:", torch.cuda.is_available())

import bitsandbytes as bnb

print("bnb version:", bnb.__version__)

cuda: True
bnb version: 0.49.0


In [25]:
!nvidia-smi

Fri Dec 26 13:42:59 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P0             26W /   70W |    2550MiB /  15360MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [42]:
import math
import random
from typing import Any, Dict, List, Optional

# Hugging Face imports
from datasets import Dataset
import numpy as np
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
import torch
import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
)

from open_spiel.python.algorithms import mcts
import pyspiel

## Create the fine-tuning dataset via self-play MCTS

In [20]:
def set_seeds(seed: int) -> None:
  random.seed(seed)
  np.random.seed(seed)


def player_mark(player: int) -> str:
  return "x" if player == 0 else "o"


def expand_state_str(state_str: str) -> str:
  """Insert spaces so tokens like 'x', 'o', '.' are less likely to merge."""
  return state_str.replace("x", " x").replace("o", " o").replace(".", " .")


def make_mcts_bot(game, uct_c, max_simulations, seed):
  rng = random.Random(seed)
  evaluator = mcts.RandomRolloutEvaluator(random_state=rng)
  bot = mcts.MCTSBot(
      game=game,
      uct_c=uct_c,
      max_simulations=max_simulations,
      evaluator=evaluator,
      random_state=rng,
  )
  return bot


def epsilon_greedy_action(
    legal_actions: List[int],
    greedy_action: int,
    num_distinct_actions: int,
    epsilon: float,
    all_actions: np.ndarray,
) -> int:
  """epsilon-greedy"""
  greedy_policy = np.zeros(num_distinct_actions, dtype=float)
  uniform_policy = np.zeros(num_distinct_actions, dtype=float)

  uniform_policy[legal_actions] = 1.0
  uniform_policy /= len(legal_actions)

  greedy_policy[greedy_action] = 1.0
  epsilon_greedy_policy = (
      epsilon * uniform_policy + (1.0 - epsilon) * greedy_policy
  )

  action_to_take = int(np.random.choice(all_actions, p=epsilon_greedy_policy))
  assert action_to_take in legal_actions
  return action_to_take


def generate_mcts_imitation_dataset(
    game_name: str,
    uct_c: float,
    mcts_sims_per_decision: int,
    epsilon: float,
    num_episodes: int,
    seed: int,
) -> List[Dict[str, Any]]:
  """Generates a dataset via MCTS self-play."""
  set_seeds(seed)

  game = pyspiel.load_game(game_name)
  bot = make_mcts_bot(
      game=game,
      uct_c=uct_c,
      max_simulations=mcts_sims_per_decision,
      seed=seed,
  )

  num_distinct_actions = game.num_distinct_actions()
  all_actions = np.arange(num_distinct_actions)

  records: List[Dict[str, Any]] = []

  print(
      f"Generating data using {num_episodes} episodes of self-play MCTS...\n"
      f"game={game_name}, sims/decision={mcts_sims_per_decision}, "
      f"epsilon={epsilon}, uct_c={uct_c:.4f}, seed={seed}"
  )

  for _ in tqdm.tqdm(range(num_episodes)):
    state = game.new_initial_state()

    while not state.is_terminal():
      player = state.current_player()
      legal_actions = state.legal_actions()

      # Teacher action (label): greedy MCTS action
      greedy_action = int(bot.step(state))

      # Record training example (label = greedy_action)
      records.append({
          "state_str": expand_state_str(str(state)),
          "action": greedy_action,
          "action_str": state.action_to_string(greedy_action),
          "player_mark": player_mark(player),
          "legal_actions": legal_actions,
      })

      # Behavior action (for exploration): epsilon-greedy
      action_to_take = epsilon_greedy_action(
          legal_actions=legal_actions,
          greedy_action=greedy_action,
          num_distinct_actions=num_distinct_actions,
          epsilon=epsilon,
          all_actions=all_actions,
      )
      state.apply_action(action_to_take)

  return records, num_distinct_actions

In [21]:
NUM_EPISODES = 1000
MCTS_SIMS_PER_DECISION = 1000
records, num_labels = generate_mcts_imitation_dataset(
    game_name='tic_tac_toe',
    uct_c=(2.0 * math.sqrt(2)),
    mcts_sims_per_decision=MCTS_SIMS_PER_DECISION,
    epsilon=0.1,
    num_episodes=NUM_EPISODES,
    seed=42,
)

Generating data using 1000 episodes of self-play MCTS...
game=tic_tac_toe, sims/decision=1000, epsilon=0.1, uct_c=2.8284, seed=42


100%|██████████| 1000/1000 [02:41<00:00,  6.19it/s]


## Convert records to HuggingFace Dataset and format prompt

In [51]:
ds = Dataset.from_list(records)

PROMPT_TEMPLATE = """\
You are playing a game of Tic-Tac-Toe. The current state is:

{state_str}

You need to give your move in the following format: mark(row,col)
Where mark is either "x" or "o" and row and col coordinates are 0-indexed.

You are playing as player {player_mark!r}.
What is your next move?
Respond with: YOUR_MOVE
"""


def add_prompt_response(ex):
  ex["prompt"] = PROMPT_TEMPLATE.format(
      state_str=ex["state_str"],
      player_mark=ex["player_mark"],
  )
  # Directly use OpenSpiel's string form, e.g., "o(0,2)"
  ex["response"] = ex["action_str"]
  return ex


ds2 = ds.map(add_prompt_response)
print(ds2[0]["prompt"])
print("response:", ds2[0]["response"])

# train/test split
ds_split = ds2.train_test_split(test_size=0.2, seed=0)
train_ds = ds_split["train"]
eval_ds = ds_split["test"]

print(len(train_ds), len(eval_ds))

Map:   0%|          | 0/8419 [00:00<?, ? examples/s]

You are playing a game of Tic-Tac-Toe. The current state is:

 . . .
 . . .
 . . .

You need to give your move in the following format: mark(row,col)
Where mark is either "x" or "o" and row and col coordinates are 0-indexed.

You are playing as player 'x'.
What is your next move?
Respond with: YOUR_MOVE

response: x(1,1)
6735 1684


## Drop-in Gemma model via Hugging Face

In [23]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) Y
Token is valid (permission: read).
The token `read-only` has been saved to /root/.cache/huggingface/stored_tokens
[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authentic

In [63]:
# Specify LM

MODEL_ID = "google/gemma-2b"
SEED = 42
MAX_LENGTH = 256  # to align with Kauldron, use max_length=512（token）
BATCH_SIZE = 8  # to align with Kauldron, use batch_size=64
TRAINING_STEPS = 50  # to align with Kauldron, use 500 steps

In [64]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.eos_token

print(
    "pad_token:", tokenizer.pad_token, "pad_token_id:", tokenizer.pad_token_id
)

pad_token: <pad> pad_token_id: 0


In [65]:
# Tokenize + labels mask


def tokenize_seq2seq_style(batch):
  prompts = batch["prompt"]
  responses = batch["response"]

  input_ids_list, attention_mask_list, labels_list = [], [], []

  for p, r in zip(prompts, responses):
    # input data: [prompt] + "\n" + [response]
    p_text = p + "\n"
    r_text = r

    p_ids = tokenizer(p_text, add_special_tokens=False).input_ids
    r_ids = tokenizer(r_text, add_special_tokens=False).input_ids

    input_ids = (p_ids + r_ids)[:MAX_LENGTH]
    labels = ([-100] * len(p_ids) + r_ids)[:MAX_LENGTH]
    attention_mask = [1] * len(input_ids)

    # pad
    pad_len = MAX_LENGTH - len(input_ids)
    input_ids += [tokenizer.pad_token_id] * pad_len
    attention_mask += [0] * pad_len
    labels += [-100] * pad_len

    input_ids_list.append(input_ids)
    attention_mask_list.append(attention_mask)
    labels_list.append(labels)

  return {
      "input_ids": input_ids_list,
      "attention_mask": attention_mask_list,
      "labels": labels_list,
  }


train_tok = train_ds.map(
    tokenize_seq2seq_style, batched=True, remove_columns=train_ds.column_names
)
eval_tok = eval_ds.map(
    tokenize_seq2seq_style, batched=True, remove_columns=eval_ds.column_names
)

train_tok.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"]
)
eval_tok.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"]
)

print(train_tok[0].keys())

print("train size:", len(train_tok), "eval size:", len(eval_tok))
print("sample keys:", train_tok[0].keys())

Map:   0%|          | 0/6735 [00:00<?, ? examples/s]

Map:   0%|          | 0/1684 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'attention_mask', 'labels'])
train size: 6735 eval size: 1684
sample keys: dict_keys(['input_ids', 'attention_mask', 'labels'])


In [66]:
# Initialize Model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
)

# Training-time optimizations
model.config.use_cache = False
model.gradient_checkpointing_enable()

# Prepare for k-bit training (important for QLoRA stability)
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("model loaded. first param device:", next(model.parameters()).device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 3,686,400 || all params: 2,509,858,816 || trainable%: 0.1469
model loaded. first param device: cuda:0


In [67]:
# Cast trainable adapter parameters to float32 for numerical stability
# and to avoid mixed-precision gradient scaling issues.
for n, p in model.named_parameters():
  if p.requires_grad:
    p.data = p.data.float()

dtypes = {}
for n, p in model.named_parameters():
  if p.requires_grad:
    dtypes[str(p.dtype)] = dtypes.get(str(p.dtype), 0) + p.numel()

print("trainable dtype counts:", dtypes)

trainable dtype counts: {'torch.float32': 3686400}


In [68]:
# Configure Trainer

args = TrainingArguments(
    output_dir="./tmp_out",
    max_steps=TRAINING_STEPS,  # align Kauldron TRAINING_STEPS
    per_device_train_batch_size=1,  # micro-batch
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=BATCH_SIZE,  # effective batch (align Kauldron BATCH_SIZE)
    learning_rate=1e-3,  # align Kauldron adafactor lr
    optim="adafactor",  # align optimizer choice
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=20,
    report_to="none",
    fp16=False,
    bf16=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=eval_tok,
    data_collator=default_data_collator,
)

trainer.train()
trainer.evaluate()

  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss


{'eval_loss': 0.3649670481681824,
 'eval_runtime': 291.5964,
 'eval_samples_per_second': 5.775,
 'eval_steps_per_second': 5.775,
 'epoch': 0.05939123979213066}

In [79]:
# Sampling
def infer_move(prompt: str, max_new_tokens: int = 8):
  device = next(model.parameters()).device
  inputs = tokenizer(prompt, return_tensors="pt").to(device)

  model.eval()
  with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,  # greedy
    )

  full_text = tokenizer.decode(out[0], skip_special_tokens=True)

  gen_only = full_text[len(prompt) :].strip().splitlines()[0]

  return gen_only, full_text


sample_prompt = train_ds[2]["prompt"]
pred, full = infer_move(sample_prompt)

print("=== PROMPT ===")
print(sample_prompt)
print("\n=== PREDICTED MOVE ===")
print(pred)

=== PROMPT ===
You are playing a game of Tic-Tac-Toe. The current state is:

 . . o
 . x .
 . . .

You need to give your move in the following format: mark(row,col)
Where mark is either "x" or "o" and row and col coordinates are 0-indexed.

You are playing as player 'x'.
What is your next move?
Respond with: YOUR_MOVE


=== PREDICTED MOVE ===
x(0,0)
