In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, lpad, lit, regexp_extract, date_format, lag
from pyspark.sql.window import Window
import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType

In [None]:
spark = SparkSession.builder \
    .appName("machungkhoan") \
    .config("spark.cores.max", "2") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

In [None]:
# Đọc dữ liệu từ bảng nguồn machungkhoan
df_lichsugia = spark.read.format("iceberg").load("stock_db.datn_lichsugia")
df_machungkhoan = spark.read.format("iceberg").load("stock_db.datn_machungkhoan")

# Convert 'ngay' to DateType
df_lichsugia = df_lichsugia.withColumn("ngay", to_date(col("ngay"), "dd/MM/yyyy"))

# Extract numeric values from 'thaydoi' and split into change and percentage change
df_lichsugia = df_lichsugia.withColumn("thaydoi_value", regexp_extract(col("thaydoi"), r'([\d\.-]+)', 1).cast("float"))
df_lichsugia = df_lichsugia.withColumn("thaydoi_percent", regexp_extract(col("thaydoi"), r'\(([\d\.-]+)%\)', 1).cast("float"))

# Convert 'ngay' to 'dateid' in the format ddMMyyyy
df_lichsugia = df_lichsugia.withColumn("dateid", date_format(col("ngay"), "ddMMyyyy").cast("int"))

dim_stock_df = df_machungkhoan.filter((col('categoryname').isNotNull()) & (col('categoryname') != ""))
df_stock_drop_duplicate = dim_stock_df.dropDuplicates(['symbol'])

# Ghi dữ liệu vào bảng dim_stock
dim_stock_data = df_stock_drop_duplicate.select(
    col("symbol").alias("stocksymbol"),
    "companyname",
    "categoryid"
).distinct()

# Convert dim_stock_df to a list of stocksymbols
valid_stocksymbols = [row.stocksymbol for row in dim_stock_data.select("stocksymbol").distinct().collect()]

# Select and rename columns to match the schema of 'fact_price_history'
df_fact_price_history = df_lichsugia.select(
    col("symbol").alias("stocksymbol"),
    col("dateid"),
    col("giamocua").alias("openprice").cast("decimal(18, 2)"),
    col("giadongcua").alias("closeprice").cast("decimal(18, 2)"),
    col("giacaonhat").alias("highprice").cast("decimal(18, 2)"),
    col("giathapnhat").alias("lowprice").cast("decimal(18, 2)"),
    col("khoiluongkhoplenh").alias("volume").cast("bigint")
)

# Đảm bảo df_fact_price_history đã được lọc để loại bỏ các giá trị null
df_fact_price_history_filtered = df_fact_price_history.filter(col("stocksymbol").isin(valid_stocksymbols))
df_fact_price_history_filtered = df_fact_price_history_filtered.withColumn('dateid_padded', lpad(col('dateid').cast('string'), 8, '0'))
df_fact_price_history_filtered = df_fact_price_history_filtered.withColumn('date', to_date(col('dateid_padded'), 'ddMMyyyy'))

In [None]:
def calculate_obv(df):
    # Khung cửa sổ để truy cập giá đóng cửa và khối lượng của phiên trước đó
    window_spec = Window.partitionBy("stocksymbol").orderBy(col("date").asc())
    
    # Lấy giá đóng cửa và khối lượng của phiên trước đó
    df = df.withColumn("prev_close", lag("closeprice").over(window_spec)) \
        .withColumn("prev_volume", lag("volume").over(window_spec))

    # Tính toán OBV
    df = df.withColumn("OBV_change",
        when(col("prev_close").isNull(), 0).  # Ngày đầu tiên
        otherwise(when(col("closeprice") > col("prev_close"), col("volume")).
                  when(col("closeprice") < col("prev_close"), -col("volume")).
                  otherwise(0))
    )

    # Tính tổng OBV
    df = df.withColumn("OBV", sum("OBV_change").over(window_spec))

    return df

In [None]:
def calculate_ad(df):
    # Tính toán giá trị A/D cho từng ngày, xử lý trường hợp mẫu số bằng 0
    df = df.withColumn(
        "A/D",
        when(
            (col("highprice") - col("lowprice")) != 0,
            (((col("closeprice") - col("lowprice")) - (col("highprice") - col("closeprice"))) * col("volume")) / (col("highprice") - col("lowprice"))
        ).otherwise(0)
    )
    return df

