# Research Notebook: Development of SAC Reinforcement Learning Model for Stock Trading

## Team Members:
- Javon Kitson
- Nathan Metheny
- Kobe Mensah

---

## Notebook Introduction

This Jupyter notebook serves as a detailed documentation and workspace for our research and development process in creating a Soft Actor-Critic (SAC) reinforcement learning model tailored for stock trading. The notebook is structured to guide through the various stages of model development, from data handling to final evaluation.


## Notebook Objectives

1. **Dataset Handling and Preparation**: To detail the process of acquiring, cleaning, and preparing the stock price data for model training.
2. **Feature Engineering**: To explore and define the set of features that will be used by the SAC model, enhancing its ability to make informed trading decisions.
3. **Model Development**: To document the iterative process of designing, training, and tuning the SAC model.
4. **Evaluation and Testing**: To assess the model's performance using historical data, focusing on its ability to make profitable trading decisions.

## Dataset
First Rate Data: https://firstratedata.com/cb/4/complete-stocks-etf \
Volume: 262GB \
Unique Tickers: 10120 ~7k Stocks + ~3k ETFs

---

### Data Engineering

- Overview of the data source and its characteristics.
- Steps taken for data cleaning and preprocessing.
- Strategies for data storage and retrieval.

### Feature Engineering

- Identification and justification of chosen features for the model.
- Detailed explanation of the feature engineering process.
- Analysis of the impact of these features on the model’s performance.


### Model Development

- Detailed description of the SAC algorithm and its suitability for stock trading.
- Architecture of the Actor and Critic networks within the SAC framework.
- Hyperparameter selection and optimization process.

### Training Process

- Methodology for training the SAC model using the prepared dataset.
- Techniques employed for improving model performance and avoiding overfitting.
- Continuous evaluation during the training phase.

### Model Evaluation and Testing

- Criteria and metrics for model evaluation, including profitability and Sharpe Ratio.
- Testing process using the validation and test datasets.
- Comparison of model performance against traditional trading strategies.

## Environment Setup

