In [16]:
# %pip install pandas
# %pip install numpy
# %pip install matplotlib
# %pip install scikit-learn
# %pip install tensorflow
# %pip install -U imbalanced-learn

In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout, Dropout, LayerNormalization, MultiHeadAttention, Input
from tensorflow.keras.layers import Attention, Reshape
from tensorflow.keras.models import Model

In [18]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import models
import utils
import data_processing


In [19]:
COMMODITY = 'cobalt'

DATE_COLUMN = 'Date'
VALUE_COLUMN = 'Value'  
QUANTITY_COLUMN = 'Std. Quantity (KG)'
UNIT_RATE_COLUMN = 'Std. Unit Rate ($/KG)'
BRENT_OIL_COLUMN = 'Brent Oil Value'
WTI_OIL_COLUMN = 'WTI Oil Value'
SHIP_COUNT_COLUMN = 'ship_count'
PORT_COUNT_COLUMN = 'popular_port_count'

VALUE_SPIKES_COLUMN = 'Value Spikes'  
QUANTITY_SPIKES_COLUMN = 'Std. Quantity (KG) Spikes'
UNIT_RATE_SPIKES_COLUMN = 'Std. Unit Rate ($/KG) Spikes'
BRENT_OIL_SPIKES_COLUMN = 'Brent Oil Value Spikes'
WTI_OIL_SPIKES_COLUMN = 'WTI Oil Value Spikes'
SHIP_COUNT_SPIKES_COLUMN = 'Ship Count Spikes'
PORT_COUNT_SPIKES_COLUMN = 'Port Count Spikes'

FEATURE_COLUMNS = [VALUE_COLUMN, QUANTITY_COLUMN, UNIT_RATE_COLUMN,  WTI_OIL_COLUMN, BRENT_OIL_COLUMN, SHIP_COUNT_COLUMN, PORT_COUNT_COLUMN]
# FEATURE_COLUMNS = [VALUE_COLUMN, QUANTITY_COLUMN, UNIT_RATE_COLUMN,  WTI_OIL_SPIKES_COLUMN, BRENT_OIL_SPIKES_COLUMN]
# FEATURE_COLUMNS = [VALUE_COLUMN, QUANTITY_COLUMN, UNIT_RATE_COLUMN,  WTI_OIL_COLUMN, BRENT_OIL_COLUMN, SHIP_COUNT_SPIKES_COLUMN, PORT_COUNT_SPIKES_COLUMN]
TARGET_COLUMN = 'Price'

ORIGIN_COUNTRY_COLUMN = 'Country of Origin'
DEST_COUNTRY_COLUMN = 'Country of Destination'

PETROL_FILE_PATH = '../../volza/petroleum/petrol_crude_oil_spot_price.csv'
VOLZA_FILE_PATH = f'../../volza/{COMMODITY}/{COMMODITY}.csv'
PRICE_FILE_PATH = f"../../volza/{COMMODITY}/{COMMODITY}_prices.csv"
AIS_POPULAR_FILE_PATH = f'../../ais/ais_ml_features.csv' 

NB_OUTPUT_PATH = f"{COMMODITY}/{COMMODITY}_model_performance (No Balancing).csv"
RUS_OUTPUT_PATH = f"{COMMODITY}/{COMMODITY}_model_performance (Random Under Sampling).csv"
ROS_OUTPUT_PATH = f"{COMMODITY}/{COMMODITY}_model_performance (Random Over Sampling).csv"


SPIKES_THRESHOLD = 2
SPIKES_WINDOW_SIZE = 20
BIN_COUNT = 10
FILL_METHOD = 'ffill'

RANDOM_STATE = 42

## Dataframe Prep

In [20]:
from datetime import datetime

#Formatting the date and price for Volza data
volza_pd = pd.read_csv(VOLZA_FILE_PATH)
volza_pd = volza_pd[(volza_pd["Country of Origin"].notnull()) & (volza_pd["Country of Destination"].notnull())]
volza_pd = volza_pd.rename(columns={'Unnamed: 0': 'ID'})
volza_pd['Date'] = volza_pd['Date'].apply(lambda x: x.split(' ')[0])
volza_pd['Date'] = pd.to_datetime(volza_pd['Date'], errors='raise', format='%Y-%m-%d')
volza_pd = utils.convert_to_kg(volza_pd)
volza_pd.head(3)

