<a href="https://colab.research.google.com/github/Chienlovecode/Apple_stocks_predict/blob/main/Apple_Predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#1. Cài đặt thư viện
import pandas as pd # Import pandas and give it alias pd
import numpy as np
import matplotlib.pyplot as plt
import yfinance as yf

from pyspark.sql import SparkSession
from pyspark.sql.functions import to_date, col, avg, lag, when
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler, MinMaxScaler

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from sklearn.metrics import mean_absolute_error, r2_score

# #2. Khởi tạo SparkSession
spark = SparkSession.builder \
    .appName("StockLSTM_PySpark_AAPL") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()


In [6]:
# 3. Download AAPL data via yfinance and load into Spark
START = "2015-01-01"
TODAY = pd.to_datetime("today").strftime("%Y-%m-%d")

# Download into pandas
pdf = yf.download('AAPL', START, TODAY).reset_index()

# Check if the pandas DataFrame is empty
if pdf.empty:
    raise ValueError("The downloaded data is empty. Check your date range or internet connection.")

# Flatten MultiIndex columns (nếu có)
if isinstance(pdf.columns, pd.MultiIndex):
    # Lấy level 0 (Open, High, Low, Close, Adj Close, Volume, Date)
    pdf.columns = pdf.columns.get_level_values(0)

# Tạo Spark DataFrame từ Pandas
df_spark = spark.createDataFrame(pdf)

# Set legacy time parser policy
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY") # This line is added to set the legacy time parser

# Chuyển cột 'Date' về kiểu ngày và sắp xếp
from pyspark.sql.functions import to_date, col
df_spark = (
    df_spark
    .withColumn("Date", to_date(col("Date").cast("string"), "yyyy-MM-dd"))
    .orderBy("Date")
)

df_spark.printSchema()
df_spark.show(5)


[*********************100%***********************]  1 of 1 completed


root
 |-- Date: date (nullable = true)
 |-- Close: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)
 |-- Open: double (nullable = true)
 |-- Volume: long (nullable = true)

+----------+------------------+------------------+------------------+------------------+---------+
|      Date|             Close|              High|               Low|              Open|   Volume|
+----------+------------------+------------------+------------------+------------------+---------+
|2015-01-02|24.288583755493164| 24.75733822078495|23.848709278537115|24.746229620306494|212818400|
|2015-01-05|23.604337692260742|24.137518365002265| 23.41772538901526|24.057541179344305|257142000|
|2015-01-06| 23.60655403137207|23.866478974309715|23.244434724352086| 23.66875811738099|263188400|
|2015-01-07|23.937572479248047|24.037543101651647| 23.70430543372708|  23.8153846563686|160423600|
|2015-01-08| 24.85731315612793|24.915074837750936|24.148627035050126| 24.26637245776407

In [7]:
# 3. Feature Engineering in Spark
# Moving Average (MA20)
w = Window().orderBy("Date").rowsBetween(-19, 0)
df_spark = df_spark.withColumn("MA20", avg("Close").over(w))

# RSI calculation
df_spark = df_spark.withColumn("delta", col("Close") - lag("Close",1).over(Window.orderBy("Date")))
df_spark = df_spark.withColumn("gain", when(col("delta")>0, col("delta")).otherwise(0))
df_spark = df_spark.withColumn("loss", when(col("delta")<0, -col("delta")).otherwise(0))

w14 = Window().orderBy("Date").rowsBetween(-13, 0)
df_spark = df_spark.withColumn("avg_gain", avg("gain").over(w14)) \
                   .withColumn("avg_loss", avg("loss").over(w14))
df_spark = df_spark.withColumn("RS", col("avg_gain")/col("avg_loss")) \
                   .withColumn("RSI", 100 - (100/(1+col("RS"))))

df_spark.select("Date", "Close", "MA20", "RSI").show(5)


+----------+------------------+------------------+------------------+
|      Date|             Close|              MA20|               RSI|
+----------+------------------+------------------+------------------+
|2015-01-02|24.288583755493164|24.288583755493164|              NULL|
|2015-01-05|23.604337692260742|23.946460723876953|               0.0|
|2015-01-06| 23.60655403137207|23.833158493041992|0.3228638748110484|
|2015-01-07|23.937572479248047|23.859261989593506| 32.75096400245195|
|2015-01-08| 24.85731315612793| 24.05887222290039| 64.67899754052037|
+----------+------------------+------------------+------------------+
only showing top 5 rows



In [10]:
# 4. Scaling features with Spark ML
from pyspark.sql.functions import isnan, when, count, col

# Check for nulls in relevant columns
for column in ["Open","High","Low","Close","Volume","MA20","RSI"]:
    null_count = df_spark.select(count(when(isnan(column) | col(column).isNull(), column))).first()[0]
    print(f"Number of nulls in column {column}: {null_count}")

# Drop rows with nulls in any of the relevant columns
df_spark = df_spark.dropna(subset=["Open","High","Low","Close","Volume","MA20","RSI"])

assembler = VectorAssembler(
    inputCols=["Open","High","Low","Close","Volume","MA20","RSI"],
    outputCol="features_raw"
)
df_vec = assembler.transform(df_spark)

scaler = MinMaxScaler(inputCol="features_raw", outputCol="features_scaled")
scaler_model = scaler.fit(df_vec)
df_scaled = scaler_model.transform(df_vec)

df_scaled.select("Date", "features_scaled").show(5)

Number of nulls in column Open: 0
Number of nulls in column High: 0
Number of nulls in column Low: 0
Number of nulls in column Close: 0
Number of nulls in column Volume: 0
Number of nulls in column MA20: 0
Number of nulls in column RSI: 0
+----------+--------------------+
|      Date|     features_scaled|
+----------+--------------------+
|2015-01-05|[0.01471649271360...|
|2015-01-06|[0.01307605044738...|
|2015-01-07|[0.01369473062004...|
|2015-01-08|[0.01559764132614...|
|2015-01-09|[0.01882219492671...|
|2015-01-12|[0.01875660525088...|
|2015-01-13|[0.01765983035581...|
|2015-01-14|[0.01541949961150...|
|2015-01-15|[0.01631938376838...|
|2015-01-16|[0.01353539583343...|
|2015-01-20|[0.01429464712023...|
|2015-01-21|[0.01533513753738...|
|2015-01-22|[0.01656312361042...|
|2015-01-23|[0.01847535327772...|
|2015-01-26|[0.01982520650426...|
|2015-01-27|[0.01858783748185...|
|2015-01-28|[0.02347162329273...|
|2015-01-29|[0.02224361826114...|
|2015-01-30|[0.02419340728778...|
|2015-02-02|[