# Imports

In [1]:
import unsloth
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
import torch
from transformers import AutoTokenizer
from typing import List, Dict, Tuple, Any
import json
import re
from tqdm import tqdm
from grpo import SYSTEM_PROMPT
from tsp_llm import load_tsp_dataset
from tsp import calculate_tsp_distance

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Standard import failed for UnslothXPOTrainer: No module named 'UnslothXPOTrainer'. Using tempfile instead!


# Load Model

In [2]:
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
  model_name="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
  max_seq_length=4096,
  load_in_4bit=True,
  dtype=torch.bfloat16
)

# Prepare model for inference
FastLanguageModel.for_inference(model)

# Use 'chatml' chat template for Llama 3.1
tokenizer = get_chat_template(
  tokenizer,
  chat_template="chatml",
  mapping={
    "role":"from",
    "content":"value",
    "user":"human",
    "assistant":"gpt"
  }
)

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.381 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu118. CUDA: 8.0. CUDA Toolkit: 11.8. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth: Will map <|im_end|> to EOS = <|eot_id|>.


# Generate

In [3]:
def extract_assistant_response_regex(decoded_text: str) -> str:
    """Extract assistant response using regex pattern matching."""
    # Pattern to match assistant's response between markers
    pattern = r"<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)"
    
    # Search with DOTALL to match across newlines
    match = re.search(pattern, decoded_text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    else:
        # Return the original (might be direct generation without markers)
        return decoded_text.strip()
    
def generate(prompt: str, max_new_tokens: int = 2048, num_samples: int = 1, temperature: float = 0.7, top_p: float = 0.9) -> Dict[str, Any]:
  """
  Generate a response from the model based on the prompt.

  Args:
    prompt: The input prompt
    max_new_tokens: Max number of tokenx to generate
    num_samples: Number of samples per prompt
    temperature: Model temperature
    top_p: Model top_p
  
  Returns:
    Dict[str, Any]: Summary dict containing input token count, average output token count, and
                    model responses 
  """

  # Format chat message
  messages = [
    {"from": "system", "value": SYSTEM_PROMPT},
    {"from": "human", "value": prompt}
  ]

  # Tokenize input and move input tensors to GPU
  inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt"
  ).to("cuda")

  # Get input token count
  input_token_count = inputs.shape[1]

  # Generate outputs
  outputs = model.generate(
    input_ids=inputs,
    max_new_tokens=max_new_tokens,
    num_return_sequences=num_samples,
    do_sample=True,
    temperature=temperature,
    top_p=top_p,
    return_dict_in_generate=True,
    output_scores=True
  )

  # Get output token counts
  output_sequences = outputs.sequences
  avg_output_token_count = sum([len(output) - input_token_count for output in output_sequences]) / len(output_sequences)

  # Process outputs
  responses = [tokenizer.decode(output, skip_special_tokens=True) for output in output_sequences]
  responses = [extract_assistant_response_regex(response) for response in responses]

  # Return token counts and responses
  out = {
    "input token count": input_token_count,
    "average output token count": avg_output_token_count,
    "responses": responses
  }

  return out

# String Parsing Functions

In [4]:
def extract_trace(response: str) -> List[int]:
    """
    Extracts the last occurrence of a list of ints inside <trace></trace> or <trace><trace> brackets.

    Args:
        response: The model response
    
    Returns:
        List[int]: The last trace provided by the model
    """

    # Regex pattern to match <trace>...</trace> and <trace>...<trace>
    matches = re.findall(r"<trace>\s*([\d]+(?:\s*,\s*[\d]+)*)\s*</trace>|<trace>\s*([\d]+(?:\s*,\s*[\d]+)*)\s*<trace>", response)
    if not matches:
        return []
    
    # Extract the last non-empty match
    last_match = next(filter(None, matches[-1]))

    # Convert to list of integers
    return [int(num) for num in re.split(r"\s*,\s*", last_match)]

