In [1]:
import json
import pickle
import os

import numpy as np
import pandas as pd
from sqlalchemy import create_engine

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import xgboost as xgb

from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll import scope

import yfinance as yf

In [2]:
# Dump to pickle
def dump_pickle(obj, filename: str):
    with open(filename, "wb") as f_out:
        return pickle.dump(obj, f_out)


# Load from pickle
def load_pickle(filename: str):
    with open(filename, "rb") as f_in:
        return pickle.load(f_in)


# Download stock price data
def download_stock_data(ticker, start_date, end_date):
    stock_data = yf.download(ticker, start=start_date, end=end_date)
    return stock_data


# Preprocess data and create input sequences
def preprocess_data(data, sequence_length):
    sequences = []
    data = data.reset_index()
    adj_close = data['Adj Close']
    date = data['Date']
    
    for i in range(len(adj_close) - sequence_length):
        adj_close_sequence = adj_close.iloc[i:i+sequence_length+1].values
        target_date = date.iloc[i+sequence_length]
        sequence = np.append(target_date, adj_close_sequence)
        sequences.append(sequence)

    return sequences


# Split data into training and testing sets
def split_data(data, test_size=0.2):
    train_data, test_data = train_test_split(data, test_size=test_size, shuffle=False)
    return train_data, test_data


# Hyperparameter tuning
def run_optimization(X_train, y_train, X_test, y_test, num_trials: int = 30):
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_test, label=y_test)

    def objective(params):
        model = xgb.train(
            params=params,
            dtrain=train,
            num_boost_round=1000,
            evals=[(valid, 'validation')],
            early_stopping_rounds=50
        )
        y_pred = model.predict(valid)
        rmse = mean_squared_error(y_test, y_pred, squared=False)
        return {'loss': rmse, 'status': STATUS_OK}    
    
    search_space = {
        'max_depth': scope.int(hp.quniform('max_depth', 4, 100, 1)),
        'learning_rate': hp.loguniform('learning_rate', -3, 0),
        'reg_alpha': hp.loguniform('reg_alpha', -5, -1),
        'reg_lambda': hp.loguniform('reg_lambda', -6, -1),
        'min_child_weight': hp.loguniform('min_child_weight', -1, 3),
        'objective': 'reg:linear',
        'seed': 42
    }

    rstate = np.random.default_rng(42)  # for reproducible results
    best_params = fmin(
        fn=objective,
        space=search_space,
        algo=tpe.suggest,
        max_evals=num_trials,
        trials=Trials(),
        rstate=rstate
    )
    return best_params


# Train XGBoost regression model
def train_model(X_train, y_train, X_test, y_test, best_params):
    train = xgb.DMatrix(X_train, label=y_train)
    valid = xgb.DMatrix(X_test, label=y_test)
    
    best_params['max_depth']=int(best_params['max_depth'])

    model = xgb.train(
        params=best_params,
        dtrain=train,
        num_boost_round=1000,
        evals=[(valid, 'validation')],
        early_stopping_rounds=50,
    )
    return model


# Evaluate model on test data
def evaluate_model(model, X_test, y_test):
    valid = xgb.DMatrix(X_test, label=y_test)
    
    predictions = model.predict(valid)
    mse = mean_squared_error(y_test, predictions)
    rmse = mean_squared_error(y_test, predictions, squared=False)
    return mse, rmse


# Make predictions
def predict_model(model, data):
    valid = xgb.DMatrix(data)
    predictions = model.predict(valid)
    return predictions
    

In [39]:
# Generate the test.json for pytest with the goog stock data from 2024-01-02
ticker = 'goog'
start_date = '2024-01-02'
end_date = '2024-01-03'

stock_data = download_stock_data(ticker, start_date, end_date)
stock_data.reset_index(inplace=True)
stock_data['Date'] = stock_data['Date'].astype(str)
stock_data.to_json('test.json', orient = 'records')

[*********************100%%**********************]  1 of 1 completed


In [3]:
# Download stock data
# amzn / goog / msft
ticker = 'goog'
start_date = '2024-01-02'
end_date = '2024-10-01'

stock_data = download_stock_data(ticker, start_date, end_date)

[*********************100%%**********************]  1 of 1 completed


In [4]:
df = (
    pd.DataFrame(stock_data)
    .reset_index()
    .rename(
        columns={
            "Date": "date",
            "Open": "open",
            "High": "high",
            "Low": "low",
            "Close": "close",
            "Adj Close": "adj_close",
            "Volume": "volume",
        }
    )
)
df["symbol"] = ticker
df = df[["date", "symbol", "open", "high", "low", "close", "adj_close", "volume"]]

In [5]:
# Display the DataFrame
df.head()

