# SFT Training for Convex Optimization Exercises (v2)

This notebook trains a language model on convex optimization proof problems using Supervised Fine-Tuning (SFT).

**Dataset**: Boyd & Vandenberghe's "Convex Optimization" exercises (`exercises.jsonl`)

**Key Changes in v2:**
- ‚úÖ Train/Test Split: 10 smallest examples for testing, rest for training
- ‚úÖ Base vs SFT Comparison: Compare base model and fine-tuned model side-by-side

**Approach**: 
- Train directly on optimization exercises (proof-based problems)
- Use reasoning tags: `<start_working_out>...<end_working_out><SOLUTION>...</SOLUTION>`
- Multiple epochs for small dataset
- Based on Unsloth's Qwen GRPO notebook structure

## 1. Setup and Model Loading

In [1]:
BASE_MODEL = "unsloth/Qwen2.5-1.5B"
BASE_MODEL = "unsloth/Qwen3-4B-Base"


In [2]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048  # Can increase for longer proofs
lora_rank = 32  # Larger rank = smarter, but slower

# Load base model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = BASE_MODEL,
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.9,
)

# Apply LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank * 2,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

print("‚úÖ Model loaded successfully")
print(f"üî• CUDA available: {torch.cuda.is_available()}")

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


ü¶• Unsloth Zoo will now patch everything to make training faster!
INFO 12-02 23:32:50 [vllm_utils.py:700] Unsloth: Patching vLLM v1 graph capture
==((====))==  Unsloth 2025.11.3: Fast Qwen3 patching. Transformers: 4.57.1. vLLM: 0.11.2.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 8. Max memory: 39.494 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-4B-Base with actual GPU utilization = 88.97%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.49 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 28.19 GB. Also swap space = 6 GB.
Unsloth: FAILED getting compilation_config with erro

2025-12-02 23:32:54,842	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 12-02 23:32:54 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 12-02 23:32:55 [core.py:93] Initializing a V1 LLM engine (v0.11.2) with config: model='unsloth/Qwen3-4B-Base', speculative_config=None, tokenizer='unsloth/Qwen3-4B-Base', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.23s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.01s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.04s/it]


INFO 12-02 23:32:59 [default_loader.py:314] Loading weights took 2.13 seconds
INFO 12-02 23:32:59 [punica_selector.py:20] Using PunicaWrapperGPU.





INFO 12-02 23:32:59 [gpu_model_runner.py:3338] Model loading took 7.6334 GiB memory and 2.899625 seconds
INFO 12-02 23:33:14 [backends.py:631] Using cache directory: /home/ec2-user/.cache/vllm/torch_compile_cache/f88e8d602b/rank_0_0/backbone for vLLM's torch.compile
INFO 12-02 23:33:14 [backends.py:647] Dynamo bytecode transform time: 14.25 s
INFO 12-02 23:33:19 [backends.py:210] Directly load the compiled graph(s) for dynamic shape from the cache, took 3.975 s
INFO 12-02 23:33:23 [monitor.py:34] torch.compile takes 18.22 s in total
INFO 12-02 23:33:24 [gpu_worker.py:359] Available KV cache memory: 26.83 GiB
INFO 12-02 23:33:24 [kv_cache_utils.py:1229] GPU KV cache size: 195,360 tokens
INFO 12-02 23:33:24 [kv_cache_utils.py:1234] Maximum concurrency for 2,048 tokens per request: 95.39x
INFO 12-02 23:33:25 [kernel_warmup.py:65] Warming up FlashInfer attention.
INFO 12-02 23:33:25 [vllm_utils.py:705] Unsloth: Running patched vLLM v1 `capture_model`.


Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   0%|          | 0/102 [00:00<?, ?it/s]



Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 102/102 [00:09<00:00, 10.51it/s]
Capturing CUDA graphs (decode, FULL): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 78/78 [00:07<00:00, 10.63it/s]

