In [24]:
# !pip install git+https://github.com/uiuc-focal-lab/syncode.git lark huggingface_hub

import os
import json
import re
from dataclasses import dataclass
from typing import List, Tuple, Optional
from syncode import Syncode
from transformers import AutoTokenizer

MODEL_NAME = "meta-llama/Llama-3.2-1B"

MATH_GRAMMAR = r"""
start: "{" WS "\"" "operands" "\"" WS ":" WS "[" WS float WS "," WS float WS "]" WS "," WS "\"" "operator" "\"" WS ":" WS operator WS "}"

float: /[0-9]+(\.[0-9]+)?/

operator: "+" | "-" | "*" | "/"

WS: /[ \t\n\r]+/

%ignore WS
"""

@dataclass
class MATHOperation:
    operands: Tuple[float, float]
    operator: str
    result: float

class MATHProcessor:

    def __init__(self, model_name: str = MODEL_NAME, hf_token: Optional[str] = None):
        if hf_token:
            os.environ["HF_TOKEN"] = hf_token
            os.system("huggingface-cli login --token $HF_TOKEN")

        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.syn_llm = self._initialize_syncode()
        self.EXAMPLES_QUERY = """
        These are samples for extracting floating point operands and the operator from a query.
        Learn this way and follow this formatting for your answer.

        Example 1:
        Query: "What is 10.5 times 10.0?"
        {"operands": [10.5, 10.0], "operator": "*"}

        Example 2:
        Query: "What is 10.5 minus 10.0?"
        {"operands": [100.5, 10.0], "operator": "-"}

        Example 3:
        Query: "What is 999.0 divided by 999.0?"
        {"operands": [999.0, 999.0], "operator": "/"}
        """

    def _initialize_syncode(self) -> Syncode:
        return Syncode(
            model=self.model_name,
            grammar=MATH_GRAMMAR,
            parse_output_only=True,
            max_new_tokens=200,
            mode='grammar_mask',
            pad_token_id=self.tokenizer.eos_token_id,
            do_sample=False,
        )

    def _extract_json(self, text: str) -> dict:
        match = re.search(r"\{.*?\}", text)
        if not match:
            raise ValueError("No valid JSON object found in the output.")

        try:
            return json.loads(match.group(0))
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {e}")

    def _compute_result(self, operands: List[float], operator: str) -> float:
        op1, op2 = operands
        operations = {
            "+": lambda x, y: x + y,
            "-": lambda x, y: x - y,
            "*": lambda x, y: x * y,
            "/": lambda x, y: x / y if y != 0 else float('inf')
        }

        if operator not in operations:
            raise ValueError(f"Unsupported operator: {operator}")

        return operations[operator](op1, op2)

    def process_query(self, query: str) -> MATHOperation:
        prompt = f'{self.EXAMPLES_QUERY}\nQuery: "{query.strip()}"\n'

        extraction = self.syn_llm.infer(prompt)[0].strip()

        data = self._extract_json(extraction)
        operands = tuple(map(float, data["operands"]))
        operator = data["operator"]

        result = self._compute_result(operands, operator)
        return MATHOperation(operands, operator, result)

def run_syncode(textual_query: str) -> Tuple[List[float], str, float]:
    processor = MATHProcessor(hf_token="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") # ENTER TOKEN HERE!!!
    try:
        operation = processor.process_query(textual_query)
        operands_list = list(operation.operands)
        return (operands_list, operation.operator, operation.result)

    except Exception as e:
        raise ValueError(f"Error processing query: {str(e)}")

if __name__ == "__main__":
    test_queries = [
        "What is (-327.0) multiplied by 11.0?",
        "What is 45.1 plus -23.54?",
        "What is 120.4 divided by 4.0?"
    ]

    for query in test_queries:
        try:
            operands, operator, result = run_syncode(query)
            # print(f"Query: {query}")
            print(f"Operands: {operands}")
            print(f"Operator: {operator}")
            print(f"Result: {result}")
            print()
        except ValueError as e:
            print(f"Error: {e}\n")

Query: What is (-327.0) multiplied by 11.0?
Operands: [-327.0, 11.0]
Operator: *
Result: -3597.0

Query: What is 45.1 plus -23.54?
Operands: [45.1, -23.54]
Operator: +
Result: 21.560000000000002

Query: What is 120.4 divided by 4.0?
Operands: [120.4, 4.0]
Operator: /
Result: 30.1

