# PySpark data analysis using Spark DataFrames
Based on [this post](https://towardsdatascience.com/beginners-guide-to-pyspark-bbe3b553b79f).

PySpark DataFrames cheat sheet [here](https://www.datacamp.com/blog/pyspark-cheat-sheet-spark-dataframes-in-python).

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, lit, countDistinct

In [None]:
# Entry point for working with RDD
sc = SparkContext(appName = "pyspark-data-analysis")

In [None]:
ses = SparkSession.builder\
        .master("localhost")\
        .appName("pyspark-data-analysis")\
        .getOrCreate()

### Loading and analysing data structure

In [None]:
# Loading a csv
df = ses.read.csv("./data/stocks_price_final.csv", sep = ",", header = True)

In [None]:
# What is the data type?
type(df)

In [None]:
# Counting number of records
df.count()

In [None]:
# Printing columns and data types identified
df.printSchema()

In [None]:
# Creating the schema to be used for the RRD during file loading
schema = [
    StructField('_c0', IntegerType(), True),
    StructField('symbol', StringType(), True),
    StructField('date', DateType(), True),
    StructField('open', DoubleType(), True),
    StructField('high', DoubleType(), True),
    StructField('low', DoubleType(), True),
    StructField('close', DoubleType(), True),
    StructField('volume', IntegerType(), True),
    StructField('adjusted', DoubleType(), True),
    StructField('market.cap', StringType(), True),
    StructField('sector', StringType(), True),
    StructField('industry', StringType(), True),
    StructField('exchange', StringType(), True),
]

structure = StructType(fields = schema)

In [None]:
# Loading a csv specifying schema
df = ses.read.csv("./data/stocks_price_final.csv", sep = ",", header = True, schema = structure)

In [None]:
# Printing columns and data types identified, you can also use .schema, .dtypes, .columns
df.printSchema()

In [None]:
# Printing the first 20 records, you can also use .take(n), .head(n), .first()
df.show()

### Querying and visualizing data

In [None]:
# Filtering data
# Transforming to pandas is important for visualization purposes using matplotlib
TSLA = df.filter(col("symbol") == lit("TSLA")).toPandas()
GME = df.filter(col("symbol") == lit("GME")).toPandas()

In [None]:
# What is the data type?
type(TSLA)

In [None]:
fig, axes = plt.subplots(2, 1, sharex = True, figsize = (30, 15))
axes[0].plot(TSLA["date"], TSLA["high"], label = "TSLA highest price")
axes[0].plot(TSLA["date"], TSLA["low"], label = "TSLA lowest price")
axes[1].plot(GME["date"], GME["high"], label = "GME highest price")
axes[1].plot(GME["date"], GME["low"], label = "GME lowest price")
axes[0].set_title("Stock price time series for TSLA and GME")
axes[0].set_ylabel("Price")
axes[1].set_ylabel("Price")
plt.xlabel("Date")
axes[0].grid()
axes[1].grid()
axes[0].legend()
axes[1].legend()

In [None]:
ncomp_by_sector = df.select(["symbol", "sector"]).groupBy("sector").agg(countDistinct("symbol").alias("n_companies")).orderBy(col("n_companies")).toPandas()

In [None]:
plt.figure(figsize = (15, 10))
plt.barh(ncomp_by_sector["sector"], ncomp_by_sector["n_companies"])
plt.title("Number of companies by sector")

### Storing data

In [None]:
# Storing data in JSON format
df.filter(col("symbol") == lit("TSLA")).select(["date", "high", "low"]).write.save("./data/TSLA_timeseries", format = "csv", header = "true")

### Stoping Spark context and session

In [None]:
ses.stop()
sc.stop()