INFO 12-02 23:33:42 [gpu_model_runner.py:4244] Graph capturing finished in 17 secs, took 1.62 GiB
INFO 12-02 23:33:42 [vllm_utils.py:712] Unsloth: Patched vLLM v1 graph capture finished in 17 secs.





INFO 12-02 23:33:43 [core.py:250] init engine (profile, create kv cache, warmup model) took 43.30 seconds
INFO 12-02 23:33:44 [llm.py:352] Supported tasks: ('generate',)
Unsloth: Just some info: will skip parsing ['post_attention_layernorm', 'k_norm', 'norm1', 'layer_norm1', 'post_feedforward_layernorm', 'ffn_norm', 'post_layernorm', 'pre_feedforward_layernorm', 'layer_norm2', 'input_layernorm', 'norm', 'norm2', 'attention_norm', 'q_norm']
Performing substitution for additional_keys=set()
Unsloth: Just some info: will skip parsing ['post_attention_layernorm', 'k_norm', 'norm1', 'layer_norm1', 'post_feedforward_layernorm', 'ffn_norm', 'cross_attn_input_layernorm', 'post_layernorm', 'pre_feedforward_layernorm', 'layer_norm2', 'input_layernorm', 'norm', 'cross_attn_post_attention_layernorm', 'norm2', 'attention_norm', 'q_norm']


Unsloth 2025.11.3 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


‚úÖ Model loaded successfully
üî• CUDA available: True


## 2. Configure Chat Template with Reasoning Tags

