In [None]:
import os
import json
import huggingface_hub

In [None]:
from syncode import Syncode

prompt_template = """
Query: What is 327. multiplied by 11.0?

Output: Operands: 327.0 and 11.0 with Operator: *

Query: What is 327 divided by 11?

Output: Operands: 327.0 and 11.0 with Operator: /

Query: What is 326. plus 11.000

Output: Operands: 326.0 and 11.0 with Operator: +

Query: What is 326 minus .5?

Output: Operands: 326.0 and 0.5 with Operator: -

Query: What is 327 divides 11?

Output: Operands: 327.0 and 11.0 with Operator: /

Query: What is 326 subtracted by 11.5?

Output: Operands: 326.0 and 11.5 with Operator: -

Query: {query}

Output: 
"""

def parse_output(output):
    [operands_split, operator_split] = output.split("with")
    idx = operands_split.find("Operands:")
    operands_split = operands_split[idx + len("Operands:"):].strip()
    operands_split = operands_split.strip().split("and")
    operand1 = float(operands_split[0].strip())
    operand2 = float(operands_split[1].strip())
    idx = operator_split.find("Operator:")
    operator_split = operator_split[idx + len("Operator:"):].strip()
    operator = operator_split.strip()
    return operand1, operand2, operator

def run_operation(operand1, operand2, operator):
    if operator == "+":
        return operand1 + operand2
    elif operator == "-":
        return operand1 - operand2
    elif operator == "*":
        return operand1 * operand2
    elif operator == "/":
        return operand1 / operand2
    else:
        raise ValueError(f"Invalid operator: {operator}")
grammar = """start: "Operands: " operand " and " operand " with " " Operator: " OPERATOR

operand: integer "." integer

integer: DIGIT+

DIGIT: /[0-9]/
    
OPERATOR: "+" | "-" | "*" | "/"
"""
model_name = "meta-llama/Llama-3.2-1B"

syn_llm = Syncode(model=model_name, grammar=grammar, parse_output_only=True, quantize=True, device="cuda")
    
def run_syncode(textual_query):
    output = syn_llm.infer(prompt_template.format(query=textual_query))
    operand1, operand2, operator = parse_output(output[0])
    return ([operand1, operand2], str(operator), float(run_operation(operand1, operand2, operator)))

query = "What is 327. multiplied by 11.0?"
run_syncode(query)

([327.0, 11.0], '*', 3597.0)

In [20]:
class SyncodeTester:
    def __init__(self, run_syncode, verbose=False):
        self.run_syncode = run_syncode
        self.test_cases = []
        self.answers = []
        self.verbose = verbose
    
    def add_test_case(self, query, answer):
        self.test_cases.append(query)
        self.answers.append(answer)

    def run_tests(self):
        correct = 0
        total = 0
        for i, query in enumerate(self.test_cases):
            correct += int(self.run_test(i))
            total += 1
        
        print(f"Correct Acc: {correct}/{total}")
        print(f"Correct: {correct}, Total: {total}")
        return correct, total
    
    def run_test(self, idx):
        assert idx < len(self.test_cases), f"Index {idx} out of range"
        
        query = self.test_cases[idx]
        try:
            ([operand1, operand2], operator, result) = self.run_syncode(query)
            if self.verbose:
                print(f"Query: {query}")
                print(f"Output: Operands: {operand1} and {operand2} with Operator: {operator}")
                print(f"Expected: {self.answers[idx]}")
                print(f"Result: {result}")
            assert type(operand1) == float
            assert type(operand2) == float
            assert type(operator) == str
            assert type(result) == float
            expected_result = self.answers[idx]
            assert round(result, 3) == round(expected_result, 3)
            return True
        except:
            return False
    

In [21]:
import random
def make_testcases(num=100, verbose=False):
    plus_options = ["plus", "add", "added to"]
    minus_options = ["minus", "subtracted by"]
    multiply_options = ["multiply", "multiplied by"]
    divide_options = ["divides", "divided by"]
    
    test_cases = []
    answers = []
    idx2operator = {0: "+", 1: "-", 2: "*", 3: "/"}
    idx2operator_str = {0: plus_options, 1: minus_options, 2: multiply_options, 3: divide_options}
    idxnums = {0: 0, 1: 0, 2: 0, 3: 0}
    
    for i in range(num):
        operator_idx = random.choice([0, 1, 2, 3])
        operator = idx2operator[operator_idx]
        operand1 = round(random.uniform(0, 1000), random.choice([0, 1, 2]))
        operand2 = round(random.uniform(0, 1000), random.choice([0, 1, 2]))
        
        operator_embed = random.choice(idx2operator_str[operator_idx])
        
        while operator_embed not in ["divides", "subtracted from"] and operand2 == 0:
            operand2 = random.uniform(0, 1000)
        
        query = f"{random.choice(['What is ', 'What '])}{operand1} {operator_embed} {operand2} {random.choice(['?', ''])}"
        test_cases.append(query)
        answers.append(run_operation(operand1, operand2, operator))
        idxnums[operator_idx] += 1
    
    if verbose:
        print(f"Number of test cases for each operator: {idxnums}")
        
    return test_cases, answers

In [22]:
import warnings
warnings.filterwarnings("ignore")
import transformers
transformers.logging.set_verbosity_error()

In [23]:
syncode_tester = SyncodeTester(run_syncode, verbose=True)
testcases = make_testcases(10, verbose=True)
for query, answer in zip(*testcases):
    syncode_tester.add_test_case(query, answer)

syncode_tester.verbose = True
syncode_tester.run_tests()

Number of test cases for each operator: {0: 5, 1: 1, 2: 2, 3: 2}
Query: What 458.05 plus 241.8 
Output: Operands: 458.05 and 241.8 with Operator: +
Expected: 699.85
Result: 699.85
Query: What is 739.0 multiply 244.9 ?
Output: Operands: 739.0 and 244.9 with Operator: *
Expected: 180981.1
Result: 180981.1
Query: What 879.69 divided by 558.0 
Output: Operands: 879.69 and 558.0 with Operator: /
Expected: 1.5765053763440862
Result: 1.5765053763440862
Query: What is 36.0 added to 962.0 
Output: Operands: 36.0 and 962.0 with Operator: +
Expected: 998.0
Result: 998.0
Query: What is 930.03 divides 674.0 ?
Output: Operands: 930.03 and 674.0 with Operator: /
Expected: 1.37986646884273
Result: 1.37986646884273
Query: What 996.3 multiplied by 57.3 
Output: Operands: 996.3 and 57.3 with Operator: *
Expected: 57087.99
Result: 57087.99
Query: What 393.98 added to 681.6 ?
Output: Operands: 393.98 and 681.6 with Operator: +
Expected: 1075.58
Result: 1075.58
Query: What is 67.0 added to 210.5 ?
Output: O

(10, 10)