### Import modules

In [None]:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import json
from openai import OpenAI

### Import tools (API client)

In [None]:
from tools import (
    get_top_selling_products,
    get_top_categories,
    get_sales_trends,
    get_revenue_by_category
)

### Initialize Cloud LLM

In [None]:
cloud_model="gpt-4o"
cloud_llm = OpenAI()

### Initialize Edge LLM

In [None]:
edge_model="phi3:mini"
edge_llm = OpenAI(
    base_url='http://10.0.0.125:11434/v1/',
    api_key='_',
)

### Initialize Vector DB

In [None]:
chroma_client = chromadb.PersistentClient(path="./data")
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
collection = chroma_client.get_or_create_collection(name="products", embedding_function=embedding_function)

### Helper Functions

#### Map tools to prompt (Cloud LLM)

In [None]:
def map_tools(prompt):
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_top_selling_products",
                "description": "Retrieve top-selling products for a specified period",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "start_date": {
                            "type": "string",
                            "description": "The start date for the period (YYYY-MM-DD)"
                        },
                        "end_date": {
                            "type": "string",
                            "description": "The end date for the period (YYYY-MM-DD)"
                        }
                    },
                    "required": ["start_date", "end_date"]
                }
            }
        },
        {
            "type": "function",
            "function": {
                "name": "get_top_categories",
                "description": "Retrieve top-selling categories for a specified period",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "start_date": {
                            "type": "string",
                            "description": "The start date for the period (YYYY-MM-DD)"
                        },
                        "end_date": {
                            "type": "string",
                            "description": "The end date for the period (YYYY-MM-DD)"
                        }
                    },
                    "required": ["start_date", "end_date"]
                }
            }
        },
        {
            "type": "function",
            "function": {
                "name": "get_sales_trends",
                "description": "Retrieve sales trends over a specified period",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "start_date": {
                            "type": "string",
                            "description": "The start date for the period (YYYY-MM-DD)"
                        },
                        "end_date": {
                            "type": "string",
                            "description": "The end date for the period (YYYY-MM-DD)"
                        }
                    },
                    "required": ["start_date", "end_date"]
                }
            }
        },
        {
            "type": "function",
            "function": {
                "name": "get_revenue_by_category",
                "description": "Retrieve the revenue generated by each category over a specified period",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "start_date": {
                            "type": "string",
                            "description": "The start date for the period (YYYY-MM-DD)"
                        },
                        "end_date": {
                            "type": "string",
                            "description": "The end date for the period (YYYY-MM-DD)"
                        }
                    },
                    "required": ["start_date", "end_date"]
                }
            }
        }
    ]

    messages = [{"role": "user", "content": prompt}]
    response = cloud_llm.chat.completions.create(
        model=cloud_model,
        messages=messages,
        tools=tools,
        tool_choice="auto"
    )
    
    # Ensure response has valid tool_calls
    response_message = response.choices[0].message
    tool_calls = getattr(response_message, 'tool_calls', None)

    functions = []
    if tool_calls:
        for tool in tool_calls:
            function_name = tool.function.name
            arguments = json.loads(tool.function.arguments)
            functions.append({
                "function_name": function_name,
                "arguments": arguments
            })

    return functions

#### Execute tools locally (Invoke API)

In [None]:
import json

def execute_tools(functions):
    local_functions = {
        "get_top_selling_products": get_top_selling_products,
        "get_top_categories": get_top_categories,
        "get_sales_trends": get_sales_trends,
        "get_revenue_by_category": get_revenue_by_category
    }

    combined_results = []
    for detail in functions:
        function_name = detail["function_name"]
        arguments = detail["arguments"]
        function_result = local_functions[function_name](**arguments)
        if isinstance(function_result, list):
            combined_results.extend(function_result)
        else:
            combined_results.append(function_result)
    return combined_results

#### Retrieve from Context

In [None]:
def retriever(query):
    vector = embedding_function([query])
    results = collection.query(    
        query_embeddings=vector,
        n_results=5,
        include=["documents"]
    )
    res = " \n".join(str(item) for item in results['documents'][0])
    return res

In [None]:
retriever("connectivity options of Nimbus Book")

#### Generate answer (Edge LLM)

In [None]:
def generate_response(prompt,context):
    input_text = (
        "Based on the below context, respond with a concise answer in a single sentence. If you don't find the answer within the context, say I do not know. Don't repeat the question\n\n"
        f"{context}\n\n"
        f"{prompt}"
    )
    response = edge_llm.chat.completions.create(
        model=edge_model,
        messages=[
            {"role": "user", "content": input_text},
        ],
        max_tokens=150,
        temperature=0
    )

    return response.choices[0].message.content.strip()

#### Test helper functions

In [None]:
tools=map_tools("What was the top selling product in Q2 based on revenue?")

In [None]:
tools

In [None]:
tool_output=execute_tools(tools)

In [None]:
#tool_output

### Agent to federate the LMs

In [None]:
def agent(prompt):
    tools = map_tools(prompt)
    
    if tools:    
        tool_output = execute_tools(tools)
        context = json.dumps(tool_output)       
    else:
        context = retriever(prompt)
        
    response = generate_response(prompt, context)
    return response

In [None]:
agent("What was the top selling product in June based on revenue?")

In [None]:
agent("What is the CPU of Nimbus Book?")