# Constraint-Guided Decoding for Small Language Models

**Author:** Nolan W. Platt  
**Project:** Lightweight Neurosymbolic Approach for SLMs



## Setup and Installation

In [None]:
# install required packages
!pip install -q transformers accelerate bitsandbytes z3-solver torch sentencepiece

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from z3 import *
import re
from typing import List, Dict, Tuple, Set, Optional
from dataclasses import dataclass
import numpy as np
from collections import defaultdict
import time

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: Tesla T4


## Z3-Based Constraint Checker

Implementation of the constraint engine with:
1. Arithmetic constraint checking
2. Logical consistency verification
3. Variable tracking and extraction
4. Constraint caching for efficiency

In [None]:
@dataclass
class Constraint:
    """Represents a logical constraint"""
    constraint_type: str  # 'arithmetic', 'logical', 'syntax'
    expression: str
    variables: Set[str]
    z3_formula: Optional[Any] = None

class ConstraintChecker:
    """Z3-based constraint verification engine"""

    def __init__(self):
        self.constraint_cache = {}

    def clear(self):
        """Reset solver state"""
        self.constraint_cache = {}

    def extract_arithmetic_statements(self, text: str) -> List[Dict]:
        """
        Extract arithmetic statements from text
        Examples: "x = 5", "y = 10", "x + y = 15"
        """
        statements = []

        # Pattern: variable = number (but not part of an expression)
        # using word boundaries and negative lookahead/lookbehind to avoid matching inside expressions
        simple_pattern = r'(?<![+\-*/])\s*([a-zA-Z])\s*=\s*(\d+(?:\.\d+)?)\s*(?![+\-*/])'

        for match in re.finditer(simple_pattern, text):
            var_name = match.group(1)
            value = match.group(2)

            # Skip if this looks like it's part of a larger expression
            start = match.start()
            if start > 0 and text[start-1] in '+-*/':
                continue

            statements.append({
                'variable': var_name,
                'value': float(value),
                'full_text': match.group(0)
            })

        return statements

    def check_arithmetic_consistency(self, text: str) -> Tuple[bool, List[str]]:
        """
        Check if all arithmetic statements in text are consistent
        Returns: (is_consistent, list_of_violations)
        """
        cache_key = hash(text)
        if cache_key in self.constraint_cache:
            return self.constraint_cache[cache_key]

        statements = self.extract_arithmetic_statements(text)
        if not statements:
            result = (True, [])
            self.constraint_cache[cache_key] = result
            return result

        # build variable assignments
        var_assignments = {}
        violations = []

        for stmt in statements:
            var_name = stmt['variable']
            value = stmt['value']

            if var_name in var_assignments:
                # check if this contradicts previous assignment
                if abs(var_assignments[var_name] - value) > 0.001:
                    violations.append(
                        f"Contradiction: {var_name} = {var_assignments[var_name]} "
                        f"but later {var_name} = {value}"
                    )
            else:
                var_assignments[var_name] = value

        #  check expressions like "x + y = 15"
        expr_pattern = r'([a-zA-Z_]\w*)\s*\+\s*([a-zA-Z_]\w*)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(expr_pattern, text):
            var1 = match.group(1)
            var2 = match.group(2)
            expected = float(match.group(3))

            if var1 in var_assignments and var2 in var_assignments:
                actual = var_assignments[var1] + var_assignments[var2]
                if abs(actual - expected) > 0.001:
                    violations.append(
                        f"Arithmetic error: {var1} + {var2} should be {actual} "
                        f"but text claims {expected}"
                    )

        # check multiplication like "2x = 10"
        mult_pattern = r'(\d+)\s*\*?\s*([a-zA-Z_]\w*)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(mult_pattern, text):
            multiplier = float(match.group(1))
            var_name = match.group(2)
            expected = float(match.group(3))

            if var_name in var_assignments:
                actual = multiplier * var_assignments[var_name]
                if abs(actual - expected) > 0.001:
                    violations.append(
                        f"Arithmetic error: {multiplier}*{var_name} should be {actual} "
                        f"but text claims {expected}"
                    )

        is_consistent = len(violations) == 0
        result_tuple = (is_consistent, violations)
        self.constraint_cache[cache_key] = result_tuple
        return result_tuple

    def check_logical_consistency(self, text: str) -> Tuple[bool, List[str]]:
        """
        Check for logical contradictions
        """
        violations = []

        # true/false statements regex patterns
        positive_pattern = r'([A-Z]\w*)\s+is\s+(true|correct)'
        negative_pattern = r'([A-Z]\w*)\s+is\s+(false|incorrect|not\s+true)'

        positive_matches = [(m.group(1).lower(), m.group(0)) for m in re.finditer(positive_pattern, text, re.IGNORECASE)]
        negative_matches = [(m.group(1).lower(), m.group(0)) for m in re.finditer(negative_pattern, text, re.IGNORECASE)]

        positive_props = {match[0] for match in positive_matches}
        negative_props = {match[0] for match in negative_matches}

        contradictions = positive_props & negative_props
        if contradictions:
            violations.append(f"Logical contradiction for propositions: {contradictions}")
            return (False, violations)

        return (True, [])

    def check_constraints(self, text: str, constraint_types: List[str] = None) -> Tuple[bool, List[str]]:
        """
        Main constraint checking interface
        """
        if constraint_types is None:
            constraint_types = ['arithmetic', 'logical']

        all_violations = []
        is_consistent = True

        if 'arithmetic' in constraint_types:
            arith_consistent, arith_violations = self.check_arithmetic_consistency(text)
            if not arith_consistent:
                is_consistent = False
                all_violations.extend(arith_violations)

        if 'logical' in constraint_types:
            logic_consistent, logic_violations = self.check_logical_consistency(text)
            if not logic_consistent:
                is_consistent = False
                all_violations.extend(logic_violations)

        return (is_consistent, all_violations)

