# Simple Calculator with `treelang`

This cookbook shows how to implement a simple calculator that computes arithmetic expression using a small set of operations (`add`, `subtract`, `multiply`, `divide`, `power`, `squareroot`). 

With `treelang` each expression requires a *single call* to the LLM to evaluate it!

The MCP server with the tool definition for each operation is in `calculator.py` in this directory and we will use the `stdio` transport (see the [python SDK documentation](https://github.com/modelcontextprotocol/python-sdk)) to communicate with it.


## Imports


In [1]:
import asyncio

# mcp components for stdio communication
from mcp import ClientSession, StdioServerParameters, stdio_client

# our helpful `Arborist` implementation
from treelang.ai.arborist import OpenAIArborist

## Client
Most of the work here is setting up the MCP client-server communication channel and the math we will ask the `Arborist` to solve. 

In [None]:
import os
import math

from treelang.ai.arborist import EvalType
from treelang.ai.provider import MCPToolProvider
from treelang.trees.tree import AST

# the asynchronous `main` function
async def main():
    # server parameters for running the calculator server
    path = os.path.join(os.getcwd(), "calculator.py")
    server_params = StdioServerParameters(
        path=path, 
        command="python", args=[path], env=None
    )
    
    # establish a communication channel with the server
    async with stdio_client(server_params) as (read, write):
        # create a client session for interaction
        async with ClientSession(read, write) as session:
            # initialize the session
            await session.initialize()
            provider = MCPToolProvider(session)
            # create and configure the Arborist
            arborist = OpenAIArborist(model="gpt-4o-2024-11-20", provider=provider)
            # we will evaluate the following math expressions
            expressions = [
                "sqrt( ( 25 + 10 ) * 4 ) + 3^2 - 8",
                "( 15 / 3 ) + ( 2^4 - 6 ) * ( 9 - 5 )",
                "( 50 - 8 ) / 2 + sqrt( 64 ) * 3^2",
                "( 7^2 - 10 ) / 5 + sqrt( 49 ) * ( 6 - 2 )"
            ]
            # expected results for the above expressions
            expected_results = [
                math.sqrt((25 + 10) * 4) + 3**2 - 8,
                (15 / 3) + (2**4 - 6) * (9 - 5),
                (50 - 8) / 2 + math.sqrt(64) * 3**2,
                (7**2 - 10) / 5 + math.sqrt(49) * (6 - 2)
            ]
            # evaluate each expression and print the results
            for idx, expr in enumerate(expressions):
                # ask the arborist...
                response = await arborist.eval(f"can you please calculate {expr}?")
                # ... and ye shall receive
                print(f"{expr} = {response.content}")
                # but check that results are correct
                assert response.content == expected_results[idx], f"Expected {expected_results[idx]}, got {response.content}"
            # let's see what the AST for the first expression looks like
            response = await arborist.eval(f"what is {expressions[0]}?", EvalType.TREE)
            print(f"\nAST for {expressions[0]} is:\n {AST.repr(response.content)}")
            # let's see a chatty response
            question = f"can you please calculate {expressions[1]}?"
            response = await arborist.eval(question, EvalType.WALK)
            print(f"\nQUESTION: {question}")
            print(f"\nANSWER: {response.explain()}")

## Main
The MCP and the `Arborist` are asynchornous processes so we hook into the jupiter notebook's event loop and invoke the async `main`.

In [3]:
loop = asyncio.get_event_loop()
loop.create_task(main())

<Task pending name='Task-5' coro=<main() running at /tmp/ipykernel_31603/4107820128.py:9>>

sqrt( ( 25 + 10 ) * 4 ) + 3^2 - 8 = 12.83215956619923
( 15 / 3 ) + ( 2^4 - 6 ) * ( 9 - 5 ) = 45.0
( 50 - 8 ) / 2 + sqrt( 64 ) * 3^2 = 93.0
( 7^2 - 10 ) / 5 + sqrt( 49 ) * ( 6 - 2 ) = 35.8

AST for sqrt( ( 25 + 10 ) * 4 ) + 3^2 - 8 is:
 {"subtract_1": {"add_1": {"sqrt_1": {"multiply_1": {"add_2": {"a": [25], "b": [10]}, "b": [4]}}, "power_1": {"a": [3], "b": [2]}}, "b": [8]}}

QUESTION: can you please calculate ( 15 / 3 ) + ( 2^4 - 6 ) * ( 9 - 5 )?

ANSWER: Alright, let's break this down in a clear and logical way!

The question asks us to calculate the result of this expression:  
**(15 ÷ 3) + (2⁴ - 6) × (9 - 5)**

Here's how it works step-by-step:

1. **First part: (15 ÷ 3):**  
   Divide 15 by 3, which gives you **5**.

2. **Second part: (2⁴ - 6):**  
   The "2⁴" means 2 raised to the power of 4, or 2 × 2 × 2 × 2, which equals **16**.  
   Subtract 6 from 16, and you get **10**.

3. **Third part: (9 - 5):**  
   Subtract 5 from 9, and you get **4**.

4. **Final step: Combine it all!*