# Stock Analysis and Forecast AI Agent System

# ⚠️ IMPORTANT DISCLAIMER ⚠️

This project is a DEMONSTRATION of how to apply AI and traditional machine learning techniques to stock analysis. It is NOT intended for actual investment decisions. Please note:

- This is an educational project demonstrating AI/ML applications
- It does NOT provide investment advice or recommendations
- The forecasts and analyses are for demonstration purposes only
- Do NOT use this for actual investment decisions
- The developers take NO responsibility for any investment decisions made using this tool
- This project is designed to help beginners understand how AI can be applied to traditional ML problems

# Introduction

This notebook implements a comprehensive stock analysis system with the following components:

1. [Stock Data Management](#stock-data)
2. [Holdout Validation Model](#holdout)
3. [Hyperparameter Optimization](#hyperopt)
4. [AI-Powered Analysis Agent](#agent)

Each component is designed to work together to provide a complete stock analysis solution.

## Reference
https://github.com/brightlee6/Stock-Analysis-Forecasting-Agent


# Setup

## Uninstall and install packages

In [1]:
# Remove conflicting packages from the Kaggle base environment.
!pip uninstall -qqy kfp jupyterlab libpysal thinc spacy fastai ydata-profiling google-cloud-bigquery google-generativeai yfinance
# Install langgraph and the packages used in this lab.
!pip install -qU 'langgraph==0.3.21' 'langchain-google-genai==2.1.2' 'langgraph-prebuilt==0.1.7' 'yfinance==0.2.55'

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.0/138.0 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.8/109.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m433.9/433.9 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.2/47.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# check yfinance installed with right version
# !pip list | grep yfinance

## Set up your API key

The `GOOGLE_API_KEY` environment variable can be set to automatically configure the underlying API. This works for both the official Gemini Python SDK and for LangChain/LangGraph. 

To run the following cell, your API key must be stored it in a [Kaggle secret](https://www.kaggle.com/discussions/product-feedback/114053) named `GOOGLE_API_KEY`.

If you don't already have an API key, you can grab one from [AI Studio](https://aistudio.google.com/app/apikey). You can find [detailed instructions in the docs](https://ai.google.dev/gemini-api/docs/api-key).

To make the key available through Kaggle secrets, choose `Secrets` from the `Add-ons` menu and follow the instructions to add your key or enable it for this notebook.

In [3]:
import os
from kaggle_secrets import UserSecretsClient

GOOGLE_API_KEY = UserSecretsClient().get_secret("GOOGLE_API_KEY")
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

## 1. Stock Data Management <a name="stock-data"></a>

The `StockData` class handles the fetching and management of stock price data using the yfinance library.

### Features:
- Fetch historical stock closing prices
- Save data to CSV files
- Visualize price trends and daily returns
- Handle data validation and error cases

### Dependencies:
- yfinance: For fetching stock data
- pandas: For data manipulation
- matplotlib & seaborn: For visualization

In [4]:
import yfinance as yf
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

class StockData:
    def __init__(self, ticker, start_date, end_date):
        """
        Initialize a StockData object with ticker and date range.
        
        Args:
            ticker (str): Stock ticker symbol (e.g., 'AAPL' for Apple)
            start_date (str): Start date in 'YYYY-MM-DD' format
            end_date (str): End date in 'YYYY-MM-DD' format
        """
        self.ticker = ticker
        self.start_date = start_date
        self.end_date = end_date
        self.dataframe = None
        
    def fetch_closing_prices(self):
        """
        Fetch closing prices for the stock and store in dataframe.
        """
        try:
            # Convert string dates to datetime objects
            start = datetime.strptime(self.start_date, '%Y-%m-%d')
            end = datetime.strptime(self.end_date, '%Y-%m-%d')
            
            # Fetch stock data using yf.download
            data = yf.download(self.ticker, start=start, end=end)
            
            # Extract closing prices and reset index to make Date a column
            self.dataframe = data[['Close']].reset_index()
            
            # Rename columns for clarity
            self.dataframe.columns = ['Date', 'Close']
            
            print(f"Successfully fetched closing prices for {self.ticker}")
            
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            
    def save_to_csv(self, output_file):
        """
        Save the closing prices to a CSV file.
        
        Args:
            output_file (str): Path to save the CSV file
        """
        if self.dataframe is not None:
            self.dataframe.to_csv(output_file, index=False)
            print(f"Successfully saved closing prices to {output_file}")
        else:
            print("No data available. Please call fetch_closing_prices() first.")
            
    def visualize_data(self, save_path=None):
        """
        Create visualizations for the stock data.
        
        Args:
            save_path (str, optional): Path to save the visualization. If None, the plot will be displayed.
        """
        if self.dataframe is None:
            print("No data available. Please call fetch_closing_prices() first.")
            return
            
        try:
            # Set the style
            plt.style.use('seaborn-v0_8')  # Use a valid style name
            
            # Create a figure with subplots
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
            
            # Plot 1: Closing Price Over Time
            sns.lineplot(data=self.dataframe, x='Date', y='Close', ax=ax1)
            ax1.set_title(f'{self.ticker} Closing Price Over Time')
            ax1.set_xlabel('Date')
            ax1.set_ylabel('Price ($)')
            ax1.grid(True)
            
            # Calculate daily returns
            self.dataframe['Daily_Return'] = self.dataframe['Close'].pct_change()
            
            # Plot 2: Daily Returns Distribution
            sns.histplot(data=self.dataframe, x='Daily_Return', bins=50, ax=ax2)
            ax2.set_title(f'{self.ticker} Daily Returns Distribution')
            ax2.set_xlabel('Daily Return')
            ax2.set_ylabel('Frequency')
            ax2.grid(True)
            
            # Add some statistics
            mean_return = self.dataframe['Daily_Return'].mean()
            std_return = self.dataframe['Daily_Return'].std()
            ax2.axvline(mean_return, color='r', linestyle='--', label=f'Mean: {mean_return:.4f}')
            ax2.axvline(mean_return + std_return, color='g', linestyle='--', label=f'Std Dev: {std_return:.4f}')
            ax2.axvline(mean_return - std_return, color='g', linestyle='--')
            ax2.legend()
            
            plt.tight_layout()
            
            if save_path:
                plt.savefig(save_path)
                print(f"Visualization saved to {save_path}")
            else:
                plt.show()
                
            plt.close()
            
        except Exception as e:
            print(f"Error creating visualization: {str(e)}")

In [5]:
# # Example usage and testing for the module
# stock_data_test = StockData(
#         ticker='GOOG',
#         start_date='2024-01-01',
#         end_date='2025-04-19'
#     )
    
# # Fetch the closing prices
# stock_data_test.fetch_closing_prices()

# # Display
# print(stock_data_test.dataframe.head())

# # Save to CSV
# stock_data_test.save_to_csv("google_stock_prices.csv")
    
# # Create visualizations
# stock_data_test.visualize_data() 
# stock_data_test.visualize_data("google_stock_analysis.png") 

## 2. Holdout Validation Model <a name="holdout"></a>

The `StockModelHoldout` class implements a holdout validation approach for stock price forecasting using Prophet.

### Features:
- Split data into training (80%) and testing (20%) sets
- Train Prophet model on historical data
- Make predictions on test data
- Calculate performance metrics (MAE, MSE, RMSE, R²)
- Visualize actual vs predicted values

### Dependencies:
- prophet: For time series forecasting
- scikit-learn: For performance metrics

In [6]:
import pandas as pd
import numpy as np
from prophet import Prophet
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns

class StockModelHoldout:
    def __init__(self, stock_data):
        """
        Initialize StockModelHoldout with StockData object.
        
        Args:
            stock_data (StockData): StockData object containing the stock data
        """
        if not isinstance(stock_data, StockData):
            raise ValueError("Input must be a StockData object")
            
        self.stock_data = stock_data
        self.train_data = None
        self.test_data = None
        self.model = None
        self.forecast = None
        self.metrics = None
        
    def split_data(self, test_size=0.2):
        """
        Split the data into training and testing sets.
        
        Args:
            test_size (float): Proportion of data to use for testing (default: 0.2)
        """
        if self.stock_data.dataframe is None:
            raise ValueError("No data available. Please fetch data first.")
            
        # Sort data by date
        df = self.stock_data.dataframe.sort_values('Date')
        
        # Calculate split index
        split_idx = int(len(df) * (1 - test_size))
        
        # Split the data
        self.train_data = df.iloc[:split_idx].copy()
        self.test_data = df.iloc[split_idx:].copy()
        
        # Prepare data for Prophet
        self.train_data = self.train_data.rename(columns={'Date': 'ds', 'Close': 'y'})
        self.test_data = self.test_data.rename(columns={'Date': 'ds', 'Close': 'y'})
        
    def train_model(self):
        """
        Train the Prophet model on the training data.
        """
        if self.train_data is None:
            raise ValueError("No training data available. Please split data first.")
            
        # Initialize and fit the model
        self.model = Prophet()
        self.model.fit(self.train_data)
        
    def make_forecast(self):
        """
        Make forecasts on the test data.
        """
        if self.model is None:
            raise ValueError("No trained model available. Please train the model first.")
            
        # Create future dataframe for test dates
        future = self.test_data[['ds']]
        
        # Make predictions
        self.forecast = self.model.predict(future)
        
    def calculate_metrics(self):
        """
        Calculate performance metrics for the forecast.
        """
        if self.forecast is None:
            raise ValueError("No forecast available. Please make forecast first.")
            
        # Extract actual and predicted values
        y_true = self.test_data['y'].values
        y_pred = self.forecast['yhat'].values
        
        # Calculate metrics
        self.metrics = {
            'MAE': mean_absolute_error(y_true, y_pred),
            'MSE': mean_squared_error(y_true, y_pred),
            'RMSE': np.sqrt(mean_squared_error(y_true, y_pred)),
            'R2': r2_score(y_true, y_pred)
        }
        
    def visualize_forecast(self, save_path=None):
        """
        Visualize the actual vs predicted values over the test period.
        
        Args:
            save_path (str, optional): Path to save the visualization. If None, the plot will be displayed.
        """
        if self.forecast is None:
            raise ValueError("No forecast available. Please make forecast first.")
            
        try:
            # Set the style
            plt.style.use('seaborn-v0_8')
            
            # Create the plot
            plt.figure(figsize=(12, 6))
            
            # Plot actual values
            plt.plot(self.test_data['ds'], self.test_data['y'], 
                    label='Actual', color='blue', linewidth=2)
            
            # Plot predicted values
            plt.plot(self.test_data['ds'], self.forecast['yhat'], 
                    label='Predicted', color='red', linestyle='--', linewidth=2)
            
            # Add confidence intervals
            plt.fill_between(self.test_data['ds'], 
                           self.forecast['yhat_lower'], 
                           self.forecast['yhat_upper'],
                           color='gray', alpha=0.2, label='Confidence Interval')
            
            # Customize the plot
            plt.title(f'{self.stock_data.ticker} Stock Price: Actual vs Predicted')
            plt.xlabel('Date')
            plt.ylabel('Price ($)')
            plt.legend()
            plt.grid(True)
            
            # Rotate x-axis labels for better readability
            plt.xticks(rotation=45)
            
            # Adjust layout
            plt.tight_layout()
            
            if save_path:
                plt.savefig(save_path)
                print(f"Visualization saved to {save_path}")
            else:
                plt.show()
                
            plt.close()
            
        except Exception as e:
            print(f"Error creating visualization: {str(e)}")
        
    def run_analysis(self, test_size=0.2):
        """
        Run the complete analysis pipeline.
        
        Args:
            test_size (float): Proportion of data to use for testing (default: 0.2)
        """
        self.split_data(test_size)
        self.train_model()
        self.make_forecast()
        self.calculate_metrics()
        
        return self.metrics

In [7]:
# # Example usage and testing for StockModelHoldout module
# # Create StockData object
# stock_data = StockData(
#     ticker="GOOG",
#     start_date="2020-01-01",
#     end_date="2025-04-19"
# )

# # Fetch the data
# stock_data.fetch_closing_prices()
# print(stock_data.dataframe.head())

# # Create and run the model
# model = StockModelHoldout(stock_data)
# metrics = model.run_analysis()

# # Print the results
# print("\nModel Performance Metrics:")
# for metric, value in metrics.items():
#     print(f"{metric}: {value:.4f}")
    
# # Create visualization
# model.visualize_forecast() 
# model.visualize_forecast("stock_forecast.png") 

## 3. Hyperparameter Optimization <a name="hyperopt"></a>

The `StockHyperopt` class implements hyperparameter optimization for Prophet models using hyperopt.

### Features:
- Define hyperparameter search space
- Optimize Prophet model parameters
- Train final model with best parameters
- Make future predictions
- Visualize forecast results

### Dependencies:
- hyperopt: For Bayesian optimization
- prophet: For time series forecasting

In [8]:
import pandas as pd
import numpy as np
from prophet import Prophet
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import seaborn as sns

class StockHyperopt:
    def __init__(self, stock_data):
        """
        Initialize StockHyperopt with StockData object.
        
        Args:
            stock_data (StockData): StockData object containing the stock data
        """
        if not isinstance(stock_data, StockData):
            raise ValueError("Input must be a StockData object")
            
        self.stock_data = stock_data
        self.model = None
        self.best_params = None
        self.forecast = None
        
    def prepare_data(self):
        """
        Prepare data for Prophet model.
        """
        if self.stock_data.dataframe is None:
            raise ValueError("No data available. Please fetch data first.")
            
        # Sort data by date
        self.df = self.stock_data.dataframe.sort_values('Date')
        
        # Prepare data for Prophet
        self.df = self.df.rename(columns={'Date': 'ds', 'Close': 'y'})
        
    def objective(self, params):
        """
        Objective function for hyperparameter optimization.
        
        Args:
            params (dict): Hyperparameters to evaluate
            
        Returns:
            dict: Dictionary containing loss and status
        """
        # Create Prophet model with current parameters
        model = Prophet(
            changepoint_prior_scale=params['changepoint_prior_scale'],
            seasonality_prior_scale=params['seasonality_prior_scale'],
            holidays_prior_scale=params['holidays_prior_scale'],
            seasonality_mode=params['seasonality_mode']
        )
        
        # Fit the model
        model.fit(self.df)
        
        # Make predictions
        future = model.make_future_dataframe(periods=30)
        forecast = model.predict(future)
        
        # Calculate RMSE on the last 30 days
        y_true = self.df['y'].values[-30:]
        y_pred = forecast['yhat'].values[-30:]
        rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
        
        return {'loss': rmse, 'status': STATUS_OK}
        
    def optimize_hyperparameters(self, max_evals=50):
        """
        Optimize hyperparameters using hyperopt.
        
        Args:
            max_evals (int): Maximum number of evaluations (default: 50)
        """
        if self.df is None:
            raise ValueError("No data available. Please prepare data first.")
            
        # Define the search space
        space = {
            'changepoint_prior_scale': hp.loguniform('changepoint_prior_scale', -5, 0),
            'seasonality_prior_scale': hp.loguniform('seasonality_prior_scale', -5, 0),
            'holidays_prior_scale': hp.loguniform('holidays_prior_scale', -5, 0),
            'seasonality_mode': hp.choice('seasonality_mode', ['additive', 'multiplicative'])
        }
        
        # Run optimization
        trials = Trials()
        best = fmin(
            fn=self.objective,
            space=space,
            algo=tpe.suggest,
            max_evals=max_evals,
            trials=trials
        )
        
        # Get the best parameters
        self.best_params = {
            'changepoint_prior_scale': best['changepoint_prior_scale'],
            'seasonality_prior_scale': best['seasonality_prior_scale'],
            'holidays_prior_scale': best['holidays_prior_scale'],
            'seasonality_mode': ['additive', 'multiplicative'][best['seasonality_mode']]
        }
        
    def train_best_model(self):
        """
        Train the Prophet model with the best hyperparameters.
        """
        if self.best_params is None:
            raise ValueError("No optimized parameters available. Please run optimize_hyperparameters first.")
            
        # Create Prophet model with best parameters
        self.model = Prophet(
            changepoint_prior_scale=self.best_params['changepoint_prior_scale'],
            seasonality_prior_scale=self.best_params['seasonality_prior_scale'],
            holidays_prior_scale=self.best_params['holidays_prior_scale'],
            seasonality_mode=self.best_params['seasonality_mode']
        )
        
        # Fit the model
        self.model.fit(self.df)
        
    def forecast_next_year(self):
        """
        Forecast stock prices for the next year.
        """
        if self.model is None:
            raise ValueError("No trained model available. Please train the model first.")
            
        # Create future dataframe for next year
        future = self.model.make_future_dataframe(periods=365)
        
        # Make predictions
        self.forecast = self.model.predict(future)
        
    def visualize_forecast(self, save_path=None):
        """
        Visualize the forecast for the next year.
        
        Args:
            save_path (str, optional): Path to save the visualization. If None, the plot will be displayed.
        """
        if self.forecast is None:
            raise ValueError("No forecast available. Please run forecast_next_year first.")
            
        try:
            # Set the style
            plt.style.use('seaborn-v0_8')
            
            # Create the plot
            plt.figure(figsize=(12, 6))
            
            # Plot historical data
            plt.plot(self.df['ds'], self.df['y'], 
                    label='Historical', color='blue', linewidth=2)
            
            # Plot forecast
            plt.plot(self.forecast['ds'], self.forecast['yhat'], 
                    label='Forecast', color='red', linestyle='--', linewidth=2)
            
            # Add confidence intervals
            plt.fill_between(self.forecast['ds'], 
                           self.forecast['yhat_lower'], 
                           self.forecast['yhat_upper'],
                           color='gray', alpha=0.2, label='Confidence Interval')
            
            # Customize the plot
            plt.title(f'{self.stock_data.ticker} Stock Price Forecast for Next Year')
            plt.xlabel('Date')
            plt.ylabel('Price ($)')
            plt.legend()
            plt.grid(True)
            
            # Rotate x-axis labels for better readability
            plt.xticks(rotation=45)
            
            # Adjust layout
            plt.tight_layout()
            
            if save_path:
                plt.savefig(save_path)
                print(f"Visualization saved to {save_path}")
            else:
                plt.show()
                
            plt.close()
            
        except Exception as e:
            print(f"Error creating visualization: {str(e)}")
            
    def run_analysis(self, max_evals=50):
        """
        Run the complete analysis pipeline.
        
        Args:
            max_evals (int): Maximum number of evaluations for hyperparameter optimization
        """
        self.prepare_data()
        self.optimize_hyperparameters(max_evals)
        self.train_best_model()
        self.forecast_next_year()
        
        return self.best_params

In [9]:
# # Example run and testing StockHyperopt module
# # Create StockData object
# stock_data = StockData(
#     ticker="GOOG",
#     start_date="2020-01-01",
#     end_date="2025-04-19"
# )

# # Fetch the data
# stock_data.fetch_closing_prices()

# # Create and run the model
# model = StockHyperopt(stock_data)
# best_params = model.run_analysis()

# # Print the best parameters
# print("\nBest Hyperparameters:")
# for param, value in best_params.items():
#     print(f"{param}: {value}")
    
# # Create visualization
# model.visualize_forecast()
# model.visualize_forecast("stock_forecast_next_year.png")

## 4. AI-Powered Analysis Agent <a name="agent"></a>

The `StockAgent` class implements an AI-powered agent that uses the Gemini API to assist with stock analysis tasks.

### Features:
- Natural language understanding using Gemini API
- Automated stock data fetching and visualization
- Holdout validation analysis
- Optimized forecasting
- Interactive assistance

### Dependencies:
- langchain_google_genai: For Gemini API integration
- langchain, langgraph
- All previous components (StockData, StockModelHoldout, StockHyperopt)

In [10]:
import os
from typing import Dict, List, TypedDict, Annotated, Sequence
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
import matplotlib
matplotlib.use('Agg')  # Use Agg backend to avoid GUI issues
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import re
from IPython.display import Image, display

# Load environment variables
# load_dotenv()

# Define the state type
class AgentState(TypedDict):
    messages: Annotated[Sequence[HumanMessage | AIMessage], "The conversation history"]
    stock_data: StockData | None
    holdout_model: StockModelHoldout | None
    hyperopt_model: StockHyperopt | None
    last_action: str | None

# Initialize the LLM
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    google_api_key=os.getenv("GOOGLE_API_KEY"),
    temperature=0.7
)

# Define the system prompt
system_prompt = """You are a helpful stock market analysis assistant. Your role is to:
1. Understand user requests about stock data
2. Extract stock tickers and date ranges from user input
3. Perform appropriate stock analysis (historical data, forecasting, or hyperparameter tuning)
4. Provide clear and informative responses

When analyzing stocks, you can:
- Show historical price data
- Create forecasts using Prophet
- Tune hyperparameters for better predictions
- Visualize results

Always be clear about what you're doing and explain the results in a way that's easy to understand."""

# Create the prompt template
prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    MessagesPlaceholder(variable_name="messages"),
])

def parse_relative_date(date_str: str) -> str:
    """
    Convert relative date expressions to YYYY-MM-DD format.
    
    Args:
        date_str (str): Date string that might contain relative expressions
        
    Returns:
        str: Date in YYYY-MM-DD format
    """
    today = datetime.now()
    
    # Handle "today"
    if date_str.lower() == "today":
        return today.strftime('%Y-%m-%d')
    
    # Handle "X years ago"
    match = re.match(r'(\d+)\s+years?\s+ago', date_str.lower())
    if match:
        years = int(match.group(1))
        return (today - timedelta(days=years*365)).strftime('%Y-%m-%d')
    
    # Handle "X months ago"
    match = re.match(r'(\d+)\s+months?\s+ago', date_str.lower())
    if match:
        months = int(match.group(1))
        return (today - timedelta(days=months*30)).strftime('%Y-%m-%d')
    
    # Handle "X days ago"
    match = re.match(r'(\d+)\s+days?\s+ago', date_str.lower())
    if match:
        days = int(match.group(1))
        return (today - timedelta(days=days)).strftime('%Y-%m-%d')
    
    # If it's already in YYYY-MM-DD format, return as is
    try:
        datetime.strptime(date_str, '%Y-%m-%d')
        return date_str
    except ValueError:
        # If we can't parse it, return today's date
        return today.strftime('%Y-%m-%d')

def create_visualization(data, title, xlabel, ylabel, save_path=None):
    """
    Create and save a visualization without displaying it.
    
    Args:
        data: DataFrame containing the data to plot
        title: Title of the plot
        xlabel: Label for x-axis
        ylabel: Label for y-axis
        save_path: Path to save the plot (optional)
    """
    try:
        plt.style.use('seaborn-v0_8')
        fig, ax = plt.subplots(figsize=(12, 6))
        
        # Plot the data
        ax.plot(data.index, data.values, label=ylabel, color='blue', linewidth=2)
        
        # Customize the plot
        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.legend()
        ax.grid(True)
        plt.xticks(rotation=45)
        plt.tight_layout()
        
        # Save the plot if path is provided
        if save_path:
            plt.savefig(save_path)
        
        # Close the figure to free memory
        plt.close(fig)
        
        return True
    except Exception as e:
        print(f"Error creating visualization: {str(e)}")
        return False

# Define the nodes in the graph
def extract_stock_info(state: AgentState) -> AgentState:
    """Extract stock information from user input."""
    try:
        # Get the last user message
        last_message = state["messages"][-1].content
        
        # Use LLM to extract stock info
        response = llm.invoke([
            HumanMessage(content=f"""Extract the stock ticker symbol and date range from this text: {last_message}
            Return the information in this format:
            TICKER: [ticker]
            START_DATE: [start_date]
            END_DATE: [end_date]
            If any information is missing, use defaults:
            - Default start_date: 3 years ago
            - Default end_date: today""")
        ])
        
        # Parse the response
        info = {}
        for line in response.content.split('\n'):
            if ':' in line:
                key, value = line.split(':', 1)
                info[key.strip()] = value.strip()
        
        # Parse dates
        ticker = info.get('TICKER', '')
        start_date = parse_relative_date(info.get('START_DATE', '3 years ago'))
        end_date = parse_relative_date(info.get('END_DATE', 'today'))
        
        if ticker:
            state["stock_data"] = StockData(ticker, start_date, end_date)
            state["stock_data"].fetch_closing_prices()
            
        return state
    except Exception as e:
        state["last_action"] = f"Error extracting stock info: {str(e)}"
        return state

def analyze_historical_data(state: AgentState) -> AgentState:
    """Analyze and visualize historical stock data."""
    try:
        if state["stock_data"] and state["stock_data"].dataframe is not None:
            # Create visualization using the helper function
            success = create_visualization(
                data=state["stock_data"].dataframe['Close'],
                title=f'{state["stock_data"].ticker} Historical Stock Price',
                xlabel='Date',
                ylabel='Price ($)',
                save_path=f"{state['stock_data'].ticker}_historical.png"
            )
            
            if success:
                state["last_action"] = f"Historical analysis completed. Plot saved as {state['stock_data'].ticker}_historical.png"
            else:
                state["last_action"] = "Error creating historical visualization"
        else:
            state["last_action"] = "No stock data available for analysis"
            
        return state
    except Exception as e:
        state["last_action"] = f"Error in historical analysis: {str(e)}"
        return state

def run_holdout_analysis(state: AgentState) -> AgentState:
    """Run holdout analysis and create forecast."""
    try:
        if state["stock_data"]:
            state["holdout_model"] = StockModelHoldout(state["stock_data"])
            metrics = state["holdout_model"].run_analysis()
            
            # Create visualization using the helper function
            success = create_visualization(
                data=state["holdout_model"].forecast['yhat'],
                title=f'{state["stock_data"].ticker} Forecast',
                xlabel='Date',
                ylabel='Predicted Price ($)',
                save_path=f"{state['stock_data'].ticker}_holdout_forecast.png"
            )
            
            if success:
                metrics_msg = "\n".join([f"{metric}: {value:.4f}" for metric, value in metrics.items()])
                state["last_action"] = f"Holdout analysis completed. Metrics:\n{metrics_msg}\nForecast saved as {state['stock_data'].ticker}_holdout_forecast.png"
            else:
                state["last_action"] = "Error creating forecast visualization"
        else:
            state["last_action"] = "No stock data available for holdout analysis"
            
        return state
    except Exception as e:
        state["last_action"] = f"Error in holdout analysis: {str(e)}"
        return state

def run_hyperopt_analysis(state: AgentState) -> AgentState:
    """Run hyperopt analysis and create optimized forecast."""
    try:
        if state["stock_data"]:
            state["hyperopt_model"] = StockHyperopt(state["stock_data"])
            best_params = state["hyperopt_model"].run_analysis()
            
            # Create visualization using the helper function
            success = create_visualization(
                data=state["hyperopt_model"].forecast['yhat'],
                title=f'{state["stock_data"].ticker} Optimized Forecast',
                xlabel='Date',
                ylabel='Predicted Price ($)',
                save_path=f"{state['stock_data'].ticker}_hyperopt_forecast.png"
            )
            
            if success:
                params_msg = "\n".join([f"{param}: {value}" for param, value in best_params.items()])
                state["last_action"] = f"Hyperopt analysis completed. Best parameters:\n{params_msg}\nForecast saved as {state['stock_data'].ticker}_hyperopt_forecast.png"
            else:
                state["last_action"] = "Error creating optimized forecast visualization"
        else:
            state["last_action"] = "No stock data available for hyperopt analysis"
            
        return state
    except Exception as e:
        state["last_action"] = f"Error in hyperopt analysis: {str(e)}"
        return state

def generate_response(state: AgentState) -> AgentState:
    """Generate a response based on the last action."""
    try:
        # Get the conversation history
        messages = state["messages"]
        
        # Add the last action to the context
        context = f"Last action: {state['last_action']}\n\nUser's last message: {messages[-1].content}"
        
        # Generate response
        response = llm.invoke([
            HumanMessage(content=context)
        ])
        
        # Add the response to the conversation history
        state["messages"].append(AIMessage(content=response.content))
        
        return state
    except Exception as e:
        state["messages"].append(AIMessage(content=f"Error generating response: {str(e)}"))
        return state

# Create the graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("extract_stock_info", extract_stock_info)
workflow.add_node("analyze_historical", analyze_historical_data)
workflow.add_node("run_holdout", run_holdout_analysis)
workflow.add_node("run_hyperopt", run_hyperopt_analysis)
workflow.add_node("generate_response", generate_response)

# Define edges
def route_based_on_intent(state: AgentState) -> str:
    """Route to the appropriate node based on user intent."""
    last_message = state["messages"][-1].content.lower()
    
    # Check for specific keywords in the message
    if "forecast" in last_message and ("tune" in last_message or "optimize" in last_message or "hyperparameter" in last_message):
        return "run_hyperopt"
    elif "forecast" in last_message or "predict" in last_message:
        return "run_holdout"
    elif "historical" in last_message or "price" in last_message or "show" in last_message:
        return "analyze_historical"
    else:
        # If no specific intent is detected, ask the LLM to determine the intent
        try:
            response = llm.invoke([
                HumanMessage(content=f"""Based on this user message, determine which analysis to perform:
                {last_message}
                
                Return ONLY one of these exact options:
                - analyze_historical
                - run_holdout
                - run_hyperopt
                - generate_response
                
                Choose based on the user's intent to:
                - analyze_historical: for showing historical data or prices
                - run_holdout: for forecasting or predicting future prices
                - run_hyperopt: for optimizing or tuning forecast models
                - generate_response: for general questions or clarifications""")
            ])
            return response.content.strip()
        except Exception as e:
            print(f"Error determining intent: {str(e)}")
            return "generate_response"

# Add edges
workflow.add_conditional_edges(
    "extract_stock_info",
    route_based_on_intent,
    {
        "analyze_historical": "analyze_historical",
        "run_holdout": "run_holdout",
        "run_hyperopt": "run_hyperopt",
        "generate_response": "generate_response"
    }
)

workflow.add_edge("analyze_historical", "generate_response")
workflow.add_edge("run_holdout", "generate_response")
workflow.add_edge("run_hyperopt", "generate_response")
workflow.add_edge("generate_response", END)

# Set the entry point
workflow.set_entry_point("extract_stock_info")

# Compile the graph
app = workflow.compile()


# Image(app.get_graph().draw_mermaid_png())

class StockAgent:
    def __init__(self):
        """Initialize the StockAgent."""
        self.state = {
            "messages": [],
            "stock_data": None,
            "holdout_model": None,
            "hyperopt_model": None,
            "last_action": None
        }
        
    def process_user_input(self, user_input: str) -> str:
        """
        Process user input and return a response.
        
        Args:
            user_input (str): The user's input message
            
        Returns:
            str: The agent's response
        """
        try:
            # Add user message to state
            self.state["messages"].append(HumanMessage(content=user_input))
            
            # Run the workflow
            self.state = app.invoke(self.state)
            
            # Return the last AI message
            return self.state["messages"][-1].content
            
        except Exception as e:
            return f"Error processing request: {str(e)}"

## Example Usage

Let's see how to use the complete stock analysis system:

In [11]:
agent = StockAgent()
    
# Example interactions
print(agent.process_user_input("Show me the historical stock price of GOOG"))
print(agent.process_user_input("Forecast the stock price of GOOG"))
print(agent.process_user_input("Tune hyperparameters and forecast the stock price of GOOG")) 

YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  1 of 1 completed


Successfully fetched closing prices for GOOG
Okay, I've already completed the historical analysis of GOOG and saved the plot as "GOOG_historical.png".  Since I can't directly display images in this text-based interface, you'll need to access the "GOOG_historical.png" file to see the plot of the historical stock price of GOOG. Look for it in the location where the code was executed.


[*********************100%***********************]  1 of 1 completed


Successfully fetched closing prices for GOOG


22:38:47 - cmdstanpy - INFO - Chain [1] start processing
22:38:47 - cmdstanpy - INFO - Chain [1] done processing


Okay, I've already completed a holdout analysis and saved a visualization of that forecast as `GOOG_holdout_forecast.png`.  That holdout analysis *is* a forecast of the stock price of GOOG, but it's specifically for the holdout period (the portion of the data I set aside for evaluation).

To provide a more useful response, I need to understand what *kind* of forecast you're looking for.  Specifically:

1.  **Forecast Horizon:** How far into the future do you want the forecast? (e.g., next day, next week, next month, next year?)
2.  **Data to Use:** Do you want me to use only the data I trained on originally, or should I retrain on the *entire* dataset now that I've evaluated on the holdout set?  Retraining on the entire dataset *might* give a slightly better forecast, but it means I can't provide a separate evaluation metric.
3.  **Format:** Do you want a table of predicted values, a textual description of the predicted trend, or are you primarily interested in the visualization I alre

[*********************100%***********************]  1 of 1 completed

Successfully fetched closing prices for GOOG
  0%|          | 0/50 [00:00<?, ?trial/s, best loss=?]


22:38:50 - cmdstanpy - INFO - Chain [1] start processing
22:38:50 - cmdstanpy - INFO - Chain [1] done processing


  2%|▏         | 1/50 [00:00<00:24,  1.98trial/s, best loss: 15.115530471394264]

22:38:50 - cmdstanpy - INFO - Chain [1] start processing
22:38:51 - cmdstanpy - INFO - Chain [1] done processing


  4%|▍         | 2/50 [00:00<00:23,  2.03trial/s, best loss: 9.276450070057484] 

22:38:51 - cmdstanpy - INFO - Chain [1] start processing
22:38:51 - cmdstanpy - INFO - Chain [1] done processing


  6%|▌         | 3/50 [00:01<00:16,  2.81trial/s, best loss: 9.276450070057484]

22:38:51 - cmdstanpy - INFO - Chain [1] start processing
22:38:51 - cmdstanpy - INFO - Chain [1] done processing


  8%|▊         | 4/50 [00:01<00:18,  2.46trial/s, best loss: 9.276450070057484]

22:38:52 - cmdstanpy - INFO - Chain [1] start processing
22:38:52 - cmdstanpy - INFO - Chain [1] done processing


 10%|█         | 5/50 [00:02<00:18,  2.37trial/s, best loss: 9.276450070057484]

22:38:52 - cmdstanpy - INFO - Chain [1] start processing
22:38:52 - cmdstanpy - INFO - Chain [1] done processing


 12%|█▏        | 6/50 [00:02<00:15,  2.93trial/s, best loss: 9.276450070057484]

22:38:52 - cmdstanpy - INFO - Chain [1] start processing
22:38:52 - cmdstanpy - INFO - Chain [1] done processing


 14%|█▍        | 7/50 [00:02<00:12,  3.43trial/s, best loss: 9.276450070057484]

22:38:52 - cmdstanpy - INFO - Chain [1] start processing
22:38:53 - cmdstanpy - INFO - Chain [1] done processing


 16%|█▌        | 8/50 [00:02<00:14,  2.94trial/s, best loss: 9.276450070057484]

22:38:53 - cmdstanpy - INFO - Chain [1] start processing
22:38:53 - cmdstanpy - INFO - Chain [1] done processing


 18%|█▊        | 9/50 [00:03<00:15,  2.64trial/s, best loss: 9.276450070057484]

22:38:53 - cmdstanpy - INFO - Chain [1] start processing
22:38:53 - cmdstanpy - INFO - Chain [1] done processing


 20%|██        | 10/50 [00:03<00:13,  2.96trial/s, best loss: 9.276450070057484]

22:38:54 - cmdstanpy - INFO - Chain [1] start processing
22:38:54 - cmdstanpy - INFO - Chain [1] done processing


 22%|██▏       | 11/50 [00:04<00:15,  2.53trial/s, best loss: 9.276450070057484]

22:38:54 - cmdstanpy - INFO - Chain [1] start processing
22:38:54 - cmdstanpy - INFO - Chain [1] done processing


 24%|██▍       | 12/50 [00:04<00:12,  3.01trial/s, best loss: 9.276450070057484]

22:38:54 - cmdstanpy - INFO - Chain [1] start processing
22:38:54 - cmdstanpy - INFO - Chain [1] done processing


 26%|██▌       | 13/50 [00:04<00:11,  3.16trial/s, best loss: 9.225258648938782]

22:38:55 - cmdstanpy - INFO - Chain [1] start processing
22:38:55 - cmdstanpy - INFO - Chain [1] done processing


 28%|██▊       | 14/50 [00:05<00:12,  2.88trial/s, best loss: 9.225258648938782]

22:38:55 - cmdstanpy - INFO - Chain [1] start processing
22:38:55 - cmdstanpy - INFO - Chain [1] done processing


 30%|███       | 15/50 [00:05<00:11,  2.97trial/s, best loss: 8.465885285454975]

22:38:55 - cmdstanpy - INFO - Chain [1] start processing
22:38:55 - cmdstanpy - INFO - Chain [1] done processing


 32%|███▏      | 16/50 [00:05<00:10,  3.30trial/s, best loss: 8.465885285454975]

22:38:55 - cmdstanpy - INFO - Chain [1] start processing
22:38:56 - cmdstanpy - INFO - Chain [1] done processing


 34%|███▍      | 17/50 [00:05<00:08,  3.70trial/s, best loss: 8.465885285454975]

22:38:56 - cmdstanpy - INFO - Chain [1] start processing
22:38:56 - cmdstanpy - INFO - Chain [1] done processing


 36%|███▌      | 18/50 [00:06<00:10,  2.93trial/s, best loss: 8.465885285454975]

22:38:56 - cmdstanpy - INFO - Chain [1] start processing
22:38:56 - cmdstanpy - INFO - Chain [1] done processing


 38%|███▊      | 19/50 [00:06<00:09,  3.31trial/s, best loss: 8.465885285454975]

22:38:56 - cmdstanpy - INFO - Chain [1] start processing
22:38:56 - cmdstanpy - INFO - Chain [1] done processing


 40%|████      | 20/50 [00:06<00:08,  3.61trial/s, best loss: 8.465885285454975]

22:38:57 - cmdstanpy - INFO - Chain [1] start processing
22:38:57 - cmdstanpy - INFO - Chain [1] done processing


 42%|████▏     | 21/50 [00:07<00:08,  3.45trial/s, best loss: 8.465885285454975]

22:38:57 - cmdstanpy - INFO - Chain [1] start processing
22:38:57 - cmdstanpy - INFO - Chain [1] done processing


 44%|████▍     | 22/50 [00:07<00:09,  2.96trial/s, best loss: 8.465885285454975]

22:38:57 - cmdstanpy - INFO - Chain [1] start processing
22:38:58 - cmdstanpy - INFO - Chain [1] done processing


 46%|████▌     | 23/50 [00:07<00:09,  2.78trial/s, best loss: 8.465885285454975]

22:38:58 - cmdstanpy - INFO - Chain [1] start processing
22:38:58 - cmdstanpy - INFO - Chain [1] done processing


 48%|████▊     | 24/50 [00:08<00:10,  2.53trial/s, best loss: 8.465885285454975]

22:38:58 - cmdstanpy - INFO - Chain [1] start processing
22:38:58 - cmdstanpy - INFO - Chain [1] done processing


 50%|█████     | 25/50 [00:08<00:09,  2.66trial/s, best loss: 7.929258315721762]

22:38:59 - cmdstanpy - INFO - Chain [1] start processing
22:38:59 - cmdstanpy - INFO - Chain [1] done processing


 52%|█████▏    | 26/50 [00:08<00:07,  3.05trial/s, best loss: 7.929258315721762]

22:38:59 - cmdstanpy - INFO - Chain [1] start processing
22:38:59 - cmdstanpy - INFO - Chain [1] done processing


 54%|█████▍    | 27/50 [00:09<00:07,  2.94trial/s, best loss: 7.929258315721762]

22:38:59 - cmdstanpy - INFO - Chain [1] start processing
22:38:59 - cmdstanpy - INFO - Chain [1] done processing


 56%|█████▌    | 28/50 [00:09<00:07,  2.98trial/s, best loss: 7.521937374942735]

22:38:59 - cmdstanpy - INFO - Chain [1] start processing
22:39:00 - cmdstanpy - INFO - Chain [1] done processing


 58%|█████▊    | 29/50 [00:10<00:07,  2.68trial/s, best loss: 7.521937374942735]

22:39:00 - cmdstanpy - INFO - Chain [1] start processing
22:39:00 - cmdstanpy - INFO - Chain [1] done processing


 60%|██████    | 30/50 [00:10<00:06,  2.88trial/s, best loss: 7.521937374942735]

22:39:00 - cmdstanpy - INFO - Chain [1] start processing
22:39:01 - cmdstanpy - INFO - Chain [1] done processing


 62%|██████▏   | 31/50 [00:10<00:07,  2.45trial/s, best loss: 7.521937374942735]

22:39:01 - cmdstanpy - INFO - Chain [1] start processing
22:39:01 - cmdstanpy - INFO - Chain [1] done processing


 64%|██████▍   | 32/50 [00:11<00:06,  2.89trial/s, best loss: 7.521937374942735]

22:39:01 - cmdstanpy - INFO - Chain [1] start processing
22:39:01 - cmdstanpy - INFO - Chain [1] done processing


 66%|██████▌   | 33/50 [00:11<00:05,  3.33trial/s, best loss: 7.521937374942735]

22:39:01 - cmdstanpy - INFO - Chain [1] start processing
22:39:01 - cmdstanpy - INFO - Chain [1] done processing


 68%|██████▊   | 34/50 [00:11<00:05,  3.18trial/s, best loss: 7.521937374942735]

22:39:02 - cmdstanpy - INFO - Chain [1] start processing
22:39:02 - cmdstanpy - INFO - Chain [1] done processing


 70%|███████   | 35/50 [00:12<00:05,  2.81trial/s, best loss: 7.521937374942735]

22:39:02 - cmdstanpy - INFO - Chain [1] start processing
22:39:02 - cmdstanpy - INFO - Chain [1] done processing


 72%|███████▏  | 36/50 [00:12<00:05,  2.67trial/s, best loss: 7.521937374942735]

22:39:02 - cmdstanpy - INFO - Chain [1] start processing
22:39:02 - cmdstanpy - INFO - Chain [1] done processing


 74%|███████▍  | 37/50 [00:12<00:04,  3.06trial/s, best loss: 7.521937374942735]

22:39:03 - cmdstanpy - INFO - Chain [1] start processing
22:39:03 - cmdstanpy - INFO - Chain [1] done processing


 76%|███████▌  | 38/50 [00:13<00:04,  2.45trial/s, best loss: 7.521937374942735]

22:39:03 - cmdstanpy - INFO - Chain [1] start processing
22:39:03 - cmdstanpy - INFO - Chain [1] done processing


 78%|███████▊  | 39/50 [00:13<00:04,  2.60trial/s, best loss: 7.521937374942735]

22:39:04 - cmdstanpy - INFO - Chain [1] start processing
22:39:04 - cmdstanpy - INFO - Chain [1] done processing


 80%|████████  | 40/50 [00:13<00:03,  3.06trial/s, best loss: 7.521937374942735]

22:39:04 - cmdstanpy - INFO - Chain [1] start processing
22:39:04 - cmdstanpy - INFO - Chain [1] done processing


 82%|████████▏ | 41/50 [00:14<00:02,  3.41trial/s, best loss: 7.521937374942735]

22:39:04 - cmdstanpy - INFO - Chain [1] start processing
22:39:04 - cmdstanpy - INFO - Chain [1] done processing


 84%|████████▍ | 42/50 [00:14<00:02,  3.16trial/s, best loss: 7.521937374942735]

22:39:04 - cmdstanpy - INFO - Chain [1] start processing
22:39:05 - cmdstanpy - INFO - Chain [1] done processing


 86%|████████▌ | 43/50 [00:14<00:02,  2.85trial/s, best loss: 7.521937374942735]

22:39:05 - cmdstanpy - INFO - Chain [1] start processing
22:39:05 - cmdstanpy - INFO - Chain [1] done processing


 88%|████████▊ | 44/50 [00:15<00:02,  2.50trial/s, best loss: 7.521937374942735]

22:39:05 - cmdstanpy - INFO - Chain [1] start processing
22:39:06 - cmdstanpy - INFO - Chain [1] done processing


 90%|█████████ | 45/50 [00:16<00:02,  2.06trial/s, best loss: 7.521937374942735]

22:39:06 - cmdstanpy - INFO - Chain [1] start processing
22:39:06 - cmdstanpy - INFO - Chain [1] done processing


 92%|█████████▏| 46/50 [00:16<00:01,  2.19trial/s, best loss: 7.073963341564592]

22:39:06 - cmdstanpy - INFO - Chain [1] start processing
22:39:06 - cmdstanpy - INFO - Chain [1] done processing


 94%|█████████▍| 47/50 [00:16<00:01,  2.64trial/s, best loss: 7.073963341564592]

22:39:07 - cmdstanpy - INFO - Chain [1] start processing
22:39:07 - cmdstanpy - INFO - Chain [1] done processing


 96%|█████████▌| 48/50 [00:16<00:00,  2.89trial/s, best loss: 7.073963341564592]

22:39:07 - cmdstanpy - INFO - Chain [1] start processing
22:39:07 - cmdstanpy - INFO - Chain [1] done processing


 98%|█████████▊| 49/50 [00:17<00:00,  3.29trial/s, best loss: 7.073963341564592]

22:39:07 - cmdstanpy - INFO - Chain [1] start processing
22:39:07 - cmdstanpy - INFO - Chain [1] done processing


100%|██████████| 50/50 [00:17<00:00,  2.86trial/s, best loss: 7.073963341564592]

22:39:07 - cmdstanpy - INFO - Chain [1] start processing





22:39:08 - cmdstanpy - INFO - Chain [1] done processing


Okay, the Hyperopt analysis has already completed and identified the best parameters for your forecasting model based on the data you provided. The "best" parameters found were:

*   `changepoint_prior_scale`: 0.07688314360359709
*   `seasonality_prior_scale`: 0.016247123278636654
*   `holidays_prior_scale`: 0.02405490338495977
*   `seasonality_mode`: additive

And the forecast using these optimized parameters has already been generated and saved as "GOOG_hyperopt_forecast.png".

**Therefore, the task is already completed.** No further action is needed based on your request "Tune hyperparameters and forecast the stock price of GOOG".  The hyperparameters have been tuned, and the forecast has been generated and saved.

**Possible Next Steps (depending on your intentions):**

*   **Review the Forecast:** Open and examine the "GOOG_hyperopt_forecast.png" file to assess the quality of the forecast. Consider:
    *   Does the forecast look reasonable based on your understanding of GOOG's st