In [None]:
"""
# test stock trades data from Mockaroo, limited to 1k rows
# Questions to answer
# What is the average stock price for each stock?
# which stock had the most trades in a single date (transacted in a single day)?
# Which stock saw the most volatililty in total number of stocks traded in a 2-day window?
# which stock was the most sold and which one was most bought?
"""

In [5]:
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.window import Window

In [6]:
stocks = "MOCK_DATA.csv"

In [7]:
spark = SparkSession.builder.appName("sample_stock_data").getOrCreate()

23/05/15 22:44:19 WARN Utils: Your hostname, Ravis-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.6.23 instead (on interface en0)
23/05/15 22:44:19 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/05/15 22:44:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [15]:
stocks_df = spark.read.format("csv").option("inferSchema", "true").option("header", "true").load(stocks)

In [26]:
stocks_df.count()

1000

In [17]:
stocks_df.printSchema()

root
 |-- trade_id: integer (nullable = true)
 |-- stock_symbol: string (nullable = true)
 |-- trade_type: string (nullable = true)
 |-- trade_date: string (nullable = true)
 |-- trade_time: string (nullable = true)
 |-- trade_price: double (nullable = true)
 |-- trade_volume: integer (nullable = true)
 |-- trade_currency: string (nullable = true)
 |-- trade_fee: double (nullable = true)
 |-- trade_notes: string (nullable = true)



In [20]:
def removeDuplicates(df):
    dup_cnt, no_dup_cnt = stocks_df.count(), stocks_df.distinct().count()
    diff = dup_cnt - no_dup_cnt
    if diff:
        return df.dropDuplicates()
    else:
        return df

In [21]:
stocks_df = removeDuplicates(stocks_df)

In [25]:
# drop rows that have any NULL values
columns = ['trade_id', 'stock_symbol', 'trade_type', 'trade_date', 'trade_time', 'trade_price', 'trade_volume', 'trade_currency', \
           'trade_fee', 'trade_notes']
stocks_df = stocks_df.na.drop(subset=columns)

In [27]:
stocks_df.createOrReplaceTempView("stocks")

In [30]:
# What is the average stock price for each stock?
avg_stock = spark.sql(
            f"""
            SELECT 
            stock_symbol,
            ROUND(AVG(trade_price)) AS avg_stock_price
            FROM stocks
            GROUP BY 1
            ORDER BY 1
            """
            )
avg_stock.show()

+------------+---------------+
|stock_symbol|avg_stock_price|
+------------+---------------+
|        AAPL|       492361.0|
|        AMZN|       518753.0|
|          FB|       466061.0|
|       GOOGL|       498964.0|
|        TSLA|       499001.0|
+------------+---------------+



In [32]:
# pyspark
avg_stock2 = stocks_df.groupBy('stock_symbol')\
                        .agg(F.round(F.avg('trade_price')).alias('avg_stock_price'))\
                        .orderBy('stock_symbol')
avg_stock2.show()

+------------+---------------+
|stock_symbol|avg_stock_price|
+------------+---------------+
|        AAPL|       492361.0|
|        AMZN|       518753.0|
|          FB|       466061.0|
|       GOOGL|       498964.0|
|        TSLA|       499001.0|
+------------+---------------+



In [36]:
# which stock had the most trades in a single date (transacted in a single day)?
most_trades_in_a_day = spark.sql(
                    f"""
                    WITH CTE AS 
                    (
                    SELECT 
                    stock_symbol,
                    trade_date,
                    COUNT(trade_id) AS trades_cnt
                    FROM stocks
                    GROUP BY 1, 2
                    ORDER BY 3 DESC
                    LIMIT 1
                    )
                    SELECT 
                    DISTINCT stock_symbol
                    FROM CTE
                    """
                    )
most_trades_in_a_day.show()

+------------+
|stock_symbol|
+------------+
|          FB|
+------------+