def extract_total_length(prompt: str) -> List[int]:
    """
    Extracts reference distances provided in prompt.

    Args:
      prompt: The model prompt

    Returns:
      List[int]: The reference distances
    """
    return [int(match) for match in re.findall(r'total length:\s*(\d+)', prompt)]

def extract_tsp(prompt: str) -> Dict[int, Tuple[int, int]]:
    """
    Extracts the tsp nodes from the prompt.

    Args:
        prompt: The user prompt

    Returns:
        Dict[int, Tuple[int, int]]: Nodes in i: x, y format where i is the node index, x is the x
                                    cooridnate, and y is the y coordinate
    """
    # Find the line that starts with "Given"
    lines = prompt.split('\n')
    node_lines = []
    
    # Flag to track when we're in the nodes section
    capturing = False
    
    # Find and collect the node lines
    for line in lines:
        if line.startswith("Given"):
            capturing = True
            continue
        
        if capturing and line.startswith("Node:"):
            node_lines.append(line)
        
        # Stop capturing when we hit an empty line after finding nodes
        if capturing and line.strip() == "":
            break
    
    # If we didn't find any nodes through the first approach, try direct pattern matching
    if not node_lines:
        import re
        node_pattern = r"Node: (\d+): \((-?\d+), (-?\d+)\)"
        node_lines = re.findall(node_pattern, prompt)
        
        # Convert regexp results to dictionary directly if we found matches
        if node_lines:
            return {int(node): (int(x), int(y)) for node, x, y in node_lines}
    
    # Process the node lines
    tsp = {}
    for line in node_lines:
        # Extract node number and coordinates
        parts = line.split(": ")
        node = int(parts[1].strip())
        
        # Extract coordinates - handling the parentheses
        coords = parts[2].strip()
        coords = coords.strip("()")
        x, y = map(int, coords.split(", "))
        
        tsp[node] = (x, y)
    
    return tsp

# Reward Functions

In [5]:
def optimal_solution_reward_func(responses: List[str], optimal_distance: float, tsp: Dict[int, Tuple[int, int]]) -> List[float]:
    """ Score = 2.0 if response solution matches optimal solution, 0.0 otherwise """
    valid_response_rewards = valid_response_reward_func(responses, len(tsp) + 1)
    traces = [extract_trace(r) for r in responses]
    distances = [calculate_tsp_distance(tsp, t) for t in traces]
    return [2.0 if ((valid_response_rewards[i] == 1.0) and (distances[i] -.1 <= optimal_distance)) else 0.0 for i in range(len(responses))]

def improvement_reward_func(responses: List[str], reference_distance: int, tsp: Dict[int, Tuple[int, int]]) -> List[float]:
    """ Score = 2.0 if response solution improves on provided solutions, 0.0 otherwise """
    valid_response_rewards = valid_response_reward_func(responses, len(tsp) + 1)
    traces = [extract_trace(r) for r in responses]
    distances = [calculate_tsp_distance(tsp, t) for t in traces]
    return [2.0 if ((valid_response_rewards[i] == 1.0) and (distances[i] < reference_distance)) else 0.0 for i in range(len(responses))]

def valid_response_reward_func(responses: List[str], trace_length: int) -> List[float]:
    """ Score = 1.0 if response solution contains trace with correct length, start node end node, and node set, 0.0 otherwise """
    traces = [extract_trace(r) for r in responses]
    return [
        1.0 if len(trace) == trace_length
        and trace[0] == 0
        and trace[-1] == 0
        and set(trace) == set(range(trace_length - 1)) else 0.0 for trace in traces]

