In [7]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as func
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DateType, DoubleType, ArrayType
import matplotlib.pyplot as plt
import seaborn as sns
import os
from dotenv import load_dotenv
from pyspark.sql.window import Window
from pyspark.sql import functions as F
import uuid

%load_ext dotenv
%dotenv ../.env

# Load environment variables from the .env file in the parent directory
# dotenv_path = os.path.join(os.path.dirname(__file__), "..", ".env")
# load_dotenv(dotenv_path)

The dotenv extension is already loaded. To reload it, use:
  %reload_ext dotenv


# PySpark Data Cleaning for PostgreSQL Stock Data

This PySpark script is designed for data cleaning and schema customization of stock data retrieved from a PostgreSQL database. The code covers the following steps:

1. **Setting Up Spark Session:**
   - Specifies the path to the PostgreSQL JDBC driver JAR file.
   - Creates a Spark session with specific configurations, including the JDBC driver.

2. **Reading Database Properties:**
   - Reads database connection properties from environment variables such as user, password, driver, and URL.

3. **Reading Data from PostgreSQL:**
   - Attempts to read stock data from the specified PostgreSQL table using the provided database properties.

4. **Defining Custom Schema:**
   - Defines a custom schema for the stock data, specifying data types for each column.

5. **Applying Custom Schema:**
   - Converts the data types of columns in the DataFrame to match the custom schema.

6. **Displaying the Result:**
   - Shows the cleaned stock data with the custom schema.

7. **Exception Handling:**
   - Catches and handles any exceptions that may occur during the data reading process, providing detailed error information.

Note: Make sure to replace placeholder values such as schema name, table name, and environment variables with actual values specific to your PostgreSQL setup.

The code aims to ensure that the stock data adheres to a defined schema for downstream analysis or processing.


In [8]:
# Specify the path to the PostgreSQL JDBC driver JAR file
postgres_jar_path = "./drivers/postgresql-42.6.0.jar"

# Create a Spark session with the PostgreSQL JDBC driver
spark = SparkSession.builder \
    .appName("StockDataCleaning") \
    .config("spark.executor.memory", "4g") \
    .config("spark.jars", postgres_jar_path) \
    .getOrCreate()

# Read database properties from environment variables
db_properties = {
    "user": os.getenv("DB_USER"),
    "password": os.getenv("DB_PASSWORD"),
    "driver": os.getenv("DB_DRIVER"),
    "url": os.getenv("DB_URL"),
}

# Schema and table name
schema_name = "public"  # Replace with your actual schema name
table_name = f'{schema_name}."Stocks"'

# Attempt to read data from PostgreSQL
try:
    stockData = spark.read.jdbc(url=db_properties["url"],
                                table=table_name,
                                properties=db_properties)

    # Define your custom schema
    custom_schema = StructType([
        StructField("transaction_id", StringType(), True),
        StructField("stock_id", StringType(), True),
        StructField("ticker_symbol", StringType(), True),
        StructField("date", DateType(), True),
        StructField("low", FloatType(), True),
        StructField("open", FloatType(), True),
        StructField("high", FloatType(), True),
        StructField("volume", IntegerType(), True),
        StructField("close", FloatType(), True)
    ])

    # Apply the custom schema to the DataFrame
    stockData = stockData \
        .withColumn("transaction_id", stockData["transaction_id"].cast(StringType())) \
        .withColumn("stock_id", stockData["stock_id"].cast(StringType())) \
        .withColumn("ticker_symbol", stockData["ticker_symbol"].cast(StringType())) \
        .withColumn("date", stockData["date"].cast(DateType())) \
        .withColumn("low", stockData["low"].cast(FloatType())) \
        .withColumn("open", stockData["open"].cast(FloatType())) \
        .withColumn("high", stockData["high"].cast(FloatType())) \
        .withColumn("volume", stockData["volume"].cast(IntegerType())) \
        .withColumn("close", stockData["close"].cast(FloatType()))

    # Show the DataFrame with the custom schema
    stockData.show()

except Exception as e:
    print("Error reading data from PostgreSQL:")
    print(e)
    # Print the full stack trace for debugging
    import traceback
    traceback.print_exc()