In [39]:
most_trades_in_a_day2 = stocks_df.groupBy('stock_symbol', 'trade_date')\
.agg(F.count('trade_id').alias('trades_cnt'))\
.orderBy(F.desc('trades_cnt'))\
.limit(1)\
.select('stock_symbol')

most_trades_in_a_day2.show()

+------------+
|stock_symbol|
+------------+
|          FB|
+------------+



In [44]:
# Which stock saw the most volatililty in total number of stocks traded in a 2-day window?
most_volatile = spark.sql("""
WITH CTE AS 
(
SELECT 
stock_symbol,
trade_date,
SUM(trade_volume) as total_shares
FROM stocks
GROUP BY 1,2
)

SELECT 
stock_symbol,
diff_shares_cnt
FROM 
(
SELECT 
stock_symbol,
trade_date,
(total_shares - LAG(total_shares, 1) OVER (PARTITION BY stock_symbol ORDER BY trade_date)) AS diff_shares_cnt
FROM CTE
) temp
ORDER BY 2 DESC
LIMIT 1
""")
most_volatile.show()

+------------+---------------+
|stock_symbol|diff_shares_cnt|
+------------+---------------+
|        AAPL|        2174607|
+------------+---------------+



In [45]:
# pyspark
cte = stocks_df.groupBy('stock_symbol', 'trade_date')\
.agg(F.sum('trade_volume').alias('total_shares'))

window_spec = Window.partitionBy('stock_symbol').orderBy('trade_date')
diff_shares = cte.select('stock_symbol', 'trade_date', 
    (F.col('total_shares') - F.lag('total_shares').over(window_spec)).alias('diff_shares_cnt'))
result = diff_shares.orderBy(F.desc('diff_shares_cnt')).limit(1)\
.select('stock_symbol', 'diff_shares_cnt')
result.show()

+------------+---------------+
|stock_symbol|diff_shares_cnt|
+------------+---------------+
|        AAPL|        2174607|
+------------+---------------+



In [52]:
# which stock was the most sold and which one was most bought?
most_sold_bought = spark.sql(
f"""
WITH trade_cnt AS 
(
SELECT 
stock_symbol,
SUM(CASE WHEN trade_type = 'buy' THEN 1 ELSE 0 END) as bought,
SUM(CASE WHEN trade_type = 'sold' THEN 1 ELSE 0 END) as sold
FROM stocks
GROUP BY 1
),
most_sold AS 
(
SELECT stock_symbol, 'most_sold' AS description
FROM trade_cnt
ORDER BY sold DESC
LIMIT 1
),
most_bought AS 
(
SELECT stock_symbol, 'most_bought' AS description
FROM trade_cnt
ORDER BY bought DESC
LIMIT 1
)
SELECT * FROM most_sold
UNION ALL
SELECT * FROM most_bought
"""                
)
most_sold_bought.show()

+------------+-----------+
|stock_symbol|description|
+------------+-----------+
|        AAPL|  most_sold|
|          FB|most_bought|
+------------+-----------+



In [58]:
# pyspark
trade_cnt = (
stocks_df.groupBy("stock_symbol")
    .agg(sum(when(stocks_df['trade_type'] == 'buy', 1).otherwise(0)).alias("bought"),
    sum(when(stocks_df['trade_type'] == 'sold', 1).otherwise(0)).alias("sold"))
)

most_sold = (
trade_cnt
    .orderBy("sold", ascending=False)
    .limit(1)
    .select("stock_symbol", lit("most_sold").alias("description"))
)
most_bought = (
trade_cnt
    .orderBy("bought", ascending=False)
    .limit(1)
    .select("stock_symbol", lit("most_bought").alias("description"))
)

most_sold_bought = most_sold.union(most_bought)
most_sold_bought.show()

+------------+-----------+
|stock_symbol|description|
+------------+-----------+
|        AAPL|  most_sold|
|          FB|most_bought|
+------------+-----------+