Unnamed: 0,date,symbol,open,high,low,close,adj_close,volume
0,2024-01-02,goog,139.600006,140.615005,137.740005,139.559998,139.218109,20071900
1,2024-01-03,goog,138.600006,141.089996,138.429993,140.360001,140.016144,18974300
2,2024-01-04,goog,139.850006,140.634995,138.009995,138.039993,137.701828,18253300
3,2024-01-05,goog,138.352005,138.809998,136.850006,137.389999,137.053421,15433200
4,2024-01-08,goog,138.0,140.639999,137.880005,140.529999,140.185715,17645300


In [6]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 188 entries, 0 to 187
Data columns (total 8 columns):
 #   Column     Non-Null Count  Dtype         
---  ------     --------------  -----         
 0   date       188 non-null    datetime64[ns]
 1   symbol     188 non-null    object        
 2   open       188 non-null    float64       
 3   high       188 non-null    float64       
 4   low        188 non-null    float64       
 5   close      188 non-null    float64       
 6   adj_close  188 non-null    float64       
 7   volume     188 non-null    int64         
dtypes: datetime64[ns](1), float64(5), int64(1), object(1)
memory usage: 11.9+ KB


In [45]:
# Save historical data to csv.
df.to_csv(f'stocks_{ticker}_{start_date}_to_{end_date}.csv')

In [9]:
# Create connection to postgres instance
db_uri = 'postgresql://airflow:albeta2023@mlops-rds-instance.crcqa0ua6cb3.us-east-1.rds.amazonaws.com:5432/stocks'
engine = create_engine(db_uri)

In [10]:
# Insert historical data into postgres table 'stock_ohlc'
df.to_sql('stock_ohlc', con=engine, if_exists='append', index=False) 

188

In [35]:
stock_data.head()

Unnamed: 0,Date,Open,High,Low,Close,Adj Close,Volume
0,2024-01-02,139.600006,140.615005,137.740005,139.559998,139.401367,20071900


In [64]:
# Preprocess data
# For every date, build the close price sequence of 10 consecutive days
# Each sequence is conform by 'Date' + last 10 close prices + 'Adj Close'
sequence_length = 10
data_sequences = preprocess_data(stock_data, sequence_length)
data_sequences[-2:]

[array([Timestamp('2024-06-27 00:00:00'), 179.55999755859375,
        176.74000549316406, 178.3699951171875, 178.77999877929688,
        176.4499969482422, 177.7100067138672, 180.25999450683594,
        180.7899932861328, 185.5800018310547, 185.3699951171875,
        186.86000061035156], dtype=object),
 array([Timestamp('2024-06-28 00:00:00'), 176.74000549316406,
        178.3699951171875, 178.77999877929688, 176.4499969482422,
        177.7100067138672, 180.25999450683594, 180.7899932861328,
        185.5800018310547, 185.3699951171875, 186.86000061035156,
        183.4199981689453], dtype=object)]

In [71]:
# Create pandas dataframe from the sequences
data_sequences_df = pd.DataFrame(data=data_sequences, columns=['date','d10','d9','d8','d7','d6','d5','d4','d3','d2','d1','adj_close']) 
data_sequences_df.head()

Unnamed: 0,date,d10,d9,d8,d7,d6,d5,d4,d3,d2,d1,adj_close
0,2024-01-17,139.401367,140.200455,137.883087,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585
1,2024-01-18,140.200455,137.883087,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195
2,2024-01-19,137.883087,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,147.801804
3,2024-01-22,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,147.801804,147.542114
4,2024-01-23,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,147.801804,147.542114,148.511002


In [24]:
# Split data into training and testing sets
train_data, test_data = split_data(data_sequences_df)

In [25]:
train_data.describe()

Unnamed: 0,date,d10,d9,d8,d7,d6,d5,d4,d3,d2,d1,adj_close
count,91,91.0,91.0,91.0,91.0,91.0,91.0,91.0,91.0,91.0,91.0,91.0
mean,2024-03-21 15:33:37.582417664,149.893933,150.237937,150.584466,150.977864,151.395411,151.798908,152.192965,152.585266,152.96209,153.300386,153.654379
min,2024-01-17 00:00:00,132.409317,132.409317,132.409317,132.409317,132.409317,132.409317,132.409317,132.409317,132.409317,132.409317,132.409317
25%,2024-02-18 00:00:00,142.47287,142.637688,143.052208,143.441765,143.571625,143.656525,143.726448,143.846313,143.99614,144.046082,144.470596
50%,2024-03-21 00:00:00,148.311218,148.511002,148.560944,148.570938,149.509857,150.049255,150.179108,150.978195,151.527573,151.597488,151.767303
75%,2024-04-23 12:00:00,156.172287,156.257187,156.511902,156.991356,157.380913,157.625633,157.770462,157.865356,158.484657,159.373642,160.172722
max,2024-05-24 00:00:00,173.492569,173.492569,173.492569,173.682358,175.230591,177.08847,178.257156,179.335922,179.335922,179.335922,179.335922
std,,10.028395,10.19992,10.389032,10.578796,10.778278,11.044948,11.341521,11.655343,11.910105,12.090741,12.282982


