This is a modular notebook that trains a supervised fine-tuning (SFT) baseline to compare against GRPO results. We use the same data, hyperparameters, and system prompt as the GRPO experiments.

The system prompt (currently sp-struct) can be replaced by sp-base, sp-declare, and sp-reflect, and the dataset can be replaced to pair with the correct system prompt.

### Install Unsloth

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # [NOTE] Do the below ONLY in Colab!
    !pip install --no-deps unsloth

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3

### Import wandb

In [None]:
import wandb

wandb.login()

wandb.init(
    project="gsm8k-prolog-prover",
    name="sft-sp-struct"   # Changed name
)

### Load model

In [None]:
from unsloth import is_bfloat16_supported, FastLanguageModel
import torch
max_seq_length = 2048

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,
    fast_inference = False,  # SFT doesn't need vLLM
    max_lora_rank = 64,
    gpu_memory_utilization = 0.7,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 32,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

### System prompt

In [None]:
SYSTEM_PROMPT = """
You are a specialized Prolog code-generating assistant.

Your task is to solve math problems by providing a structured answer in two clearly defined sections:

1. <reasoning>
   - Provide a clear, concise step-by-step explanation of how you arrive at the solution.

2. <answer>
   - Provide executable Prolog code using constraint logic programming to compute the numeric answer.
   - Always start with: ':- use_module(library(clpq)).'
   - Define any necessary numeric constants or intermediate values using predicates.
   - Final answer should be unified explicitly in solve(X) using curly-brace constraints, without printing commands.

Use this XML format strictly:
<reasoning>
(Your step-by-step reasoning here)
</reasoning>
<answer>
:- use_module(library(clpq)).

(Any predicates/constants defined here)

solve(X) :-
    (Intermediate computations using curly braces)
    {X = final constraint logic}.
</answer>
"""

### Load and format dataset for SFT

In [None]:
from datasets import load_dataset, DatasetDict

def get_gsm8k_split(subset_size=2500, seed=42):
    """
    Load dataset and split into 70% train, 15% validation, 15% test.
    Same split as GRPO experiments.
    """
    dataset = load_dataset("niklasm222/gsm8k-prolog-prover-sp_struct-v4", split="train")
    subset = dataset.shuffle(seed=seed).select(range(subset_size))

    # Split off 15% for test
    split_1 = subset.train_test_split(test_size=0.15, seed=seed)
    train_val = split_1["train"]
    test = split_1["test"]

    # From remaining 85%, split off 15% for validation
    val_ratio = 0.15 / 0.85
    split_2 = train_val.train_test_split(test_size=val_ratio, seed=seed)
    train = split_2["train"]
    val = split_2["test"]

    return DatasetDict({"train": train, "validation": val, "test": test})

# Load Data
splits = get_gsm8k_split()
train_dataset = splits["train"]
val_dataset = splits["validation"]
test_dataset = splits["test"]

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

### Format dataset for SFT

In [None]:
def formatting_func(example):
    """
    Format examples for SFT training.
    Wraps the reference Prolog code in <answer> tags.
    Note: We only supervise the <answer> section since we don't have 
    ground-truth reasoning steps in the dataset.
    """
    # Create the complete conversation with assistant response
    messages = example["prompt"] + [
        {
            "role": "assistant",
            "content": f"<answer>\n{example['output']}\n</answer>"
        }
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False,
        add_generation_prompt=False
    )
    
    return {"text": text}

# Format all splits
train_dataset_formatted = train_dataset.map(
    formatting_func,
    remove_columns=train_dataset.column_names
)

print("\nExample formatted training sample:")
print(train_dataset_formatted[0]["text"][:500] + "...")

### SFTConfig and SFTTrainer

In [None]:
from trl import SFTConfig, SFTTrainer

training_args = SFTConfig(
    seed=42,
    learning_rate=5e-6,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    save_steps=250,
    max_grad_norm=0.1,
    max_seq_length=2048,
    report_to="wandb",
    output_dir="outputs_sft",
    dataset_text_field="text",  # SFT-specific parameter
)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset_formatted,
)

# Train
trainer.train()

# Save the LoRA Adapter
model.save_lora("sft_saved_lora")

# Merge to 16bit
if True: model.save_pretrained_merged(
    "qwen2.5-3b-sft-1.75k-gsm8k-sp-struct", 
    tokenizer, 
    save_method="merged_16bit"
)
if True: model.push_to_hub_merged(
    "niklasm222/qwen2.5-3b-sft-1.75k-gsm8k-sp-struct", 
    tokenizer, 
    save_method="merged_16bit", 
    token=""
)