# Introducing Spark DataFrames

A Spark DataFrame is conceptually equivalent to other DataFrames we are more familiar with, Pandas or R DataFrames. The key difference is that Spark DataFrame has richer optimizations under the hood and the benefit of being distributed across a cluster.

Under the covers, Spark DataFrames are immutable distributed collections of data, aka Resilient Distributed Datasets (RDDs). 

* **Resilient**: They are fault-tolerant, so if part of your operation fails, Spark quickly recovers the lost computation.
* **Distributed**: RDDs are distributed across networked machines known as a cluster.
* **DataFrame**: A data structure where data is organized into named columns, like a table in a relational database, but with richer optimizations under the hood.

Let's first create a new schema called `IDS` in Databricks and added some tables to the schema. Load and run this [notebook]([https://github.com/happyrabbit/IntroDataScience/blob/master/Python/LoadDatasetSpark.ipynb](https://github.com/happyrabbit/IntroDataScience/blob/master/Python/LoadDatasetSpark.ipynb)).

# Read and write data

You can run SQL queries to read data from the schema we just created with the special syntax `spark.sql("SELECT * FROM myTable")`

In [3]:
from pyspark.sql import functions as F
SimDat = spark.sql("select * from ids.segdata_df")

In [4]:
## Check the head of the data
## show the top 6 rows
SimDat.show(6)

# Tidy and Reshape Data

We will illustrate the data manipulations in order:

- Display
- Query
- Summarize
- Create new variable
- Merge
- Reshape data

## Display

`describe` function can generate descriptive statistics of each column. The descriptive statistics include

* Count – Count of values of each column
* Mean – Mean value of each column
* Stddev – standard deviation of each column
* Min – Minimum value of each column
* Max – Maximum value of each column

In [7]:
display(SimDat.describe())

summary,age,gender,income,house,store_exp,online_exp,store_trans,online_trans,Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8,Q9,Q10,segment
count,1000.0,1000,816.0,1000,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000
mean,38.84,,113543.06522194942,,1356.850523048724,2120.1811872603525,5.35,13.546,3.101,1.823,1.992,2.763,2.945,2.448,3.434,2.396,3.085,2.32,
stddev,16.416817959394926,,49842.28719659571,,2774.39978497596,1731.2243079138482,3.695559112929619,7.956959042071111,1.4501385802404894,1.1683475664527387,1.4021062377373203,1.155060675966923,1.2843771348885904,1.4385288893051,1.455940851739439,1.1543467921367978,1.118492747417714,1.1361737943162085,
min,16.0,Female,41775.637022548,No,-500.0,68.81722750413141,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,Conspicuous
max,300.0,Male,319704.337940878,Yes,50000.0,9479.44230953055,20.0,36.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,Style


In [8]:
# Columns
SimDat.columns

In [9]:
# Column Data Type
SimDat.printSchema()

## Query

DataFrame syntax is more flexible than SQL syntax. Here we illustrate general usage patterns of SQL and DataFrames.

Suppose we have a data set we loaded as a table called `myTable` and an equivalent DataFrame, called `df`.
We have three fields/columns called `col_1` (numeric type), `col_2` (string type) and `col_3` (timestamp type)
Here are basic SQL operations and their DataFrame equivalents. 

Notice that columns in DataFrames are referenced by `col("<columnName>")`.

| SQL                                         | DataFrame (Python)                    |
| ------------------------------------------- | ------------------------------------- | 
| `SELECT col_1 FROM myTable`                 | `df.select(col("col_1"))`             | 
| `DESCRIBE myTable`                          | `df.printSchema()`                    | 
| `SELECT * FROM myTable WHERE col_1 > 0`     | `df.filter(col("col_1") > 0)`         | 
| `.GROUP BY col_2`                          | `.groupBy(col("col_2"))`             | 
| `.ORDER BY col_2`                          | `.orderBy(col("col_2"))`             | 
| `.WHERE year(col_3) > 1990`                | `.filter(year(col("col_3")) > 1990)` | 
| `SELECT * FROM myTable LIMIT 10`            | `df.limit(10)`                        |
| `display(myTable)` (text format)            | `df.show()`                           | 
| `display(myTable)` (html format)            | `display(df)`                         |

In [11]:
# Select rows that meet logical criteria. For example, get rows with income more than 300000