In [26]:
test_data.describe()

Unnamed: 0,date,d10,d9,d8,d7,d6,d5,d4,d3,d2,d1,adj_close
count,23,23.0,23.0,23.0,23.0,23.0,23.0,23.0,23.0,23.0,23.0,23.0
mean,2024-06-12 09:23:28.695652096,176.00012,176.333349,176.639672,176.760004,176.867804,177.005697,177.11582,177.387302,177.716533,178.238229,178.555204
min,2024-05-28 00:00:00,170.705734,171.734573,173.362717,173.362717,173.362717,173.362717,173.362717,173.362717,173.362717,173.362717,173.362717
25%,2024-06-04 12:00:00,174.541374,174.895973,175.080765,175.490295,175.939789,175.939789,175.939789,175.939789,175.939789,176.289787,176.540001
50%,2024-06-12 00:00:00,176.630005,176.740005,176.868744,176.868744,177.08847,177.198349,177.198349,177.198349,177.198349,177.710007,177.817657
75%,2024-06-20 12:00:00,177.807663,177.982468,178.16864,178.16864,178.16864,178.223579,178.279999,178.279999,178.574997,179.169998,179.909996
max,2024-06-28 00:00:00,179.559998,179.559998,179.559998,179.559998,179.559998,180.259995,180.789993,185.580002,185.580002,186.860001,186.860001
std,,2.325649,2.067299,1.867195,1.753676,1.731454,1.87053,2.016424,2.649782,3.129998,3.597478,3.722256


In [66]:
# Setup numerical features and the target (there are no categorical features)
num_features = ['d10','d9','d8','d7','d6','d5','d4','d3','d2','d1']
target = ['adj_close']

In [28]:
# Prepare training data
X_train = train_data[num_features]
y_train = train_data[target]

# Prepare test data
X_test = test_data[num_features]
y_test = test_data[target]

In [29]:
dump_pickle((X_train, y_train), "train_goog.pkl")
dump_pickle((X_test, y_test), "test_goog.pkl")

In [19]:
# Find best_params for XGBoost
best_params = run_optimization(X_train, y_train, X_test, y_test)

[0]	validation-rmse:15.31032                          
[1]	validation-rmse:10.61778                          
[2]	validation-rmse:8.12751                           
[3]	validation-rmse:7.03434                           
[4]	validation-rmse:6.18613                           
  0%|          | 0/30 [00:00<?, ?trial/s, best loss=?]




[5]	validation-rmse:5.64713                           
[6]	validation-rmse:5.35651                           
[7]	validation-rmse:5.45521                           
[8]	validation-rmse:5.51340                           
[9]	validation-rmse:5.31227                           
[10]	validation-rmse:5.21957                          
[11]	validation-rmse:5.31525                          
[12]	validation-rmse:5.37918                          
[13]	validation-rmse:5.45131                          
[14]	validation-rmse:5.29302                          
[15]	validation-rmse:5.20689                          
[16]	validation-rmse:5.32144                          
[17]	validation-rmse:5.21815                          
[18]	validation-rmse:5.30336                          
[19]	validation-rmse:5.21216                          
[20]	validation-rmse:5.26786                          
[21]	validation-rmse:5.18847                          
[22]	validation-rmse:5.15024                          
[23]	valid




[1]	validation-rmse:16.12038                                                   
[2]	validation-rmse:13.37033                                                   
[3]	validation-rmse:11.43581                                                   
[4]	validation-rmse:9.98118                                                    
[5]	validation-rmse:8.69913                                                    
[6]	validation-rmse:7.62577                                                    
[7]	validation-rmse:6.94170                                                    
[8]	validation-rmse:6.36489                                                    
[9]	validation-rmse:5.92992                                                    
[10]	validation-rmse:5.62107                                                   
[11]	validation-rmse:5.37781                                                   
[12]	validation-rmse:5.32636                                                   
[13]	validation-rmse:5.17017            




[2]	validation-rmse:11.24868                                                   
[3]	validation-rmse:8.85637                                                    
[4]	validation-rmse:7.11218                                                    
[5]	validation-rmse:5.95012                                                    
[6]	validation-rmse:5.10647                                                    
[7]	validation-rmse:4.56385                                                    
[8]	validation-rmse:4.18238                                                    
[9]	validation-rmse:3.93133                                                    
[10]	validation-rmse:3.86353                                                   
[11]	validation-rmse:3.74233                                                   
[12]	validation-rmse:3.72522                                                   
[13]	validation-rmse:3.66004                                                   
[14]	validation-rmse:3.65846            




[3]	validation-rmse:19.88913
[4]	validation-rmse:18.79666                                                  
[5]	validation-rmse:17.75756                                                  
[6]	validation-rmse:16.74755                                                  
[7]	validation-rmse:15.85015                                                  
[8]	validation-rmse:15.10802                                                  
[9]	validation-rmse:14.32102                                                  
[10]	validation-rmse:13.58475                                                 
[11]	validation-rmse:12.98433                                                 
[12]	validation-rmse:12.34102                                                 
[13]	validation-rmse:11.73896                                                 
[14]	validation-rmse:11.25591                                                 
[15]	validation-rmse:10.72989                                                 
[16]	validation-rmse:10




