# SFT Model Format Verification

Test the merged SFT model to verify it follows the `<think>...</think>` format.

## 1. Mount Drive & Find Model Path

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# "Shared with me" folders don't appear directly in Colab.
# You need to add a shortcut to your Drive first:
#
# 1. Go to drive.google.com
# 2. Click "Shared with me" in the left sidebar
# 3. Right-click on "srl_outputs" folder
# 4. Click "Organize" → "Add shortcut" → "My Drive"
#
# Then the path will be: /content/drive/MyDrive/srl_outputs/merged_sft_4b

import os

# Try these paths in order
possible_paths = [
    "/content/drive/MyDrive/srl_outputs/merged_sft_4b",
    "/content/drive/My Drive/srl_outputs/merged_sft_4b",
    "/content/drive/Shareddrives/srl_outputs/merged_sft_4b",
]

MODEL_PATH = None
for path in possible_paths:
    if os.path.exists(path):
        MODEL_PATH = path
        print(f"✓ Found model at: {MODEL_PATH}")
        break

if MODEL_PATH is None:
    print("❌ Model not found at expected paths.")
    print("\nLet's search your Drive...\n")
    
    # Show what's in MyDrive
    mydrive = "/content/drive/MyDrive"
    if os.path.exists(mydrive):
        print(f"Contents of {mydrive}:")
        for item in sorted(os.listdir(mydrive))[:20]:
            print(f"  {item}")
        if len(os.listdir(mydrive)) > 20:
            print("  ... (more files)")

In [None]:
# If not found above, set the path manually here after adding the shortcut:
# MODEL_PATH = "/content/drive/MyDrive/srl_outputs/merged_sft_4b"

# Verify contents
if MODEL_PATH and os.path.exists(MODEL_PATH):
    print(f"Model directory contents:")
    for item in sorted(os.listdir(MODEL_PATH)):
        print(f"  {item}")

## 2. Install & Load Model

In [None]:
!pip install -q transformers accelerate torch

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"Loading from: {MODEL_PATH}")

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    padding_side="left",
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"✓ Tokenizer loaded")

In [None]:
print("Loading model...")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
model.eval()

print(f"✓ Model loaded on: {next(model.parameters()).device}")
if torch.cuda.is_available():
    print(f"  GPU memory: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

## 3. Prompt & Parsing Functions

In [None]:
from typing import List, Optional, Tuple

PROMPT_PREAMBLE = (
    "You are a helpful assistant for solving mathematical problems. "
    "A user will provide a math problem, which may include a partial solution. "
    "Your task is to continue the solution by providing the very next logical step. "
    "A user will ask you to solve a task. You should first draft your thinking process "
    "(inner monologue). Then, generate the solution. "
    "Your response format must follow the template below:\n"
    "<think> Your thoughts or/and draft, like working through an exercise on scratch paper. "
    "Be as casual and as long as you want until you are confident to generate a correct solution. </think>\n"
    "Provide only the single, next step to continue the solution. Do not solve the entire problem."
)

def build_prompt(problem: str, previous_steps: List[str] = None) -> str:
    if previous_steps is None:
        previous_steps = []
    parts = [PROMPT_PREAMBLE, "", "Problem:", problem.strip(), ""]
    parts.extend([s.strip() for s in previous_steps])
    return "\n".join(parts)

def parse_output(text: str) -> Tuple[Optional[str], Optional[str]]:
    """Returns (thought, action) or (None, None) if invalid format."""
    if not text:
        return (None, None)
    idx = text.find("</think>")
    if idx == -1:
        return (None, None)
    action = text[idx + 8:].strip()
    open_idx = text.find("<think>")
    thought = text[open_idx + 7:idx].strip() if open_idx != -1 else text[:idx].strip()
    return (thought, action)

print("✓ Functions ready")

## 4. Generation Function

In [None]:
def generate(problem: str, previous_steps=None, max_tokens=512, greedy=True, verbose=True):
    prompt = build_prompt(problem, previous_steps or [])
    
    if verbose:
        print("="*70)
        print("PROMPT:")
        print("="*70)
        print(prompt[:500] + "..." if len(prompt) > 500 else prompt)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[1]
    
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=not greedy,
            temperature=0.7 if not greedy else None,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    generated = tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
    thought, action = parse_output(generated)
    is_valid = action is not None
    
    if verbose:
        print("\n" + "="*70)
        print("OUTPUT:")
        print("="*70)
        print(generated)
        print("\n" + "="*70)
        print("FORMAT CHECK:")
        print("="*70)
        print(f"  <think>:  {'✓' if '<think>' in generated else '❌'}")
        print(f"  </think>: {'✓' if '</think>' in generated else '❌'}")
        print(f"  Valid:    {'✓' if is_valid else '❌'}")
        if action:
            print(f"\nACTION: {action}")
    
    return {"output": generated, "thought": thought, "action": action, "valid": is_valid}

print("✓ Generate function ready")

## 5. Test the Model

In [None]:
# Test 1
r1 = generate("Calculate the derivative of f(x) = x^3")

In [None]:
# Test 2: With previous step
r2 = generate(
    "Solve: 2x + 5 = 13",
    previous_steps=["Step 1: Subtract 5 from both sides: 2x = 8"]
)

In [None]:
# Test 3
r3 = generate("What is 15% of 80?")

## 6. Batch Test

In [None]:
problems = [
    "What is 2 + 2?",
    "Solve x^2 = 16",
    "Find the area of a circle with radius 5",
    "What is the derivative of sin(x)?",
    "Factor x^2 - 9",
]

print("Batch testing...\n")
results = []
for i, p in enumerate(problems, 1):
    r = generate(p, verbose=False)
    results.append(r)
    s = "✓" if r["valid"] else "❌"
    print(f"{s} Test {i}: {p[:40]}")
    if r["valid"]:
        print(f"   → {r['action'][:60]}..." if len(r['action']) > 60 else f"   → {r['action']}")
    else:
        print(f"   → {r['output'][:60]}...")

valid = sum(r["valid"] for r in results)
print(f"\n{'='*50}")
print(f"RESULT: {valid}/{len(results)} valid ({100*valid//len(results)}%)")