# Imports

In [24]:
import unsloth
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import torch
from typing import List, Dict, Tuple, Any
import json
import re
from tqdm import tqdm
from grpo import SYSTEM_PROMPT
from tsp_llm import load_tsp_dataset, coordinates_to_tsp
from tsp import calculate_tsp_distance

# Load Model

In [25]:
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
  model_name="unsloth/Qwen2.5-3B-Instruct-bnb-4bit",
  max_seq_length=4096,
  load_in_4bit=True,
  dtype=torch.bfloat16
)

# Prepare model for inference
FastLanguageModel.for_inference(model)

==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.50.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.381 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu118. CUDA: 8.0. CUDA Toolkit: 11.8. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.05G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/271 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 2048, padding_idx=151654)
    (layers): ModuleList(
      (0-35): 36 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=True)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=True)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((

# Generate

In [26]:
def extract_assistant_response_regex(text: str) -> str:
    """
    Extract the assistant's response from the decoded text output of Qwen.
    
    Args:
        text: The decoded text output from the model
        
    Returns:
        str: The extracted assistant response or empty string if no match found
    """
    # Common patterns for assistant responses in chat models
    patterns = [
        r"(?:^|\n)(?:assistant|assistant:)(.*?)(?:$|\n\s*(?:user|user:|system|system:|<\|im_end\|>))",  # Matches cases with various endings
        r"<\|im_start\|>assistant\s*(.*?)(?:<\|im_end\|>|$)",  # Matches with special tokens
        r"(?:^|\n)assistant:\s*(.*?)(?:$|\n\s*(?:user|user:|system|system:))",  # Standard chat format
    ]
    
    # Try each pattern until we find a match
    for pattern in patterns:
        matches = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if matches:
            return matches.group(1).strip()
    
    # If no structured format is found, return everything after the last "user:" or system prompt
    last_user = re.search(r"(?:^|\n)user:(?!.*\nuser:).*?\n(.*?)$", text, re.DOTALL | re.IGNORECASE)
    if last_user:
        return last_user.group(1).strip()
    
    # If all else fails, return the original text (this is a fallback)
    return text.strip()
    

def generate(prompt: str, max_new_tokens: int = 2048, num_samples: int = 3, temperature: float = 0.7, top_p: float = 0.9) -> Dict[str, Any]:
    """
    Generate a response from the model based on the prompt.

    Args:
        prompt: The input prompt
        max_new_tokens: Max number of tokens to generate
        num_samples: Number of samples per prompt
        temperature: Model temperature
        top_p: Model top_p
    
    Returns:
        Dict[str, Any]: Summary dict containing input token count, average output token count, and
                        model responses 
    """

    # Format chat message for Qwen
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt}
    ]

    # Tokenize input and move input tensors to GPU
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to("cuda")

    # Get input token count
    input_token_count = inputs.shape[1]

    # Generate outputs
    outputs = model.generate(
        input_ids=inputs,
        max_new_tokens=max_new_tokens,
        num_return_sequences=num_samples,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        return_dict_in_generate=True,
        output_scores=True
    )
        
    # Get output token counts
    output_sequences = outputs.sequences
    avg_output_token_count = sum([len(o) - input_token_count for o in output_sequences]) / len(output_sequences)
        
    # Process outputs
    responses = [tokenizer.decode(o, skip_special_tokens=True) for o in output_sequences]
    responses = [extract_assistant_response_regex(r) for r in responses]

    # Return token counts and responses
    out = {
        "input token count": input_token_count,
        "average output token count": avg_output_token_count,
        "responses": responses
    }

    return out

# String Parsing Functions

In [27]:
def extract_trace(response: str) -> List[int]:
    """
    Extracts the last occurrence of a list of ints inside <trace></trace> or <trace><trace> brackets.

    Args:
        response: The model response
    
    Returns:
        List[int]: The last trace provided by the model
    """

    # Regex pattern to match <trace>...</trace> and <trace>...<trace>
    matches = re.findall(r"<trace>\s*([\d]+(?:\s*,\s*[\d]+)*)\s*</trace>|<trace>\s*([\d]+(?:\s*,\s*[\d]+)*)\s*<trace>", response)
    if not matches:
        return []
    
    # Extract the last non-empty match
    last_match = next(filter(None, matches[-1]))

    # Convert to list of integers
    return [int(num) for num in re.split(r"\s*,\s*", last_match)]

# Reward Functions

In [28]:
def optimal_solution_reward_func(traces: List[List[int]], distances: List[float], optimal_distance: float, trace_length: int) -> List[float]:
    """ Score = 2.0 if response solution matches optimal solution, 0.0 otherwise """
    valid_response_rewards = valid_response_reward_func(traces, trace_length)
    return [2.0 if ((valid_response_rewards[i] == 1.0) and (distances[i] -.1 <= optimal_distance)) else 0.0 for i in range(len(traces))]

def improvement_reward_func(traces: List[List[int]], distances: List[float], reference_distance: int, trace_length: int) -> List[float]:
    """ Score = 2.0 if response solution improves on provided solutions, 0.0 otherwise """
    valid_response_rewards = valid_response_reward_func(traces, trace_length)
    return [2.0 if ((valid_response_rewards[i] == 1.0) and (round(distances[i]) <= reference_distance)) else 0.0 for i in range(len(traces))]