[4]	validation-rmse:13.13381                                                    
[5]	validation-rmse:11.80568                                                    
[6]	validation-rmse:10.66945                                                    
[7]	validation-rmse:9.70426                                                     
[8]	validation-rmse:8.73003                                                     
[9]	validation-rmse:7.90538                                                     
[10]	validation-rmse:7.19180                                                    
[11]	validation-rmse:6.60322                                                    
[12]	validation-rmse:6.10417                                                    
[13]	validation-rmse:5.74938                                                    
[14]	validation-rmse:5.39504                                                    
[15]	validation-rmse:5.09898                                                    
[16]	validation-rmse:4.84663




[5]	validation-rmse:3.50662                                                     
[6]	validation-rmse:3.51260                                                     
[7]	validation-rmse:3.51821                                                     
[8]	validation-rmse:3.52975                                                     
[9]	validation-rmse:3.53020                                                     
[10]	validation-rmse:3.52876                                                    
[11]	validation-rmse:3.52876                                                    
[12]	validation-rmse:3.52679                                                    
[13]	validation-rmse:3.52739                                                    
[14]	validation-rmse:3.52796                                                    
[15]	validation-rmse:3.52862                                                    
[16]	validation-rmse:3.52831                                                    
[17]	validation-rmse:3.52810




[0]	validation-rmse:20.17781                                                    
[1]	validation-rmse:16.20722                                                    
[2]	validation-rmse:13.15716                                                    
[3]	validation-rmse:10.77144                                                    
[4]	validation-rmse:8.73867                                                     
[5]	validation-rmse:7.45874                                                     
[6]	validation-rmse:6.59117                                                     
[7]	validation-rmse:5.72713                                                     
[8]	validation-rmse:5.21396                                                     
[9]	validation-rmse:4.78859                                                     
[10]	validation-rmse:4.49801                                                    
[11]	validation-rmse:4.43559                                                    
[12]	validation-rmse:4.29355




[7]	validation-rmse:14.80886                                                    
[8]	validation-rmse:13.90603                                                    
[9]	validation-rmse:13.07304                                                    
[10]	validation-rmse:12.32884                                                   
[11]	validation-rmse:11.62094                                                   
[12]	validation-rmse:10.98951                                                   
[13]	validation-rmse:10.38768                                                   
[14]	validation-rmse:9.83557                                                    
[15]	validation-rmse:9.34439                                                    
[16]	validation-rmse:8.87584                                                    
[17]	validation-rmse:8.44616                                                    
[18]	validation-rmse:8.07227                                                    
[19]	validation-rmse:7.70963




[4]	validation-rmse:3.66735                                                     
[5]	validation-rmse:3.66735                                                     
[6]	validation-rmse:3.66735                                                     
[7]	validation-rmse:3.66735                                                     
[8]	validation-rmse:3.66735                                                     
[9]	validation-rmse:3.66735                                                     
[10]	validation-rmse:3.66735                                                    
[11]	validation-rmse:3.66735                                                    
[12]	validation-rmse:3.66735                                                    
[13]	validation-rmse:3.66735                                                    
[14]	validation-rmse:3.66735                                                    
[15]	validation-rmse:3.66735                                                    
[16]	validation-rmse:3.66735




[8]	validation-rmse:4.02131                                                     
[9]	validation-rmse:3.91161                                                     
[10]	validation-rmse:3.87218                                                    
[11]	validation-rmse:3.82919                                                    
[12]	validation-rmse:3.79467                                                    
[13]	validation-rmse:3.79067                                                    
[14]	validation-rmse:3.77472                                                    
[15]	validation-rmse:3.76325                                                    
[16]	validation-rmse:3.76086                                                    
[17]	validation-rmse:3.75466                                                    
[18]	validation-rmse:3.74875                                                    
[19]	validation-rmse:3.74448                                                    
[20]	validation-rmse:3.74448




[1]	validation-rmse:4.54390                                                      
[2]	validation-rmse:4.13179                                                      
[3]	validation-rmse:4.04172                                                      
[4]	validation-rmse:4.03185                                                      
[5]	validation-rmse:4.00180                                                      
[6]	validation-rmse:4.00493                                                      
[7]	validation-rmse:4.00809                                                      
[8]	validation-rmse:4.00788                                                      
[9]	validation-rmse:4.00673                                                      
[10]	validation-rmse:4.00676                                                     
[11]	validation-rmse:4.00676                                                     
[12]	validation-rmse:4.00676                                                     
[13]	validation-