### Test the Constraint Checker

## Phi-2 Integration with Constraint-Guided Beam Search

Implementation of Algorithm 1 from the paper with:
1. 4-bit quantized Phi-2 model loading
2. Modified beam search with constraint verification
3. Incremental constraint checking
4. Early pruning for efficiency

In [None]:
# load Phi-2 with 4-bit quantization
print("Loading Phi-2 model with 4-bit quantization...")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.float16,
)

print(f"✓ Model loaded successfully!")
print(f"  Model size: {model.get_memory_footprint() / 1e9:.2f} GB")

Loading Phi-2 model with 4-bit quantization...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Model loaded successfully!
  Model size: 1.78 GB


In [None]:
@dataclass
class BeamState:
    """State for a single beam during search"""
    sequence: List[int]  # token IDs
    text: str  # decoded text
    score: float  # log probability
    violations: Set[str]  # accumulated constraint violations
    is_finished: bool = False

class ConstraintGuidedDecoder:
    """Modified beam search with constraint enforcement (Algorithm 1)"""

    def __init__(self, model, tokenizer, constraint_checker, beam_width=4, lambda_weight=1.0):
        self.model = model
        self.tokenizer = tokenizer
        self.checker = constraint_checker
        self.beam_width = beam_width
        self.lambda_weight = lambda_weight

    def get_top_k_tokens(self, input_ids: torch.Tensor, k: int) -> List[Tuple[int, float]]:
        """
        Get top-k next tokens with their log probabilities
        """
        with torch.no_grad():
            outputs = self.model(input_ids)
            logits = outputs.logits[:, -1, :]  # get logits for last position

            #  softmax to get probabilities
            probs = torch.nn.functional.softmax(logits, dim=-1)
            log_probs = torch.log(probs)

            #  top-k
            top_log_probs, top_indices = torch.topk(log_probs, k, dim=-1)

            # -> list of (token_id, log_prob) tuples
            results = []
            for idx, log_prob in zip(top_indices[0].cpu().tolist(), top_log_probs[0].cpu().tolist()):
                results.append((idx, log_prob))

            return results

    def calculate_constraint_penalty(self, text: str, prev_violations: Set[str]) -> float:
        """
        Calculate phi(y_t) from Equation 1
        Returns: 0 if constraints satisfied, or -inf if new violations
        """
        is_consistent, new_violations = self.checker.check_constraints(text)

        # convert violations list to set
        current_violations = set(new_violations)

        # check if new violations (Eq. 2)
        if not current_violations.issubset(prev_violations):
            return float('-inf')  # new violations - prune beam

        return 0.0  # No new violations

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 100,
        constraint_types: List[str] = None,
        temperature: float = 1.0,
    ) -> Dict:
        """
        Main constraint-guided generation (Algorithm 1)

        Returns:
            dict with 'text', 'score', 'violations', 'stats'
        """
        start_time = time.time()

        # initialiaze
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
        initial_state = BeamState(
            sequence=input_ids[0].tolist(),
            text=prompt,
            score=0.0,
            violations=set()
        )

        beams = [initial_state]

        tokens_generated = 0
        candidates_pruned = 0

        # beam search loop
        for step in range(max_new_tokens):
            if all(beam.is_finished for beam in beams):
                break

            all_candidates = []

            # expand each beam
            for beam in beams:
                if beam.is_finished:
                    # keep finished beams, nothing else to do
                    all_candidates.append(beam)
                    continue

                # get top 2k tokens for this beam
                beam_input_ids = torch.tensor([beam.sequence]).to(self.model.device)
                top_tokens = self.get_top_k_tokens(beam_input_ids, k=2 * self.beam_width)

                # eval each token
                for token_id, log_prob in top_tokens:
                    # create new seq
                    new_sequence = beam.sequence + [token_id]
                    new_text = self.tokenizer.decode(new_sequence, skip_special_tokens=True)

                    # check constraints incrementally
                    constraint_penalty = self.calculate_constraint_penalty(
                        new_text, beam.violations
                    )

                    # prune iff constraints violated (Eq. 1)
                    if constraint_penalty == float('-inf'):
                        candidates_pruned += 1
                        continue

                    # calc score w/ constraint penalty
                    new_score = beam.score + log_prob + self.lambda_weight * constraint_penalty

                    # check if sequence is finished
                    is_finished = (token_id == self.tokenizer.eos_token_id)

                    # gen new beam state
                    new_beam = BeamState(
                        sequence=new_sequence,
                        text=new_text,
                        score=new_score,
                        violations=beam.violations.copy(),
                        is_finished=is_finished
                    )

                    all_candidates.append(new_beam)

            # get top-k candidates
            all_candidates.sort(key=lambda x: x.score, reverse=True)
            beams = all_candidates[:self.beam_width]

            tokens_generated += 1

        # return best complete sequence
        best_beam = max(beams, key=lambda x: x.score)

        end_time = time.time()

        return {
            'text': best_beam.text,
            'score': best_beam.score,
            'violations': list(best_beam.violations),
            'stats': {
                'tokens_generated': tokens_generated,
                'candidates_pruned': candidates_pruned,
                'time_seconds': end_time - start_time,
                'final_beam_count': len(beams)
            }
        }

