<a href="https://colab.research.google.com/github/lucasl02/SyncodeMath/blob/main/syncode_mathematical_programming.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Syncode Inference

In [None]:
!pip install syncode

Collecting syncode
  Downloading syncode-0.4.16.tar.gz (206 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/206.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.1/206.1 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting fire (from syncode)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Collecting interegular (from syncode)
  Downloading interegular-0.3.3-py37-none-any.whl.metadata (3.0 kB)
Collecting regex==2024.11.6 (from syncode)
  Downloading regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Collecting transformers==4.53.2 (from syncode)
  Download

In [None]:
from syncode import Syncode
import warnings
warnings.filterwarnings('ignore')

model_name = "microsoft/phi-2"

# Load the Syncode augmented model
syn_llm = Syncode(model=model_name, grammar='json', max_new_tokens=400)

In [None]:
import json
import operator

def run_syncode(textual_query):
    """
    Takes a textual query, uses Syncode to parse operands and operator,
    and returns (operands, operator, final_answer).
    """
    # 1. mnake prompt
    # use example as a one shot encoding exmaple
    prompt = f"""Parse the arithmetic query into a JSON object with keys "operands" (list of numbers) and "operator" (string: "+", "-", "*", "/").

              EXAMPLES:
              Query: What is 327. multiplied by 11.0?
              JSON: {{ "operands": [327.0, 11.0], "operator": "*" }}

              Query: What is 45.1 plus 23.54?
              JSON: {{ "operands": [45.1, 23.54], "operator": "+" }}

              NOW:
              Query: {textual_query}
              JSON: """

    output = syn_llm.infer(prompt)[0]
    # raw output has lots of empty lines
    #print(f"Raw output: {output}")

    # get rid of all the white space
    start_idx = output.find('{')
    end_idx = output.rfind('}')

    if start_idx != -1 and end_idx != -1:
        json_str = output[start_idx : end_idx + 1]
        data = json.loads(json_str)
    else:
        data = json.loads(output)

    # get operands
    operands = data.get("operands", [])
    operands = [float(x) for x in operands]


    # get operator
    op_symbol = data.get("operator", "")


    answer = 0.0
    # make sure we have two number sbeing returned
    if len(operands) >= 2:
        if op_symbol == '+':
            answer = operands[0] + operands[1]
        elif op_symbol == '-':
            answer = operands[0] - operands[1]
        elif op_symbol == '*':
            answer = operands[0] * operands[1]
        elif op_symbol == '/':
            answer = operands[0] / operands[1]
    #print((operands, op_symbol, float(answer)))
    return (operands, op_symbol, float(answer))

In [None]:
queries = ["What is 327. multiplied by 11.0?",
    "What is 45.1 plus 23.54?",
    "What is 120.4 divided by 4.0?"]

for q in queries:
    answer = run_syncode(q)
    print(f"Query: {q}")
    print(f"Result: {answer}")
    print()

Query: What is 327. multiplied by 11.0?
Result: 3597.0

Query: What is 45.1 plus 23.54?
Result: 68.64

Query: What is 120.4 divided by 4.0?
Result: 30.1