[10]	validation-rmse:9.49723                                                     
[11]	validation-rmse:8.87432                                                     
[12]	validation-rmse:8.42115                                                     
[13]	validation-rmse:8.01814                                                     
[14]	validation-rmse:7.56261                                                     
[15]	validation-rmse:7.07828                                                     
[16]	validation-rmse:6.65179                                                     
[17]	validation-rmse:6.37218                                                     
[18]	validation-rmse:6.03228                                                     
[19]	validation-rmse:5.74324                                                     
[20]	validation-rmse:5.49063                                                     
[21]	validation-rmse:5.27017                                                     
[22]	validation-




[0]	validation-rmse:12.43220                                                     
[1]	validation-rmse:7.19005                                                      
[2]	validation-rmse:4.89774                                                      
[3]	validation-rmse:4.01983                                                      
[4]	validation-rmse:3.88030                                                      
[5]	validation-rmse:3.69442                                                      
[6]	validation-rmse:3.70088                                                      
[7]	validation-rmse:3.78797                                                      
[8]	validation-rmse:3.73981                                                      
[9]	validation-rmse:3.78934                                                      
[10]	validation-rmse:3.83085                                                     
[11]	validation-rmse:3.85580                                                     
[12]	validation-




[3]	validation-rmse:4.57624                                                      
[4]	validation-rmse:3.98673                                                      
[5]	validation-rmse:3.78154                                                      
[6]	validation-rmse:3.75371                                                      
[7]	validation-rmse:3.71946                                                      
[8]	validation-rmse:3.71573                                                      
[9]	validation-rmse:3.71548                                                      
[10]	validation-rmse:3.71497                                                     
[11]	validation-rmse:3.71497                                                     
[12]	validation-rmse:3.71497                                                     
[13]	validation-rmse:3.71497                                                     
[14]	validation-rmse:3.71497                                                     
[15]	validation-




[7]	validation-rmse:14.86476                                                     
[8]	validation-rmse:13.98702                                                     
[9]	validation-rmse:13.10121                                                     
[10]	validation-rmse:12.42429                                                    
[11]	validation-rmse:11.67892                                                    
[12]	validation-rmse:10.97172                                                    
[13]	validation-rmse:10.37194                                                    
[14]	validation-rmse:9.77802                                                     
[15]	validation-rmse:9.23009                                                     
[16]	validation-rmse:8.70725                                                     
[17]	validation-rmse:8.24024                                                     
[18]	validation-rmse:7.85108                                                     
[19]	validation-




[5]	validation-rmse:4.73890                                                      
[6]	validation-rmse:4.73890                                                      
[7]	validation-rmse:4.73890                                                      
[8]	validation-rmse:4.69460                                                      
[9]	validation-rmse:4.69460                                                      
[10]	validation-rmse:4.68294                                                     
[11]	validation-rmse:4.68294                                                     
[12]	validation-rmse:4.68294                                                     
[13]	validation-rmse:4.68233                                                     
[14]	validation-rmse:4.68233                                                     
[15]	validation-rmse:4.68233                                                     
[16]	validation-rmse:4.68233                                                     
[17]	validation-




[1]	validation-rmse:20.35055                                                     
[2]	validation-rmse:18.49480                                                     
[3]	validation-rmse:16.64754                                                     
[4]	validation-rmse:15.20355                                                     
[5]	validation-rmse:13.77999                                                     
[6]	validation-rmse:12.53651                                                     
[7]	validation-rmse:11.43651                                                     
[8]	validation-rmse:10.48725                                                     
[9]	validation-rmse:9.66205                                                      
[10]	validation-rmse:8.92257                                                     
[11]	validation-rmse:8.29178                                                     
[12]	validation-rmse:7.74725                                                     
[13]	validation-




[3]	validation-rmse:19.12826                                                     
[4]	validation-rmse:17.76420                                                     
[5]	validation-rmse:16.50907                                                     
[6]	validation-rmse:15.50528                                                     
[7]	validation-rmse:14.58187                                                     
[8]	validation-rmse:13.61148                                                     
[9]	validation-rmse:12.72908                                                     
[10]	validation-rmse:11.99060                                                    
[11]	validation-rmse:11.33712                                                    
[12]	validation-rmse:10.71619                                                    
[13]	validation-rmse:10.16803                                                    
[14]	validation-rmse:9.64712                                                     
[15]	validation-




[15]	validation-rmse:4.64955
[16]	validation-rmse:4.42961                                                     
[17]	validation-rmse:4.22280                                                     
[18]	validation-rmse:4.08357                                                     
[19]	validation-rmse:3.99015                                                     
[20]	validation-rmse:3.94642                                                     
[21]	validation-rmse:3.85455                                                     
[22]	validation-rmse:3.83929                                                     
[23]	validation-rmse:3.83163                                                     
[24]	validation-rmse:3.80148                                                     
[25]	validation-rmse:3.75984                                                     
[26]	validation-rmse:3.75030                                                     
[27]	validation-rmse:3.75480                                         