In [None]:
def calculate_ema(df, period, column='closeprice'):
    """
    Calculate the Exponential Moving Average (EMA) for a given period.
    
    :param df: The DataFrame containing stock price data
    :param period: The period for which EMA is calculated
    :param column: The column name for which EMA is calculated
    :return: DataFrame with EMA values
    """
    alpha = 2 / (period + 1)
    
    windowSpec = Window.partitionBy("stocksymbol").orderBy("date")

    # Calculate initial SMA (Simple Moving Average)
    sma_window_spec = Window.partitionBy("stocksymbol").orderBy("date").rowsBetween(-period + 1, 0)
    df = df.withColumn('SMA', F.avg(column).over(sma_window_spec))
    
    # Initialize EMA column
    df = df.withColumn('EMA', F.lit(None).cast(DoubleType()))

    # Calculate EMA using iterative approach
    def calculate_ema_iter(closeprice, previous_ema, alpha):
        if previous_ema is None:
            return closeprice
        return (closeprice * alpha) + (previous_ema * (1 - alpha))

    # Define UDF for EMA calculation
    ema_udf = F.udf(lambda closeprice, previous_ema: calculate_ema_iter(closeprice, previous_ema, alpha), DoubleType())

    # Apply UDF to calculate EMA iteratively
    df = df.withColumn('EMA', F.coalesce(
        ema_udf(col(column), lag(col('EMA')).over(windowSpec)),
        col('SMA')
    ))
    
    return df.drop('SMA')

In [None]:

def calculate_macd(df, short_period, long_period, signal_period, column='closeprice'):
    """
    Calculate the Moving Average Convergence Divergence (MACD) for a given short, long, and signal periods.
    
    :param df: The DataFrame containing stock price data
    :param short_period: The short period for EMA calculation
    :param long_period: The long period for EMA calculation
    :param signal_period: The period for the signal line
    :param column: The column name for which MACD is calculated
    :return: DataFrame with MACD values
    """
    # Calculate short EMA
    short_ema_df = calculate_ema(df, short_period, column).withColumnRenamed("EMA", "short_ema")
    
    # Calculate long EMA
    long_ema_df = calculate_ema(df, long_period, column).withColumnRenamed("EMA", "long_ema")
    
    # Join short EMA and long EMA DataFrames
    macd_df = short_ema_df.join(long_ema_df, on=['stocksymbol', 'dateid'], how='inner')
    
    # Calculate MACD line
    macd_df = macd_df.withColumn('MACD', col('long_ema') - col('short_ema')).select('stocksymbol', 'dateid', 'MACD')
    
    # Calculate signal line
    signal_window_spec = Window.partitionBy("stocksymbol").orderBy("dateid").rowsBetween(-signal_period + 1, 0)
    macd_df = macd_df.withColumn('Signal_Line', F.avg('MACD').over(signal_window_spec))
    
    # Calculate MACD Histogram
    macd_df = macd_df.withColumn('MACD_Histogram', col('MACD') - col('Signal_Line'))
    
    return macd_df


In [None]:
def calculate_rs(df, period, column='closeprice'):
    """
    Calculate the Relative Strength (RS) for a given period.
    
    :param df: The DataFrame containing stock price data
    :param period: The period for which RS is calculated
    :param column: The column name for which RS is calculated
    :return: DataFrame with RS values
    """
    windowSpec = Window.partitionBy("stocksymbol").orderBy("dateid")

    # Calculate price change
    df = df.withColumn("price_change", col(column) - lag(col(column)).over(windowSpec))

    # Separate gains and losses
    df = df.withColumn("gain", F.when(col("price_change") > 0, col("price_change")).otherwise(0))
    df = df.withColumn("loss", F.when(col("price_change") < 0, -col("price_change")).otherwise(0))

    # Calculate average gains and losses
    avg_gain_window_spec = Window.partitionBy("stocksymbol").orderBy("dateid").rowsBetween(-period + 1, 0)
    df = df.withColumn("avg_gain", F.avg("gain").over(avg_gain_window_spec))
    df = df.withColumn("avg_loss", F.avg("loss").over(avg_gain_window_spec))

    # Calculate RS
    df = df.withColumn("RS", F.when((col("avg_loss") != 0), col("avg_gain") / col("avg_loss")).otherwise(0))

    return df

In [None]:
def calculate_rsi(df, period):
    """
    Calculate the Relative Strength Index (RSI) for a given period.
    
    :param df: The DataFrame containing stock price data with RS values
    :param period: The period for which RSI is calculated
    :return: DataFrame with RSI values
    """
    windowSpec = Window.partitionBy("stocksymbol").orderBy("dateid")

    # Calculate RSI
    df = df.withColumn("RSI", 100 - (100 / (1 + col("RS"))))

    return df