In [1]:
# AWS CLI and AWS Python SDK (boto3)
!pip install --disable-pip-version-check -q awscli==1.18.216 boto3==1.16.56 botocore==1.19.56

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
aiobotocore 2.5.2 requires botocore<1.29.162,>=1.29.161, but you have botocore 1.19.56 which is incompatible.[0m[31m
[0m

In [2]:
# SageMaker
!pip install --disable-pip-version-check -q sagemaker==2.29.0
!pip install --disable-pip-version-check -q smdebug==1.0.1
!pip install --disable-pip-version-check -q sagemaker-experiments==0.1.26

In [3]:
# PyAthena
!pip install --disable-pip-version-check -q PyAthena==2.1.0

In [4]:
# AWS Data Wrangler
!pip install --disable-pip-version-check -q awswrangler

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
aiobotocore 2.5.2 requires botocore<1.29.162,>=1.29.161, but you have botocore 1.34.64 which is incompatible.
awscli 1.18.216 requires botocore==1.19.56, but you have botocore 1.34.64 which is incompatible.
awscli 1.18.216 requires s3transfer<0.4.0,>=0.3.0, but you have s3transfer 0.10.1 which is incompatible.[0m[31m
[0m

In [5]:
# Zip
!conda install -y zip

Retrieving notices: ...working... done
Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 23.7.2
  latest version: 24.1.2

Please update conda by running

    $ conda update -n base -c conda-forge conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.1.2



## Package Plan ##

  environment location: /Users/npc/miniforge3/envs/trading

  added / updated specs:
    - zip


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    zip-3.0                    |       hb547adb_3         170 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         170 KB

The following NEW packages will be INSTALLED:

  zip                conda-forge/osx-arm64::zip-3.0-hb547adb_3 

The following packages will be UPDATED:

  ca-certific

In [6]:
!python -m pip install TA-Lib



In [7]:
!conda install -y -c conda-forge ta-lib

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 23.7.2
  latest version: 24.1.2

Please update conda by running

    $ conda update -n base -c conda-forge conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.1.2



## Package Plan ##

  environment location: /Users/npc/miniforge3/envs/trading

  added / updated specs:
    - ta-lib


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    libta-lib-0.4.0            |       hb547adb_2         236 KB  conda-forge
    ta-lib-0.4.28              |   py39hf4a74a7_0         302 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         538 KB

The following NEW packages will be INSTALLED:

  libta-lib          conda-forge/osx-arm64::libta-lib-0.4.0-hb547adb_2 
  ta-

In [8]:
!conda install -y -c conda-forge libta-lib

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 23.7.2
  latest version: 24.1.2

Please update conda by running

    $ conda update -n base -c conda-forge conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.1.2



# All requested packages already installed.



In [9]:
# Matplotlib
!pip install --disable-pip-version-check -q matplotlib==3.1.3

  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[146 lines of output][0m
  [31m   [0m Edit setup.cfg to change the build options
  [31m   [0m 
  [31m   [0m BUILDING MATPLOTLIB
  [31m   [0m   matplotlib: yes [3.1.3]
  [31m   [0m       python: yes [3.9.15 | packaged by conda-forge | (main, Nov 22 2022,
  [31m   [0m                   08:48:25)  [Clang 14.0.6 ]]
  [31m   [0m     platform: yes [darwin]
  [31m   [0m 
  [31m   [0m OPTIONAL SUBPACKAGES
  [31m   [0m  sample_data: yes [installing]
  [31m   [0m        tests: no  [skipping due to configuration]
  [31m   [0m 
  [31m   [0m OPTIONAL BACKEND EXTENSIONS
  [31m   [0m          agg: yes [installing]
  [31m   [0m        tkagg: yes [installing; run-time loading from Python Tcl/Tk]
  [31m   [0m       macosx: yes [installing, darwin]
  [31m   [0m 
  [31m   [0m OPTION

In [10]:
# Seaborn
!pip install --disable-pip-version-check -q seaborn==0.10.0

In [11]:
!pip install protobuf==3.20.0 holidays

Collecting protobuf==3.20.0
  Obtaining dependency information for protobuf==3.20.0 from https://files.pythonhosted.org/packages/ae/80/9eaa62a2afcc5407a6b7d2652c208f073df3a5c83b5bff90bf99553fbcf2/protobuf-3.20.0-py2.py3-none-any.whl.metadata
  Downloading protobuf-3.20.0-py2.py3-none-any.whl.metadata (720 bytes)
Collecting holidays
  Obtaining dependency information for holidays from https://files.pythonhosted.org/packages/91/7a/2c5c043e4a7cff3dbab6b0f3a79b492e76c7dc1a06f309897509c9d467f5/holidays-0.44-py3-none-any.whl.metadata
  Downloading holidays-0.44-py3-none-any.whl.metadata (22 kB)
Downloading protobuf-3.20.0-py2.py3-none-any.whl (162 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading holidays-0.44-py3-none-any.whl (922 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m922.7/922.7 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m[31m10.1 MB/s[0m eta 

In [14]:
!python --version
!pip list

Python 3.10.6
Package                              Version
------------------------------------ --------------------
alabaster                            0.7.12
anaconda-client                      1.11.0
anaconda-project                     0.11.1
anyio                                3.5.0
appdirs                              1.4.4
argon2-cffi                          21.3.0
argon2-cffi-bindings                 21.2.0
arrow                                1.2.2
astroid                              3.0.1
astropy                              6.0.0
astropy-iers-data                    0.2023.12.4.0.30.20
asttokens                            2.4.1
atomicwrites                         1.4.0
attrs                                23.1.0
Automat                              20.2.0
autopep8                             1.6.0
autovizwidget                        0.20.4
awscli                               1.18.216
awswrangler                          3.6.0
Babel                                2.9.

## Imports

In [1]:
import os
import boto3
import talib
import warnings
import holidays
import sagemaker
import pandas as pd
import awswrangler as wr
from pyathena import connect
from sagemaker.session import Session
from time import time, gmtime, strftime
from sagemaker.feature_store.feature_group import FeatureGroup
from matplotlib import pyplot as plt

warnings.filterwarnings("ignore")

ModuleNotFoundError: No module named 'boto3'

## Data Engineering

In [None]:
sess = sagemaker.Session()
bucket = "stockdata90210"
role = sagemaker.get_execution_role()
region = boto3.Session().region_name
account_id = boto3.client("sts").get_caller_identity().get("Account")

sm = boto3.Session().client(service_name="sagemaker", region_name=region)

Couldn't call 'get_role' to get Role ARN from role name SageMaker-ExecutionRole-20240203T174094 to get Role path.


In [None]:
print("Default bucket: {}".format(bucket))

Default bucket: stockdata90210


In [6]:
!aws s3 ls "stockdata90210"

                           PRE 1day/
                           PRE 1min/
                           PRE athena/
                           PRE company_profile/
                           PRE stock_dividends/
                           PRE universe/


In [None]:
database_name = "stockdata"

In [8]:
# Set S3 staging directory -- this is a temporary directory used for Athena queries
s3_staging_dir = "s3://{0}/athena/staging".format(bucket)

In [None]:
conn = connect(region_name=region, s3_staging_dir=s3_staging_dir)

### Create Athena Database

In [61]:
statement = "CREATE DATABASE IF NOT EXISTS {}".format(database_name)
print(statement)

CREATE DATABASE IF NOT EXISTS stockdata


In [None]:
pd.read_sql(statement, conn)

  pd.read_sql(statement, conn)


### Verify The Database Has Been Created Succesfully

In [None]:
statement = "SHOW DATABASES"

df_show = pd.read_sql(statement, conn)
df_show.head(5)

  df_show = pd.read_sql(statement, conn)


Unnamed: 0,database_name
0,default
1,stockdata


### Create Meta Data Tables

#### Stock Universe

##### Create Athena Table

In [None]:
database_name = "stockdata"
table_name_tsv = "universe"
s3_private_path_tsv = "s3://{}/universe/".format(bucket)
print(s3_private_path_tsv)

s3://stockdata90210/universe/


In [24]:
# SQL statement to execute
statement = """CREATE EXTERNAL TABLE IF NOT EXISTS {}.{}(
         Ticker string,
         Name string,
         First_Date string,
         Last_Date string
) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' LOCATION '{}'
TBLPROPERTIES ('skip.header.line.count'='1')""".format(
    database_name, table_name_tsv, s3_private_path_tsv
)

print(statement)

CREATE EXTERNAL TABLE IF NOT EXISTS stockdata.universe(
         Ticker string,
         Name string,
         First_Date string,
         Last_Date string
) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' LOCATION 's3://stockdata90210/universe/'
TBLPROPERTIES ('skip.header.line.count'='1')


In [25]:
pd.read_sql(statement, conn)

  pd.read_sql(statement, conn)


##### Verify The Table Has Been Created Succesfully

In [85]:
statement = "SHOW TABLES in {}".format(database_name)

df_show = pd.read_sql(statement, conn)
df_show.head(5)

  df_show = pd.read_sql(statement, conn)


Unnamed: 0,tab_name
0,universe


##### Run A Sample Query

In [86]:
product_category = "ticker"

statement = """SELECT * FROM {}.{} LIMIT 100""".format(
    database_name, table_name_tsv, product_category
)

In [87]:
print(statement)
df = pd.read_sql(statement, conn)
df.head(5)

SELECT * FROM stockdata.universe LIMIT 100


  df = pd.read_sql(statement, conn)


Unnamed: 0,ticker,name,first_date,last_date
0,A,Agilent Technologies Inc,2005-01-03,2023-04-21
1,AA,Alcoa Corporation,2016-10-18,2023-04-21
2,AACG,Ata Creativity Global American Depositary Shares,2008-01-29,2023-04-21
3,AADI,Aadi Bioscience,2017-08-08,2023-04-21
4,AAIC,Arlington Asset Investment Class A,2009-06-10,2023-04-21


#### Company Profiles

##### Create Athena Table

In [None]:
database_name = "stockdata"
table_name_tsv = "company_profile"
s3_private_path_tsv = "s3://{}/company_profile/".format(bucket)
print(s3_private_path_tsv)

In [112]:
# SQL statement to execute
statement = """CREATE EXTERNAL TABLE IF NOT EXISTS {}.{}(
         `Ticker` string,
         `Company Name` string,
         `Country` string,
         `State` string,
         `Exchange` string,
         `Sector` string,
         `Industry` string,
         `Ipo Date` string
) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' LOCATION '{}'
TBLPROPERTIES ('skip.header.line.count'='1')""".format(
    database_name, table_name_tsv, s3_private_path_tsv
)

print(statement)

CREATE EXTERNAL TABLE IF NOT EXISTS stockdata.company_profile(
         `Ticker` string,
         `Company Name` string,
         `Country` string,
         `State` string,
         `Exchange` string,
         `Sector` string,
         `Industry` string,
         `Ipo Date` string
) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' LOCATION 's3://stockdata90210/company_profile/'
TBLPROPERTIES ('skip.header.line.count'='1')


In [113]:
pd.read_sql(statement, conn)

  pd.read_sql(statement, conn)


##### Verify The Table Has Been Created Succesfully

In [None]:
statement = "SHOW TABLES in {}".format(database_name)

df_show = pd.read_sql(statement, conn)
df_show.head(5)

##### Run A Sample Query

In [99]:
product_category = "ticker"

statement = """SELECT * FROM {}.{} LIMIT 100""".format(
    database_name, table_name_tsv, product_category
)

In [100]:
print(statement)
df = pd.read_sql(statement, conn)
df.head(5)

SELECT * FROM stockdata.universe LIMIT 100


  df = pd.read_sql(statement, conn)


Unnamed: 0,ticker,name,first_date,last_date
0,A,Agilent Technologies Inc,2005-01-03,2023-04-21
1,AA,Alcoa Corporation,2016-10-18,2023-04-21
2,AACG,Ata Creativity Global American Depositary Shares,2008-01-29,2023-04-21
3,AADI,Aadi Bioscience,2017-08-08,2023-04-21
4,AAIC,Arlington Asset Investment Class A,2009-06-10,2023-04-21


#### Stock Tables

##### Move tickers to individual folders

In [202]:
s3_client = boto3.client('s3')
original_prefix = '1min/Stocks/'

def move_files_to_own_folder():
    paginator = s3_client.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=original_prefix)

    for page in page_iterator:
        if 'Contents' in page:  # Check if the page has contents
            for item in page['Contents']:
                file_key = item['Key']
                if file_key.endswith('/'):  # Skip directories
                    continue

                # Skip files that are already in the correct format
                path_parts = file_key.split('/')
                if len(path_parts) == 4 and path_parts[-1].startswith(path_parts[-2]):
                    print(f"Skipping already processed file: {file_key}")
                    continue

                print(f"Processing: {file_key}")
                
                # Extract the file name
                file_name = os.path.basename(file_key)
                # Create a new folder name by stripping off '_full_1min_adjsplitdiv.txt'
                new_folder_base = file_name.replace('_full_1min_adjsplitdiv.txt', '')
                new_folder_path = f"{original_prefix}{new_folder_base}/"
                # Define the new key for the file within its new folder
                new_file_key = f"{new_folder_path}{new_folder_base}.txt"
                print(f"New file path: {new_file_key}")

                # Copy the file to the new location
                copy_source = {'Bucket': bucket_, 'Key': file_key}
                s3_client.copy_object(Bucket=bucket, CopySource=copy_source, Key=new_file_key)
                
                # Delete the original file
                s3_client.delete_object(Bucket=bucket, Key=file_key)
                
                print(f"Moved {file_key} to {new_file_key}")

move_files_to_own_folder()

Skipping already processed file: 1min/Stocks/A/A.txt
Skipping already processed file: 1min/Stocks/AA/AA.txt
Skipping already processed file: 1min/Stocks/AABA-DELISTED/AABA-DELISTED.txt
Skipping already processed file: 1min/Stocks/AACG/AACG.txt
Skipping already processed file: 1min/Stocks/AADI/AADI.txt
Skipping already processed file: 1min/Stocks/AAIC/AAIC.txt
Skipping already processed file: 1min/Stocks/AAL/AAL.txt
Skipping already processed file: 1min/Stocks/AAMC/AAMC.txt
Skipping already processed file: 1min/Stocks/AAME/AAME.txt
Skipping already processed file: 1min/Stocks/AAN/AAN.txt
Skipping already processed file: 1min/Stocks/AAOI/AAOI.txt
Skipping already processed file: 1min/Stocks/AAON/AAON.txt
Skipping already processed file: 1min/Stocks/AAP/AAP.txt
Skipping already processed file: 1min/Stocks/AAPL/AAPL.txt
Skipping already processed file: 1min/Stocks/AAT/AAT.txt
Skipping already processed file: 1min/Stocks/AATC-DELISTED/AATC-DELISTED.txt
Skipping already processed file: 1min/

##### Create Athea Tables

In [142]:
# Initialize Boto3 clients
original_prefix = '1min/Stocks/'  # Include trailing slash
database_name = "stockdata"

def auto_create_tables():
    total_items = 0
    paginator = s3_client.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=bucket, Prefix=original_prefix)


    for page in page_iterator:
        if 'Contents' in page:  # Check if the page has contents
            for item in page['Contents']:
                file_key = item['Key']
                if file_key.endswith('/'):  # Skip directories
                    continue
                
                
                # Extract directory and table name
                directory, _ = os.path.split(file_key)
                path_parts = file_key.split('/')
                if len(path_parts) < 3:  # Ensure there's enough parts
                    continue

                table_name = path_parts[2] + "_1min"
                table_name = table_name.replace("-", "_").replace(".", "_")
                #print(table_name)
                s3_private_path_tsv = "s3://{}/{}".format(bucket, directory)
                #print(s3_private_path_tsv)
                total_items +=1

                # Drop the existing table
                drop_statement = f"DROP TABLE IF EXISTS {database_name}.{table_name}"
                pd.read_sql(drop_statement, conn)
                
                # Corrected SQL statement
                statement = f"""CREATE EXTERNAL TABLE IF NOT EXISTS {database_name}.{table_name} (
                        Timestamp timestamp,
                        Open float,
                        High float,
                        Low float,
                        Close float,
                        Volume int
                ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' LOCATION '{s3_private_path_tsv}/'"""
                
                # Print statement for debugging
                print(statement)
                
                pd.read_sql(statement, conn)

In [None]:
auto_create_tables()

##### Verify The Table Has Been Created Succesfully - 7017 tickers, 1 universe, 1 company profile - 7019

In [21]:
import pandas as pd

statement = "SHOW TABLES in {}".format(database_name)

df_show = pd.read_sql(statement, conn)

total_tables = len(df_show)

print("Total number of tables:", total_tables)

  df_show = pd.read_sql(statement, conn)


Total number of tables: 7022


In [10]:
statement = "SHOW TABLES in {}".format(database_name)

df_show = pd.read_sql(statement, conn)
df_show.head(5)

Unnamed: 0,tab_name
0,a_1min
1,aa_1min
2,aaba_delisted_1min
3,aacg_1min
4,aadi_1min


##### Run A Sample Query

In [45]:
print(table_name_tsv)

universe


In [46]:
product_category = "ticker"

table_name_tsv = table_name_tsv.lower()
# table_name_tsv = table_name_tsv + "_1min"

statement = """SELECT * FROM {}.{} LIMIT 100""".format(
    database_name, table_name_tsv, product_category
)

In [47]:
print(statement)
df = pd.read_sql(statement, conn)
df.head(5)

SELECT * FROM stockdata.universe LIMIT 100


  df = pd.read_sql(statement, conn)


Unnamed: 0,ticker,name,first_date,last_date
0,A,Agilent Technologies Inc,2005-01-03,2023-04-21
1,AA,Alcoa Corporation,2016-10-18,2023-04-21
2,AACG,Ata Creativity Global American Depositary Shares,2008-01-29,2023-04-21
3,AADI,Aadi Bioscience,2017-08-08,2023-04-21
4,AAIC,Arlington Asset Investment Class A,2009-06-10,2023-04-21


## Feature Engineering

In [11]:
region = boto3.Session().region_name

boto_session = boto3.Session(region_name=region)

sagemaker_client = boto_session.client(service_name="sagemaker", region_name=region)

In [12]:
featurestore_runtime = boto_session.client(
    service_name="sagemaker-featurestore-runtime", region_name=region
)

In [13]:
feature_store_session = Session(
    boto_session=boto_session,
    sagemaker_client=sagemaker_client,
    sagemaker_featurestore_runtime_client=featurestore_runtime,
)

In [14]:
current_time_sec = int(time())
original_prefix = "1min/Stocks/"
offline_prefix = 'featurestore/Stocks'

stock_feature_group_name = "stock-feature-group-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Creating feature group: {stock_feature_group_name}")

# Initialize the feature group
stock_feature_group = FeatureGroup(
    name=stock_feature_group_name,
    sagemaker_session=feature_store_session
)   

Creating feature group: stock-feature-group-2024-02-21-00-00-58


In [15]:
def wait_for_feature_group_creation_complete(feature_group):
    status = feature_group.describe().get("FeatureGroupStatus")
    while status == "Creating":
        print("Waiting for Feature Group Creation")
        time.sleep(5)
        status = feature_group.describe().get("FeatureGroupStatus")
    if status != "Created":
        raise RuntimeError(f"Failed to create feature group {feature_group.name}")
    print(f"FeatureGroup {feature_group.name} successfully created.")

In [16]:
# Constants
SMA5_PERIOD = 5  # Period for 5-day Simple Moving Average
ROLLING_WINDOW = 20  # Period for the rolling window used for statistics like Rolling Mean and Rolling Standard Deviation
RSI_PERIOD = 14  # Period for Relative Strength Index (RSI) calculation
EMA_PERIOD = 12  # Period for Exponential Moving Average (EMA) calculation
MACD_FASTPERIOD = 12  # Fast period for Moving Average Convergence Divergence (MACD) calculation
MACD_SLOWPERIOD = 26  # Slow period for MACD calculation
MACD_SIGNALPERIOD = 9  # Signal period for MACD calculation
SMA10_PERIOD = 10  # Period for 10-day Simple Moving Average
SMA20_PERIOD = 20  # Period for 20-day Simple Moving Average
SMA50_PERIOD = 50  # Period for 50-day Simple Moving Average
BBANDS_PERIOD = 20  # Period for Bollinger Bands calculation
ROC_PERIOD = 12  # Period for Rate of Change (ROC) calculation
ATR_PERIOD = 14  # Period for Average True Range (ATR) calculation
CCI_PERIOD = 20  # Period for Commodity Channel Index (CCI) calculation
WILLR_PERIOD = 14  # Period for Williams %R (WILLR) calculation
STOCH_FASTK_PERIOD = 5  # Fast period for Stochastic Oscillator calculation
STOCH_SLOWK_PERIOD = 3  # SlowK period for Stochastic Oscillator calculation
STOCH_SLOWD_PERIOD = 3  # SlowD period for Stochastic Oscillator calculation
MFI_PERIOD = 14  # Period for Money Flow Index (MFI) calculation
us_holidays = holidays.UnitedStates()

In [17]:
def process_data(df):
    # Convert 'Timestamp' to datetime format and sort the data by it
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df.sort_values(by='timestamp', inplace=True)
    
    # Calculate day of the week
    df['day_of_week'] = df['timestamp'].dt.dayofweek  # Monday=0, Sunday=6
    
    df['is_holiday'] = df['timestamp'].dt.date.apply(lambda x: int(x in us_holidays))
    
    # Compute Unix timestamp
    df['timestamp'] = df['timestamp'].astype('int64') // 10**9
    
    # Calculate Returns - doesnt look right
    df['returns'] = (df['close'] / df['close'].shift(1) - 1)
    
    # Calculate general statistics
    df['avg_Price'] = df["close"].mean()
    df['avg_Returns'] = df['returns'].mean()
    df['volatility'] = df['returns'].std()
    df['volume_Volatility'] = df['volume'].std()

    # Calculate Moving Average and Price and Volume change
    df['moving_Avg'] = df['close'].rolling(window=SMA5_PERIOD).mean()
    df['price_Change'] = df['close'].pct_change() * 100  # Price change percentage
    df['volume_Change'] = df['volume'].diff()  # Difference in volume
    
    # Calculate rolling window statistics
    df['Rolling_Mean'] = df['close'].rolling(window=ROLLING_WINDOW).mean()
    df['Rolling_Std'] = df['close'].rolling(window=ROLLING_WINDOW).std()

    # Technical Indicators
    df['RSI'] = talib.RSI(df['close'], timeperiod=RSI_PERIOD)
    df['EMA'] = talib.EMA(df['close'], timeperiod=EMA_PERIOD)
    df['MACD'] = talib.MACD(df['close'], fastperiod=MACD_FASTPERIOD, slowperiod=MACD_SLOWPERIOD, signalperiod=MACD_SIGNALPERIOD)[0]
    df['SMA_5'] = talib.SMA(df['close'], timeperiod=SMA5_PERIOD)
    df['SMA_10'] = talib.SMA(df['close'], timeperiod=SMA10_PERIOD)
    df['SMA_20'] = talib.SMA(df['close'], timeperiod=SMA20_PERIOD)
    df['SMA_50'] = talib.SMA(df['close'], timeperiod=SMA50_PERIOD)
    upper, middle, lower = talib.BBANDS(df['close'], timeperiod=BBANDS_PERIOD)
    df['BBANDS_Upper'] = upper
    df['BBANDS_Middle'] = middle
    df['BBANDS_Lower'] = lower
    df['VWAP'] = (df['volume'] * (df['high'] + df['low'] + df['close']) / 3).cumsum() / df['volume'].cumsum()
    df['ROC'] = talib.ROC(df['close'], timeperiod=ROC_PERIOD)
    df['ATR'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=ATR_PERIOD)
    df['CCI'] = talib.CCI(df['high'], df['low'], df['close'], timeperiod=CCI_PERIOD)
    df['WilliamsR'] = talib.WILLR(df['high'], df['low'], df['close'], timeperiod=WILLR_PERIOD)
    slowk, slowd = talib.STOCH(df['high'], df['low'], df['close'], fastk_period=STOCH_FASTK_PERIOD, slowk_period=STOCH_SLOWK_PERIOD, slowk_matype=0, slowd_period=STOCH_SLOWD_PERIOD, slowd_matype=0)
    df['Stochastic_SlowK'] = slowk
    df['Stochastic_SlowD'] = slowd
    df['MFI'] = talib.MFI(df['high'], df['low'], df['close'], df['volume'], timeperiod=MFI_PERIOD)
    
    # Handle missing data
    df = df.ffill().bfill().fillna(0)
    df.dropna(inplace=True)
    
    return df

In [19]:
# Define a function for parallel execution
def process_ticker(ticker):
    print(ticker)

    try:
        timestamp_str = strftime("%Y%m%d%H%M%S", gmtime())
        temp_table_name = f"temp_table_{uuid.uuid4().hex}_{timestamp_str}"
        
        df = wr.athena.read_sql_query(sql=f"SELECT * FROM \"{ticker}\"", database="stockdata")
        df = process_data(df)
    
        # Extract the stock symbol or identifier from the file_key
        stock_identifier = os.path.basename(ticker).split('_')[0]  # Adjust based on your file naming convention
        df["Ticker"] = stock_identifier
        df["EventTime"] = pd.Series([current_time_sec] * len(df), dtype="float64")
    except Exception as e:
        print(f"Error processing ticker {ticker}: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of an error

    return df

In [None]:
# 96 vCPU + 768GiB

from concurrent.futures import ThreadPoolExecutor
import uuid
# List to store results
dfs = []
chunk_size=20

# Parallel execution for Athena queries
with ThreadPoolExecutor(max_workers=80) as executor:  # Adjust the number of workers based on available CPU
    for i in range(0, len(df_show['tab_name']), chunk_size):
        chunk = df_show['tab_name'][i:i + chunk_size]
        partial_results = list(executor.map(process_ticker, chunk))
        dfs.extend(partial_results)

    results = list(executor.map(process_ticker, df_show['tab_name']))

# Concatenate results,,
df_main = pd.concat(results, ignore_index=True)

print("All tickers processed.")

a_1min
aa_1min
aaba_delisted_1min
aacg_1min
aadi_1min
aaic_1min
aal_1min
aamc_1min
aame_1min
aan_1min
aaoi_1min
aaon_1min
aap_1min
aapl_1min
aat_1min
aatc_delisted_1min
aau_1min
aaww_delisted_1min
ab_1min
abb_1min
abbv_1minabc_1min
abcb_1min

abcl_1min
abcm_1min
abeo_1min
abev_1min
abg_1min
abi_delisted_1min
abio_1min
abk_delisted_1min
abm_1min
abmd_delisted_1min
abnb_1min
abos_1min
abr_1min
abs_delisted_1min
absi_1min
abst_1min
abt_1min
abtx_delisted_1minabus_1min

abvc_1min
ac_1min
aca_1min
acac_1min
acad_1min
acas_delisted_1min
acax_1min
acb_1min
acbi_delisted_1min
accd_1min
acco_1min
acdc_1min
acel_1min
acer_1min
acet_1min
acgl_1min
acgn_1min
achc_1min
achl_1minachr_1min
achv_1min
aci_1min

acii_delisted_1min
aciu_1min
aciw_1min
acls_1min
aclx_1min
acm_1min
acmr_1min
acn_1min
acnb_1min
acnt_1min
acon_1min
aconw_1min
acor_1min
acp_1min
acqr_delisted_1min
acqru_1min
acqrw_1minacr_1min
acre_1min

acrs_1min
acrv_1min
acrx_1min
acst_1min
actg_1min
acu_1min
acv_1min
acva_1min
acxm_delist

In [None]:
# 96 vCPU + 768GiB

from concurrent.futures import ThreadPoolExecutor
import uuid
# List to store results
dfs = []

# Parallel execution for Athena queries
with ThreadPoolExecutor(max_workers=84) as executor:  # Adjust the number of workers based on available CPU
    results = list(executor.map(process_ticker, df_show['tab_name']))

# Concatenate results,,
df_main = pd.concat(results, ignore_index=True)

print("All tickers processed.")

### Create Dataframe

### Create FeatureGroups in SageMaker FeatureStore

In [None]:
df_main.head()

In [None]:
# Load feature definitions to the feature group
# Auto-detect the schema based on the dataframe
stock_feature_group.load_feature_definitions(data_frame=df_main)  # This should be done once, schema should be consistent across files

In [None]:
# Create the feature group
stock_feature_group.create(
    s3_uri=f"s3://{bucket}/{offline_prefix}",
    record_identifier_name="Ticker",
    event_time_feature_name="EventTime",
    role_arn=role,
    enable_online_store=True
)

In [None]:
wait_for_feature_group_creation_complete(feature_group=stock_feature_group)

In [None]:
stock_feature_group.describe()

In [None]:
sagemaker_client.list_feature_groups()

### Put Records in FeatureStore

In [None]:
stock_feature_group.ingest(data_frame=df_main, max_workers=5, wait=True)

In [None]:
record_identifier_value = str(appl)

featurestore_runtime.get_record(
    FeatureGroupName=stock_feature_group
    RecordIdentifierValueAsString=record_identifier_value
)

In [None]:
print(stock_feature_group.as_hive_ddf())

## Data Research

### Heatmap

In [None]:
# Generate Global Heatmap
def global_heatmap():
    # Get the feature names
    feature_definitions = stock_feature_group.describe()
    feature_names = [feature['FeatureName'] for feature in feature_definitions['FeatureDefinitions']]
    
    # Generate a heatmap
    plt.figure(figsize=(20, 10))
    sns.heatmap(df_main[feature_names].corr(), annot=True, fmt=".2f", cmap='coolwarm', center=0)
    plt.title('Global Heatmap')
    plt.show()

In [None]:
# Define a function for parallel execution
def process_ticker(ticker):
    print(ticker)

    try:
        timestamp_str = strftime("%Y%m%d%H%M%S", gmtime())
        temp_table_name = f"temp_table_{uuid.uuid4().hex}_{timestamp_str}"
        
        df = wr.athena.read_sql_query(sql=f"SELECT * FROM \"{ticker}\"", database="stockdata")
        df = process_data(df)
    
        # Extract the stock symbol or identifier from the file_key
        stock_identifier = os.path.basename(ticker).split('_')[0]  # Adjust based on your file naming convention
        df["Ticker"] = stock_identifier
        df["EventTime"] = pd.Series([current_time_sec] * len(df), dtype="float64")
    except Exception as e:
        print(f"Error processing ticker {ticker}: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of an error

    return df