[6]	validation-rmse:15.43581                                                     
[7]	validation-rmse:14.71226                                                     
[8]	validation-rmse:13.80991                                                     
[9]	validation-rmse:13.05283                                                     
[10]	validation-rmse:12.35392                                                    
[11]	validation-rmse:11.70385                                                    
[12]	validation-rmse:11.12147                                                    
[13]	validation-rmse:10.60023                                                    
[14]	validation-rmse:10.13339                                                    
[15]	validation-rmse:9.71570                                                     
[16]	validation-rmse:9.34120                                                     
[17]	validation-rmse:9.04544                                                     
[18]	validation-




[1]	validation-rmse:18.02053                                                     
[2]	validation-rmse:15.14775                                                     
[3]	validation-rmse:12.96373                                                     
[4]	validation-rmse:10.99182                                                     
[5]	validation-rmse:9.70236                                                      
[6]	validation-rmse:8.64864                                                      
[7]	validation-rmse:7.59866                                                      
[8]	validation-rmse:6.75536                                                      
[9]	validation-rmse:6.08341                                                      
[10]	validation-rmse:5.50700                                                     
[11]	validation-rmse:5.05787                                                     
[12]	validation-rmse:4.71141                                                     
[13]	validation-




[7]	validation-rmse:4.44732                                                      
[8]	validation-rmse:4.14873                                                      
[9]	validation-rmse:4.07917                                                      
[10]	validation-rmse:3.94377                                                     
[11]	validation-rmse:3.83733                                                     
[12]	validation-rmse:3.76445                                                     
[13]	validation-rmse:3.71262                                                     
[14]	validation-rmse:3.67005                                                     
[15]	validation-rmse:3.62711                                                     
[16]	validation-rmse:3.60424                                                     
[17]	validation-rmse:3.57228                                                     
[18]	validation-rmse:3.55505                                                     
[19]	validation-




[0]	validation-rmse:21.26888                                                     
[1]	validation-rmse:18.31781                                                     
[2]	validation-rmse:15.57048                                                     
[3]	validation-rmse:13.38032                                                     
[4]	validation-rmse:11.54210                                                     
[5]	validation-rmse:10.16610                                                     
[6]	validation-rmse:8.88506                                                      
[7]	validation-rmse:7.89795                                                      
[8]	validation-rmse:7.36746                                                      
[9]	validation-rmse:6.55371                                                      
[10]	validation-rmse:5.96766                                                     
[11]	validation-rmse:5.69732                                                     
[12]	validation-




[8]	validation-rmse:4.12090                                                      
[9]	validation-rmse:4.05970                                                      
[10]	validation-rmse:3.95903                                                     
[11]	validation-rmse:3.88517                                                     
[12]	validation-rmse:3.83370                                                     
[13]	validation-rmse:3.79473                                                     
[14]	validation-rmse:3.76476                                                     
[15]	validation-rmse:3.74243                                                     
[16]	validation-rmse:3.71496                                                     
[17]	validation-rmse:3.68934                                                     
[18]	validation-rmse:3.67658                                                     
[19]	validation-rmse:3.66806                                                     
[20]	validation-




[8]	validation-rmse:7.18741                                                      
[9]	validation-rmse:6.87876                                                      
[10]	validation-rmse:6.61567                                                     
[11]	validation-rmse:6.10346                                                     
[12]	validation-rmse:5.65320                                                     
[13]	validation-rmse:5.47456                                                     
[14]	validation-rmse:5.32870                                                     
[15]	validation-rmse:5.17902                                                     
[16]	validation-rmse:4.94500                                                     
[17]	validation-rmse:4.78059                                                     
[18]	validation-rmse:4.61421                                                     
[19]	validation-rmse:4.47838                                                     
[20]	validation-




[11]	validation-rmse:3.58798                                                     
[12]	validation-rmse:3.55096                                                     
[13]	validation-rmse:3.53407                                                     
[14]	validation-rmse:3.52152                                                     
[15]	validation-rmse:3.50842                                                     
[16]	validation-rmse:3.50103                                                     
[17]	validation-rmse:3.49444                                                     
[18]	validation-rmse:3.48987                                                     
[19]	validation-rmse:3.48592                                                     
[20]	validation-rmse:3.48169                                                     
[21]	validation-rmse:3.47925                                                     
[22]	validation-rmse:3.46932                                                     
[23]	validation-




[0]	validation-rmse:20.57206                                                     
[1]	validation-rmse:16.82168                                                     
[2]	validation-rmse:13.87292                                                     
[3]	validation-rmse:11.50677                                                     
[4]	validation-rmse:9.57272                                                      
[5]	validation-rmse:8.13978                                                      
[6]	validation-rmse:6.93852                                                      
[7]	validation-rmse:6.06904                                                      
[8]	validation-rmse:5.33264                                                      
[9]	validation-rmse:4.81121                                                      
[10]	validation-rmse:4.40657                                                     
[11]	validation-rmse:4.12051                                                     
[12]	validation-