Unnamed: 0,ID,Date,HS Code,Product Description,Consignee,Notify Party Name,Shipper,Std. Quantity,Std. Unit,Standard Unit Rate INR,...,Freight Term,Marks Number,HS Product Description,Gross Weight,Consignee Address,Shipper Address,Notify Party Address,Country Name,Std. Quantity (KG),Std. Unit Rate ($/KG)
0,0,2020-03-03,81052029000K,BRAND: HANRUI COBALT|EFS 1 GRANULATE COBA,TANTAL ARGENTINA SRL,,,200.0,KGS,-,...,-,-,OTHERS,0.0,,,,Argentina T2 Import,200.0,42.805
1,1,2020-11-25,8105200000,DO 3202000987-001 DECLARATION (1-1) INVOICE: 2...,INVERSIONES RINCON MEDINA LTDA,,GE ADDITIVE,11.0,KGS,-,...,-,-,COBALT MATTES AND OTHER INTERMEDIATE PRODUCTS ...,11.0,CRA 47 79 234,101 NORTH CAMPUS DRIVR IMPERIAL PA15126,,Columbia T3+ Import,11.0,251.360909
3,3,2020-12-14,81052000,"COBALT DUST: ""BEGO WIROBOND C+""-10PCS*5K",BELADENT SRL MOLDOVA OR NISPORENI,,OOO SIMPLAND 121353 OR MOSCOVA SK,50.0,KGS,-,...,-,-,,0.0,MOLDOVA OR NISPORENI,,,Moldova T3 Import,50.0,384.0744


In [21]:
#Preprocessing the AIS data
ais_popular_pd = pd.read_csv(AIS_POPULAR_FILE_PATH)
ais_popular_pd['Date'] = pd.to_datetime(ais_popular_pd['Date'])
ais_popular_pd.head(3)


Unnamed: 0,Date,ship_count,popular_port,popular_port_count
0,2020-11-10,8,LTKLJ,18
1,2020-11-12,20,IDSKP,8
2,2020-11-29,9,CNSHA,2


In [22]:
#Preprocessing the price data
prices_pd = pd.read_csv(PRICE_FILE_PATH)
prices_pd['Date'] = prices_pd['Date'].apply(lambda x: datetime.strptime(x, "%b %d, %Y").strftime("%Y-%m-%d"))
prices_pd['Date'] = pd.to_datetime(prices_pd['Date'], errors='raise', format='%Y-%m-%d')
prices_pd['Price'] = prices_pd['Price'].str.replace(',', '').astype(float)
prices_pd = prices_pd[['Date','Price']]
prices_pd.head(3)

Unnamed: 0,Date,Price
0,2024-02-23,28288.0
1,2024-02-22,28288.0
2,2024-02-21,28288.0


In [23]:
#Aggregate volza data by day
date_wise_volza = volza_pd.groupby("Date")[[VALUE_COLUMN,QUANTITY_COLUMN,'Gross Weight']].sum()

In [24]:
# Avg of Commodity Price in Volza
avg_price_volza = volza_pd.groupby('Date')[UNIT_RATE_COLUMN].mean()
date_wise_volza = date_wise_volza.join(avg_price_volza, how='left')
date_wise_volza

Unnamed: 0_level_0,Value,Std. Quantity (KG),Gross Weight,Std. Unit Rate ($/KG)
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2020-01-01,1.217433e+08,4.147777e+06,42055.0,92.545195
2020-01-02,9.020977e+04,2.300000e+03,0.0,39.903707
2020-01-03,4.110707e+02,1.000000e+01,0.0,41.107069
2020-01-04,5.637466e+04,1.000000e+03,0.0,56.374664
2020-01-05,4.051522e+04,1.140000e+03,0.0,35.644705
...,...,...,...,...
2023-12-25,8.913675e+04,3.000000e+03,0.0,30.255888
2023-12-26,1.160515e+05,3.000000e+03,0.0,38.338660
2023-12-27,5.611000e+04,1.841780e+03,0.0,46.549676
2023-12-28,2.256061e+05,4.446500e+03,0.0,76.015813


In [25]:
# Petroleum data prep
petrol_df = pd.read_csv(PETROL_FILE_PATH, delimiter=';', on_bad_lines='warn')
petrol_df['Date'] = pd.to_datetime(petrol_df['Date'])

# Split based on types of oil
brent_df = petrol_df[petrol_df['product-name']=='UK Brent Crude Oil']
wti_df = petrol_df[petrol_df['product-name']=='WTI Crude Oil']

