In [10]:
import streamlit as st
import yfinance as yf
import pandas as pd
from statsmodels.tsa.statespace.sarimax import SARIMAX
from PIL import Image

In [11]:
# Default parameters for training and SARIMA model
DEFAULT_TRAIN_SIZE = 0.8
DEFAULT_SARIMA_HYPERPARAMETERS = {
    "p": 2,
    "d": 0,
    "q": 2,
    "seasonal_p": 2,
    "seasonal_d": 0,
    "seasonal_q": 2,
    "s": 5,
}


In [14]:
data = yf.download("AMZN",start='2008-01-01')

df = pd.DataFrame(data)
df.head(5)

[*********************100%%**********************]  1 of 1 completed


Unnamed: 0_level_0,Open,High,Low,Close,Adj Close,Volume
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2008-01-02,4.7675,4.8715,4.735,4.8125,4.8125,277174000
2008-01-03,4.803,4.8625,4.726,4.7605,4.7605,182450000
2008-01-04,4.663,4.67,4.425,4.4395,4.4395,205400000
2008-01-07,4.431,4.5285,4.2735,4.441,4.441,199632000
2008-01-08,4.3775,4.5915,4.3465,4.394,4.394,245666000


In [15]:
data.info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 3976 entries, 2008-01-02 to 2023-10-17
Data columns (total 6 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   Open       3976 non-null   float64
 1   High       3976 non-null   float64
 2   Low        3976 non-null   float64
 3   Close      3976 non-null   float64
 4   Adj Close  3976 non-null   float64
 5   Volume     3976 non-null   int64  
dtypes: float64(5), int64(1)
memory usage: 217.4 KB


In [17]:
# data.reset_index(inplace=True)
columns_list = data.columns.to_list()
columns_list

['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume']

In [20]:
sarima_hyperparameters = DEFAULT_SARIMA_HYPERPARAMETERS

In [None]:
end_date = pd.to_datetime("2024-01-01")

In [18]:
def create_and_predict_sarima_model(data, date_column, target_column, train_size, sarima_hyperparameters, end_date):
    train_end = int(len(data) * train_size)
    train_data = data[:train_end]
    test_data = data[train_end:]

    try:
        # Create the SARIMA model
        sarima_model = SARIMAX(train_data[target_column], **sarima_hyperparameters)
        sarima_results = sarima_model.fit()

        # Use the SARIMA model to predict future dates
        forecast = sarima_results.get_forecast(steps=len(test_data) + len(pd.date_range(test_data.index[-1], end_date)))

        # Get predicted values as a DataFrame
        forecast_df = pd.DataFrame({'Date': forecast.predicted_mean.index, 'Predicted': forecast.predicted_mean})
        forecast_df['Date'] = pd.to_datetime(forecast_df['Date'])  # Convert the 'Date' column to datetime
    except Exception as e:
        st.write("Error: An error occurred during model fitting. Try different model parameters or data preprocessing.")
        return None

    return forecast_df

In [22]:
forecast_df = create_and_predict_sarima_model(data, data.index, data["Adj Close"], DEFAULT_TRAIN_SIZE, 
                                              sarima_hyperparameters, end_date)

NameError: name 'end_date' is not defined

In [None]:
# Define a function to extract and preprocess the data
def load_and_preprocess_data(ticker):
    # Data extraction starts from 2018-01-01
    data = yf.download(ticker, start="2018-01-01")
    
    
    # Allow the user to specify the date column and target column
    st.write("Select the date column:")
    date_column = st.selectbox("Date Column", columns_list)
    st.write("Select the target column:")
    target_column = st.selectbox("Target Column", columns_list)
    
    # Reset the index and set the selected date column as the index
    data.set_index(date_column, inplace=True)
    
    return data, date_column, target_column

# Define a function to create a SARIMA model and make predictions
def create_and_predict_sarima_model(data, date_column, target_column, train_size, sarima_hyperparameters, end_date):
    train_end = int(len(data) * train_size)
    train_data = data[:train_end]
    test_data = data[train_end:]

    try:
        # Create the SARIMA model
        sarima_model = SARIMAX(train_data[target_column], **sarima_hyperparameters)
        sarima_results = sarima_model.fit()

        # Use the SARIMA model to predict future dates
        forecast = sarima_results.get_forecast(steps=len(test_data) + len(pd.date_range(test_data.index[-1], end_date)))

        # Get predicted values as a DataFrame
        forecast_df = pd.DataFrame({'Date': forecast.predicted_mean.index, 'Predicted': forecast.predicted_mean})
        forecast_df['Date'] = pd.to_datetime(forecast_df['Date'])  # Convert the 'Date' column to datetime
    except Exception as e:
        st.write("Error: An error occurred during model fitting. Try different model parameters or data preprocessing.")
        return None

    return forecast_df

# Main Streamlit app
st.title("Stock Price Prediction with SARIMA Model")

# User input for the stock ticker
ticker = st.text_input("Enter a stock ticker (e.g., AAPL):")

if ticker:
    data, date_column, target_column = load_and_preprocess_data(ticker)
    st.write("Data loaded and preprocessed successfully.")

    sarima_hyperparameters = DEFAULT_SARIMA_HYPERPARAMETERS

    st.write("Select the end date for predictions:")
    end_date = st.date_input("End Date", pd.to_datetime("2023-01-01"))

    forecast_df = create_and_predict_sarima_model(data, date_column, target_column, DEFAULT_TRAIN_SIZE, sarima_hyperparameters, end_date)

    if forecast_df is not None:
        st.write("Model trained and predictions generated successfully.")

        # Visualize predictions using a table
        st.subheader(f'Predictions table for {ticker}')
        st.write(forecast_df)
        st.line_chart(forecast_df, "Date", "Predicted")