In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
import logging

# Start Spark
spark = SparkSession.builder \
    .appName("BTC_Prediction") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")  
logging.getLogger("py4j").setLevel(logging.ERROR)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/22 11:37:49 WARN Utils: Your hostname, chaima-ThinkPad-T470-W10DG, resolves to a loopback address: 127.0.1.1; using 192.168.1.120 instead (on interface wlp4s0)
26/01/22 11:37:49 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/22 11:37:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
# Load silver data
silver_path = "../data/silver/silver_dataset"
df = spark.read.parquet(silver_path)
df.show(5)
df.printSchema()

+-------------------+--------+--------+--------+--------+-------+--------------------+------------------+----------------+---------------------+----------------------+---------------+--------------------+-----------------+-----------------+-------------------+
|          open_time|    open|    high|     low|   close| volume|          close_time|quote_asset_volume|number_of_trades|taker_buy_base_volume|taker_buy_quote_volume|close_t_plus_10|           return_1m|             MA_5|            MA_10|        taker_ratio|
+-------------------+--------+--------+--------+--------+-------+--------------------+------------------+----------------+---------------------+----------------------+---------------+--------------------+-----------------+-----------------+-------------------+
|2026-01-19 09:54:00| 93001.8|93022.23|93001.79|93014.94|5.62556|2026-01-19 09:54:...|    523245.8306426|            1064|              3.42885|        318907.7899883|       92975.79|1.412875879821618...|         9300

In [5]:
median_value=df.approxQuantile("close_t_plus_10", [0.5], 0.1)[0]
print("median of target:", median_value)

                                                                                

median of target: 93042.34


In [6]:
# identify numeric columns
numeric_cols = [
    c for c, t in df.dtypes
    if t in ["int", "bigint", "double", "float"]
    and c != "close_t_plus_10"
]

In [None]:
#assemble features
assembler = VectorAssembler(
    inputCols=numeric_cols,
    outputCol="feature s"
)

In [8]:
df_ml = assembler.transform(df).select(
    "open_time",
    "features",
    "close_t_plus_10"
)

In [9]:
# count rows
total = df_ml.count()
train_limit = int(total * 0.8)

In [10]:
# index by time
window = Window.orderBy("open_time")
df_ml = df_ml.withColumn("rn", F.row_number().over(window))

In [11]:
# split
train_data = df_ml.filter(F.col("rn") <= train_limit) \
                  .drop("rn", "open_time")

test_data = df_ml.filter(F.col("rn") > train_limit) \
                 .drop("rn", "open_time")

In [12]:

lr = LinearRegression(featuresCol="features", labelCol="close_t_plus_10")
model = lr.fit(train_data)
predictions = model.transform(test_data)



                                                                                

In [13]:
#evaluate model
rmse = RegressionEvaluator(labelCol="close_t_plus_10", predictionCol="prediction", metricName="rmse")
mae = RegressionEvaluator(labelCol="close_t_plus_10",predictionCol="prediction",metricName="mae")
r2= RegressionEvaluator(labelCol="close_t_plus_10", predictionCol="prediction", metricName="r2")
rmse = rmse.evaluate(predictions)
mae = mae.evaluate(predictions)
r2 = r2.evaluate(predictions)

In [14]:
#printing results
print(f"root mean square error (RMSE): {rmse}")
print(f"mean absolute error (MAE): {mae}")
print(f"R squared (R2): {r2}")

root mean square error (RMSE): 62.80099893403466
mean absolute error (MAE): 50.74993129601723
R squared (R2): 0.019616593165821428
