In [None]:
%%capture --no-stderr
%pip install --quiet -U langgraph langchain-core langchain_openai python-dotenv langsmith pydantic spotipy

In [None]:
%pip install --quiet -U jupyterlab-lsp
%pip install --quiet -U "python-lsp-server[all]"

In [6]:
## Setup logging
import logging
import os
from dotenv import load_dotenv

load_dotenv(override=True)
logger = logging.getLogger(__name__)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',  # Define the format
    handlers=[logging.StreamHandler()]  # Output to the console
)

# Define Search Tool

In [7]:
from langchain_community.tools.tavily_search import TavilySearchResults

search_tool = TavilySearchResults(
    max_results=50,
    include_answer=True,
    include_raw_content=True,
    include_images=True,
    # search_depth="advanced",
    # include_domains = []
    # exclude_domains = []
)
name = search_tool.get_name()
desc = search_tool.description
desc

'A search engine optimized for comprehensive, accurate, and trusted results. Useful for when you need to answer questions about current events. Input should be a search query.'

# Create ToolNode

In [8]:
from langgraph.prebuilt import ToolNode
from spotify_tools import get_playlists, get_track_list, create_spotify_playlist, add_tracks_to_playlist, filter_artists, get_artists_from_playlist, find_similar_artist
from plan import validate_plan

tools = [get_playlists, get_track_list, create_spotify_playlist, add_tracks_to_playlist, filter_artists, validate_plan, get_artists_from_playlist, search_tool, find_similar_artist]
# tools = [get_playlists, get_track_list, create_spotify_playlist, add_tracks_to_playlist, search_tool]
tool_node = ToolNode(tools)

# Manual Test of Tool Infra

In [None]:
from langchain_core.messages import AIMessage, HumanMessage
message_with_single_tool_call = AIMessage(
    content="",
    tool_calls=[
        {
            "name": "get_playlists",
            "args" : {},
            "id": "tool_call_id",
            "type": "tool_call",
        }
    ],
)

# tool_node.invoke({"messages": [message_with_single_tool_call]})

# Bind Tools to Model

In [9]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model=os.getenv("OPENAI_MODEL_NAME"), temperature=0.8)
llm_with_tools = llm.bind_tools(tools)

# First Message

In [10]:
from langchain_core.messages import HumanMessage
from prompts import Prompts

human_message = HumanMessage(Prompts.SPOTIFY)
ai_tool_call_message = llm_with_tools.invoke([human_message])

2024-11-12 20:04:03,690 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [None]:
ai_tool_call_message.pretty_print()

In [None]:
import json
plan = ai_tool_call_message.additional_kwargs["tool_calls"][0]["function"]["arguments"]
json.loads(plan)

# chat_prompt_template holds all messages (Human, AI, Tool)

In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

chat_prompt_template: ChatPromptTemplate = human_message + ai_tool_call_message
messages = chat_prompt_template.format_messages()
messages

# Check for Tools in Message

In [None]:
from langchain_core.messages import ToolMessage

tools_in_ai_message = []
tools_by_name = {tool.name: tool for tool in tools}
messages = chat_prompt_template.format_messages()
for tool_call in messages[-1].tool_calls:
    tool = tools_by_name[tool_call["name"]]
    tools_in_ai_message.append(tool)
    # observation = tool.invoke(tool_call["args"])
    # result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
if len(tools_in_ai_message) == 0 or len(tools_in_ai_message) > 1:
    raise ValueError
else:
    print(tools_in_ai_message)

# Tool Call (should be validate_plan())

In [None]:
response = tool_node.invoke({"messages": messages})
tool_message = response["messages"][0]
tool_message.pretty_print()

In [None]:
chat_prompt_template += tool_message
chat_prompt_template.format_messages()

# Send tool result to LLM

In [None]:
ai_tool_call_message = llm_with_tools.invoke(chat_prompt_template.format_messages())
ai_tool_call_message.pretty_print()

# Tool Call (should be get_playlists())

In [None]:
chat_prompt_template += ai_tool_call_message
messages = chat_prompt_template.format_messages()
response = tool_node.invoke({"messages": messages})

In [None]:
playlist_tool_message: ToolMessage = response["messages"][0]
playlist_tool_message.pretty_print()

In [None]:
# We need to check that is JSON serializable
# json.loads(playlist_tool_message.content)

# Format Message List for LLM

In [None]:
chat_prompt_template += playlist_tool_message
messages = chat_prompt_template.format_messages()

# Send Playlists Results to LLM

In [None]:
ai_tool_call_message = llm_with_tools.invoke(messages)
ai_tool_call_message.pretty_print()

In [None]:
chat_prompt_template += ai_tool_call_message
messages = chat_prompt_template.format_messages()

# Tool Call (get_artists_from_playlist())

In [None]:
response = tool_node.invoke({"messages": messages})
track_list_tool_message: ToolMessage = response["messages"][0]

# Decode AI Tool Call message

In [None]:
track_list_tool_message.pretty_print()

In [None]:
# json.loads(track_list_tool_message.content)

# Add Tool message containing playlist tracks to list of messages

In [None]:
chat_prompt_template += track_list_tool_message
messages = chat_prompt_template.format_messages()
messages

# Send Tool Result to LLM

In [None]:
# llm_with_structured_output = llm.with_structured_output(schema=Tracks, method='function_calling', include_raw=True)
# response = llm_with_structured_output.invoke(chat_prompt_template.format_messages())
# response
response = llm_with_tools.invoke(messages)

In [None]:
response.pretty_print()

In [None]:
chat_prompt_template += response
messages = chat_prompt_template.format_messages()

In [None]:
from langchain_core.messages import ToolMessage

tools_in_ai_message = []
tools_by_name = {tool.name: tool for tool in tools}
for tool_call in messages[-1].tool_calls:
    tool = tools_by_name[tool_call["name"]]
    tools_in_ai_message.append(tool)
    # observation = tool.invoke(tool_call["args"])
    # result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
print(tools_in_ai_message)
if len(tools_in_ai_message) == 0:
    raise ValueError

# Tool Call (find_similar_artitsts())

In [None]:
response = tool_node.invoke({"messages": messages})

In [None]:
check_artist_tool_message: ToolMessage = response["messages"][0]
check_artist_tool_message.pretty_print()

In [None]:
# We recreate the ToolMessage to take care of UNicode characters
tool_message = ToolMessage(content=check_artist_tool_message.content, name=check_artist_tool_message.name, tool_call_id=check_artist_tool_message.tool_call_id)
tool_message.pretty_print()

In [None]:
chat_prompt_template += tool_message
chat_prompt_template.format_messages()

In [None]:
response = llm_with_tools.invoke(chat_prompt_template.format_messages())

In [None]:
response.pretty_print()

In [None]:
chat_prompt_template += response
chat_prompt_template.format_messages()

In [None]:
response = tool_node.invoke({"messages": [response]})

In [None]:
response