SubDat = SimDat.filter(F.col('income') > 300000)
display(SubDat)

age,gender,income,house,store_exp,online_exp,store_trans,online_trans,Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8,Q9,Q10,segment
41,Male,317476.198766336,Yes,3029.8442719051,4179.67069641638,11,12,1,4,5,4,4,4,4,1,4,2,Conspicuous
37,Female,315697.156955637,Yes,6548.97012066964,4284.06452951059,13,11,1,4,5,4,4,4,4,1,4,2,Conspicuous
40,Male,301398.026985122,Yes,4840.46078736722,3618.211998630221,10,11,1,4,4,4,4,4,4,1,4,1,Conspicuous
33,Male,319704.337940878,Yes,5998.305280860251,4395.923168115151,9,11,1,4,4,4,4,4,4,1,4,2,Conspicuous


In [12]:
# get records with age between 20 and 40
SubDat = SimDat.filter(SimDat.age.between(20,40))
display(SubDat)

age,gender,income,house,store_exp,online_exp,store_trans,online_trans,Q1,Q2,Q3,Q4,Q5,Q6,Q7,Q8,Q9,Q10,segment
33,Male,174461.035852419,Yes,3916.76067388849,7322.934531126509,7,14,1,4,5,4,4,4,4,1,4,2,Conspicuous
38,Female,247626.234595709,Yes,5731.83879848872,5340.25441107555,12,7,1,4,5,4,4,4,4,1,4,1,Conspicuous
37,Male,197545.892379841,Yes,5287.41480092342,4221.309359340789,8,11,1,4,5,4,4,4,4,1,4,1,Conspicuous
34,Female,210712.440204045,Yes,5257.71802698796,4838.52722735659,13,10,1,4,4,4,4,4,4,1,4,1,Conspicuous
30,Male,,Yes,5130.69569642776,4546.07858685111,9,7,1,4,4,4,4,4,4,1,4,2,Conspicuous
38,Male,110136.832624483,Yes,4364.83797108174,4809.2466659957,17,10,1,4,5,4,4,4,4,1,4,2,Conspicuous
38,Male,220745.714756061,Yes,5318.03454397684,4468.867575897701,14,18,1,4,4,4,4,4,4,1,4,2,Conspicuous
38,Female,271749.669436335,Yes,5275.00282248712,3813.2107303581,5,15,1,4,4,4,4,4,4,1,4,2,Conspicuous
37,Male,139062.795876456,Yes,4070.6805018801897,7595.57843440283,16,13,1,4,4,4,4,4,4,1,4,1,Conspicuous
33,Male,102389.638269935,Yes,4070.76652010764,5683.644832397979,13,11,1,4,4,4,4,4,4,1,4,1,Conspicuous


In [13]:
# select columns
SubDat = SimDat.select('age','gender','income')
SubDat.show(3)

In [14]:
# sort and look at the top n records
SimDat.sort(col("income").desc()).show(5)

In [15]:
# Delete duplicated rows.
SimDat.dropDuplicates()

## Summarize

A standard marketing problem is customer segmentation. It usually starts with designing survey and collecting data. Then run a cluster analysis on the data to get customer segments. Once we have different segments, the next is to understand how each group of customer look like by summarizing some key metrics. For example, we can do the following data aggregation for different segments of clothes customers.

In [17]:
df = (SimDat
      # here we import pyspark.sql.functions as F
      .withColumn('isFemale',  F.when(SimDat["gender"] == "Female", 1).otherwise(0))
      .groupby('segment')
      .agg({'age': 'mean',
            'store_trans': 'mean',
            'online_trans': 'mean',
           'isFemale': 'mean'})
      )
df.show()

In [18]:
# change column names
df = (df
      .withColumnRenamed('avg(store_trans)','avg_store_trans')
      .withColumnRenamed('avg(age)','avg_age')
      .withColumnRenamed('avg(online_trans)','avg_online_trans')
      .withColumnRenamed('avg(isFemale))','female_pct')
     )
df.show()

## Create new variable

You can use `.withColumn` function to create new column based on the current columns. We used this function when we summarized the data (`.withColumn('isFemale',  F.when(SimDat["gender"] == "Female", 1).otherwise(0))`). If you want to get a new column that tells if `store_trans` is larger than `online_trans`, you can do:

