In [1]:
# imports
from google.cloud import bigquery as bq
from datetime import datetime as dt
from time import sleep

# from data import save_gbq

# parameters
project_id = 'ioracle'
table_name = 'aapl_data'
ticker_name = 'aapl'
model_name= 'aapl_arima_predict'
train_start = '2018-01-01'
train_end = dt.today().strftime('%Y-%m-%d')
# model parameters
NON_SEASONAL_ORDER, INCLUDE_DRIFT = (0,1,5), False




In [2]:
import yfinance as yf
import pandas as pd
from datetime import datetime as dt
import os #only for jupyter notebook

In [3]:
# only for jupyter notebook
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="../../service-account-file.json"

In [4]:
def _get_dataframe(ticker_name, start, end):
    """
    get_dataframe(ticker_name, start, end)
    Downloads OHLC,adj close and volume from yahoo finance
    returns dataframe

    """
    df = yf.download(ticker_name, start=start, end=end)
    return df


def _get_start_end(kwargs):
    '''
    get_start_end(kwargs)
    from kwargs, get start, end dates
    if not stated, will return default values
    return start, end dates
    '''
    start = kwargs.get('start', "2017-01-01")
    end = kwargs.get('end', dt.today().strftime('%Y-%m-%d')) #not inclusive
    return start, end
         

def save_local(ticker_name, path_filename, **kwargs):
    """
    save_local(path_filename, ticker_name, **kwargs)
    save df to local path
    """
    start, end = _get_start_end(kwargs)
      
    df = _get_dataframe(ticker_name, start=start, end=end)
    if len(df) != 0:
        df.to_csv(path_filename)
        print(f"{ticker_name} from {start} to {end} saved to {path_filename}")
        

def save_gbq(ticker_name, table_name, **kwargs):
    """
    save_to_gbq(table_name, project_id=None)
    convert df to uploadable format for gbq
    """
    
    start, end = _get_start_end(kwargs)
    project_id = kwargs.get('project_id', "ioracle")
    
            
    temp = _get_dataframe(ticker_name, start=start, end=end)
    
##    For testing, avoid keep downloading data
#     temp = pd.read_csv('play.csv')

    if len(temp) != 0: # check that df is not empty
        temp = temp.rename(columns={'Adj Close': 'Adj_Close'}).reset_index()
        temp.to_gbq(f'{project_id}.main.{table_name}', 
                    project_id=project_id, 
                    table_schema = [{'name': 'Date','type':'DATE'}], #hard code schema for date from DATETIME to DATE
                    if_exists='replace'
                   )
        

def read_local(path_filename):
    """
    read_local(path_filename)
    reads the csv file and parses date col as date, setting the date as the index
    returns the df
    """
    df = pd.read_csv(path_filename)
    df['Date'] = pd.to_datetime(df['Date'])
    return df.set_index('Date')
    

# read from gbq (undo changes)
def read_gbq(table_name, **kwargs):

        
    project_id = kwargs.get('project_id', "ioracle")

    sql = f"SELECT * FROM `{project_id}.main.{table_name}` "

    df = pd.read_gbq(sql, project_id=project_id)
    df = df.sort_values('Date').set_index('Date')   
    return df

In [5]:
def predict(tgt_date):

#     # upload latest data from yfinance
#     save_gbq(ticker_name, table_name)

    # initialize client
    client = bq.Client(project=project_id)

    #delete any previous models of same name
    query1 = f"DROP MODEL IF EXISTS `{project_id}.main.{model_name}`"
    client.query(query1)

    # train model
    query2 = f"""
            CREATE MODEL IF NOT EXISTS `{project_id}.main.{model_name}`
                OPTIONS(MODEL_TYPE='ARIMA_PLUS',
                         time_series_timestamp_col='Date',
                         time_series_data_col='Adj_Close',
                         DATA_FREQUENCY = 'DAILY',
                         HOLIDAY_REGION = 'GLOBAL',
                         CLEAN_SPIKES_AND_DIPS = FALSE,
                         AUTO_ARIMA = FALSE,
                         NON_SEASONAL_ORDER = {NON_SEASONAL_ORDER},
                         INCLUDE_DRIFT = {INCLUDE_DRIFT}) AS
                SELECT Date, Adj_Close
                FROM `{project_id}.main.{table_name}`
                WHERE Date Between '{train_start}' AND '{train_end}'
                ORDER BY Date ASC
            """
    client.query(query2)

    #get predicted value
    query3 = f"""SELECT * FROM ML.FORECAST(MODEL `{project_id}.main.{model_name}`,
                STRUCT(50 AS horizon)
                )"""
    while True:
        try:
            df = client.query(query3).to_dataframe()
            df['Date'] = df['forecast_timestamp'].apply(lambda x: x.date())
            df = df.set_index('Date')
            pred = df.loc[dt.strptime(tgt_date, '%Y-%m-%d').date(), 'forecast_value']
            return pred
        except:
            sleep(1)
    

In [7]:
predict('2022-02-02')

165.61896769005656