+--------------------+--------------------+-------------+----------+--------+--------+--------+---------+--------+
|      transaction_id|            stock_id|ticker_symbol|      date|     low|    open|    high|   volume|   close|
+--------------------+--------------------+-------------+----------+--------+--------+--------+---------+--------+
|35e4605e-6c3f-432...|3e772c8a-6ae0-428...|         TSLA|2010-06-29|1.169333|1.266667|1.666667|281494500|1.592667|
|8ab623dc-f18e-409...|3e772c8a-6ae0-428...|         TSLA|2010-06-30|1.553333|1.719333|   2.028|257806500|1.588667|
|c8731b55-94c0-4a3...|3e772c8a-6ae0-428...|         TSLA|2010-07-01|1.351333|1.666667|   1.728|123282000|   1.464|
|6b6d50f0-b7de-43c...|3e772c8a-6ae0-428...|         TSLA|2010-07-02|1.247333|1.533333|    1.54| 77097000|    1.28|
|7760d575-5c6c-465...|3e772c8a-6ae0-428...|         TSLA|2010-07-06|1.055333|1.333333|1.333333|103003500|   1.074|
|7443eec4-98b7-4da...|3e772c8a-6ae0-428...|         TSLA|2010-07-07|0.998667|1.0

## Data Preprocessing: Selecting Latest Stock Data

This cell performs data preprocessing on the stock data. It first converts the 'date' column to a date type if it's not already. Then, it creates a window specification for each ticker symbol, ordering the data by date in descending order.

A row number is added to the DataFrame based on the window specification, and rows with a row number equal to 1 (representing the latest date) for each group are filtered. The unnecessary 'row_number' column is dropped, and only the essential columns (ticker symbol, date, low, open, high, volume, and close) are selected in the final result.

The processed result is displayed using the `show()` function.


In [9]:
# Convert the 'date' column to a date type if it's not already
stockData = stockData.withColumn("date", F.to_date(stockData["date"]))

# Create a window specification for each group, ordered by the 'date' column in descending order
windowSpec = Window().partitionBy("ticker_symbol").orderBy(F.desc("date"))

# Add a row number to the DataFrame based on the window specification
rankedData = stockData.withColumn("row_number", F.row_number().over(windowSpec))

# Filter the rows with row number equal to 1 (latest date) for each group
latestData = rankedData.filter("row_number = 1").drop("row_number")

# Select only the necessary columns
result = latestData.select("ticker_symbol", "date", "low", "open", "high", "volume", "close")

# Show the result
result.show()

+-------------+----------+--------+--------+--------+---------+--------+
|ticker_symbol|      date|     low|    open|    high|   volume|   close|
+-------------+----------+--------+--------+--------+---------+--------+
|         AAPL|1986-03-12|0.110491|0.111049|0.112165| 85680000|0.110491|
|         MSFT|2010-06-28|   24.12|   24.51|   24.61| 73784800|   24.31|
|         TSLA|2023-11-17|  226.54|   232.0|  237.39|142140688|   234.3|
+-------------+----------+--------+--------+--------+---------+--------+



## Data Quality Check: Missing Values

This cell performs a check for missing values in the stock data. It calculates the sum of null values for each column and prints the results.

In [10]:
# Missing Values
missing_values = stockData.select([func.sum(func.col(c).isNull().cast("int")).alias(c + '_missing') for c in stockData.columns]).collect()

print("Missing Values:")
for row in missing_values[0].asDict():
    print(f"{row}: {missing_values[0][row]}")


Missing Values:
transaction_id_missing: 0
stock_id_missing: 0
ticker_symbol_missing: 0
date_missing: 0
low_missing: 0
open_missing: 0
high_missing: 0
volume_missing: 0
close_missing: 0


## Data Cleaning: Handling Missing Values and Deduplication

This cell addresses data quality by filling missing values in the 'volume' column with 0. Additionally, it removes duplicate rows based on the 'date' and 'close' columns, ensuring data integrity.

In [11]:
# Fill missing values in the 'volume' column with 0
stockData = stockData.na.fill(0, subset=['volume'])

# Drop duplicate rows based on 'date' and 'close' columns
cleanedStockData = stockData.dropDuplicates(['date', 'close'])

# Check for duplicate values in the 'date' column again
duplicate_rows = cleanedStockData.groupBy('date', 'close').count().filter('count > 1')

# Show the duplicate dates and close prices, if any
if duplicate_rows.count() > 0:
    print("Duplicate dates and close prices found after deduplication:")
    duplicate_rows.show()
else:
    print("No duplicate dates and close prices found.")

cleanedStockData.orderBy(func.desc("date")).show()