def strict_format_reward_func(responses: List[str]) -> List[float]:
    """ Score = 0.5 if response solution matches requested solution format, 0.0 otherwise """
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<trace>\n.*?\n</trace>\n$"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(responses: List[str]) -> List[float]:
    """ Score = 0.5 if response solution loosely matches request solution format, 0.0 otherwise """
    pattern = r"<reasoning>.*?</reasoning>\s*<trace>.*?</trace>"
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def apply_reward_functions(problem: Dict[str, Any], responses: str) -> Dict[str, float]:
    n = len(responses)
    optimal_solution_rewards = optimal_solution_reward_func(responses, problem['solution']['distance'], extract_tsp(problem['prompt']))
    improvement_rewards = improvement_reward_func(responses, extract_total_length(problem['prompt'])[-1], extract_tsp(problem['prompt']))
    valid_response_rewards = valid_response_reward_func(responses, problem['size'] + 1)
    print(optimal_solution_rewards)
    print(improvement_rewards)
    print(valid_response_rewards)
    strict_format_rewards = strict_format_reward_func(responses)
    soft_format_rewards = soft_format_reward_func(responses)

    return {
        "average optimal solution reward": sum(optimal_solution_rewards) / n,
        "average improvement reward": sum(improvement_rewards) / n,
        "average valid response reward": sum(valid_response_rewards) / n,
        "average strict format reward": sum(strict_format_rewards) / n,
        "average soft format reward": sum(soft_format_rewards) / n
    }

# Load Benchmark Dataset and Solve

In [None]:
# Load benchmark prompt dataset
dataset = load_tsp_dataset("tsp_benchmark_prompt_dataset.json")

pbar = tqdm(total=60, desc="Solving TSP Benchmark")

results = {
  "size_5": [],
  "size_10": [],
  "size_15": []
}

# Solve benchmark dataset
for i in range(len(dataset)):
  if i == 10: i += 20
  if i == 40: i += 20
  if i == 70: break

  problem = dataset[i]
  size_key = f"size_{problem['size']}"
  prompt = problem['prompt']
  solution = problem['solution']

  # Generate 3 completions per prompt
  generate_out = generate(prompt, num_samples=3)

  # Calculate average rewards across 3 samples
  rewards = apply_reward_functions(problem, generate_out['responses'])

  # Remove completions from gen out
  del generate_out['responses']

  # Append problem summary to results
  generate_out.update(rewards)
  problem_summary = generate_out
  results[size_key].append(problem_summary)

  # Update progress bar
  pbar.update(1)

pbar.close()

Solving TSP Benchmark:   0%|          | 0/60 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Solving TSP Benchmark:   2%|▏         | 1/60 [01:31<1:29:33, 91.07s/it]