### Initialize the Constraint-Guided Decoder

In [None]:
# create decoder instance
constraint_checker = ConstraintChecker()
decoder = ConstraintGuidedDecoder(
    model=model,
    tokenizer=tokenizer,
    constraint_checker=constraint_checker,
    beam_width=6,  # can reduce  this width to 2 if OOM
    lambda_weight=1.0
)

print("✓ Constraint-Guided Decoder initialized!")

✓ Constraint-Guided Decoder initialized!


## Testing and Validation

Test the complete pipeline with arithmetic reasoning examples

In [None]:


import time

print("="*80)
print("TESTS")
print("="*80)

# ========================================================================
# TEST 1: Basic Arithmetic Consistency
# ========================================================================
print("\n" + "="*80)
print("TEST 1: Basic Arithmetic Consistency")
print("="*80)
print("Setup: Give model x = 5, ask for 2x")
print("Expected: Should output 10 (2*5=10)")
print("Testing: Does constraint system prevent wrong answers?\n")

prompt1 = "x = 5. Calculate 2x.\nAnswer: 2x ="

# Run constrained generation
print("Running constrained generation...")
start = time.time()
constrained_result = decoder.generate(
    prompt=prompt1,
    max_new_tokens=20,
    constraint_types=['arithmetic']
)
constrained_time = time.time() - start

print(f"\n--- CONSTRAINED OUTPUT ---")
print(f"Full text: {constrained_result['text']}")
print(f"\nConstraint violations: {len(constrained_result['violations'])}")
if constrained_result['violations']:
    for v in constrained_result['violations']:
        print(f"  - {v}")
print(f"Candidates pruned: {constrained_result['stats']['candidates_pruned']}")
print(f"Generation time: {constrained_time:.2f}s")

# Check if answer is correct
constrained_correct = "10" in constrained_result['text']

print(f"\n--- RESULT ---")
print(f"Contains '10': {'YES' if constrained_correct else 'NO'}")
print(f"Status: {'PASS' if constrained_correct and len(constrained_result['violations'])==0 else 'FAIL'}")

