In [1]:
from langchain_groq import ChatGroq
import os
import yfinance as yf
import pandas as pd
from langchain_core.tools import tool
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, ToolMessage
from datetime import date
import plotly.graph_objects as go
from typing import List, Dict
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@tool
def get_stock_info(symbol: str, key: str) -> str:
    '''Return the correct stock info value given the appropriate symbol and key.'''
    try:
        data = yf.Ticker(symbol)
        stock_info = data.info
        return str(stock_info.get(key, f"Key '{key}' not found"))
    except Exception as e:
        logger.error(f"Error fetching stock info for {symbol}: {e}")
        return f"Error: {str(e)}"

@tool
def get_historical_price(symbol: str, start_date: str, end_date: str) -> pd.DataFrame:
    """Fetches historical stock prices for a given symbol from 'start_date' to 'end_date'."""
    try:
        data = yf.Ticker(symbol)
        hist = data.history(start=start_date, end=end_date)
        hist = hist.reset_index()
        hist[symbol] = hist['Close']
        return hist[['Date', symbol]]
    except Exception as e:
        logger.error(f"Error fetching historical price for {symbol}: {e}")
        return pd.DataFrame()

def plot_price_over_time(historical_price_dfs: List[pd.DataFrame]) -> None:
    try:
        full_df = pd.concat(historical_price_dfs, axis=1).drop_duplicates().sort_values('Date')
        full_df = full_df.loc[:, ~full_df.columns.duplicated()]

        fig = go.Figure()

        for column in full_df.columns[1:]:
            fig.add_trace(go.Scatter(x=full_df['Date'], y=full_df[column], mode='lines+markers', name=column))

        fig.update_layout(
            title='Stock Price Over Time: ' + ', '.join(full_df.columns.tolist()[1:]),
            xaxis_title='Date',
            yaxis_title='Stock Price (USD)',
            yaxis_tickprefix='$',
            yaxis_tickformat=',.2f',
            xaxis=dict(tickangle=-45, nticks=20, tickfont=dict(size=10)),
            yaxis=dict(showgrid=True, gridcolor='lightgrey'),
            legend_title_text='Stock Symbol',
            plot_bgcolor='white',
            paper_bgcolor='white',
            legend=dict(bgcolor='white', bordercolor='black')
        )

        fig.write_image("plot.png")
        print(f"![Plot](plot.png)")
    except Exception as e:
        logger.error(f"Error plotting price over time: {e}")

def call_functions(llm_with_tools, user_prompt: str) -> str:
    system_prompt = f'''You are a decisive finance assistant that analyzes stocks and stock prices. Today is {date.today()}.
    When asked about buying or not buying a stock, you MUST provide a clear "Buy" or "Do Not Buy" recommendation.
    Avoid diplomatic or uncertain answers. Base your decision on the available data and current market trends.
    Explain your reasoning briefly, but always conclude with a definitive recommendation.'''

    messages = [SystemMessage(system_prompt), HumanMessage(user_prompt)]
    ai_msg = llm_with_tools.invoke(messages)
    messages.append(ai_msg)
    historical_price_dfs = []
    symbols = []

    for tool_call in ai_msg.tool_calls:
        selected_tool = {"get_stock_info": get_stock_info, "get_historical_price": get_historical_price}[tool_call["name"].lower()]
        tool_output = selected_tool.invoke(tool_call["args"])
        if tool_call['name'] == 'get_historical_price':
            historical_price_dfs.append(tool_output)
            symbols.append(tool_output.columns[1])
        else:
            messages.append(ToolMessage(tool_output, tool_call_id=tool_call["id"]))

    if historical_price_dfs:
        plot_price_over_time(historical_price_dfs)
        symbols = ' and '.join(symbols)
        messages.append(ToolMessage(f'Tell the user that a historical stock price chart for {symbols} has been generated.', tool_call_id=0))

    response = llm_with_tools.invoke(messages).content
    
    if "buy" not in response.lower() and "do not buy" not in response.lower():
        response += "\n\nBased on the analysis, my definitive recommendation is: "
        response += "Buy" if "positive" in response.lower() or "growth" in response.lower() else "Do Not Buy"
    
    return response

def main():
    llm = ChatGroq(groq_api_key=os.getenv('GROQ_API_KEY'), model='llama-3.1-70b-versatile')
    tools = [get_stock_info, get_historical_price]
    llm_with_tools = llm.bind_tools(tools)

    while True:
        try:
            user_input = input("You: ")
            response = call_functions(llm_with_tools, user_input)
            print("Assistant:", response)
        except Exception as e:
            logger.error(f"An error occurred: {e}")
            print("An error occurred. Please try again.")

if __name__ == "__main__":
    main()


In [None]:
``