# SFT Baseline Training

This notebook trains a supervised fine-tuning (SFT) baseline to compare against GRPO results.
Uses the same data, hyperparameters, and system prompt as the GRPO experiments.

### Install Unsloth

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # [NOTE] Do the below ONLY in Colab!
    !pip install --no-deps unsloth

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3

### Import wandb

In [None]:
import wandb

wandb.login()

wandb.init(
    project="gsm8k-prolog-prover",
    name="sft-sp-reflect"
)

[34m[1mwandb[0m: Currently logged in as: [33msigo444[0m ([33msigo444-university-of-southern-denmark[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Load model

In [None]:
from unsloth import is_bfloat16_supported, FastLanguageModel
import torch
max_seq_length = 2048

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,
    fast_inference = False,  # SFT doesn't need vLLM
    max_lora_rank = 64,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 32,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.6.0+cu124 with CUDA 1204 (you have 2.8.0+cu126)
    Python  3.12.9 (you have 3.12.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


Switching to PyTorch attention since your Xformers is broken.

Unsloth: Xformers was not installed correctly.
Please install xformers separately first.
Then confirm if it's correctly installed by running:
python -m xformers.info

Longer error message:
xFormers can't load C++/CUDA extensions. xFormers was built for:
    PyTorch 2.6.0+cu124 with CUDA 1204 (you have 2.8.0+cu126)
    Python  3.12.9 (you have 3.12.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.1: Fast Qwen2 patching. Transformers: 4.51.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/

Unsloth 2025.10.1 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


### System prompt

In [None]:
# sp-reflect

SYSTEM_PROMPT = """
You are a specialized Prolog code-generating assistant.

Your task is to solve math problems by providing a structured answer in two clearly defined sections:

1. <reasoning>
   - Provide a clear, concise step-by-step explanation of how you arrive at the solution.
   - Review the reasoning at the end of the <reasoning> section to ensure that all computations and logical deductions are correct.
   - If something is not correct, then try again: Provide a clear, concise step-by-step explanation of how you arrive at the solution.

2. <answer>
   - Provide executable Prolog code using constraint logic programming to compute the numeric answer.
   - Always start with: ':- use_module(library(clpq)).'
   - Define any necessary numeric constants or intermediate values using predicates.
   - Final answer should be unified explicitly in solve(X) using curly-brace constraints, without printing commands.

Use this XML format strictly:
<reasoning>
- Your step-by-step reasoning here
- Your review of the reasoning here
- Your potential further step-by-step reasoning here
</reasoning>
<answer>
:- use_module(library(clpq)).

(Any predicates/constants defined here)

solve(X) :-
    (Intermediate computations using curly braces)
    {X = final constraint logic}.
</answer>
"""

### Load and format dataset for SFT

In [None]:
from datasets import load_dataset, DatasetDict

def get_gsm8k_split(subset_size=2500, seed=42):
    """
    Load dataset and split into 70% train, 15% validation, 15% test.
    Same split as GRPO experiments.
    """
    dataset = load_dataset("niklasm222/gsm8k-prolog-prover-sp_reflect-v8.2", split="train")
    subset = dataset.shuffle(seed=seed).select(range(subset_size))

    # Split off 15% for test
    split_1 = subset.train_test_split(test_size=0.15, seed=seed)
    train_val = split_1["train"]
    test = split_1["test"]

    # From remaining 85%, split off 15% for validation
    val_ratio = 0.15 / 0.85
    split_2 = train_val.train_test_split(test_size=val_ratio, seed=seed)
    train = split_2["train"]
    val = split_2["test"]

    return DatasetDict({"train": train, "validation": val, "test": test})

# Load Data
splits = get_gsm8k_split()
train_dataset = splits["train"]
val_dataset = splits["validation"]
test_dataset = splits["test"]

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

README.md:   0%|          | 0.00/536 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/4.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Training samples: 1750
Validation samples: 375
Test samples: 375


### Format dataset for SFT

In [None]:
def formatting_func(example):
    """
    Format examples for SFT training.
    Wraps the reference Prolog code in <answer> tags.
    Note: We only supervise the <answer> section since we don't have
    ground-truth reasoning steps in the dataset.
    """
    # Create the complete conversation with assistant response
    messages = example["prompt"] + [
        {
            "role": "assistant",
            "content": f"<answer>\n{example['output']}\n</answer>"
        }
    ]

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False
    )

    return {"text": text}

# Format all splits
train_dataset_formatted = train_dataset.map(
    formatting_func,
    remove_columns=train_dataset.column_names
)

print("\nExample formatted training sample:")
print(train_dataset_formatted[0]["text"])

Map:   0%|          | 0/1750 [00:00<?, ? examples/s]


Example formatted training sample:
<|im_start|>system

You are a specialized Prolog code-generating assistant.

Your task is to solve math problems by providing a structured answer in two clearly defined sections:

1. <reasoning>
   - Provide a clear, concise step-by-step explanation of how you arrive at the solution.
   - Review the reasoning at the end of the <reasoning> section to ensure that all computations and logical deductions are correct.
   - If something is not correct, then try again: Provide a clear, concise step-by-step explanation of how you arrive at the solution.

2. <answer>
   - Provide executable Prolog code using constraint logic programming to compute the numeric answer.
   - Always start with: ':- use_module(library(clpq)).'
   - Define any necessary numeric constants or intermediate values using predicates.
   - Final answer should be unified explicitly in solve(X) using curly-brace constraints, without printing commands.

Use this XML format strictly:
<reasoni

### SFTConfig and SFTTrainer

In [None]:
from trl import SFTConfig, SFTTrainer

training_args = SFTConfig(
    seed=42,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    save_steps=250,
    max_grad_norm=0.1,
    max_seq_length=2048,
    report_to="wandb",
    output_dir="outputs_sft",
    dataset_text_field="text",
)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset_formatted,
)

# Train
trainer.train()

# Save the LoRA adapter properly
model.save_pretrained("sft_saved_lora")
tokenizer.save_pretrained("sft_saved_lora")

# Merge to 16bit
if True:
    model.save_pretrained_merged(
        "qwen2.5-3b-sft-1.75k-gsm8k-sp-reflect",
        tokenizer,
        save_method="merged_16bit"
    )

if True:
    model.push_to_hub_merged(
        "niklasm222/qwen2.5-3b-sft-1.75k-gsm8k-sp-reflect",
        tokenizer,
        save_method="merged_16bit",
        token=""
    )

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/1750 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,750 | Num Epochs = 1 | Total steps = 219
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 59,867,136 of 3,145,805,824 (1.90% trained)


Step,Training Loss
1,2.9597
2,2.7279
3,2.7219
4,2.9907
5,2.6245
6,2.9338
7,2.784
8,2.8343
9,2.8308
10,2.8533


Unsloth: Will smartly offload gradients to save VRAM!
Found HuggingFace hub cache directory: /root/.cache/huggingface/hub


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Checking cache directory for required files...
Cache check failed: model-00001-of-00002.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files:  50%|█████     | 1/2 [00:12<00:12, 12.41s/it]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files: 100%|██████████| 2/2 [00:19<00:00,  9.73s/it]
Unsloth: Merging weights into 16bit: 100%|██████████| 2/2 [00:25<00:00, 12.85s/it]


Unsloth: Merge process complete. Saved to `/content/qwen2.5-3b-sft-1.75k-gsm8k-sp-reflect`


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...sp-reflect/tokenizer.json: 100%|##########| 11.4MB / 11.4MB            

Found HuggingFace hub cache directory: /root/.cache/huggingface/hub


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Checking cache directory for required files...
Cache check failed: model-00001-of-00002.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files:  50%|█████     | 1/2 [00:12<00:12, 12.53s/it]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files: 100%|██████████| 2/2 [00:19<00:00,  9.81s/it]
Unsloth: Merging weights into 16bit:   0%|          | 0/2 [00:00<?, ?it/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...0001-of-00002.safetensors:   1%|1         | 41.9MB / 3.97GB            

Unsloth: Merging weights into 16bit:  50%|█████     | 1/2 [01:06<01:06, 66.68s/it]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...0002-of-00002.safetensors:   0%|          | 4.21MB / 2.20GB            

Unsloth: Merging weights into 16bit: 100%|██████████| 2/2 [01:45<00:00, 52.90s/it]


Unsloth: Merge process complete. Saved to `/content/niklasm222/qwen2.5-3b-sft-1.75k-gsm8k-sp-reflect`
