In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

#start Spark
spark = SparkSession.builder \
    .appName("BTC_Prediction") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")  # remove logs


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/21 10:03:58 WARN Utils: Your hostname, chaima-ThinkPad-T470-W10DG, resolves to a loopback address: 127.0.1.1; using 192.168.1.120 instead (on interface wlp4s0)
26/01/21 10:03:58 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/21 10:04:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
#load silver data
silver_path = "../data/silver/silver_dataset"
df = spark.read.parquet(silver_path)
df.show(5)

                                                                                

+-------------------+--------+--------+--------+--------+-------+--------------------+------------------+----------------+---------------------+----------------------+------+---------------+--------------------+-----------------+-----------------+-------------------+
|          open_time|    open|    high|     low|   close| volume|          close_time|quote_asset_volume|number_of_trades|taker_buy_base_volume|taker_buy_quote_volume|ignore|close_t_plus_10|           return_1m|             MA_5|            MA_10|        taker_ratio|
+-------------------+--------+--------+--------+--------+-------+--------------------+------------------+----------------+---------------------+----------------------+------+---------------+--------------------+-----------------+-----------------+-------------------+
|2026-01-19 09:53:00|93046.35|93046.35|92997.65| 93001.8|9.10532|2026-01-19 09:53:...|    847027.2616431|            2024|              2.08878|        194272.7849138|     0|       92965.51|      

In [3]:
#drop rows where target is null (still needed)
df = df.na.drop(subset=["close_t_plus_10"])
#drop nan in numeric features
numeric_cols = [
    c for c, t in df.dtypes
    if t in ["int", "bigint", "double", "float"]
    and c != "close_t_plus_10"
]
df = df.na.drop(subset=numeric_cols + ["close_t_plus_10"])


In [4]:
#assemble features
assembler = VectorAssembler(
    inputCols=numeric_cols,
    outputCol="features"
)

In [5]:
df_ml = assembler.transform(df).select(
    "open_time",
    "features",
    "close_t_plus_10"
)

In [6]:
# count rows
total = df_ml.count()
train_limit = int(total * 0.8)

In [7]:
# index by time
window = Window.orderBy("open_time")
df_ml = df_ml.withColumn("rn", F.row_number().over(window))

In [8]:
# split
train_data = df_ml.filter(F.col("rn") <= train_limit) \
                  .drop("rn", "open_time")

test_data = df_ml.filter(F.col("rn") > train_limit) \
                 .drop("rn", "open_time")

In [9]:

lr = LinearRegression(featuresCol="features", labelCol="close_t_plus_10")

model = lr.fit(train_data)
predictions = model.transform(test_data)



                                                                                

In [12]:
#evaluate model
rmse = RegressionEvaluator(labelCol="close_t_plus_10", predictionCol="prediction", metricName="rmse")
mae = RegressionEvaluator(labelCol="close_t_plus_10",predictionCol="prediction",metricName="mae")
r2 = RegressionEvaluator(labelCol="close_t_plus_10",predictionCol="prediction",metricName="r2")

rmse = rmse.evaluate(predictions)
mae = mae.evaluate(predictions)

In [15]:
#printing results
print(f"root mean square error (RMSE): {rmse}")
print(f"mean absolute error (MAE): {mae}")

root mean square error (RMSE): 62.80099893403466
mean absolute error (MAE): 50.74993129601723