[11]	validation-rmse:4.14272                                                     
[12]	validation-rmse:4.04013                                                     
[13]	validation-rmse:4.04198                                                     
[14]	validation-rmse:4.05254                                                     
[15]	validation-rmse:4.06091                                                     
[16]	validation-rmse:4.02334                                                     
[17]	validation-rmse:4.04299                                                     
[18]	validation-rmse:4.06727                                                     
[19]	validation-rmse:4.03526                                                     
[20]	validation-rmse:4.01544                                                     
[21]	validation-rmse:4.03982                                                     
[22]	validation-rmse:4.06604                                                     
[23]	validation-




[7]	validation-rmse:17.67986                                                     
[8]	validation-rmse:16.93271                                                     
[9]	validation-rmse:16.24224                                                     
[10]	validation-rmse:15.62703                                                    
[11]	validation-rmse:14.98445                                                    
[12]	validation-rmse:14.37491                                                    
[13]	validation-rmse:13.79683                                                    
[14]	validation-rmse:13.24868                                                    
[15]	validation-rmse:12.74314                                                    
[16]	validation-rmse:12.24990                                                    
[17]	validation-rmse:11.77118                                                    
[18]	validation-rmse:11.31755                                                    
[19]	validation-




[2]	validation-rmse:5.41103                                                      
[3]	validation-rmse:4.44852                                                      
[4]	validation-rmse:4.31447                                                      
[5]	validation-rmse:4.08490                                                      
[6]	validation-rmse:4.14645                                                      
[7]	validation-rmse:4.08228                                                      
[8]	validation-rmse:4.00762                                                      
[9]	validation-rmse:4.07435                                                      
[10]	validation-rmse:4.12324                                                     
[11]	validation-rmse:4.06881                                                     
[12]	validation-rmse:4.12413                                                     
[13]	validation-rmse:4.17444                                                     
[14]	validation-

In [20]:
# Show best_params
best_params

{'learning_rate': 0.18236672959420272,
 'max_depth': 11.0,
 'min_child_weight': 1.2534175878035614,
 'reg_alpha': 0.02376073607996226,
 'reg_lambda': 0.3476822703285598}

In [21]:
# Save best_params to json
with open('hpo_best_params_xgboost.json', 'w') as f:
    json.dump(best_params, f)

In [22]:
# Train XGBoost regression model with the best_params
loaded_model = train_model(X_train, y_train, X_test, y_test, best_params)

[0]	validation-rmse:21.26888
[1]	validation-rmse:18.31781
[2]	validation-rmse:15.57048
[3]	validation-rmse:13.38032
[4]	validation-rmse:11.54210
[5]	validation-rmse:10.16610
[6]	validation-rmse:8.88506
[7]	validation-rmse:7.89795
[8]	validation-rmse:7.36746
[9]	validation-rmse:6.55371
[10]	validation-rmse:5.96766
[11]	validation-rmse:5.69732
[12]	validation-rmse:5.47492
[13]	validation-rmse:5.14104
[14]	validation-rmse:4.84689
[15]	validation-rmse:4.63373
[16]	validation-rmse:4.42947
[17]	validation-rmse:4.29384
[18]	validation-rmse:4.18033
[19]	validation-rmse:4.08614
[20]	validation-rmse:4.00042
[21]	validation-rmse:3.98092
[22]	validation-rmse:3.92377
[23]	validation-rmse:3.87568
[24]	validation-rmse:3.82950
[25]	validation-rmse:3.79173
[26]	validation-rmse:3.74982
[27]	validation-rmse:3.72033
[28]	validation-rmse:3.69571
[29]	validation-rmse:3.67376
[30]	validation-rmse:3.64447
[31]	validation-rmse:3.62795
[32]	validation-rmse:3.60782
[33]	validation-rmse:3.59472
[34]	validation-rm

In [23]:
# Evaluate the model peformance
evaluate_model(loaded_model, X_test, y_test)

(11.851912732286703, 3.4426607053682625)

In [24]:
# Save the trained model to pickle
with open(f'xgboost.bin', 'wb') as f:
    pickle.dump(loaded_model, f)

# Calculate predictions using the actual trained model

In [74]:
# Convert dataframe to DMatrix
data_sequences_df_dmtx = xgb.DMatrix(data_sequences_df[num_features])

# Calculate the predictions
data_sequences_df['prediction']=loaded_model.predict(data_sequences_df_dmtx)
data_sequences_df.head()

NameError: name 'loaded_model' is not defined

In [27]:
# Testing one sequence prediction
data = [178.369995, 178.779999, 176.449997,	177.710007, 180.259995, 180.789993, 185.580002, 185.369995, 186.860001, 183.419998]

# Reshape data to a 2D array with 1 row and 10 columns
data_dmtx = xgb.DMatrix(pd.DataFrame([data]), feature_names=("d10", "d9", "d8", "d7", "d6", "d5", "d4", "d3", "d2", "d1"))

