# Fine-tuning Ministral-3B on PokÃ©mon Showdown
Upload `dataset.jsonl` to the Colab runtime before running.

In [None]:
# Check GPU
!nvidia-smi

In [None]:
!pip install -q --upgrade "transformers>=5.0.0.dev0" trl peft accelerate bitsandbytes "mistral-common>=1.8.6"
!pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
from huggingface_hub import login
from google.colab import userdata

login(token=userdata.get("HF_TOKEN"))

In [None]:
import torch
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer

MODEL = "mistralai/Ministral-3-3B-Instruct-2512-BF16"

# H100 - no quantization needed, full BF16
model = Mistral3ForConditionalGeneration.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.tie_weights()

tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
import json
import random
from datasets import Dataset

samples = []
with open("dataset.jsonl") as f:
    for line in f:
        s = json.loads(line)
        samples.append({
            "messages": [
                {"role": "user", "content": s["prompt"]},
                {"role": "assistant", "content": s["completion"]},
            ]
        })

random.shuffle(samples)
split = int(len(samples) * 0.95)
train_data = Dataset.from_list(samples[:split])
val_data = Dataset.from_list(samples[split:])

print(f"Train: {len(train_data)} | Val: {len(val_data)}")

In [None]:
def format_sample(sample):
    return {"text": tokenizer.apply_chat_template(
        sample["messages"],
        tokenize=False,
        add_generation_prompt=False,
    )}

train_data = train_data.map(format_sample)
val_data = val_data.map(format_sample)

In [None]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=SFTConfig(
        dataset_text_field="text",
        max_length=512,
        packing=False,
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,
        max_steps=1000,
        learning_rate=1e-4,
        bf16=True,
        logging_steps=50,
        eval_strategy="steps",
        eval_steps=200,
        save_strategy="steps",
        save_steps=200,
        output_dir="output",
        optim="adamw_torch",
        warmup_steps=50,
        seed=42,
    ),
)

trainer.train()

In [None]:
# Quick inference test
model.eval()

prompt = "Turn 1. Weather: none. Your pokemon: Garchomp (100/100 HP, healthy) | Type: dragon/ground | Atk: 130 SpA: 80 Spe: 102. Opponent: Kingambit (100/100 HP, healthy) | Type: dark/steel | Def: 100 SpD: 60 Spe: 50. What move do you use?"

encoded = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
)
input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded
input_ids = input_ids.to("cuda")

with torch.no_grad():
    outputs = model.generate(input_ids, max_new_tokens=32, temperature=0.1, do_sample=True)
print(tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True))

In [None]:
# Merge LoRA into base model and push full model to HuggingFace Hub
import torch
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
from peft import PeftModel
from google.colab import userdata
import os

REPO_NAME = "mistral-hackaton-2026/ministral-3b-pokemon-showdown"
MODEL = "mistralai/Ministral-3-3B-Instruct-2512-BF16"
HF_TOKEN = userdata.get("HF_TOKEN")

# Find latest checkpoint
checkpoints = sorted([d for d in os.listdir("output") if d.startswith("checkpoint-")])
adapter_path = os.path.join("output", checkpoints[-1]) if checkpoints else "output"
print(f"Loading adapter from: {adapter_path}")

# Reload base model in full BF16
print("Loading base model in BF16...")
base_model = Mistral3ForConditionalGeneration.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
base_model.tie_weights()

# Load the LoRA adapter from local training output
print("Loading LoRA adapter...")
peft_model = PeftModel.from_pretrained(base_model, adapter_path)

# Merge LoRA weights into the base model
print("Merging weights...")
merged_model = peft_model.merge_and_unload()

# Push full merged model + tokenizer
print("Pushing merged model to HuggingFace Hub...")
merged_model.push_to_hub(REPO_NAME, token=HF_TOKEN)
tokenizer.push_to_hub(REPO_NAME, token=HF_TOKEN)
print("Done!")