# ========================================================================
# TEST 2: Baseline Comparison (No Constraints)
# ========================================================================
print("\n\n" + "="*80)
print("TEST 2: Baseline vs Constrained Comparison")
print("="*80)
print("Setup: x = 7, y = 3, calculate x + y")
print("Expected: Both should output 10")
print("Testing: Does constraint add overhead? Does it help accuracy?\n")

prompt2 = "Given x = 7 and y = 3, what is x + y?\nAnswer: x + y ="

# Baseline (no constraints)
print("Running BASELINE (standard beam search, no constraints)...")
input_ids = tokenizer.encode(prompt2, return_tensors="pt").to(model.device)
start = time.time()
baseline_output = model.generate(
    input_ids,
    max_new_tokens=20,
    num_beams=4,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=False
)
baseline_time = time.time() - start
baseline_text = tokenizer.decode(baseline_output[0], skip_special_tokens=True)

# Check baseline for violations
baseline_consistent, baseline_viols = constraint_checker.check_constraints(baseline_text, ['arithmetic'])

print(f"\n--- BASELINE OUTPUT ---")
print(f"Full text: {baseline_text}")
print(f"Constraint violations: {len(baseline_viols)}")
if baseline_viols:
    for v in baseline_viols:
        print(f"  - {v}")
print(f"Contains '10': {'YES' if '10' in baseline_text else 'NO'}")
print(f"Generation time: {baseline_time:.2f}s")

# Constrained
print("\nRunning CONSTRAINED generation...")
start = time.time()
constrained_result2 = decoder.generate(
    prompt=prompt2,
    max_new_tokens=20,
    constraint_types=['arithmetic']
)
constrained_time2 = time.time() - start

print(f"\n--- CONSTRAINED OUTPUT ---")
print(f"Full text: {constrained_result2['text']}")
print(f"Constraint violations: {len(constrained_result2['violations'])}")
if constrained_result2['violations']:
    for v in constrained_result2['violations']:
        print(f"  - {v}")
print(f"Contains '10': {'YES' if '10' in constrained_result2['text'] else 'NO'}")
print(f"Candidates pruned: {constrained_result2['stats']['candidates_pruned']}")
print(f"Generation time: {constrained_time2:.2f}s")

# Comparison
print(f"\n--- COMPARISON ---")
print(f"Baseline violations: {len(baseline_viols)}")
print(f"Constrained violations: {len(constrained_result2['violations'])}")
print(f"Time overhead: {constrained_time2 - baseline_time:.2f}s ({((constrained_time2/baseline_time - 1)*100):.1f}% slower)")

improvement = len(baseline_viols) - len(constrained_result2['violations'])
if improvement > 0:
    print(f"IMPROVEMENT: Reduced {improvement} violation(s)")
elif improvement < 0:
    print(f"WORSE: Added {abs(improvement)} violation(s)")
else:
    print(f"= SAME: No difference in violations")

# ========================================================================
# TEST 3: Contradiction Prevention
# ========================================================================
print("\n\n" + "="*80)
print("TEST 3: Preventing Contradictions")
print("="*80)
print("Setup: State a = 6, then ask model to calculate 2a")
print("Expected: Should NOT say things like '2a = 15' (would be wrong)")
print("Testing: Can we prevent the model from generating false arithmetic?\n")

prompt3 = "We know that a = 6. Now calculate 2a.\nSolution: Since a = 6, we have 2a ="

print("Running constrained generation...")
constrained_result3 = decoder.generate(
    prompt=prompt3,
    max_new_tokens=25,
    constraint_types=['arithmetic']
)

print(f"\n--- CONSTRAINED OUTPUT ---")
print(f"Full text: {constrained_result3['text']}")
print(f"\nConstraint violations: {len(constrained_result3['violations'])}")
if constrained_result3['violations']:
    for v in constrained_result3['violations']:
        print(f"  - {v}")
print(f"Candidates pruned: {constrained_result3['stats']['candidates_pruned']}")

# Check correctness (2*6 = 12)
correct_answer = "12" in constrained_result3['text']

print(f"\n--- RESULT ---")
print(f"Expected answer (12): {'FOUND ✓' if correct_answer else 'NOT FOUND ✗'}")
print(f"No contradictions: {'YES ✓' if len(constrained_result3['violations'])==0 else 'NO ✗'}")
print(f"Status: {'PASS' if correct_answer and len(constrained_result3['violations'])==0 else 'FAIL'}")