No duplicate dates and close prices found.
+--------------------+--------------------+-------------+----------+------+------+------+---------+------+
|      transaction_id|            stock_id|ticker_symbol|      date|   low|  open|  high|   volume| close|
+--------------------+--------------------+-------------+----------+------+------+------+---------+------+
|688ed242-e621-4a2...|3e772c8a-6ae0-428...|         TSLA|2023-11-17|226.54| 232.0|237.39|142140688| 234.3|
|1f80a373-93d9-4bb...|3e772c8a-6ae0-428...|         TSLA|2023-11-16|230.96|239.49|240.88|136816800|233.59|
|673b4eb1-4aeb-406...|3e772c8a-6ae0-428...|         TSLA|2023-11-15|236.45|239.29| 246.7|150354000|242.84|
|cce2ac8d-55e3-46a...|3e772c8a-6ae0-428...|         TSLA|2023-11-14|230.72|235.03|238.14|149771600|237.41|
|17340da8-e119-406...|3e772c8a-6ae0-428...|         TSLA|2023-11-13|211.61| 215.6| 225.4|140447600|223.71|
|d2d2ac09-22fc-462...|3e772c8a-6ae0-428...|         TSLA|2023-11-10|205.69|210.03|215.38|130994000|21

## Stock Data Moving Averages Analysis

This PySpark script calculates Simple Moving Averages (SMA) and Exponential Moving Averages (EMA) for different periods on stock data. The analysis includes importing libraries, defining functions, setting parameters, and displaying the results.

- **Moving Averages:**
  - Simple Moving Averages (SMA) for periods: 5, 20, 50, 200.
  - Exponential Moving Averages (EMA) with corresponding alpha values.

- **Data Manipulation:**
  - Utilizes PySpark functions and windows for efficient data processing.

- **Result Display:**
  - Presents the DataFrame with date, close price, SMAs, and EMAs in descending order.


In [12]:
round_to_decimal = 2

# Generate a UUID for each row
cleanedStockData = cleanedStockData.withColumn("cal_id", F.lit(str(uuid.uuid4())))

def calculate_ema(data, alpha):
    ema = data[0]
    for i in range(1, len(data)):
        ema = alpha * data[i] + (1 - alpha) * ema
    return ema

calculate_ema_udf = F.udf(lambda data, alpha: float(calculate_ema(data, alpha)), FloatType())

periods = [5, 20, 50, 200]
alpha_values = [2 / (p + 1) for p in periods]

partition_cols = ["stock_id", "ticker_symbol"]

windows = [Window().partitionBy(partition_cols).orderBy(F.desc("date")).rowsBetween(0, p - 1) for p in periods]

# Calculate simple moving averages
for p in periods:
    cleanedStockData = cleanedStockData.withColumn(f"{p}_days_sma", F.round(F.avg("close").over(windows[periods.index(p)]), 2))

# Calculate exponential moving averages using UDF
for p, alpha in zip(periods, alpha_values):
    cleanedStockData = cleanedStockData.withColumn(f"{p}_days_ema", F.round(calculate_ema_udf(F.collect_list("close").over(windows[periods.index(p)]), F.lit(alpha)), round_to_decimal))

# Show the result
cleanedStockData.select(['cal_id','transaction_id','date'] + [f"{p}_days_sma" for p in periods] + [f"{p}_days_ema" for p in periods]).orderBy(F.desc("date")).show()


# Schema and table name
schema_name = "public"  # Replace with your actual schema name
table_name = f'{schema_name}."MovingAverages"'

# Write the DataFrame to the database table
cleanedStockData.write.jdbc(os.getenv("DB_URL"), table_name, mode="overwrite", properties=db_properties)

                                                                                

+--------------------+--------------------+----------+----------+-----------+-----------+------------+----------+-----------+-----------+------------+
|              cal_id|      transaction_id|      date|5_days_sma|20_days_sma|50_days_sma|200_days_sma|5_days_ema|20_days_ema|50_days_ema|200_days_ema|
+--------------------+--------------------+----------+----------+-----------+-----------+------------+----------+-----------+-----------+------------+
|1240fdb8-656d-49a...|688ed242-e621-4a2...|2023-11-17|    234.37|     217.82|     240.04|      222.29|    232.66|     216.37|     247.66|      213.05|
|1240fdb8-656d-49a...|1f80a373-93d9-4bb...|2023-11-16|    230.44|     216.71|     240.33|      222.06|    226.56|     215.85|      247.6|      212.71|
|1240fdb8-656d-49a...|673b4eb1-4aeb-406...|2023-11-15|    225.72|     216.03|     240.68|       221.8|    222.25|     217.51|      249.0|      213.65|
|1240fdb8-656d-49a...|cce2ac8d-55e3-46a...|2023-11-14|    221.57|     216.03|     240.87|     

## Bollinger Bands Calculation Explanation

This Jupiter Notebook cell performs the computation of Bollinger Bands on stock data for volatility analysis. The breakdown includes critical steps and considerations:

- **Decimal Rounding:**
  - All numerical values are rounded to two decimal places for consistency and readability.

