In [0]:
%pip install databricks_langchain

In [0]:
%sql
CREATE OR REPLACE FUNCTION alexander_genser.default.get_train_stations(search_location STRING, location_lat FLOAT, location_lon FLOAT)

RETURNS STRING COMMENT "Returns SBB stations"

LANGUAGE PYTHON

AS $$
import requests

url = "GET http://transport.opendata.ch/v1/locations"
params = {
    "query": search_location,
    "x": location_lat,
    "y": location_lon,
     "type": 'station' 
}

response = requests.get(url, params=params)

if response.status_code == 200:
   station_details = response.json()
   return station_details
else:
   return f"Failed to retrieve connection. Status code: {response.status_code}"
$$;

In [0]:
%sql
CREATE OR REPLACE FUNCTION alexander_genser.default.get_next_connection(from_station STRING, to_station STRING)

RETURNS STRING COMMENT "Returns next connection"

LANGUAGE PYTHON

AS $$
import requests

url = "http://transport.opendata.ch/v1/connections"

params = {
    "from": from_station,
    "to": to_station,
}

response = requests.get(url, params=params)

if response.status_code == 200:
   next_connection = response.json()
   return next_connection['connections']
else:
   return f"Failed to retrieve connection. Status code: {response.status_code}"
$$;

In [0]:
%sql
CREATE OR REPLACE FUNCTION alexander_genser.default.get_station_board(station STRING)

RETURNS STRING COMMENT "Returns station board"

LANGUAGE PYTHON

AS $$
import requests

url = "http://transport.opendata.ch/v1/stationboard"
params = {
    "station": station,
    "limit": 10
}

response = requests.get(url, params=params)

if response.status_code == 200:
  station_board  = response.json()
  return station_board
else:
   return f"Failed to retrieve connection. Status code: {response.status_code}"

$$;

In [0]:
from databricks_langchain.uc_ai import (
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
import pandas as pd

def display_tools(tools):
    display(pd.DataFrame([{k: str(v) for k, v in vars(tool).items()} for tool in tools]))

client = DatabricksFunctionClient()
set_uc_function_client(client)

#warehouse_id = "7ddb43212ab91b53"
tools = UCFunctionToolkit(
    function_names=["alexander_genser.default.*"]
).tools

In [0]:
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatDatabricks


# Utilize a Foundational Model API via ChatDatabricks 
llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")


# Define the prompt for the model, note the description to use the tools
prompt = ChatPromptTemplate.from_messages(
    [(
        "system",
        "You are a helpful assistant for helping clients plan their train journey. Make sure to use the available tools for information retrieval. Refer the tools description and make a decision of the tools to call for each user query.",
        ),
        ("placeholder", "{chat_history}"),
        ("placeholder", "{chat_history}"),
        ("human", "{input}"),
        ("placeholder", "{agent_scratchpad}"),
    ]
)

In [0]:
from langchain.agents import AgentExecutor, create_tool_calling_agent


agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [0]:
agent_executor.invoke({"input": "plan me a journey from Zurich to Bern"})