In [1]:
import findspark

findspark.add_jars('/app/postgresql-42.1.4.jar')
findspark.init()

In [2]:
from pyspark.sql import SparkSession
spark = (
    SparkSession.builder
    .appName("Stocks:ETL")
    .config("spark.driver.memory", "512m")
    .config("spark.driver.cores", "1")
    .config("spark.executor.memory", "512m")
    .config("spark.executor.cores", "1")
    .config("spark.sql.shuffle.partitions", "2")
    .getOrCreate()
)

In [None]:
spark.version

In [None]:
stocks_dir = '/dataset/stocks-small'

In [None]:
import sys

from pyspark.sql import SparkSession

# UDF
from pyspark.sql.types import StringType
#
from pyspark.sql import functions as F
from pyspark.sql.window import Window

In [None]:
df = spark.read \
    .option("header", True) \
    .option("inferSchema", True) \
    .csv(stocks_dir)

In [None]:
df.count()
df.printSchema()

In [None]:
df.show()

In [None]:
df = df.withColumn('filename', F.input_file_name())

In [None]:
df.show(truncate=False)

In [None]:
df_lookup = spark.read.csv('/dataset/yahoo-symbols-201709.csv')

In [None]:
df_lookup.show()

In [None]:
def extract_symbol_from(filename):
    return filename.split('/')[-1].split('.')[0].upper()

In [None]:
# filename = 'file:///dataset/stocks-small/ibm.us.txt' # => IBM
extract_symbol_from('file:///dataset/stocks-small/ibm.us.txt')

In [None]:
extract_symbol = F.udf(lambda filename: extract_symbol_from(filename), StringType())

In [None]:
stocks_folder = stocks_dir
df = spark.read \
        .option("header", True) \
        .option("inferSchema", True) \
        .csv(stocks_folder) \
        .withColumn("name", extract_symbol(F.input_file_name()))

In [None]:
df.show(5)

In [None]:
df = spark.read \
        .option("header", True) \
        .option("inferSchema", True) \
        .csv(stocks_folder) \
        .withColumn("name", extract_symbol(F.input_file_name())) \
        .withColumnRenamed("Date", "dateTime") \
        .withColumnRenamed("Open", "open") \
        .withColumnRenamed("High", "high") \
        .withColumnRenamed("Low", "low") \
        .withColumnRenamed("Close", "close") \
        .drop("Volume", "OpenInt")

In [None]:
df_stocks = df

In [None]:
df_stocks.show(5)

In [None]:
lookup_file = '/dataset/yahoo-symbols-201709.csv'

In [None]:
symbols_lookup = spark.read. \
        option("header", True). \
        option("inferSchema", True). \
        csv(lookup_file). \
        select("Ticker", "Category Name"). \
        withColumnRenamed("Ticker", "symbol"). \
        withColumnRenamed("Category Name", "category")

In [None]:
df_stocks.show(3)
symbols_lookup.show(3)

In [None]:
joined_df = df_stocks \
    .withColumnRenamed('dateTime', "full_date") \
    .filter("full_date >= \"2017-09-01\"") \
    .withColumn("year", F.year("full_date")) \
    .withColumn("month", F.month("full_date")) \
    .withColumn("day", F.dayofmonth("full_date")) \
    .withColumnRenamed("name", "symbol") \
    .join(symbols_lookup, ["symbol"])

In [None]:
joined_df.show(3)

In [None]:
window20 = (Window.partitionBy(F.col('symbol')).orderBy(F.col("full_date")).rowsBetween(-20, 0))
window50 = (Window.partitionBy(F.col('symbol')).orderBy(F.col("full_date")).rowsBetween(-50, 0))
window100 = (Window.partitionBy(F.col('symbol')).orderBy(F.col("full_date")).rowsBetween(-100, 0))

In [None]:
stocks_moving_avg_df = joined_df \
    .withColumn("ma20", F.avg("close").over(window20)) \
    .withColumn("ma50", F.avg("close").over(window50)) \
    .withColumn("ma100", F.avg("close").over(window100))

In [None]:
# Moving Average
stocks_moving_avg_df.select('symbol', 'close', 'ma20').show(25)

In [None]:
output_dir = '/dataset/output.parquet'

In [None]:
stocks_moving_avg_df \
    .write \
    .mode('overwrite') \
    .partitionBy("year", "month", "day") \
    .parquet(output_dir)

In [None]:
df_parquet = spark.read.parquet(output_dir)

In [None]:
df_parquet.count()

In [None]:
df_parquet.createOrReplaceTempView("stocks")

In [None]:
badHighestClosingPrice = spark.sql("SELECT symbol, MAX(close) AS price FROM stocks WHERE full_date >= '2017-09-01' AND full_date < '2017-10-01' GROUP BY symbol")
badHighestClosingPrice.explain()

In [None]:
highestClosingPrice = spark.sql("SELECT symbol, MAX(close) AS price FROM stocks WHERE year=2017 AND month=9 GROUP BY symbol")
highestClosingPrice.explain()

In [None]:
# Write to Postgres
stocks_moving_avg_df \
    .drop("year", "month", "day") \
    .write \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://postgres/workshop") \
    .option("dbtable", "workshop.stocks") \
    .option("user", "workshop") \
    .option("password", "w0rkzh0p") \
    .option("driver", "org.postgresql.Driver") \
    .mode('append') \
    .save()

In [None]:
spark.stop()