In [1]:
# Imports
import os
import json
import getpass
import base64
from typing import TypedDict, Annotated, Sequence

import operator
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, END
from langchain_core.messages import BaseMessage
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.tools import GooglePlacesTool
from langchain_core.messages import HumanMessage

# Set API Keys (temp)
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_API_KEY'] = getpass.getpass("Enter your Langchain API key: ")
os.environ['TAVILY_API_KEY'] = getpass.getpass("Enter your Tavily API key: ")
os.environ['OPENAI_API_KEY'] = getpass.getpass("Enter your OpenAI API key: ")
os.environ["GPLACES_API_KEY"] = getpass.getpass("Enter your Google Places API key: ")

## Define Prompts


In [18]:
json_prompt = "Only return a valid json string (RCF8259). Do provide any other commentary. Do not wrap the JSON in markdown such as ```json. Only use the data from the provided content."

prompt_template = """USER: Given a set of streetview images from a vehicle, your task is to determine the
coordinates from which the picture was taken. It can be anywhere in the world.

Return json with the city and coordinates following the below example. {json_prompt}
output={{"city": "Orland Park, IL, 60467, USA", "latitude": "42.0099", "longitude": "-87.62317"}}

AGENT: output="""

text_prompt = prompt_template.format(json_prompt=json_prompt)

## Define Tools

In [10]:
# Search tool
search_tool = TavilySearchResults(max_results=3)
places_tool = GooglePlacesTool()
tools = [search_tool, places_tool]

## Define LLM

In [11]:
# Init model
model = ChatOpenAI(model="gpt-4o", temperature=0)
model = model.bind_tools(tools)

## Agent

In [12]:
# Define Agent State
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

# Define nodes
# Define the function that determines whether to continue or not
def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    # If there are no tool calls, then we finish
    if not last_message.tool_calls:
        return "end"
    # Otherwise if there is, we continue
    else:
        return "continue"


# Define the function that calls the model
def call_model(state):
    messages = state["messages"]
    response = model.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}


# Define the function to execute tools
tool_node = ToolNode(tools)

In [13]:
# Define a new graph
workflow = StateGraph(AgentState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    # First, we define the start node. We use `agent`.
    # This means these are the edges taken after the `agent` node is called.
    "agent",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
    # Finally we pass in a mapping.
    # The keys are strings, and the values are other nodes.
    # END is a special node marking that the graph should finish.
    # What will happen is we will call `should_continue`, and then the output of that
    # will be matched against the keys in this mapping.
    # Based on which one it matches, that node will then be called.
    {
        # If `tools`, then we call the tool node.
        "continue": "action",
        # Otherwise we finish.
        "end": END,
    },
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()

## Run

In [20]:
# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')
  
# Path to your image
id = "103"
image_dir = f'./data/{id}/'
directions = ["north", "south", "east", "west"]
image_inputs = []
for direction in directions:
    base64_img = encode_image(f"{image_dir}{direction}.png")
    image_input = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"}}
    image_inputs.append(image_input)

# Define input
text_input = [{"type": "text", "text": text_prompt}]
inputs = {"messages": [HumanMessage(content=text_input + image_inputs)]}

# Run app with streaming
for output in app.stream(inputs):
    # stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

Output from node 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_8Kk25mBIUADuc8vkCYyhyqKh', 'function': {'arguments': '{"query": "La Poste, Bouygues Immobilier"}', 'name': 'google_places'}, 'type': 'function'}, {'id': 'call_MJNH6IjenHEMnmpWOqIDLRJE', 'function': {'arguments': '{"query": "La Poste, Bouygues Immobilier"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 62, 'prompt_tokens': 6029, 'total_tokens': 6091}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_3196d36131', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-d109e3b9-6b8e-403d-8175-76d242a4e806-0', tool_calls=[{'name': 'google_places', 'args': {'query': 'La Poste, Bouygues Immobilier'}, 'id': 'call_8Kk25mBIUADuc8vkCYyhyqKh'}, {'name': 'tavily_search_results_json', 'args': {'query': 'La Poste, Bouygues Immobilier'}, 'id': 'call_MJNH6IjenHEMnmpWOqIDLRJE'}])]}

---

Output from node 'action':
--

In [34]:
# Get content from AIMessage
result = json.loads(output["agent"]["messages"][0].content)
print(result)

{'city': 'Cergy, France', 'latitude': '49.0365', 'longitude': '2.0626'}