# Predict using the model
prediction = loaded_model.predict(data_dmtx)

print(prediction)

[178.09828]


# Calculate predictions loading the "Production" model from S3 bucket

In [67]:
import mlflow
from mlflow.tracking import MlflowClient


mlflow.set_tracking_uri("http://mlflow:5000")
client = MlflowClient()

model_name = f"stock-goog-xgboost"
stage='production'

# Get the information for the latest version of the model in a given stage
latest_version_info = client.get_latest_versions(model_name, stages=[stage])
latest_stage_version = latest_version_info[0].version
 
# Get the model in the stage
model_stage_uri = f"models:/{model_name}/{latest_stage_version}"
print(f"model_stage_uri: {model_stage_uri}")

# Load model as a PyFuncModel.
model = mlflow.pyfunc.load_model(model_stage_uri)

  latest_version_info = client.get_latest_versions(model_name, stages=[stage])


model_stage_uri: models:/stock-goog-xgboost/50


Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

 - numpy (current: 1.24.4, required: numpy==1.26.4)
 - pandas (current: 2.1.1, required: pandas==2.1.4)
 - psutil (current: 5.9.5, required: psutil==6.0.0)
 - scikit-learn (current: 1.3.1, required: scikit-learn==1.5.1)
 - scipy (current: 1.11.3, required: scipy==1.14.0)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.


In [75]:
# Calculate the predictions
data_sequences_df['prediction']=model.predict(data_sequences_df[num_features])
data_sequences_df.head()

Unnamed: 0,date,d10,d9,d8,d7,d6,d5,d4,d3,d2,d1,adj_close,prediction
0,2024-01-17,139.401367,140.200455,137.883087,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,143.711349
1,2024-01-18,140.200455,137.883087,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,145.445267
2,2024-01-19,137.883087,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,147.801804,146.763611
3,2024-01-22,137.233841,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,147.801804,147.542114,148.356293
4,2024-01-23,140.37027,142.397949,143.636551,143.506699,144.07605,143.916229,142.727585,144.825195,147.801804,147.542114,148.511002,147.855743


In [69]:
# Create the dataframe to insert into the stock_prediction table
data_sequences_df['symbol']='goog'
data_sequences_df['model']=model_stage_uri
stockprediction_df = data_sequences_df[['date', 'symbol', 'prediction', 'model']].copy()
stockprediction_df.head()

Unnamed: 0,date,symbol,prediction,model
0,2024-01-17,goog,143.711349,models:/stock-goog-xgboost/50
1,2024-01-18,goog,145.445267,models:/stock-goog-xgboost/50
2,2024-01-19,goog,146.763611,models:/stock-goog-xgboost/50
3,2024-01-22,goog,148.356293,models:/stock-goog-xgboost/50
4,2024-01-23,goog,147.855743,models:/stock-goog-xgboost/50


In [70]:
# Insert historical data into postgres table 'stock_prediction'
stockprediction_df.to_sql('stock_prediction', con=engine, if_exists='append', index=False) 

114

# Unitary prediction getting data from database

In [55]:
sequence_length = 10
prediction_date='2024-08-23'
symbol = 'goog'

sql=f"select t.adj_close from (select date, adj_close from stock_ohlc where symbol='{symbol}' and date < '{prediction_date}' order by date desc limit {sequence_length}) t order by t.date asc"
data = pd.read_sql_query(sql,con=engine)

#print(type(data))
#print(data.shape)
#print(data)

#data = np.array(data)
#print(type(data))
#print(data.shape)
#print(data)

# This is the format that We need to use for prediction
data = np.array(data).reshape((1,-1))
print(type(data))
print(data.shape)
print(data)


<class 'numpy.ndarray'>
(1, 10)
[[165.38999939 163.94999695 165.92999268 162.02999878 163.16999817
  164.74000549 168.3999939  168.96000671 167.63000488 165.49000549]]


# If We are using the actual trained model, convert data list to DMatrix
data_dmtx = xgb.DMatrix(data, feature_names=("d10", "d9", "d8", "d7", "d6", "d5", "d4", "d3", "d2", "d1"))
prediction = loaded_model.predict(data_dmtx)[0]
print(prediction)

In [57]:
# If We are using the Production model loaded from S3
prediction = model.predict(data)[0]
print(prediction)

170.82985


In [38]:
# Insert the new prediction in stock_prediction table
cols = ['date', 'symbol', 'prediction', 'model']
data_all = [prediction_date, symbol, prediction, model_stage_uri]
print(data_all)

df = pd.DataFrame([data_all], columns=cols)
df

['2024-08-16', 'goog', 178.09828, 'models:/stock-goog-xgboost/2']


Unnamed: 0,date,symbol,prediction,model
0,2024-08-16,goog,178.098282,models:/stock-goog-xgboost/2


In [39]:
# Insert unitary prediction into postgres table 'stock_predictions'
df.to_sql('stock_prediction', con=engine, if_exists='append', index=False) 

1