- **Bollinger Bands Periods:**
  - The Bollinger Bands are computed for four distinct periods: 5, 20, 50, and 200 days, providing insights into short-term and long-term volatility.

- **Partitioning for Accuracy:**
  - The data is partitioned by "stock_id" and "ticker_symbol" to ensure accurate calculations for individual stocks. This is crucial for meaningful stock market analysis.

- **Reuse of Exponential Moving Averages (EMAs):**
  - Existing EMA values, previously calculated, are reused in the Bollinger Bands computation. This approach optimizes computational efficiency and maintains consistency with prior analyses.

- **Upper and Lower Band Calculation:**
  - The upper and lower bands are determined by adding and subtracting twice the standard deviation of closing prices from the corresponding EMAs. This methodology aligns with the standard Bollinger Bands formula.

- **Result Presentation:**
  - The final DataFrame includes the date, close price, upper bands, and lower bands for each specified period, providing a comprehensive view of the stock's volatility.

This code enhances the dataset with Bollinger Bands, aiding in the identification of potential market trends and volatility patterns.

In [13]:
# Number of decimal places
round_to_decimal = 2

# Define the Bollinger Bands periods
bollinger_periods = [5, 20, 50, 200]

partition_cols = ["stock_id", "ticker_symbol"]

# Define the windows for Bollinger Bands
windows = [Window().partitionBy(partition_cols).orderBy(F.desc("date")).rowsBetween(0, p - 1) for p in bollinger_periods]

# Reuse the existing EMA values for Bollinger Bands
for p in bollinger_periods:
    upper_band_col = F.col(f"{p}_days_ema") + (2 * F.stddev("close").over(windows[bollinger_periods.index(p)]))
    lower_band_col = F.col(f"{p}_days_ema") - (2 * F.stddev("close").over(windows[bollinger_periods.index(p)]))

    cleanedStockData = cleanedStockData.withColumn(f"{p}_upper_band", F.round(upper_band_col, round_to_decimal))
    cleanedStockData = cleanedStockData.withColumn(f"{p}_lower_band", F.round(lower_band_col, round_to_decimal))

# Show the result
selected_columns = ['cal_id','transaction_id','date'] + [f"{p}_upper_band" for p in bollinger_periods] + [f"{p}_lower_band" for p in bollinger_periods]
cleanedStockData.select(selected_columns).orderBy(F.desc("date")).show()

# Schema and table name
schema_name = "public" 
table_name = f'{schema_name}."BoillingerBands"'

# Write the DataFrame to the database table
cleanedStockData.write.jdbc(os.getenv("DB_URL"), table_name, mode="overwrite", properties=db_properties)


23/11/17 19:12:45 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+--------------------+--------------------+----------+------------+-------------+-------------+--------------+------------+-------------+-------------+--------------+
|              cal_id|      transaction_id|      date|5_upper_band|20_upper_band|50_upper_band|200_upper_band|5_lower_band|20_lower_band|50_lower_band|200_lower_band|
+--------------------+--------------------+----------+------------+-------------+-------------+--------------+------------+-------------+-------------+--------------+
|1240fdb8-656d-49a...|688ed242-e621-4a2...|2023-11-17|      246.63|       240.89|       293.17|        285.45|      218.69|       191.85|       202.15|        140.65|
|1240fdb8-656d-49a...|1f80a373-93d9-4bb...|2023-11-16|      249.07|       239.22|       293.14|        285.25|      204.05|       192.48|       202.06|        140.17|
|1240fdb8-656d-49a...|673b4eb1-4aeb-406...|2023-11-15|      250.61|       239.57|       294.61|         286.4|      193.89|       195.45|       203.39|         140.9

                                                                                

## Relative Strength Index (RSI) Calculation Explanation

This code cell calculates the Relative Strength Index (RSI) for a given stock dataset. The RSI is a momentum oscillator that measures the speed and change of price movements. The breakdown includes critical steps and considerations:

- **Price Changes:**
  - The code calculates the daily price changes by subtracting the previous day's closing price from the current day's closing price.

- **Gains and Losses:**
  - Gains and losses are determined based on whether the price change is positive or negative.

- **Exponential Moving Averages (EMAs):**
  - Exponential Moving Averages (EMAs) are calculated for both gains and losses over the specified RSI period. EMAs give more weight to recent price changes, providing a more responsive indicator.

- **Handling NULL and Zero Values:**
  - NULL values for average gains are handled by replacing them with zero. Additionally, zero values for average losses are replaced with zero to avoid division errors.

- **Relative Strength (RS) Calculation:**
  - The Relative Strength (RS) is calculated as the ratio of average gains to average losses.