# ========================================================================
# TEST 4: Multi-Variable Consistency
# ========================================================================
print("\n\n" + "="*80)
print("TEST 4: Multi-Variable Arithmetic")
print("="*80)
print("Setup: m = 4, n = 5, calculate m * n")
print("Expected: Should output 20")
print("Testing: Can handle multiple variables correctly?\n")

prompt4 = "Let m = 4 and n = 5. What is m * n?\nAnswer: m * n ="

constrained_result4 = decoder.generate(
    prompt=prompt4,
    max_new_tokens=20,
    constraint_types=['arithmetic']
)

print(f"--- CONSTRAINED OUTPUT ---")
print(f"Full text: {constrained_result4['text']}")
print(f"Violations: {len(constrained_result4['violations'])}")
print(f"Contains '20': {'YES ✓' if '20' in constrained_result4['text'] else 'NO ✗'}")
print(f"Status: {'PASS' if '20' in constrained_result4['text'] else 'FAIL'}")

# ========================================================================
# SUMMARY
# ========================================================================
print("\n\n" + "="*80)
print("SUMMARY")
print("="*80)

tests = [
    ("Test 1: Basic arithmetic (x=5, 2x=?)",
     len(constrained_result['violations']) == 0),
    ("Test 2: Reduced violations vs baseline",
     len(constrained_result2['violations']) <= len(baseline_viols)),
    ("Test 3: Contradiction prevention (a=6, 2a=?)",
     len(constrained_result3['violations']) == 0),
    ("Test 4: Multi-variable (m=4,n=5,m*n=?)",
     len(constrained_result4['violations']) == 0)
]

passed = sum(1 for _, result in tests if result)

for test_name, result in tests:
    status = "PASS" if result else "FAIL"
    print(f"{status}: {test_name}")

print(f"\nOverall: {passed}/{len(tests)} tests passed")


TESTS

TEST 1: Basic Arithmetic Consistency
Setup: Give model x = 5, ask for 2x
Expected: Should output 10 (2*5=10)
Testing: Does constraint system prevent wrong answers?

Running constrained generation...

--- CONSTRAINED OUTPUT ---
Full text: x = 5. Calculate 2x.
Answer: 2x = (2 * 5) = 10.