In [20]:
SimDat.withColumn('store_lg_online', SimDat.store_trans > SimDat.online_trans).show(5)

## Merge
We create two baby data sets to show how merge works.

In [22]:
dfx = spark.createDataFrame([
  ('A', 1.0, 'online'),
  ('B', 2.0, 'store'),
  ('C', 3.0, 'online')
], 
  ["ID", "x1", "type"]
)

display(dfx)

ID,x1,type
A,1.0,online
B,2.0,store
C,3.0,online


In [23]:
dfy = spark.createDataFrame([
  ('B', True),
  ('C', True),
  ('D', False),
  ('E', True)
], 
  ["ID", "y1"]
)

display(dfy)

ID,y1
B,True
C,True
D,False
E,True


In [24]:
# Join matching rows from dfy to dfx
dfx.join(dfy, how = 'left', on = 'ID').show()

In [25]:
# Retain only rows in both sets
dfx.join(dfy, how = 'inner', on = 'ID').show()

In [26]:
# Retain all values, all rows
df_outer = dfx.join(dfy, how = 'outer', on = 'ID')
df_outer.show()

## Imputing Null or Missing Data

Null values refer to unknown or missing data as well as irrelevant responses. Strategies for dealing with this scenerio include:<br><br>

* **Dropping these records:** Works when you do not need to use the information for downstream workloads
* **Adding a placeholder (e.g. `-1`):** Allows you to see missing data later on without violating a schema
* **Basic imputing:** Allows you to have a "best guess" of what the data could have been, often by using the mean of non-missing data
* **Advanced imputing:** Determines the "best guess" of what data should be using more advanced strategies such as clustering machine learning algorithms or oversampling techniques 

Let's use the `df_outer` data frame as an example.

In [28]:
# Drop any records that have null values.
DroppedNa = df_outer.dropna("any")
DroppedNa.show()

In [29]:
# Impute values with the mean.
from pyspark.ml.feature import Imputer

In [30]:
# only support numerical column
imputer = Imputer(strategy='mean', inputCols=["x1"], outputCols=["out_x1"])
model = imputer.fit(df_outer)
model.transform(df_outer).show()

In [31]:
ImputedDF = df_outer.na.fill({"type": "store", "y1": True})
ImputedDF.show()

## Reshape data

Take a baby subset of our exemplary clothes consumers data to illustrate.

In [33]:
sdat = (SimDat
          .sample(False, 0.1)
          .select('age', 'gender', 'income','house', 'store_exp', 'online_exp')
         .dropna('any')
         )
  
sdat.show(5)

For the above data `sdat`, what if we want to have a variable indicating the purchasing channel (i.e. online or in-store) and another column with the corresponding expense amount? Assume we want to keep the rest of the columns the same. It is a task to change data from “wide” to “long”. There are two general ways to shape data:

- Use `melt()` to convert an object into a molten data frame, i.e., from wide to long (function defined below)
- Use `pivot()` to cast a molten data frame into the shape you want, i.e., from long to wide

In [35]:
from pyspark.sql.functions import array, col, explode, lit, struct

def melt(df, id_vars, value_vars, var_name, value_name):
    """Convert :class:`DataFrame` from wide to long format."""
    _vars_and_vals = F.array(*[F.struct(F.lit(c).alias(var_name),
                                        F.col(c).alias(value_name)) 
                               for c in value_vars])

    # Add to the DataFrame and explode
    _tmp = df.withColumn("_vars_and_vals", F.explode(_vars_and_vals))

    cols = id_vars + [F.col("_vars_and_vals")[x].alias(x) for x in [var_name, value_name]]
    return _tmp.select(*cols)

In [36]:
melt_sdat = melt(sdat,  ['age','gender'], ['store_exp','online_exp'], 'Channel','Expense')
melt_sdat.show()

You melted the data frame `sdat` by two variables: `store_exp` and `online_exp`. The new variable name is `Channel` . The value name is `Expense`.

Sometimes we want to convert the data from “long” to “wide”. For example, you want to compare the online and in-store expense between male and female.

In [38]:
pivot_sdat = melt_sdat.groupby("gender").pivot("Channel").agg({'Expense': 'mean'})
pivot_sdat.show()