In [3]:
reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = \
f"""You are given an optimization problem.
Think about the problem and provide your working out (proof steps).
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""

print(system_prompt)

You are given an optimization problem.
Think about the problem and provide your working out (proof steps).
Place it between <start_working_out> and <end_working_out>.
Then, provide your solution between <SOLUTION></SOLUTION>


In [4]:
# Create chat template
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ '{system_prompt}' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
    "{% endif %}"

# Replace with our specific template
chat_template = chat_template\
    .replace("'{system_prompt}'", f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

print("‚úÖ Chat template configured")

‚úÖ Chat template configured


## 3. Load and Format Optimization Exercises

In [5]:
import json
import pandas as pd
from datasets import Dataset

# Load exercises.jsonl
exercises = []
with open("exercises.jsonl", 'r', encoding='utf-8') as f:
    for line in f:
        exercises.append(json.loads(line))

dataset = pd.DataFrame(exercises)

print(f"üìö Loaded {len(dataset)} optimization exercises")
print(f"\nüìã Dataset columns: {dataset.columns.tolist()}")
dataset.head()

üìö Loaded 340 optimization exercises

üìã Dataset columns: ['exercise_number', 'exercise_text', 'solution_text', 'text']


Unnamed: 0,exercise_number,exercise_text,solution_text,text
0,2.1,"Let C ‚äÜ Rn be a convex set, with x1, . . . , x...",This is readily shown by induction from the de...,"2.1 Let C ‚äÜ Rn be a convex set, with x1, . . ...."
1,2.2,Show that a set is convex if and only if its i...,We prove the Ô¨Årst part. The intersection of tw...,2.2 Show that a set is convex if and only if i...
2,2.3,Midpoint convexity. A set C is midpoint convex...,We have to show that Œ∏x + (1 ‚àí Œ∏)y ‚àà C for all...,2.3 Midpoint convexity. A set C is midpoint co...
3,2.4,Show that the convex hull of a set S is the in...,Let H be the convex hull of S and let D be the...,2.4 Show that the convex hull of a set S is th...
4,2.5,What is the distance between two parallel hype...,The distance between the two hyperplanes is |b...,2.5 What is the distance between two parallel ...


### Format Dataset with Reasoning Tags

In [6]:
def format_dataset(x):
    """Format exercise with reasoning tags."""
    problem = x["exercise_text"]
    solution = x["solution_text"]
    
    # Wrap solution with reasoning tags
    final_prompt = \
        reasoning_start + solution + reasoning_end + \
        solution_start + "Proven." + solution_end
    
    return [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": problem},
        {"role": "assistant", "content": final_prompt},
    ]

dataset["Messages"] = dataset.apply(format_dataset, axis=1)
print("‚úÖ Dataset formatted with reasoning tags")

‚úÖ Dataset formatted with reasoning tags


### Calculate Lengths and Filter by max_seq_length

In [7]:
# Calculate token lengths
dataset["N"] = dataset["Messages"].apply(
    lambda x: len(tokenizer.apply_chat_template(x, tokenize=True))
)

print(f"\nüìä Length statistics:")
print(f"   Min: {dataset['N'].min()} tokens")
print(f"   Max: {dataset['N'].max()} tokens")
print(f"   Mean: {dataset['N'].mean():.0f} tokens")
print(f"   Median: {dataset['N'].median():.0f} tokens")

# Filter to examples that fit
original_count = len(dataset)
dataset = dataset.loc[dataset["N"] <= max_seq_length].copy()

print(f"\n‚úÖ Filtered dataset:")
print(f"   Original: {original_count} examples")
print(f"   Kept: {len(dataset)} examples")
print(f"   Removed: {original_count - len(dataset)} examples (too long)")


üìä Length statistics:
   Min: 108 tokens
   Max: 2794 tokens
   Mean: 629 tokens
   Median: 506 tokens

‚úÖ Filtered dataset:
   Original: 340 examples
   Kept: 336 examples
   Removed: 4 examples (too long)


### Split into Train and Test Sets

**Strategy**: 
- Sort by length (shortest first)
- Take 10 smallest examples for testing
- Use remaining examples for training

In [8]:
# Sort by length (ascending)
dataset_sorted = dataset.sort_values(by="N").reset_index(drop=True)

# Split: first 10 for test, rest for train
test_dataset = dataset_sorted.iloc[:10].copy()
train_dataset = dataset_sorted.iloc[10:].copy()

print(f"\nüìä Train/Test Split:")
print(f"   Test set: {len(test_dataset)} examples (10 smallest)")
print(f"   Train set: {len(train_dataset)} examples")
print(f"\n   Test set length range: {test_dataset['N'].min()}-{test_dataset['N'].max()} tokens")
print(f"   Train set length range: {train_dataset['N'].min()}-{train_dataset['N'].max()} tokens")

# Show test set examples
print(f"\nüìù Test Set Examples:")
for i, row in test_dataset.iterrows():
    print(f"   {i+1}. [{row['N']} tokens] {row['exercise_text'][:80]}...")


üìä Train/Test Split:
   Test set: 10 examples (10 smallest)
   Train set: 326 examples

   Test set length range: 108-155 tokens
   Train set length range: 155-2021 tokens

üìù Test Set Examples:
   1. [108 tokens] Find the dual cone of {Ax | x ‚™∞ 0}, where A ‚àà Rm√ón....
   2. [113 tokens] Show that the maximum volume ellipsoid enclosed in a set is unique. Show that th...
   3. [120 tokens] Give an example of two closed convex sets that are disjoint but cannot be strict...
   4. [124 tokens] Show that the function f(X) = X‚àí1 is matrix convex on Sn ++....
   5. [126 tokens] Suppose x and y are independent random vectors in Rn, with log-concave probabili...
   6. [131 tokens] Assumptions for infeasible start Newton method. Consider the set of assumptions ...
   7. [141 tokens] and 4.58....
   8. [144 tokens] Functions and epigraphs. When is the epigraph of a function a halfspace? When is...
   9. [144 tokens] Linear measurements with exponentially distributed noise. Show how to 

### Prepare Training Dataset for HuggingFace

In [9]:
# Apply chat template to create "text" field for training set
train_dataset["text"] = tokenizer.apply_chat_template(
    train_dataset["Messages"].values.tolist(), 
    tokenize=False
)

# Convert to HuggingFace Dataset
train_dataset_hf = Dataset.from_pandas(train_dataset)

print(f"\n‚úÖ Training dataset prepared")
print(f"   Total examples: {len(train_dataset_hf)}")
print(f"   Columns: {train_dataset_hf.column_names}")


‚úÖ Training dataset prepared
   Total examples: 326
   Columns: ['exercise_number', 'exercise_text', 'solution_text', 'text', 'Messages', 'N']


## 4. Configure and Run SFT Training

In [10]:
from trl import SFTTrainer, SFTConfig

# Calculate training steps
num_epochs = 15
batch_size = 1
gradient_accumulation = 4
steps_per_epoch = len(train_dataset_hf) // (batch_size * gradient_accumulation)
total_steps = steps_per_epoch * num_epochs

print(f"üìä Training configuration:")
print(f"   Examples: {len(train_dataset_hf)}")
print(f"   Epochs: {num_epochs}")
print(f"   Batch size: {batch_size} x {gradient_accumulation} = {batch_size * gradient_accumulation} (effective)")
print(f"   Steps per epoch: ~{steps_per_epoch}")
print(f"   Total steps: ~{total_steps}")
print(f"   Estimated time: ~{total_steps * 0.5 / 60:.0f}-{total_steps * 1.0 / 60:.0f} minutes on modern GPU")

üìä Training configuration:
   Examples: 326
   Epochs: 15
   Batch size: 1 x 4 = 4 (effective)
   Steps per epoch: ~81
   Total steps: ~1215
   Estimated time: ~10-20 minutes on modern GPU


In [11]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset_hf,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 50,
        num_train_epochs = 15,
        learning_rate = 2e-4,
        logging_steps = 10,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs/optimization_sft_v2",
        save_steps = 100,
        save_total_limit = 3,
        report_to = "none",
    ),
)

print("‚úÖ Trainer configured")

Unsloth: Tokenizing ["text"] (num_proc=64): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 326/326 [00:11<00:00, 28.02 examples/s]

‚úÖ Trainer configured





In [12]:
# Start training
print("üöÄ Starting training...\n")
trainer.train()
print("\n‚úÖ Training completed!")

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 326 | Num Epochs = 15 | Total steps = 1,230
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 66,060,288 of 4,088,528,384 (1.62% trained)


üöÄ Starting training...

Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
10,1.2466
20,1.1018
30,0.9578
40,0.9496
50,0.8555
60,0.8552
70,0.8605
80,0.8285
90,0.7397
100,0.6889



‚úÖ Training completed!


## 5. Save the Trained Model

In [13]:
# Save LoRA adapter
model.save_pretrained(f"{BASE_MODEL}_sft_model_v2")
tokenizer.save_pretrained(f"{BASE_MODEL}_sft_model_v2")

print(f"Model saved to: {BASE_MODEL}_sft_model_v2/")

Model saved to: unsloth/Qwen3-4B-Base_sft_model_v2/


## 6. Load Base Model for Comparison

We'll load the base model (without LoRA) to compare with our fine-tuned model.

In [14]:
torch.cuda.empty_cache()
import gc
gc.collect()

2222

In [15]:
# Load base model for comparison (no LoRA)
base_model, base_tokenizer = FastLanguageModel.from_pretrained(
    model_name = BASE_MODEL,
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    fast_inference = True,
)

# Apply same chat template to base model
base_tokenizer.chat_template = chat_template

# Enable inference mode
FastLanguageModel.for_inference(base_model)

print("‚úÖ Base model loaded for comparison")

INFO 12-03 00:04:24 [vllm_utils.py:700] Unsloth: Patching vLLM v1 graph capture
==((====))==  Unsloth 2025.11.3: Fast Qwen3 patching. Transformers: 4.57.1. vLLM: 0.11.2.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 8. Max memory: 39.494 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Your GPU cannot handle sequence lengths of 256 due to limited GPU memory.
Unsloth: Your GPU can only handle approximately the maximum sequence length of 256.
Unsloth: vLLM loading unsloth/Qwen3-4B-Base with actual GPU utilization = 2.1%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.49 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 256. Num Sequences = 128.
Unsloth: vLLM's KV Cache can 

RuntimeError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 39.49 GiB of which 22.56 MiB is free. Including non-PyTorch memory, this process has 39.46 GiB memory in use. Of the allocated memory 37.56 GiB is allocated by PyTorch, with 57.88 MiB allocated in private pools (e.g., CUDA Graphs), and 165.02 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## 7. Enable Inference Mode for Fine-tuned Model

In [None]:
# Enable inference mode for fine-tuned model
FastLanguageModel.for_inference(model)

print("‚úÖ Fine-tuned model ready for inference")

## 8. Base vs SFT Comparison Function

This function will:
1. Take a test problem
2. Generate response from base model
3. Generate response from SFT model
4. Display both side-by-side for comparison

In [None]:
def compare_base_vs_sft(problem: str, max_tokens: int = 1024, show_problem: bool = True):
    """
    Compare base model and SFT model responses on a given problem.
    
    Args:
        problem: The optimization problem text
        max_tokens: Maximum tokens to generate
        show_problem: Whether to display the problem statement
    """
    # Prepare messages
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": problem}
    ]
    
    # Format prompt
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    
    # Show problem if requested
    if show_problem:
        print("="*80)
        print("üìù PROBLEM")
        print("="*80)
        print(problem)
        print()
    
    # Generate with BASE model
    print("="*80)
    print("üîµ BASE MODEL (No Fine-tuning)")
    print("="*80)
    base_inputs = base_tokenizer(text, return_tensors="pt").to("cuda")
    base_outputs = base_model.generate(
        **base_inputs,
        max_new_tokens=max_tokens,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )
    base_response = base_tokenizer.decode(
        base_outputs[0][base_inputs['input_ids'].shape[1]:],
        skip_special_tokens=True
    )
    print(base_response)
    
    # Check format for base model
    base_has_reasoning_end = reasoning_end in base_response
    base_has_solution_start = solution_start in base_response
    base_has_solution_end = solution_end in base_response
    base_format_ok = all([base_has_reasoning_end, base_has_solution_start, base_has_solution_end])
    
    print(f"\nüìä Format Check: {'‚úÖ PASS' if base_format_ok else '‚ùå FAIL'}")
    print(f"   {reasoning_end}: {base_has_reasoning_end}")
    print(f"   {solution_start}: {base_has_solution_start}")
    print(f"   {solution_end}: {base_has_solution_end}")
    
    print()
    
    # Generate with SFT model
    print("="*80)
    print("üü¢ SFT MODEL (Fine-tuned)")
    print("="*80)
    sft_inputs = tokenizer(text, return_tensors="pt").to("cuda")
    sft_outputs = model.generate(
        **sft_inputs,
        max_new_tokens=max_tokens,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )
    sft_response = tokenizer.decode(
        sft_outputs[0][sft_inputs['input_ids'].shape[1]:],
        skip_special_tokens=True
    )
    print(sft_response)
    
    # Check format for SFT model
    sft_has_reasoning_end = reasoning_end in sft_response
    sft_has_solution_start = solution_start in sft_response
    sft_has_solution_end = solution_end in sft_response
    sft_format_ok = all([sft_has_reasoning_end, sft_has_solution_start, sft_has_solution_end])
    
    print(f"\nüìä Format Check: {'‚úÖ PASS' if sft_format_ok else '‚ùå FAIL'}")
    print(f"   {reasoning_end}: {sft_has_reasoning_end}")
    print(f"   {solution_start}: {sft_has_solution_start}")
    print(f"   {solution_end}: {sft_has_solution_end}")
    
    print("\n" + "="*80)
    print("üìà SUMMARY")
    print("="*80)
    print(f"Base Model Format: {'‚úÖ CORRECT' if base_format_ok else '‚ùå INCORRECT'}")
    print(f"SFT Model Format:  {'‚úÖ CORRECT' if sft_format_ok else '‚ùå INCORRECT'}")
    
    if sft_format_ok and not base_format_ok:
        print("\nüéâ SFT model successfully learned the format!")
    elif base_format_ok and sft_format_ok:
        print("\n‚úÖ Both models follow the format (compare proof quality manually)")
    elif not base_format_ok and not sft_format_ok:
        print("\n‚ö†Ô∏è Neither model follows the format correctly")
    
    print("\n")

print("‚úÖ Comparison function ready")

## 9. Test on All Test Set Examples

Let's compare base vs SFT on all 10 test examples.

In [None]:
# Test on first example from test set
test_problem = test_dataset.iloc[0]['exercise_text']
print(f"\nüß™ Testing on: {test_dataset.iloc[0]['exercise_number']}\n")
compare_base_vs_sft(test_problem, max_tokens=1024)

In [None]:
# Test on second example from test set
test_problem = test_dataset.iloc[1]['exercise_text']
print(f"\nüß™ Testing on: {test_dataset.iloc[1]['exercise_number']}\n")
compare_base_vs_sft(test_problem, max_tokens=1024)

In [None]:
# Test on third example from test set
test_problem = test_dataset.iloc[2]['exercise_text']
print(f"\nüß™ Testing on: {test_dataset.iloc[2]['exercise_number']}\n")
compare_base_vs_sft(test_problem, max_tokens=1024)

## 10. Interactive Testing Function

Use this to test any specific problem from the test set by index.

In [None]:
def test_by_index(idx: int, max_tokens: int = 1024):
    """
    Test a specific example from the test set by index (0-9).
    
    Args:
        idx: Index in test set (0-9)
        max_tokens: Maximum tokens to generate
    """
    if idx < 0 or idx >= len(test_dataset):
        print(f"‚ùå Invalid index. Must be between 0 and {len(test_dataset)-1}")
        return
    
    problem = test_dataset.iloc[idx]['exercise_text']
    exercise_num = test_dataset.iloc[idx]['exercise_number']
    
    print(f"\nüß™ Testing on Test Example #{idx+1}: Exercise {exercise_num}")
    print(f"   Length: {test_dataset.iloc[idx]['N']} tokens\n")
    
    compare_base_vs_sft(problem, max_tokens=max_tokens)

print("‚úÖ Interactive testing function ready")
print("\nUsage: test_by_index(0)  # Test first example")
print("       test_by_index(5)  # Test sixth example")

In [None]:
# Example usage - test any specific index
# Uncomment and run to test:
# test_by_index(0)  # Test first example
# test_by_index(5)  # Test sixth example
# test_by_index(9)  # Test last example

## 11. Show All Test Set Problems

In [None]:
print("üìã All Test Set Problems:\n")
print("="*80)
for i, row in test_dataset.iterrows():
    idx = test_dataset.index.get_loc(i)
    print(f"\n[{idx}] Exercise {row['exercise_number']} ({row['N']} tokens)")
    print("-" * 80)
    print(row['exercise_text'])
    print()

## 12. Summary

### What We Accomplished
- ‚úÖ Split dataset: 10 smallest for testing, rest for training
- ‚úÖ Trained SFT model on ~326 examples
- ‚úÖ Loaded base model for comparison
- ‚úÖ Created comparison function to test base vs SFT
- ‚úÖ Tested on multiple examples

### Key Functions
- `compare_base_vs_sft(problem)`: Compare base and SFT models on any problem
- `test_by_index(idx)`: Test specific example from test set (0-9)

### Model Locations
- **SFT model**: `optimization_sft_model_v2/`
- **Checkpoints**: `outputs/optimization_sft_v2/checkpoint-*/`

### Expected Results
- **Base model**: May not follow format, generates generic text
- **SFT model**: Should consistently follow the reasoning tag format and produce structured proofs

### Next Steps
1. Test on all 10 test examples using `test_by_index()`
2. Manually evaluate proof quality and correctness
3. If quality needs improvement:
   - Train for more epochs
   - Try larger model (Qwen2.5-3B or 7B)
   - Increase max_seq_length for longer proofs
4. Consider GRPO training for further optimization