Exercise 2: Simplify the expression 3(

Constraint violations: 0
Candidates pruned: 10
Generation time: 9.89s

--- RESULT ---
Contains '10': YES
Status: PASS


TEST 2: Baseline vs Constrained Comparison
Setup: x = 7, y = 3, calculate x + y
Expected: Both should output 10
Testing: Does constraint add overhead? Does it help accuracy?

Running BASELINE (standard beam search, no constraints)...

--- BASELINE OUTPUT ---
Full text: Given x = 7 and y = 3, what is x + y?
Answer: x + y = 7 + 3 = 10

Exercise 2: Solve the equation 2x - 5 =
Constraint violations: 2
  - Contradiction: y = 3.0 but later y = 7.0
  - Arithmetic error: x + y should be 10.0 but text claims 7.0
Contains '10': YES
Generation time: 1

## Utility Functions and Analysis Tools

In [None]:
def evaluate_on_dataset(decoder, problems: List[Dict], max_new_tokens=100):
    """
    Evaluate decoder on a list of problems
    Each problem should have 'prompt' and optionally 'expected_answer'
    """
    results = []

    for i, problem in enumerate(problems):
        print(f"\nProblem {i+1}/{len(problems)}")
        print(f"Prompt: {problem['prompt'][:100]}...")

        result = decoder.generate(
            prompt=problem['prompt'],
            max_new_tokens=max_new_tokens,
            constraint_types=['arithmetic']
        )

        results.append({
            'problem': problem,
            'output': result['text'],
            'violations': result['violations'],
            'stats': result['stats']
        })

        print(f"Violations: {len(result['violations'])}")
        print(f"Time: {result['stats']['time_seconds']:.2f}s")

    return results

def analyze_results(results: List[Dict]):
    """
    Analyze evaluation results
    """
    total = len(results)
    violation_free = sum(1 for r in results if len(r['violations']) == 0)
    avg_time = np.mean([r['stats']['time_seconds'] for r in results])
    avg_pruned = np.mean([r['stats']['candidates_pruned'] for r in results])

    print("\n" + "="*80)
    print("EVALUATION SUMMARY")
    print("="*80)
    print(f"Total problems: {total}")
    print(f"Violation-free outputs: {violation_free} ({100*violation_free/total:.1f}%)")
    print(f"Average generation time: {avg_time:.2f}s")
    print(f"Average candidates pruned: {avg_pruned:.1f}")
    print("="*80)

# Example dataset for testing
sample_problems = [
    {
        'prompt': "Question: If x = 7, what is 3x?\nAnswer:",
        'expected': 21
    },
    {
        'prompt': "Question: Sarah has 15 cookies. She eats 3. How many are left?\nAnswer:",
        'expected': 12
    },
    {
        'prompt': "Question: A rectangle has length 5 and width 3. What is its area?\nAnswer:",
        'expected': 15
    },
]

print("✓ Evaluation utilities ready!")
print(f"  Sample dataset: {len(sample_problems)} problems")

✓ Evaluation utilities ready!
  Sample dataset: 3 problems


##  Evaluation on Sample Problems

In [None]:
 results = evaluate_on_dataset(decoder, sample_problems, max_new_tokens=50)
 analyze_results(results)

## Optimization and Ablation Studies



In [None]:
def ablation_beam_width(prompt: str, beam_widths=[2, 4, 8]):
    """
    Test impact of beam width on performance
    """
    results = {}

    for width in beam_widths:
        print(f"\nTesting beam width = {width}")

        decoder_temp = ConstraintGuidedDecoder(
            model=model,
            tokenizer=tokenizer,
            constraint_checker=ConstraintChecker(),
            beam_width=width,
            lambda_weight=1.0
        )

        result = decoder_temp.generate(prompt, max_new_tokens=50)
        results[width] = result

        print(f"  Time: {result['stats']['time_seconds']:.2f}s")
        print(f"  Violations: {len(result['violations'])}")
        print(f"  Pruned: {result['stats']['candidates_pruned']}")

    return results

def ablation_lambda_weight(prompt: str, lambda_values=[0.5, 1.0, 2.0]):
    """
    Test impact of constraint strictness (lambda)
    """
    results = {}

    for lambda_val in lambda_values:
        print(f"\nTesting lambda = {lambda_val}")

        decoder_temp = ConstraintGuidedDecoder(
            model=model,
            tokenizer=tokenizer,
            constraint_checker=ConstraintChecker(),
            beam_width=4,
            lambda_weight=lambda_val
        )

        result = decoder_temp.generate(prompt, max_new_tokens=50)
        results[lambda_val] = result

        print(f"  Score: {result['score']:.4f}")
        print(f"  Violations: {len(result['violations'])}")

    return results

print("ablation study functions ready!")

✓ Ablation study functions ready!


## Save and Export Functions

In [None]:
import json
from datetime import datetime

def save_results(results: Dict, filename: str = None):
    """
    Save experimental results to JSON
    """
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"results_{timestamp}.json"

    # convert sets to lists for JSON serialization
    serializable_results = {}
    for key, value in results.items():
        if isinstance(value, dict):
            serializable_results[key] = {
                k: list(v) if isinstance(v, set) else v
                for k, v in value.items()
            }
        else:
            serializable_results[key] = value

    with open(filename, 'w') as f:
        json.dump(serializable_results, f, indent=2)

    print(f"Results saved to {filename}")

print("Save functions ready!")

✓ Save functions ready!


## Interactive Testing Cell



In [None]:
# feel free to change this prompt and run to test
custom_prompt = """Question: If a train travels 60 miles in 2 hours, what is its speed?
Answer: Let me calculate this."""

custom_result = decoder.generate(
    prompt=custom_prompt,
    max_new_tokens=50,
    constraint_types=['arithmetic']
)

print("Generated Text:")
print(custom_result['text'])
print(f"\nViolations: {custom_result['violations']}")
print(f"Generation time: {custom_result['stats']['time_seconds']:.2f}s")
print(f"Candidates pruned: {custom_result['stats']['candidates_pruned']}")

Generated Text:
Question: If a train travels 60 miles in 2 hours, what is its speed?
Answer: Let me calculate this. To find the speed, we need to divide the distance traveled by the time taken. In this case, the train traveled 60 miles in 2 hours. So, the speed would be 60 miles divided by 2 hours, which equals 30 miles per hour.

Violations: []
Generation time: 26.63s
Candidates pruned: 0