def valid_response_reward_func(traces: List[List[int]], trace_length: int) -> List[float]:
    """ Score = 1.0 if response solution contains trace with correct length, start node end node, and node set, 0.0 otherwise """
    return [
        1.0 if len(trace) == trace_length
        and trace[0] == 0
        and trace[-1] == 0
        and set(trace) == set(range(trace_length - 1)) else 0.0 for trace in traces]

def strict_format_reward_func(responses: List[str]) -> List[float]:
    """ Score = 0.5 if response solution matches requested solution format, 0.0 otherwise """
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<trace>\n.*?\n</trace>$"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(responses: List[str]) -> List[float]:
    """ Score = 0.5 if response solution loosely matches request solution format, 0.0 otherwise """
    pattern = r"<reasoning>.*?</reasoning>\s*<trace>.*?</trace>"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def apply_reward_functions(problem: Dict, responses: List[str]) -> Dict:
    n = len(responses)
    tsp = coordinates_to_tsp(problem['coordinates'])
    trace_length = len(tsp) + 1

    # Extract paths and calculate distances from responses
    traces = [extract_trace(r) for r in responses]
    distances = [calculate_tsp_distance(tsp, t) for t in traces]

    # Invoke reward functions
    optimal_solution_rewards = optimal_solution_reward_func(traces, distances, problem['solution']['distance'], trace_length)
    improvement_rewards = improvement_reward_func(traces, distances, problem['reference_distance'], trace_length)
    valid_response_rewards = valid_response_reward_func(traces, trace_length)
    strict_format_rewards = strict_format_reward_func(responses)
    soft_format_rewards = soft_format_reward_func(responses)

    # Calculate problem averages
    problem_summary = {
        "average optimal solution reward": sum(optimal_solution_rewards) / n,
        "average improvement reward": sum(improvement_rewards) / n,
        "average valid response reward": sum(valid_response_rewards) / n,
        "average strict format reward": sum(strict_format_rewards) / n,
        "average soft format reward": sum(soft_format_rewards) / n
    }

    # Add summary for each sample
    for i in range(n):
        sample_key = f"samples_{i}"
        problem_summary[sample_key] = {
            "response": responses[i],
            "solution": {
                "path": traces[i],
                "distance": distances[i]
            },
            "optimal solution reward": optimal_solution_rewards[i],
            "improvement reward": improvement_rewards[i],
            "valid response reward": valid_response_rewards[i],
            "strict format reward": strict_format_rewards[i],
            "soft format reward": soft_format_rewards[i],
        }
    
    return problem_summary

# Load Benchmark Dataset and Solve

In [31]:
# Load benchmark prompt dataset
dataset = load_tsp_dataset("tsp_benchmark_prompt_dataset.json")

pbar = tqdm(total=len(dataset) * len(dataset["size_5"]), desc="Solving TSP Benchmark")

results = {}

# Solve benchmark dataset
for size_key in dataset:
  problems = dataset[size_key]

  size_results = []

  for problem in problems:
    prompt = problem['prompt']
    solution = problem['solution']

    # Generate 3 completions per prompt
    generate_out = generate(prompt, num_samples=3)

    # Calculate average rewards across 3 samples
    reward_out = apply_reward_functions(problem, generate_out['responses'])

    # Remove completions from gen out
    del generate_out['responses']

    # Append problem summary to results
    generate_out.update(reward_out)
    problem_summary = generate_out
    size_results.append(problem_summary)

    # Update progress bar
    pbar.update(1)
  
  results[size_key] = size_results

pbar.close()

Solving TSP Benchmark:  95%|█████████▍| 104/110 [1:07:47<05:22, 53.78s/it]

KeyboardInterrupt: 

# Process Results and Save

In [32]:
summary = {}

for size_key in results:
  size_results = results[size_key]
  size_summary = {}
  n = len(size_results)

  # Calculate summary per size
  average_input_token_count = sum([problem['input token count'] for problem in size_results]) / n
  average_output_token_count = sum([problem['average output token count'] for problem in size_results]) / n
  average_optimal_solution_reward = sum([problem['average optimal solution reward'] for problem in size_results]) / n
  average_improved_reward = sum([problem['average improvement reward'] for problem in size_results]) / n
  average_valid_response_reward = sum([problem['average valid response reward'] for problem in size_results]) / n
  average_strict_format_reward = sum([problem['average strict format reward'] for problem in size_results]) / n
  average_soft_format_reward = sum([problem['average soft format reward'] for problem in size_results]) / n

  # Load size summary
  size_summary['average input token count'] = average_input_token_count
  size_summary['average output token count'] = average_output_token_count
  size_summary['average optimal solution reward'] = average_optimal_solution_reward
  size_summary['average improvement reward'] = average_improved_reward
  size_summary['average valid response reward'] = average_valid_response_reward
  size_summary['average strict format reward'] = average_strict_format_reward
  size_summary['average soft format reward'] = average_soft_format_reward

  # Append size summary to summary
  summary[size_key] = size_summary

# Append summary to results
results['summary'] = summary

# Save benchmark results
with open("tsp_benchmark_results.json", "w") as file:
  json.dump(results, file, indent=4)