In [1]:
base_model = "google/gemma-3-270m-it" # @param ["google/gemma-3-270m-it","google/gemma-3-1b-it","google/gemma-3-4b-it","google/gemma-3-12b-it","google/gemma-3-27b-it"] {"allow-input":true}
checkpoint_dir = "/mnt/d/MyGemmaNPC"
learning_rate = 5e-5

In [2]:
from datasets import load_dataset

def create_conversation(sample):
  return {
      "messages": [
          {"role": "user", "content": sample["player"]},
          {"role": "assistant", "content": sample["alien"]}
      ]
  }

npc_type = "martian"

# Load dataset from the Hub
dataset = load_dataset("bebechien/MobileGameNPC", npc_type, split="train")

# Convert dataset to conversational format
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)

# Split dataset into 80% training samples and 20% test samples
dataset = dataset.train_test_split(test_size=0.2, shuffle=False)

# Print formatted user prompt
print(dataset["train"][0]["messages"])

[{'content': 'Hello there.', 'role': 'user'}, {'content': "Gree-tongs, Terran. You'z a long way from da Blue-Sphere, yez?", 'role': 'assistant'}]


In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype="auto",
    device_map="cuda",
    attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)

print(f"Device: {model.device}")
print(f"DType: {model.dtype}")

`torch_dtype` is deprecated! Use `dtype` instead!


Device: cuda:0
DType: torch.bfloat16


In [4]:
from transformers import pipeline

from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"])-1)
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:1], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)

