In [58]:
from pyspark.sql import SparkSession
from pyspark.sql import types as T
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import os

In [5]:
spark = SparkSession.builder \
        .appName("Stock Prediction") \
        .master("local[*]") \
        .config("spark.ui.port", "8080") \
        .getOrCreate()

In [43]:
schema = T.StructType([
  T.StructField("datetime", T.DateType()),
  T.StructField("open", T.DoubleType()),
  T.StructField("high", T.DoubleType()),
  T.StructField("low", T.DoubleType()),
  T.StructField("close", T.DoubleType()),
  T.StructField("volume", T.DoubleType()),
  ])

In [None]:
df = spark.read \
    .format("csv") \
    .option("header", True) \
    .option("mode", "FAILFAST") \
    .schema(schema) \
    .load("../data/raw/archive/D1/*.csv")

                                                                                

In [51]:
@udf(returnType=T.StringType())
def get_basename(path):
  filename = os.path.basename(path)
  filename_without_ext = os.path.splitext(filename)[0]
  return filename_without_ext.split('.')[0]

df = df.withColumn("ticket_name", get_basename(F.input_file_name()))

In [59]:
window_spec = Window.partitionBy(F.col("ticket_name")).orderBy(F.col("ticket_name"))

In [62]:
df = df.withColumn("count", F.count(F.col("ticket_name")).over(window_spec))
df = df.withColumn("dense_rank", F.dense_rank().over(window_spec))

In [78]:
df = df.withColumn("SMA", F.avg(F.col("close")).over(window_spec))
df.withColumn("delta", F.col("close") - F.lag(col="close").over(window_spec)) \
  .withColumn("gain", F.when(F.col("delta") > 0, F.col("delta")).otherwise(0)) \
  .withColumn("loss", F.when(F.col("delta") < 0, -F.col("delta")).otherwise(0)) \
  .withColumn("average_gain", F.avg("gain").over(window_spec)) \
  .withColumn("average_loss", F.avg("loss").over(window_spec)) \
  .withColumn("RS", F.col("average_gain") / F.col("average_loss")) \
  .withColumn("RSI", 100 - 100/(F.col("RS") + 1)) \
  .drop("delta", "gain", "loss", "average_gain", "average_loss", "RS")  

DataFrame[datetime: date, open: double, high: double, low: double, close: double, volume: double, ticket_name: string, count: bigint, dense_rank: int, SMA: double, RSI: double]