- **RSI Calculation:**
  - The RSI is calculated using the standard formula: \(100 - \frac{100}{1 + RS}\). This value is then rounded to two decimal places.

- **List of RSI Periods:**
  - The code calculates RSI for multiple periods, including 14, 20, 50, and 200 days, providing insights into short-term and long-term momentum.

- **Result Presentation:**
  - The final DataFrame includes columns for the ticker symbol, date, closing price, and RSI for each specified period, offering a comprehensive view of the stock's momentum trends.

- **Data Partitioning:**
  - The data is partitioned by "stock_id" and "ticker_symbol" to ensure accurate calculations for individual stocks. This is crucial for meaningful stock market analysis.

- **Decimal Rounding:**
  - All numerical values, including the calculated RSI, are rounded to two decimal places for consistency and readability.

This code enhances the dataset with RSI values, contributing to the analysis of potential overbought or oversold market conditions.


In [14]:
def calculate_rsi(data, n):
    # Calculate price changes
    price_diff = F.col("close") - F.lag("close", 1).over(Window().partitionBy("stock_id", "ticker_symbol").orderBy("date"))
    
    # Separate gains and losses
    gains = F.when(price_diff > 0, price_diff).otherwise(0)
    losses = F.when(price_diff < 0, -price_diff).otherwise(0)
    
    # Calculate average gains and losses over n periods from the latest day backward
    avg_gains = F.avg(gains).over(Window().partitionBy("stock_id", "ticker_symbol").orderBy(F.desc("date")).rowsBetween(0, n-1))
    avg_losses = F.avg(losses).over(Window().partitionBy("stock_id", "ticker_symbol").orderBy(F.desc("date")).rowsBetween(0, n-1))
    
    # Handle NULL values for average gains
    avg_gains = F.coalesce(avg_gains, F.lit(0))

    # Handle 0 values for average losses
    avg_losses = F.when(avg_losses.isNull() | (avg_losses == 0), 0).otherwise(avg_losses)

    # Calculate RSI
    rs = F.when((avg_losses == 0) & (avg_gains != 0), F.lit(avg_gains)) \
      .when((avg_losses != 0) & (avg_gains == 0), F.lit(0)) \
      .otherwise(F.when(avg_losses == 0, F.lit(float('inf'))) \
                  .otherwise(avg_gains / avg_losses))
    
    # Calculate RSI and round to 2 decimal places
    rsi = 100 - (100 / (1 + rs))
    
    return F.round(rsi, 2)

# List of RSI periods
rsi_periods = [14, 20, 50, 200]

# Calculate and add RSI columns to the DataFrame for each period
for n in rsi_periods:
    column_name = f"{n}_days_rsi"
    cleanedStockData = cleanedStockData.withColumn(column_name, calculate_rsi(cleanedStockData, n))

# Show the result
result_columns = ['cal_id','transaction_id','date'] + [f"{n}_days_rsi" for n in rsi_periods]
cleanedStockData.select(result_columns).orderBy(F.desc("date")).show()

# Schema and table name
schema_name = "public" 
table_name = f'{schema_name}."RelativeIndexes"'

cleanedStockData.write.jdbc(os.getenv("DB_URL"), table_name, mode="overwrite", properties=db_properties)

# Stop the Spark session when you're done
spark.stop()

                                                                                

+--------------------+--------------------+----------+-----------+-----------+-----------+------------+
|              cal_id|      transaction_id|      date|14_days_rsi|20_days_rsi|50_days_rsi|200_days_rsi|
+--------------------+--------------------+----------+-----------+-----------+-----------+------------+
|1240fdb8-656d-49a...|688ed242-e621-4a2...|2023-11-17|      72.74|      60.33|      47.52|        52.1|
|1240fdb8-656d-49a...|1f80a373-93d9-4bb...|2023-11-16|      64.53|      55.84|       46.9|       52.37|
|1240fdb8-656d-49a...|673b4eb1-4aeb-406...|2023-11-15|      72.41|      50.06|      48.38|       53.17|
|1240fdb8-656d-49a...|cce2ac8d-55e3-46a...|2023-11-14|      64.88|      43.56|      46.58|       53.21|
|1240fdb8-656d-49a...|17340da8-e119-406...|2023-11-13|      54.83|      37.69|      46.15|       52.09|
|1240fdb8-656d-49a...|d2d2ac09-22fc-462...|2023-11-10|      51.84|      34.34|      42.26|       52.46|
|1240fdb8-656d-49a...|e2f2d74c-020a-4f8...|2023-11-09|      48.4

                                                                                