In [9]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, lpad, lit, regexp_extract, date_format, lag

# Khởi tạo SparkSession
spark = SparkSession.builder \
    .appName("EMA Calculation") \
    .config("spark.cores.max", "2") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

# Đọc dữ liệu từ bảng nguồn machungkhoan
df_lichsugia = spark.read.format("iceberg").load("stock_db.datn_lichsugia")
df_machungkhoan = spark.read.format("iceberg").load("stock_db.datn_machungkhoan")

# Convert 'ngay' to DateType
df_lichsugia = df_lichsugia.withColumn("ngay", to_date(col("ngay"), "dd/MM/yyyy"))

# Extract numeric values from 'thaydoi' and split into change and percentage change
df_lichsugia = df_lichsugia.withColumn("thaydoi_value", regexp_extract(col("thaydoi"), r'([\d\.-]+)', 1).cast("float"))
df_lichsugia = df_lichsugia.withColumn("thaydoi_percent", regexp_extract(col("thaydoi"), r'\(([\d\.-]+)%\)', 1).cast("float"))

# Convert 'ngay' to 'dateid' in the format ddMMyyyy
df_lichsugia = df_lichsugia.withColumn("dateid", date_format(col("ngay"), "ddMMyyyy").cast("int"))

dim_stock_df = df_machungkhoan.filter((col('categoryname').isNotNull()) & (col('categoryname') != ""))
df_stock_drop_duplicate = dim_stock_df.dropDuplicates(['symbol'])

# Ghi dữ liệu vào bảng dim_stock
dim_stock_data = df_stock_drop_duplicate.select(
    col("symbol").alias("stocksymbol"),
    "companyname",
    "categoryid"
).distinct()

# Convert dim_stock_df to a list of stocksymbols
valid_stocksymbols = [row.stocksymbol for row in dim_stock_data.select("stocksymbol").distinct().collect()]

# Select and rename columns to match the schema of 'fact_price_history'
df_fact_price_history = df_lichsugia.select(
    col("symbol").alias("stocksymbol"),
    col("dateid"),
    col("giamocua").alias("openprice").cast("decimal(18, 2)"),
    col("giadongcua").alias("closeprice").cast("decimal(18, 2)"),
    col("giacaonhat").alias("highprice").cast("decimal(18, 2)"),
    col("giathapnhat").alias("lowprice").cast("decimal(18, 2)"),
    col("khoiluongkhoplenh").alias("volume").cast("bigint")
)

# Đảm bảo df_fact_price_history đã được lọc để loại bỏ các giá trị null
df_fact_price_history_filtered = df_fact_price_history.filter(col("stocksymbol").isin(valid_stocksymbols))

df_fact_price_history_filtered = df_fact_price_history_filtered.withColumn('dateid_padded', lpad(col('dateid').cast('string'), 8, '0'))
df_fact_price_history_filtered = df_fact_price_history_filtered.withColumn('date', to_date(col('dateid_padded'), 'ddMMyyyy'))


In [10]:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType

def calculate_ema(df, period, column='closeprice'):
    """
    Calculate the Exponential Moving Average (EMA) for a given period.
    
    :param df: The DataFrame containing stock price data
    :param period: The period for which EMA is calculated
    :param column: The column name for which EMA is calculated
    :return: DataFrame with EMA values
    """
    alpha = 2 / (period + 1)
    
    windowSpec = Window.partitionBy("stocksymbol").orderBy("date")

    # Calculate initial SMA (Simple Moving Average)
    sma_window_spec = Window.partitionBy("stocksymbol").orderBy("date").rowsBetween(-period + 1, 0)
    df = df.withColumn('SMA', F.avg(column).over(sma_window_spec))
    
    # Initialize EMA column
    df = df.withColumn('EMA', F.lit(None).cast(DoubleType()))

    # Calculate EMA using iterative approach
    def calculate_ema_iter(closeprice, previous_ema, alpha):
        if previous_ema is None:
            return closeprice
        return (closeprice * alpha) + (previous_ema * (1 - alpha))

    # Define UDF for EMA calculation
    ema_udf = F.udf(lambda closeprice, previous_ema: calculate_ema_iter(closeprice, previous_ema, alpha), DoubleType())

    # Apply UDF to calculate EMA iteratively
    df = df.withColumn('EMA', F.coalesce(
        ema_udf(col(column), lag(col('EMA')).over(windowSpec)),
        col('SMA')
    ))
    
    return df.drop('SMA')



In [26]:
# Apply the EMA calculation
period = 200
df_with_ema = calculate_ema(df_fact_price_history_filtered, period)

# Select required columns
ema_data = df_with_ema.select(
    col("stocksymbol"),
    col("dateid"),
    col("EMA").alias("indicatorvalue"),
    lit(8).alias("indicatorid")  # Giả sử 5 là ID cho EMA
)


In [27]:
ema_data.show()



+-----------+--------+--------------+-----------+
|stocksymbol|  dateid|indicatorvalue|indicatorid|
+-----------+--------+--------------+-----------+
|        AAT|24032021|          12.7|          8|
|        AAT|25032021|         12.45|          8|
|        AAT|26032021|     12.083333|          8|
|        AAT|29032021|       11.8125|          8|
|        AAT|30032021|          11.8|          8|
|        AAT|31032021|        11.925|          8|
|        AAT| 1042021|     12.135714|          8|
|        AAT| 2042021|      12.24375|          8|
|        AAT| 5042021|     12.283333|          8|
|        AAT| 6042021|          12.4|          8|
|        AAT| 7042021|          12.5|          8|
|        AAT| 8042021|     12.558333|          8|
|        AAT| 9042021|     12.676923|          8|
|        AAT|12042021|     12.846429|          8|
|        AAT|13042021|     13.033333|          8|
|        AAT|14042021|     13.190625|          8|
|        AAT|15042021|         13.35|          8|


                                                                                

In [28]:
# Write result to `fact_stock_indicator`
ema_data.write \
    .format("jdbc") \
    .option("driver", "com.mysql.cj.jdbc.Driver") \
    .option("url", "jdbc:mysql://10.168.6.106:3306/dtm_stock") \
    .option("dbtable", "fact_stock_indicator") \
    .option("user", "acc_etl") \
    .option("password", "Vnpt123456") \
    .mode("append") \
    .save()


                                                                                