brent_df.rename(columns={'Value':'Brent Oil Value'}, inplace=True)
wti_df.rename(columns={'Value':'WTI Oil Value'}, inplace=True)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  brent_df.rename(columns={'Value':'Brent Oil Value'}, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  wti_df.rename(columns={'Value':'WTI Oil Value'}, inplace=True)


In [26]:
# Combining dataframes
prices_pd = prices_pd.set_index('Date')
ais_popular_pd = ais_popular_pd.set_index('Date')
date_wise_volza = date_wise_volza.join(ais_popular_pd, how="left").fillna(method=FILL_METHOD)
aggregated_df = date_wise_volza.join(prices_pd, how="left").fillna(method=FILL_METHOD)
aggregated_df = aggregated_df.merge(brent_df[[DATE_COLUMN, BRENT_OIL_COLUMN]], on='Date', how='left').fillna(method=FILL_METHOD)
aggregated_df = aggregated_df.merge(wti_df[[DATE_COLUMN, WTI_OIL_COLUMN]], on='Date', how='left').fillna(method=FILL_METHOD)
aggregated_df

  date_wise_volza = date_wise_volza.join(ais_popular_pd, how="left").fillna(method=FILL_METHOD)
  aggregated_df = date_wise_volza.join(prices_pd, how="left").fillna(method=FILL_METHOD)
  aggregated_df = aggregated_df.merge(brent_df[[DATE_COLUMN, BRENT_OIL_COLUMN]], on='Date', how='left').fillna(method=FILL_METHOD)
  aggregated_df = aggregated_df.merge(wti_df[[DATE_COLUMN, WTI_OIL_COLUMN]], on='Date', how='left').fillna(method=FILL_METHOD)


Unnamed: 0,Date,Value,Std. Quantity (KG),Gross Weight,Std. Unit Rate ($/KG),ship_count,popular_port,popular_port_count,Price,Brent Oil Value,WTI Oil Value
0,2020-01-01,1.217433e+08,4.147777e+06,42055.0,92.545195,,,,,67.77,61.14
1,2020-01-02,9.020977e+04,2.300000e+03,0.0,39.903707,,,,32355.0,67.05,61.17
2,2020-01-03,4.110707e+02,1.000000e+01,0.0,41.107069,,,,31850.0,69.08,63.00
3,2020-01-04,5.637466e+04,1.000000e+03,0.0,56.374664,,,,31850.0,69.08,63.00
4,2020-01-05,4.051522e+04,1.140000e+03,0.0,35.644705,,,,31850.0,69.08,63.00
...,...,...,...,...,...,...,...,...,...,...,...
1302,2023-12-25,8.913675e+04,3.000000e+03,0.0,30.255888,8925.0,USMSY,2673.0,28819.0,78.89,72.16
1303,2023-12-26,1.160515e+05,3.000000e+03,0.0,38.338660,8925.0,USMSY,2673.0,28819.0,78.89,72.16
1304,2023-12-27,5.611000e+04,1.841780e+03,0.0,46.549676,8925.0,USMSY,2673.0,28792.5,78.89,72.16
1305,2023-12-28,2.256061e+05,4.446500e+03,0.0,76.015813,8925.0,USMSY,2673.0,28787.5,78.89,72.16


In [27]:
def detect_spikes(df, column, window_size):
    ## Detecting spikes
    moving_avg = df[column].rolling(window=window_size).mean()
    std_dev = df[column].rolling(window=window_size).std()

    # Set a threshold to identify spikes
    return (abs(aggregated_df[column] - moving_avg) > SPIKES_THRESHOLD * std_dev).astype(int)

# aggregated_df['spikes'] = detect_spikes(aggregated_df, 'Price')
# print("SPIKES : NON SPIKES = ")
# print(aggregated_df['spikes'].value_counts())
# print("PERCENT OF SPIKES", aggregated_df['spikes'].value_counts()[1]/len(aggregated_df))

# **Detect spikes**

In [28]:
# aggregated_df[VALUE_SPIKES_COLUMN] = detect_spikes(aggregated_df, VALUE_COLUMN)
# aggregated_df[QUANTITY_SPIKES_COLUMN] = detect_spikes(aggregated_df, QUANTITY_COLUMN)
# aggregated_df[UNIT_RATE_SPIKES_COLUMN] = detect_spikes(aggregated_df, UNIT_RATE_COLUMN)
# aggregated_df[WTI_OIL_SPIKES_COLUMN] = detect_spikes(aggregated_df, WTI_OIL_COLUMN)
# aggregated_df[BRENT_OIL_SPIKES_COLUMN] = detect_spikes(aggregated_df, BRENT_OIL_COLUMN)
# aggregated_df[SHIP_COUNT_SPIKES_COLUMN] = detect_spikes(aggregated_df, SHIP_COUNT_COLUMN)
# aggregated_df[PORT_COUNT_SPIKES_COLUMN] = detect_spikes(aggregated_df, PORT_COUNT_COLUMN)

# #Visualise Dataset

# # Plotting the graph
# fig, ax1 = plt.subplots(figsize=(12, 6))

# # Plotting 'Value', 'Quantity', and 'Gross Weight' on the left y-axis
# ax1.plot(aggregated_df.index, aggregated_df[VALUE_SPIKES_COLUMN], label='Value Spikes', color='b')
# ax1.plot(aggregated_df.index, aggregated_df[QUANTITY_SPIKES_COLUMN], label='Quantity Spikes', color='g')
# ax1.plot(aggregated_df.index, aggregated_df[UNIT_RATE_SPIKES_COLUMN], label='Unit Rate Spikes', color='k')
# ax1.plot(aggregated_df.index, aggregated_df[BRENT_OIL_SPIKES_COLUMN], label='Brent Oil Value Spikes', color='m')
# ax1.plot(aggregated_df.index, aggregated_df[WTI_OIL_SPIKES_COLUMN], label='WTI Oil Value Spikes', color='c')
# ax1.plot(aggregated_df.index, aggregated_df[SHIP_COUNT_COLUMN], label='Ship Count Spikes', color='darkorange')
# ax1.plot(aggregated_df.index, aggregated_df[PORT_COUNT_COLUMN], label='Port Count Value Spikes', color='violet')

# ax1.set_xlabel('Date')
# ax1.set_ylabel('Value / Quantity / Gross Weight / Brent Oil Value / WTI Oil Value / Ship Count / Port Count ', color='b')
# ax1.tick_params('y', colors='b')

# # Creating a second y-axis for 'Price'
# ax2 = ax1.twinx()
# ax2.plot(aggregated_df.index, aggregated_df['Price'], label='Price', color='orange')
# ax2.set_ylabel('Price', color='orange')
# ax2.tick_params('y', colors='orange')

# # Display legend
# fig.tight_layout()
# fig.legend(loc='upper left', bbox_to_anchor=(0.1, 0.9))

# # Display the graph
# # plt.show()

In [29]:
#remove date 2020-01-01
aggregated_df = aggregated_df[aggregated_df.index != '2020-01-01']
aggregated_df

Unnamed: 0,Date,Value,Std. Quantity (KG),Gross Weight,Std. Unit Rate ($/KG),ship_count,popular_port,popular_port_count,Price,Brent Oil Value,WTI Oil Value
0,2020-01-01,1.217433e+08,4.147777e+06,42055.0,92.545195,,,,,67.77,61.14
1,2020-01-02,9.020977e+04,2.300000e+03,0.0,39.903707,,,,32355.0,67.05,61.17
2,2020-01-03,4.110707e+02,1.000000e+01,0.0,41.107069,,,,31850.0,69.08,63.00
3,2020-01-04,5.637466e+04,1.000000e+03,0.0,56.374664,,,,31850.0,69.08,63.00
4,2020-01-05,4.051522e+04,1.140000e+03,0.0,35.644705,,,,31850.0,69.08,63.00
...,...,...,...,...,...,...,...,...,...,...,...
1302,2023-12-25,8.913675e+04,3.000000e+03,0.0,30.255888,8925.0,USMSY,2673.0,28819.0,78.89,72.16
1303,2023-12-26,1.160515e+05,3.000000e+03,0.0,38.338660,8925.0,USMSY,2673.0,28819.0,78.89,72.16
1304,2023-12-27,5.611000e+04,1.841780e+03,0.0,46.549676,8925.0,USMSY,2673.0,28792.5,78.89,72.16
1305,2023-12-28,2.256061e+05,4.446500e+03,0.0,76.015813,8925.0,USMSY,2673.0,28787.5,78.89,72.16


In [30]:
# #Visualise Dataset
# # Plotting the graph
# fig, ax1 = plt.subplots(figsize=(12, 6))

# # Plotting 'Value', 'Quantity', and 'Gross Weight' on the left y-axis
# ax1.plot(aggregated_df.index, aggregated_df[VALUE_COLUMN], label='Value', color='b')
# ax1.plot(aggregated_df.index, aggregated_df[QUANTITY_COLUMN], label='Quantity', color='g')
# ax1.plot(aggregated_df.index, aggregated_df[UNIT_RATE_COLUMN], label='Unit Rate', color='k')
# ax1.plot(aggregated_df.index, aggregated_df[BRENT_OIL_COLUMN], label='Brent Oil Value', color='m')
# ax1.plot(aggregated_df.index, aggregated_df[WTI_OIL_COLUMN], label='WTI Oil Value', color='c')
# ax1.plot(aggregated_df.index, aggregated_df[SHIP_COUNT_COLUMN], label='Ship Count Value', color='darkorange')
# ax1.plot(aggregated_df.index, aggregated_df[PORT_COUNT_COLUMN], label='Port Count Value', color='violet')

# ax1.set_xlabel('Date')
# ax1.set_ylabel('Value / Quantity / Gross Weight', color='b')
# ax1.tick_params('y', colors='b')

# # Creating a second y-axis for 'Price'
# ax2 = ax1.twinx()
# ax2.plot(aggregated_df.index, aggregated_df['Price'], label='Price', color='orange')
# ax2.set_ylabel('Price', color='orange')
# ax2.tick_params('y', colors='orange')

# # Display legend
# fig.tight_layout()
# fig.legend(loc='upper left', bbox_to_anchor=(0.1, 0.9))

# # Display the graph
# # plt.show()

In [31]:
# # Plotting the price data
# plt.figure(figsize=(10, 6))  # Adjust the figure size as needed
# plt.plot(aggregated_df.index, aggregated_df['Price'], label='Price', color='blue')

# # Highlighting spikes
# spike_indices = aggregated_df[aggregated_df['spikes'] == 1].index
# spike_prices = aggregated_df.loc[spike_indices, 'Price']
# plt.scatter(spike_indices, spike_prices, color='red', marker='^', label='Spikes')

# # Adding labels and title
# plt.xlabel('Date')
# plt.ylabel('Price')
# plt.title('Price Data with Spikes')
# plt.legend()

# # Display the plot
# # plt.show()

## Baseline

In [32]:
# # Count % of spikes 
# total_spikes = aggregated_df['spikes'].sum()
# total_data_points = len(aggregated_df)
# percentage_of_spikes = (total_spikes / total_data_points) * 100

# print(f"Percentage of Spikes: {percentage_of_spikes:.2f}%")

In [33]:
# from sklearn.metrics import precision_score, recall_score

# # Probability of spike
# spike_prob = aggregated_df['spikes'].mean()

# # Random baseline predictions
# random_predictions = np.random.choice([0, 1], size=len(aggregated_df), p=[1-spike_prob, spike_prob])

# # Calculate precision and recall for the random baseline
# random_precision = precision_score(aggregated_df['spikes'], random_predictions)
# random_recall = recall_score(aggregated_df['spikes'], random_predictions)

# print(f"Random Guessing Precision: {random_precision}")
# print(f"Random Guessing Recall: {random_recall}")


## Data Prep for Classification

In [34]:
COLUMNS = FEATURE_COLUMNS + [TARGET_COLUMN]
print(COLUMNS)

['Value', 'Std. Quantity (KG)', 'Std. Unit Rate ($/KG)', 'WTI Oil Value', 'Brent Oil Value', 'ship_count', 'popular_port_count', 'Price']


In [35]:
# # Discretize
# from sklearn.preprocessing import KBinsDiscretizer

# def discretize(df, columns, bins):
#     est = KBinsDiscretizer(n_bins=bins, encode='ordinal', strategy='kmeans')
#     df[columns] = est.fit_transform(df[columns])
#     return df

# # FEATURES_1 = [VALUE_COLUMN, QUANTITY_COLUMN, UNIT_RATE_COLUMN]
# # FEATURES_2 = [WTI_OIL_COLUMN, BRENT_OIL_COLUMN]

# test_df = aggregated_df.copy()
# test_df[FEATURE_COLUMNS] = test_df[FEATURE_COLUMNS].fillna(0)

# # test_df = discretize(test_df, FEATURES_1, 2)
# # test_df = discretize(test_df, FEATURES_2, BIN_COUNT)
# discretized_df = discretize(test_df, FEATURE_COLUMNS, BIN_COUNT)
# # discretized_df = test_df.copy()
# test_df.head(2)


In [36]:
# fig, ax1 = plt.subplots(figsize=(12, 6))

# # Plotting 'Value', 'Quantity', and 'Gross Weight' on the left y-axis
# ax1.plot(test_df.index, test_df[VALUE_COLUMN], label='Value', color='b')
# ax1.plot(test_df.index, test_df[QUANTITY_COLUMN], label='Quantity', color='g')
# ax1.plot(test_df.index, test_df[UNIT_RATE_COLUMN], label='Unit Rate', color='k')
# ax1.plot(test_df.index, test_df[BRENT_OIL_COLUMN], label='Brent Oil Value', color='m')
# ax1.plot(test_df.index, test_df[WTI_OIL_COLUMN], label='WTI Oil Value', color='c')
# ax1.plot(test_df.index, test_df[SHIP_COUNT_COLUMN], label='Ship Count Value', color='darkorange')
# ax1.plot(test_df.index, test_df[PORT_COUNT_COLUMN], label='Port Count Value', color='violet')

# ax1.set_xlabel('Date')
# ax1.set_ylabel('Value / Quantity / Gross Weight', color='b')
# ax1.tick_params('y', colors='b')

# # Creating a second y-axis for 'Price'
# ax2 = ax1.twinx()
# ax2.plot(test_df.index, test_df['Price'], label='Price', color='orange')
# ax2.set_ylabel('Price', color='orange')
# ax2.tick_params('y', colors='orange')

# # Display legend
# fig.tight_layout()
# fig.legend(loc='upper left', bbox_to_anchor=(0.1, 0.9))

# # Display the graph
# plt.show()

In [37]:
# # Convert the discretized data into a DataFrame
# discretized_df = pd.DataFrame(discretized_df, columns=FEATURE_COLUMNS)

# unique_values = discretized_df[VALUE_COLUMN].fillna(method=FILL_METHOD).unique()
# print(unique_values)

# bin_counts = {col: discretized_df[col].value_counts() for col in FEATURE_COLUMNS}

# # Plotting
# plt.figure(figsize=(15, len(FEATURE_COLUMNS) * 5))

# for i, column in enumerate(FEATURE_COLUMNS):
#     plt.subplot(len(FEATURE_COLUMNS), 1, i + 1)
#     bin_counts[column].sort_index().plot(kind='bar', ax=plt.gca())

#     plt.title(f'{column} Distribution')
#     plt.ylabel('Frequency')
#     plt.xlabel('Bins')

# plt.tight_layout()
# plt.show()


In [38]:
# Clean up before passing to Arima
initial_row_count = aggregated_df.shape[0]

columns_of_interest = ['Price']  # Add other columns as necessary

aggregated_df = aggregated_df.dropna(subset=columns_of_interest)

rows_dropped = initial_row_count - aggregated_df.shape[0]

print(f"Rows dropped due to NaN values: {rows_dropped}")

Rows dropped due to NaN values: 1


## Train / Test Set Up

In [39]:
from sklearn.model_selection import train_test_split
import pmdarima as pm

# Split the aggregated DataFrame into training and testing sets first
train_size = int(len(aggregated_df) * 0.8)
train_df = aggregated_df[:train_size]
test_df = aggregated_df[train_size:]

# Fit an Auto ARIMA model to the 'Price' series of the training data
model = pm.auto_arima(train_df['Price'], seasonal=True, m=12, suppress_warnings=True, stepwise=True, error_action='ignore')

# Forecast the training series using the model (in-sample prediction) to calculate training residuals
train_forecast = model.predict_in_sample()
train_residuals = train_df['Price'] - train_forecast

# For the test set, use the model to forecast test-sample and calculate residuals
test_forecast = model.predict(n_periods=len(test_df))
test_residuals = test_df['Price'] - test_forecast

# Append residuals to the respective DataFrame as a new feature for anomaly detection
train_df = train_df.copy()
train_df['ARIMA_Residuals'] = train_residuals

test_df = test_df.copy()
test_df['ARIMA_Residuals'] = test_residuals

# Combine train and test data back if necessary for further steps
# Note: Ensure to keep track of which part is train and test when using this combined data for modeling
aggregated_df = pd.concat([train_df, test_df])

  return get_prediction_index(
  return get_prediction_index(


In [40]:
# Trying out different window sizes
SPIKE_WINDOW_SIZES = [20, 30, 40]
results_dfs = []

for window_size in SPIKE_WINDOW_SIZES:
    print(f"Evaluating window size: {window_size}")

    aggregated_df['spikes'] = detect_spikes(aggregated_df, 'Price', window_size)

    # Prepare features and target
    FEATURE_COLUMNS = [TARGET_COLUMN, 'ARIMA_Residuals']  # Adjust as needed
    X, y = data_processing.prepare_features_and_target(aggregated_df, FEATURE_COLUMNS, 'spikes')

    # Split data 
    X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_STATE, shuffle=False)

    # Scale features
    X_train_scaled, X_test_scaled = data_processing.scale_features(X_train_raw, X_test_raw)

    # Create sequences
    X_train, y_train = data_processing.create_sequences(X_train_scaled, y_train, window_size)
    X_test, y_test = data_processing.create_sequences(X_test_scaled, y_test, window_size)

    print(f'X train shape: {X_train.shape}')

    output_file_path = f'{COMMODITY} + arima h/results_{window_size}.csv'
    pred_file_path = f'{COMMODITY} + arima h/predictions/{window_size}'
    print(pred_file_path)
    results_df = models.evaluate_all(X_train, y_train, X_test, y_test, output_file_path, pred_file_path)
    results_df['Window Size'] = window_size
    results_dfs.append(results_df)

Evaluating window size: 20
X train shape: (1025, 20, 2)
cobalt + arima h/predictions/20



Predictions saved to CSV file: cobalt + arima h/predictions/20/LSTM_250_layers_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/LSTM_250_layers_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/LSTM_200_layers_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/LSTM_200_layers_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/LSTM_100_layers_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/LSTM_100_layers_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/LSTM_50_layers_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/LSTM_50_layers_predictions.npy


Predictions saved to CSV file: cobalt + arima h/predictions/20/CNN_Attention_32_filters_7_kernels_predictions.csv
Predictions saved to NPY file: coba

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/20/RNN_150_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/RNN_150_units_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/20/RNN_100_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/RNN_100_units_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/RNN_50_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/RNN_50_units_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/CNN_32_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/CNN_32_filters_7_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/CNN_32_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/CNN_32_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/20/CNN_32_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/20/CNN_32_filters_3_kernels_predictions.npy
Predic

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_32_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_32_filters_5_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_32_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_32_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_64_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_64_filters_7_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_64_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_64_filters_5_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_64_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_64_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_128_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_128_filters_7_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_128_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_128_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_128_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_128_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_256_f

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_Attention_256_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_Attention_256_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/RNN_200_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/RNN_200_units_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/RNN_150_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/RNN_150_units_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/RNN_100_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/RNN_100_units_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/RNN_50_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/RNN_50_units_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_32_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_32_filters_7_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_32_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_32_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_32_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_32_filters_3_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_64_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_64_filters_7_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_64_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_64_filters_5_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_64_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_64_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_128_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_128_filters_7_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_128_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_128_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_128_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_128_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_256_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictio

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_256_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_256_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/30/CNN_256_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/30/CNN_256_filters_3_kernels_predictions.npy
Evaluating window size: 40
X train shape: (1005, 40, 2)
cobalt + arima h/predictions/40
Predictions saved to CSV file: cobalt + arima h/predictions/40/LSTM_250_layers_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/LSTM_250_layers_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/LSTM_200_layers_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/LSTM_200_layers_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/LSTM_100_layers_predictions.csv
Predictions saved

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_Attention_128_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_Attention_128_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_Attention_128_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_Attention_128_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_Attention_256_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_Attention_256_filters_7_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_Attention_256_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_Attention_256_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_Attention_256_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_Attention_256_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/RNN_200_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/RNN_200_units_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/RNN_150_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/RNN_150_units_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/40/RNN_100_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/RNN_100_units_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/40/RNN_50_units_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/RNN_50_units_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_32_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_32_filters_7_kernels_predictions.npy


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_32_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_32_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_32_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_32_filters_3_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_64_filters_7_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_64_filters_7_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_64_filters_5_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/CNN_64_filters_5_kernels_predictions.npy
Predictions saved to CSV file: cobalt + arima h/predictions/40/CNN_64_filters_3_kernels_predictions.csv
Predictions saved to NPY file: cobalt + arima h/predictions/40/C

## Evaluation

In [41]:
# Display all results df
for idx, df in enumerate(results_dfs):
    print(f"Results for window size: {SPIKE_WINDOW_SIZES[idx]}")
    display(df) 

Results for window size: 20


Unnamed: 0,Name,Params,Accuracy,Precision (0),Recall (0),F1 (0),Precision (1),Recall (1),F1 (1),Window Size
0,LSTM,250 layers,0.111111,1.0,0.004608,0.009174,0.107438,1.0,0.19403,20
1,LSTM,200 layers,0.111111,1.0,0.004608,0.009174,0.107438,1.0,0.19403,20
2,LSTM,100 layers,0.111111,1.0,0.004608,0.009174,0.107438,1.0,0.19403,20
3,LSTM,50 layers,0.111111,1.0,0.004608,0.009174,0.107438,1.0,0.19403,20
4,CNN with Attention,"32 filters, kernel size 7",0.868313,0.900433,0.958525,0.928571,0.25,0.115385,0.157895,20
5,CNN with Attention,"32 filters, kernel size 5",0.860082,0.899563,0.949309,0.923767,0.214286,0.115385,0.15,20
6,CNN with Attention,"32 filters, kernel size 3",0.148148,0.708333,0.078341,0.141079,0.086758,0.730769,0.155102,20
7,CNN with Attention,"64 filters, kernel size 7",0.547325,0.884892,0.56682,0.691011,0.096154,0.384615,0.153846,20
8,CNN with Attention,"64 filters, kernel size 5",0.864198,0.9,0.953917,0.926174,0.230769,0.115385,0.153846,20
9,CNN with Attention,"64 filters, kernel size 3",0.839506,0.900901,0.921659,0.911162,0.190476,0.153846,0.170213,20


Results for window size: 30


Unnamed: 0,Name,Params,Accuracy,Precision (0),Recall (0),F1 (0),Precision (1),Recall (1),F1 (1),Window Size
0,LSTM,250 layers,0.098712,1.0,0.004739,0.009434,0.094828,1.0,0.173228,30
1,LSTM,200 layers,0.098712,1.0,0.004739,0.009434,0.094828,1.0,0.173228,30
2,LSTM,100 layers,0.098712,1.0,0.004739,0.009434,0.094828,1.0,0.173228,30
3,LSTM,50 layers,0.098712,1.0,0.004739,0.009434,0.094828,1.0,0.173228,30
4,CNN with Attention,"32 filters, kernel size 7",0.905579,0.905579,1.0,0.95045,0.0,0.0,0.0,30
5,CNN with Attention,"32 filters, kernel size 5",0.905579,0.905579,1.0,0.95045,0.0,0.0,0.0,30
6,CNN with Attention,"32 filters, kernel size 3",0.896996,0.904762,0.990521,0.945701,0.0,0.0,0.0,30
7,CNN with Attention,"64 filters, kernel size 7",0.905579,0.905579,1.0,0.95045,0.0,0.0,0.0,30
8,CNN with Attention,"64 filters, kernel size 5",0.905579,0.905579,1.0,0.95045,0.0,0.0,0.0,30
9,CNN with Attention,"64 filters, kernel size 3",0.875536,0.902655,0.966825,0.933638,0.0,0.0,0.0,30


Results for window size: 40


Unnamed: 0,Name,Params,Accuracy,Precision (0),Recall (0),F1 (0),Precision (1),Recall (1),F1 (1),Window Size
0,LSTM,250 layers,0.098655,1.0,0.00495,0.009852,0.094595,1.0,0.17284,40
1,LSTM,200 layers,0.098655,1.0,0.00495,0.009852,0.094595,1.0,0.17284,40
2,LSTM,100 layers,0.098655,1.0,0.00495,0.009852,0.094595,1.0,0.17284,40
3,LSTM,50 layers,0.098655,1.0,0.00495,0.009852,0.094595,1.0,0.17284,40
4,CNN with Attention,"32 filters, kernel size 7",0.686099,0.888235,0.747525,0.811828,0.037736,0.095238,0.054054,40
5,CNN with Attention,"32 filters, kernel size 5",0.704036,0.886364,0.772277,0.825397,0.021277,0.047619,0.029412,40
6,CNN with Attention,"32 filters, kernel size 3",0.690583,0.884393,0.757426,0.816,0.02,0.047619,0.028169,40
7,CNN with Attention,"64 filters, kernel size 7",0.152466,0.69697,0.113861,0.195745,0.057895,0.52381,0.104265,40
8,CNN with Attention,"64 filters, kernel size 5",0.901345,0.905405,0.99505,0.948113,0.0,0.0,0.0,40
9,CNN with Attention,"64 filters, kernel size 3",0.789238,0.893401,0.871287,0.882206,0.0,0.0,0.0,40


### **Random Under Sampling**

In [42]:
# from imblearn.under_sampling import RandomUnderSampler
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import StandardScaler


# time_series_df = aggregated_df.copy()

# # Drop rows with NaN in the 'spikes' column
# time_series_df = time_series_df.dropna(subset=['spikes'])
# discretized_df = discretize(time_series_df[FEATURE_COLUMNS], FEATURE_COLUMNS, BIN_COUNT)
# time_series_df[FEATURE_COLUMNS] = discretized_df

# # Extract features and target variable BEFORE creating sequences
# X = time_series_df[FEATURE_COLUMNS].values
# y = time_series_df['spikes'].values

# # Feature scaling using StandardScaler
# scaler = StandardScaler()
# X_scaled = scaler.fit_transform(X)

# # Apply RandomOverSampler BEFORE creating sequences
# random_under_sampler = RandomUnderSampler(random_state=RANDOM_STATE)
# X_scaled_resampled, y_resampled = random_under_sampler.fit_resample(X_scaled, y)

# # Recreate sequences with resampled data
# X_sequences_resampled, y_sequences_resampled = [], []
# for i in range(len(X_scaled_resampled) - SPIKES_WINDOW_SIZE + 1):
#     X_sequences_resampled.append(X_scaled_resampled[i:i + SPIKES_WINDOW_SIZE, :])
#     y_sequences_resampled.append(y_resampled[i + SPIKES_WINDOW_SIZE - 1])

# X_sequences_resampled, y_sequences_resampled = np.array(X_sequences_resampled), np.array(y_sequences_resampled)

# # Split the resampled sequences into training and testing sets
# X_train, X_test, y_train, y_test = train_test_split(X_sequences_resampled, y_sequences_resampled, test_size=0.2, random_state=50)

# # Proceed with model training and evaluation
# evaluate_all_models(RUS_OUTPUT_PATH)


### **Random Over Sampling**

In [43]:
# from imblearn.over_sampling import RandomOverSampler

# time_series_df = aggregated_df.copy()

# # Drop rows with NaN in the 'spikes' column
# time_series_df = time_series_df.dropna(subset=['spikes'])
# discretized_df = discretize(time_series_df[FEATURE_COLUMNS], FEATURE_COLUMNS, BIN_COUNT)
# time_series_df[FEATURE_COLUMNS] = discretized_df

# # Extract features and target variable BEFORE creating sequences
# X = time_series_df[FEATURE_COLUMNS].values
# y = time_series_df['spikes'].values

# # Feature scaling using StandardScaler
# scaler = StandardScaler()
# X_scaled = scaler.fit_transform(X)

# # Apply RandomUnderSampler BEFORE creating sequences
# random_over_sampler = RandomOverSampler(random_state=RANDOM_STATE)
# X_scaled_resampled, y_resampled = random_over_sampler.fit_resample(X_scaled, y)

# # Recreate sequences with resampled data
# X_sequences_resampled, y_sequences_resampled = [], []
# for i in range(len(X_scaled_resampled) - SPIKES_WINDOW_SIZE + 1):
#     X_sequences_resampled.append(X_scaled_resampled[i:i + SPIKES_WINDOW_SIZE, :])
#     y_sequences_resampled.append(y_resampled[i + SPIKES_WINDOW_SIZE - 1])

# X_sequences_resampled, y_sequences_resampled = np.array(X_sequences_resampled), np.array(y_sequences_resampled)

# # Split the resampled sequences into training and testing sets
# X_train, X_test, y_train, y_test = train_test_split(X_sequences_resampled, y_sequences_resampled, test_size=0.2, random_state=50)

# evaluate_all_models(ROS_OUTPUT_PATH)