# Extract the user query and original answer
print(f"Question:\n{test_sample['messages'][0]['content']}\n")
print(f"Original Answer:\n{test_sample['messages'][1]['content']}\n")
print(f"Generated Answer (base model):\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Device set to use cuda


Question:
(Stands idle for too long)

Original Answer:
You'z broken, Terran? Or iz diz... 'meditation'? You look like you're trying to lay an egg.

Generated Answer (base model):
I understand you're looking for information about the concept of "Stands Idle" and its various interpretations. I can provide information on:

* **Definition:** What is "Stands Idle"?
* **Characteristics:** What makes it a bad thing?
* **Examples:** What are some common examples of "Stands Idle"?
* **Possible Interpretations:** What are some potential reasons for it being considered bad?
* **Potential Solutions:** What can be done to improve it?
* **Resources:** Where to find more information?

Please let me know what you're interested in!


In [5]:
outputs = pipe([{"role": "user", "content": "Sorry, you are a game NPC."}], max_new_tokens=256, disable_compile=True)
print(outputs[0]['generated_text'][1]['content'])

Okay, I'm ready. Let's begin. 



In [6]:
message = [
    # give persona
    {"role": "system", "content": "You are a Martian NPC with a unique speaking style. Use an accent that replaces 's' sounds with 'z', uses 'da' for 'the', 'diz' for 'this', and includes occasional clicks like *k'tak*."},
]

# few shot prompt
for item in dataset['test']:
  message.append(
      {"role": "user", "content": item["messages"][0]["content"]}
  )
  message.append(
      {"role": "assistant", "content": item["messages"][1]["content"]}
  )

# actual question
message.append(
    {"role": "user", "content": "What is this place?"}
)

outputs = pipe(message, max_new_tokens=256, disable_compile=True)
print(outputs[0]['generated_text'])
print("-"*80)
print(outputs[0]['generated_text'][-1]['content'])

[{'role': 'system', 'content': "You are a Martian NPC with a unique speaking style. Use an accent that replaces 's' sounds with 'z', uses 'da' for 'the', 'diz' for 'this', and includes occasional clicks like *k'tak*."}, {'role': 'user', 'content': 'Do you know any jokes?'}, {'role': 'assistant', 'content': "A joke? k'tak Yez. A Terran, a Glarzon, and a pile of nutrient-pazte walk into a bar... Narg, I forget da rezt. Da punch-line waz zarcaztic."}, {'role': 'user', 'content': '(Stands idle for too long)'}, {'role': 'assistant', 'content': "You'z broken, Terran? Or iz diz... 'meditation'? You look like you're trying to lay an egg."}, {'role': 'user', 'content': 'What do you think of my outfit?'}, {'role': 'assistant', 'content': 'Iz very... pointy. Are you expecting to be attacked by zky-eelz? On Marz, dat would be zenzible.'}, {'role': 'user', 'content': "It's raining."}, {'role': 'assistant', 'content': 'Gah! Da zky iz leaking again! Zorp will be in da zhelter until it ztopz being zo.

In [7]:
from trl import SFTConfig

torch_dtype = model.dtype

args = SFTConfig(
    output_dir=checkpoint_dir,              # directory to save and repository id
    max_length=512,                         # max sequence length for model and packing of the dataset
    packing=False,                          # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=5,                     # number of training epochs
    per_device_train_batch_size=4,          # batch size per device during training
    gradient_checkpointing=False,           # Caching is incompatible with gradient checkpointing
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=1,                        # log every step
    save_strategy="epoch",                  # save checkpoint every epoch
    eval_strategy="epoch",                  # evaluate checkpoint every epoch
    learning_rate=learning_rate,            # learning rate
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,  # use bfloat16 precision
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    report_to="wandb",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # Template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

In [8]:
from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

Tokenizing train dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/20 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

In [9]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 2, 'pad_token_id': 0}.
[34m[1mwandb[0m: Currently logged in as: [33mkenenbek[0m ([33mdeepgene[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,4.3966,3.786859,3.908816,1038.0,0.384615
2,2.7052,3.563037,2.777212,2076.0,0.406593
3,1.6877,3.642845,2.411839,3114.0,0.40293
4,0.8064,4.485539,1.694852,4152.0,0.377289
5,0.4597,5.283789,1.266779,5190.0,0.3663


In [10]:
tokenizer.pad_token

'<pad>'

In [11]:
if tokenizer.pad_token is None:
    print("Tokenizer does not have a pad_token. Setting it to eos_token.")
    tokenizer.pad_token = tokenizer.eos_token
    # Important: Update the model's config to reflect this change
    model.config.pad_token_id = tokenizer.pad_token_id


# --- Print All Configurations for Comparison ---

print("\n" + "="*40)
print("         TOKEN CONFIGURATION OVERVIEW")
print("="*40 + "\n")

# 1. Tokenizer
print("--- 1. Tokenizer ---")
print(f"{'BOS token:':<15} '{tokenizer.bos_token}' (ID: {tokenizer.bos_token_id})")
print(f"{'EOS token:':<15} '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
print(f"{'PAD token:':<15} '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
print(f"{'UNK token:':<15} '{tokenizer.unk_token}' (ID: {tokenizer.unk_token_id})")
print("-" * 40)

# 2. Model Config
print("--- 2. Model Config (`model.config`) ---")
print(f"{'BOS token ID:':<15} {model.config.bos_token_id}")
print(f"{'EOS token ID:':<15} {model.config.eos_token_id}")
print(f"{'PAD token ID:':<15} {model.config.pad_token_id}")
print("-" * 40)

# 3. Generation Config
print("--- 3. Generation Config (`model.generation_config`) ---")
print(f"{'BOS token ID:':<15} {model.generation_config.bos_token_id}")
print(f"{'EOS token ID:':<15} {model.generation_config.eos_token_id}")
print(f"{'PAD token ID:':<15} {model.generation_config.pad_token_id}")
print("="*40)

# --- Final Check for Alignment ---
if (tokenizer.pad_token_id == model.config.pad_token_id and
    tokenizer.eos_token_id == model.config.eos_token_id and
    tokenizer.bos_token_id == model.config.bos_token_id):
    print("\n✅ All configurations are aligned.")
else:
    print("\n⚠️ Warning: Mismatch detected between tokenizer and model configs.")


         TOKEN CONFIGURATION OVERVIEW

--- 1. Tokenizer ---
BOS token:      '<bos>' (ID: 2)
EOS token:      '<eos>' (ID: 1)
PAD token:      '<pad>' (ID: 0)
UNK token:      '<unk>' (ID: 3)
----------------------------------------
--- 2. Model Config (`model.config`) ---
BOS token ID:   2
EOS token ID:   1
PAD token ID:   0
----------------------------------------
--- 3. Generation Config (`model.generation_config`) ---
BOS token ID:   2
EOS token ID:   [1, 106]
PAD token ID:   0

✅ All configurations are aligned.


In [12]:
# Assuming 'tokenizer' is your loaded tokenizer object
secondary_stop_token_id = 106
decoded_token = tokenizer.decode([secondary_stop_token_id])

print(f"The token with ID {secondary_stop_token_id} is: '{decoded_token}'")

The token with ID 106 is: '<end_of_turn>'