[0.0, 0.0, 2.0]
[0.0, 0.0, 2.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:   3%|▎         | 2/60 [02:15<1:01:44, 63.87s/it]

[0.0, 0.0, 0.0]
[2.0, 2.0, 2.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:   5%|▌         | 3/60 [02:49<47:31, 50.02s/it]  

[0.0, 2.0, 0.0]
[0.0, 2.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:   7%|▋         | 4/60 [04:41<1:09:27, 74.41s/it]

[0.0, 0.0, 2.0]
[0.0, 0.0, 2.0]
[1.0, 0.0, 1.0]


Solving TSP Benchmark:   8%|▊         | 5/60 [06:33<1:20:33, 87.88s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 0.0, 0.0]


Solving TSP Benchmark:  10%|█         | 6/60 [07:53<1:16:52, 85.42s/it]

[0.0, 0.0, 0.0]
[0.0, 2.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  12%|█▏        | 7/60 [08:49<1:06:55, 75.76s/it]

[2.0, 0.0, 0.0]
[2.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  13%|█▎        | 8/60 [10:41<1:15:37, 87.26s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 2.0]
[1.0, 0.0, 1.0]


Solving TSP Benchmark:  15%|█▌        | 9/60 [12:33<1:20:41, 94.92s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 1.0, 1.0]


Solving TSP Benchmark:  17%|█▋        | 10/60 [13:03<1:02:25, 74.91s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 0.0, 1.0]


Solving TSP Benchmark:  18%|█▊        | 11/60 [13:30<49:14, 60.30s/it]  

[0.0, 0.0, 0.0]
[2.0, 2.0, 0.0]
[1.0, 1.0, 0.0]


Solving TSP Benchmark:  20%|██        | 12/60 [14:33<48:51, 61.08s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  22%|██▏       | 13/60 [15:44<50:17, 64.21s/it]

[0.0, 0.0, 2.0]
[0.0, 0.0, 2.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  23%|██▎       | 14/60 [17:16<55:31, 72.42s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  25%|██▌       | 15/60 [18:48<58:43, 78.29s/it]

[2.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  27%|██▋       | 16/60 [19:36<50:48, 69.29s/it]

[2.0, 0.0, 0.0]
[2.0, 2.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  28%|██▊       | 17/60 [21:28<58:48, 82.06s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 1.0, 1.0]


Solving TSP Benchmark:  30%|███       | 18/60 [22:32<53:43, 76.74s/it]

[0.0, 0.0, 2.0]
[0.0, 0.0, 2.0]
[1.0, 0.0, 1.0]


Solving TSP Benchmark:  32%|███▏      | 19/60 [23:07<43:56, 64.30s/it]

[0.0, 2.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  33%|███▎      | 20/60 [24:07<41:59, 63.00s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 1.0, 1.0]


Solving TSP Benchmark:  35%|███▌      | 21/60 [25:59<50:28, 77.65s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 0.0, 1.0]


Solving TSP Benchmark:  37%|███▋      | 22/60 [26:44<42:54, 67.74s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  38%|███▊      | 23/60 [27:07<33:37, 54.51s/it]

[2.0, 0.0, 0.0]
[2.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  40%|████      | 24/60 [27:42<29:02, 48.40s/it]

[0.0, 0.0, 0.0]
[0.0, 2.0, 2.0]
[0.0, 1.0, 1.0]


Solving TSP Benchmark:  42%|████▏     | 25/60 [28:07<24:09, 41.42s/it]

[0.0, 0.0, 2.0]
[2.0, 0.0, 2.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  43%|████▎     | 26/60 [29:58<35:26, 62.54s/it]

[0.0, 0.0, 2.0]
[0.0, 0.0, 2.0]
[1.0, 0.0, 1.0]


Solving TSP Benchmark:  45%|████▌     | 27/60 [30:37<30:23, 55.25s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 2.0]
[1.0, 1.0, 1.0]


Solving TSP Benchmark:  47%|████▋     | 28/60 [32:04<34:33, 64.79s/it]

[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[1.0, 1.0, 1.0]


# Process Results and Save

In [15]:
summary = {}

for size in results:
  size_summary = {}
  n = len(results[size]) or 1

  # Calculate summary per size
  average_input_token_count = sum([problem['input token count'] for problem in results[size]]) / n
  average_output_token_count = sum([problem['average output token count'] for problem in results[size]]) / n
  average_optimal_solution_reward = sum([problem['average optimal solution reward'] for problem in results[size]]) / n
  average_improved_reward = sum([problem['average improvement reward'] for problem in results[size]]) / n
  average_valid_response_reward = sum([problem['average valid response reward'] for problem in results[size]]) / n
  average_strict_format_reward = sum([problem['average strict format reward'] for problem in results[size]]) / n
  average_soft_format_reward = sum([problem['average soft format reward'] for problem in results[size]]) / n

  # Load size summary
  size_summary['average input token count'] = average_input_token_count
  size_summary['average output token count'] = average_output_token_count
  size_summary['average optimal solution reward'] = average_optimal_solution_reward
  size_summary['average improvement reward'] = average_improved_reward
  size_summary['average valid response reward'] = average_valid_response_reward
  size_summary['average strict format reward'] = average_strict_format_reward
  size_summary['average soft format reward'] = average_soft_format_reward

  # Append size summary to summary
  summary[size] = size_summary

# Append summary to results
results['summary'] = summary

# Save benchmark results
with open("tsp_benchmark_results.json", "w") as file:
  json.dump(results, file, indent=4)