# Importing libs and reading csv

In [45]:
from pyspark.sql import SparkSession

spark=SparkSession.builder.appName('pyspark_test').getOrCreate()

df = spark.read.option('header','true').csv('ggplot2_diamonds.csv',inferSchema=True)

# option for detecting NAs as null, but i'll stick without it to show how to deal 
# with overall NAs
# df = spark.read.csv('ggplot2_diamonds.csv', sep=',', header=True, inferSchema=True,
#                          nullValue='NA')


# Printing Options

In [46]:
df.printSchema()

root
 |-- carat: string (nullable = true)
 |-- cut: string (nullable = true)
 |-- color: string (nullable = true)
 |-- clarity: string (nullable = true)
 |-- depth: string (nullable = true)
 |-- table: double (nullable = true)
 |-- price: string (nullable = true)
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)
 |-- z: double (nullable = true)



In [47]:
df.head(5)

[Row(carat='0.23', cut='Ideal', color='NA', clarity='SI2', depth='61.5', table=55.0, price='326', x=3.95, y=3.98, z=2.43),
 Row(carat='0.21', cut='Premium', color='E', clarity='SI1', depth='59.8', table=61.0, price='NA', x=3.89, y=3.84, z=2.31),
 Row(carat='0.23', cut='Good', color='E', clarity='VS1', depth='56.9', table=65.0, price='327', x=4.05, y=4.07, z=2.31),
 Row(carat='0.29', cut='Premium', color='I', clarity='VS2', depth='62.4', table=58.0, price='334', x=4.2, y=4.23, z=2.63),
 Row(carat='0.31', cut='Good', color='J', clarity='SI2', depth='NA', table=58.0, price='335', x=4.34, y=4.35, z=2.75)]

In [48]:
df.show(10)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|    Ideal|   NA|    SI2| 61.5| 55.0|  326|3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0|   NA|3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|  327|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2|   NA| 58.0|  335|4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|  336|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|  336|3.95|3.98|2.47|
|   NA|Very Good|    H|    SI1| 61.9| 55.0|  337|4.07|4.11|2.53|
| 0.22|     Fair|    E|     NA| 65.1| 61.0|  337|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|  338| 4.0|4.05|2.39|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
only showing top 10 rows



In [49]:
df.describe().show()



+-------+-------------------+---------+-----+-------+------------------+------------------+------------------+------------------+------------------+------------------+
|summary|              carat|      cut|color|clarity|             depth|             table|             price|                 x|                 y|                 z|
+-------+-------------------+---------+-----+-------+------------------+------------------+------------------+------------------+------------------+------------------+
|  count|              53940|    53940|53940|  53940|             53940|             53940|             53940|             53940|             53940|             53940|
|   mean|  0.797949720981092|     null| null|   null| 61.74937614712838| 57.45718390804603|3932.8665900368937| 5.731157211716609| 5.734525954764462|3.5387337782723316|
| stddev|0.47400997915312215|     null| null|   null|1.4326190412759938|2.2344905628213247|3989.4464914438868|1.1217607467924915|1.1421346741235616|0.7056988469



# Operating with columns

## Selecting columns

In [50]:
df['carat']

Column<'carat'>

In [110]:
asd_cols = ['color','clarity','table']

df.select(*asd_cols)

DataFrame[color: string, clarity: string, table: float]

## Add new column

In [52]:
df = df.withColumn('table_doubled',2 * df['table'])
df.show(5)

+-----+-------+-----+-------+-----+-----+-----+----+----+----+-------------+
|carat|    cut|color|clarity|depth|table|price|   x|   y|   z|table_doubled|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+-------------+
| 0.23|  Ideal|   NA|    SI2| 61.5| 55.0|  326|3.95|3.98|2.43|        110.0|
| 0.21|Premium|    E|    SI1| 59.8| 61.0|   NA|3.89|3.84|2.31|        122.0|
| 0.23|   Good|    E|    VS1| 56.9| 65.0|  327|4.05|4.07|2.31|        130.0|
| 0.29|Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|        116.0|
| 0.31|   Good|    J|    SI2|   NA| 58.0|  335|4.34|4.35|2.75|        116.0|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+-------------+
only showing top 5 rows



## Drop column

In [53]:
df = df.drop('table_doubled')
df.show(5)

+-----+-------+-----+-------+-----+-----+-----+----+----+----+
|carat|    cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|  Ideal|   NA|    SI2| 61.5| 55.0|  326|3.95|3.98|2.43|
| 0.21|Premium|    E|    SI1| 59.8| 61.0|   NA|3.89|3.84|2.31|
| 0.23|   Good|    E|    VS1| 56.9| 65.0|  327|4.05|4.07|2.31|
| 0.29|Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|
| 0.31|   Good|    J|    SI2|   NA| 58.0|  335|4.34|4.35|2.75|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
only showing top 5 rows



## Rename column

In [54]:
df.withColumnRenamed('x','asdasd').show()

+-----+---------+-----+-------+-----+-----+-----+------+----+----+
|carat|      cut|color|clarity|depth|table|price|asdasd|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+------+----+----+
| 0.23|    Ideal|   NA|    SI2| 61.5| 55.0|  326|  3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0|   NA|  3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|  327|  4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|  334|   4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2|   NA| 58.0|  335|  4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|  336|  3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|  336|  3.95|3.98|2.47|
|   NA|Very Good|    H|    SI1| 61.9| 55.0|  337|  4.07|4.11|2.53|
| 0.22|     Fair|    E|     NA| 65.1| 61.0|  337|  3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|  338|   4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1|   64| 55.0|  339|  4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|  340|  3.93| 3.9|2

## Casting column types

### Option 01: use python builtin datatypes

In [55]:
df.withColumn("carat",df["carat"].cast('float')).printSchema()

root
 |-- carat: float (nullable = true)
 |-- cut: string (nullable = true)
 |-- color: string (nullable = true)
 |-- clarity: string (nullable = true)
 |-- depth: string (nullable = true)
 |-- table: double (nullable = true)
 |-- price: string (nullable = true)
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)
 |-- z: double (nullable = true)



### Option 02: use pyspark datatypes (preferable)

In [56]:
from pyspark.sql.types import StringType, DateType, FloatType

float_cols = ['carat','depth','table','price']

for col in float_cols:
    df = df.withColumn(col,df[col].cast(FloatType()))

df.printSchema()

root
 |-- carat: float (nullable = true)
 |-- cut: string (nullable = true)
 |-- color: string (nullable = true)
 |-- clarity: string (nullable = true)
 |-- depth: float (nullable = true)
 |-- table: float (nullable = true)
 |-- price: float (nullable = true)
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)
 |-- z: double (nullable = true)



In [57]:
str_cols = ['color','cut','clarity']

for col in str_cols:
    df = df.withColumn(col,df[col].cast(StringType()))

df.printSchema()

root
 |-- carat: float (nullable = true)
 |-- cut: string (nullable = true)
 |-- color: string (nullable = true)
 |-- clarity: string (nullable = true)
 |-- depth: float (nullable = true)
 |-- table: float (nullable = true)
 |-- price: float (nullable = true)
 |-- x: double (nullable = true)
 |-- y: double (nullable = true)
 |-- z: double (nullable = true)



# Removing NAs


Notice that we have `NA`s in rows number 1,2,4,8,9

In [58]:
df.show(15)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|    Ideal|   NA|    SI2| 61.5| 55.0|326.0|3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0| null|3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2| null| 58.0|335.0|4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| null|Very Good|    H|    SI1| 61.9| 55.0|337.0|4.07|4.11|2.53|
| 0.22|     Fair|    E|     NA| 65.1| 61.0|337.0|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0|339.0|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|340.0|3.93| 3.9|2.46|
| 0.22|  Premium|    F|  

In [59]:
df.na.drop().show(15)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|    Ideal|   NA|    SI2| 61.5| 55.0|326.0|3.95|3.98|2.43|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| 0.22|     Fair|    E|     NA| 65.1| 61.0|337.0|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0|339.0|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|340.0|3.93| 3.9|2.46|
| 0.22|  Premium|    F|    SI1| 60.4| 61.0|342.0|3.88|3.84|2.33|
| 0.31|    Ideal|    J|    SI2| 62.2| 54.0|344.0|4.35|4.37|2.71|
|  0.2|  Premium|    E|    SI2| 60.2| 62.0|345.0|3.79|3.75|2.27|
| 0.32|  Premium|    E|  

The `df.na.drop()` method indeed drops `NA` rows but only for float types, not fot strings. The reason why this is happening is that `pyspark` understands that `NA` is a string and not a null, so we must convert the `NA` string into the `null` datatype.

Lets do it for one column (`color`) to see how this can be done.

In [60]:
from pyspark.sql.functions import col,when


df.withColumn("color", when(col("color")=="NA" ,None).otherwise(col("color"))).show()



+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|    Ideal| null|    SI2| 61.5| 55.0|326.0|3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0| null|3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2| null| 58.0|335.0|4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| null|Very Good|    H|    SI1| 61.9| 55.0|337.0|4.07|4.11|2.53|
| 0.22|     Fair|    E|     NA| 65.1| 61.0|337.0|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0|339.0|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|340.0|3.93| 3.9|2.46|
| 0.22|  Premium|    F|  

In [119]:
import pyspark.sql.functions as f
df.agg(*[f.count(f.when(f.isnull(c), c)).alias(c) for c in df.columns]).show()

+-----+---+-----+-------+-----+-----+-----+---+---+---+
|carat|cut|color|clarity|depth|table|price|  x|  y|  z|
+-----+---+-----+-------+-----+-----+-----+---+---+---+
|    1|  0|    1|      1|    1|    0|    1|  0|  0|  0|
+-----+---+-----+-------+-----+-----+-----+---+---+---+



Now we loop through the `str` columns.

In [61]:
df.select([when(col(c)=="NA",None).otherwise(col(c)).alias(c) for c in str_cols]).show(10)

+-----+---------+-------+
|color|      cut|clarity|
+-----+---------+-------+
| null|    Ideal|    SI2|
|    E|  Premium|    SI1|
|    E|     Good|    VS1|
|    I|  Premium|    VS2|
|    J|     Good|    SI2|
|    J|Very Good|   VVS2|
|    I|Very Good|   VVS1|
|    H|Very Good|    SI1|
|    E|     Fair|   null|
|    H|Very Good|    VS1|
+-----+---------+-------+
only showing top 10 rows



In [62]:
df = df.select([when(col(c)=="NA",None).otherwise(col(c)).alias(c) for c in df.columns])
df.show(10)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|    Ideal| null|    SI2| 61.5| 55.0|326.0|3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0| null|3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2| null| 58.0|335.0|4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| null|Very Good|    H|    SI1| 61.9| 55.0|337.0|4.07|4.11|2.53|
| 0.22|     Fair|    E|   null| 65.1| 61.0|337.0|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
only showing top 10 rows



In [63]:
df.na.drop().show(15)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|338.0| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0|339.0|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|340.0|3.93| 3.9|2.46|
| 0.22|  Premium|    F|    SI1| 60.4| 61.0|342.0|3.88|3.84|2.33|
| 0.31|    Ideal|    J|    SI2| 62.2| 54.0|344.0|4.35|4.37|2.71|
|  0.2|  Premium|    E|    SI2| 60.2| 62.0|345.0|3.79|3.75|2.27|
| 0.32|  Premium|    E|     I1| 60.9| 58.0|345.0|4.38|4.42|2.68|
|  0.3|    Ideal|    I|    SI2| 62.0| 54.0|348.0|4.31|4.34|2.68|
|  0.3|     Good|    J|  

# filling missing values

In [65]:
df.na.fill('Missing Values',str_cols).show()

+-----+---------+--------------+--------------+-----+-----+-----+----+----+----+
|carat|      cut|         color|       clarity|depth|table|price|   x|   y|   z|
+-----+---------+--------------+--------------+-----+-----+-----+----+----+----+
| 0.23|    Ideal|Missing Values|           SI2| 61.5| 55.0|326.0|3.95|3.98|2.43|
| 0.21|  Premium|             E|           SI1| 59.8| 61.0| null|3.89|3.84|2.31|
| 0.23|     Good|             E|           VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|
| 0.29|  Premium|             I|           VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|
| 0.31|     Good|             J|           SI2| null| 58.0|335.0|4.34|4.35|2.75|
| 0.24|Very Good|             J|          VVS2| 62.8| 57.0|336.0|3.94|3.96|2.48|
| 0.24|Very Good|             I|          VVS1| 62.3| 57.0|336.0|3.95|3.98|2.47|
| null|Very Good|             H|           SI1| 61.9| 55.0|337.0|4.07|4.11|2.53|
| 0.22|     Fair|             E|Missing Values| 65.1| 61.0|337.0|3.87|3.78|2.49|
| 0.23|Very Good|           

# Imputing missing values

In [66]:
from pyspark.ml.feature import Imputer

imputer = Imputer(
    inputCols=float_cols, 
    outputCols=["{}_imputed".format(c) for c in float_cols]
    ).setStrategy("median")

In [68]:
imputer.fit(df).transform(df).show(15)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+-------------+-------------+-------------+-------------+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|carat_imputed|depth_imputed|table_imputed|price_imputed|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+-------------+-------------+-------------+-------------+
| 0.23|    Ideal| null|    SI2| 61.5| 55.0|326.0|3.95|3.98|2.43|         0.23|         61.5|         55.0|        326.0|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0| null|3.89|3.84|2.31|         0.21|         59.8|         61.0|       2400.0|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|327.0|4.05|4.07|2.31|         0.23|         56.9|         65.0|        327.0|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|334.0| 4.2|4.23|2.63|         0.29|         62.4|         58.0|        334.0|
| 0.31|     Good|    J|    SI2| null| 58.0|335.0|4.34|4.35|2.75|         0.31|         61.8|         58.0|        335.0|
| 0.24|Very Good|    J|   VVS2| 

In [69]:
Imputer(
    inputCols=float_cols, 
    outputCols=["{}".format(c) for c in float_cols]
    ).setStrategy("median").fit(df).transform(df).show(15)

+-----+---------+-----+-------+-----+-----+------+----+----+----+
|carat|      cut|color|clarity|depth|table| price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+------+----+----+----+
| 0.23|    Ideal| null|    SI2| 61.5| 55.0| 326.0|3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0|2400.0|3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0| 327.0|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0| 334.0| 4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2| 61.8| 58.0| 335.0|4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0| 336.0|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0| 336.0|3.95|3.98|2.47|
|  0.7|Very Good|    H|    SI1| 61.9| 55.0| 337.0|4.07|4.11|2.53|
| 0.22|     Fair|    E|   null| 65.1| 61.0| 337.0|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0| 338.0| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0| 339.0|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0| 340.0|3.93| 3.9|2.46|
| 0.22|  P

# Filtering rows

In [72]:
df.filter("depth <= 58").show(5)

+-----+-------+-----+-------+-----+-----+------+----+----+----+
|carat|    cut|color|clarity|depth|table| price|   x|   y|   z|
+-----+-------+-----+-------+-----+-----+------+----+----+----+
| 0.23|   Good|    E|    VS1| 56.9| 65.0| 327.0|4.05|4.07|2.31|
| 0.26|   Good|    E|   VVS1| 57.9| 60.0| 554.0|4.22|4.25|2.45|
| 0.86|   Fair|    E|    SI2| 55.1| 69.0|2757.0|6.45|6.33|3.52|
|  0.7|   Good|    E|    VS2| 57.5| 58.0|2759.0|5.85| 5.9|3.38|
|  0.7|Premium|    D|    VS2| 58.0| 62.0|2773.0|5.87|5.78|3.38|
+-----+-------+-----+-------+-----+-----+------+----+----+----+
only showing top 5 rows



In [73]:
df.filter(df['depth'] <= 58).show(5)

+-----+-------+-----+-------+-----+-----+------+----+----+----+
|carat|    cut|color|clarity|depth|table| price|   x|   y|   z|
+-----+-------+-----+-------+-----+-----+------+----+----+----+
| 0.23|   Good|    E|    VS1| 56.9| 65.0| 327.0|4.05|4.07|2.31|
| 0.26|   Good|    E|   VVS1| 57.9| 60.0| 554.0|4.22|4.25|2.45|
| 0.86|   Fair|    E|    SI2| 55.1| 69.0|2757.0|6.45|6.33|3.52|
|  0.7|   Good|    E|    VS2| 57.5| 58.0|2759.0|5.85| 5.9|3.38|
|  0.7|Premium|    D|    VS2| 58.0| 62.0|2773.0|5.87|5.78|3.38|
+-----+-------+-----+-------+-----+-----+------+----+----+----+
only showing top 5 rows



In [77]:
df.filter((df['depth'] <= 58) | (df['depth'] >= 70)).show(5)

+-----+-------+-----+-------+-----+-----+------+----+----+----+
|carat|    cut|color|clarity|depth|table| price|   x|   y|   z|
+-----+-------+-----+-------+-----+-----+------+----+----+----+
| 0.23|   Good|    E|    VS1| 56.9| 65.0| 327.0|4.05|4.07|2.31|
| 0.26|   Good|    E|   VVS1| 57.9| 60.0| 554.0|4.22|4.25|2.45|
| 0.86|   Fair|    E|    SI2| 55.1| 69.0|2757.0|6.45|6.33|3.52|
|  0.7|   Good|    E|    VS2| 57.5| 58.0|2759.0|5.85| 5.9|3.38|
|  0.7|Premium|    D|    VS2| 58.0| 62.0|2773.0|5.87|5.78|3.38|
+-----+-------+-----+-------+-----+-----+------+----+----+----+
only showing top 5 rows



In [80]:
df.filter((df['depth'] <= 65) & (df['price'] >= 2800)).show(5)

+-----+---------+-----+-------+-----+-----+------+----+----+----+
|carat|      cut|color|clarity|depth|table| price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+------+----+----+----+
|  0.6|Very Good|    G|     IF| 61.6| 56.0|2800.0|5.43|5.46|3.35|
|  0.9|     Good|    I|    SI2| 62.2| 59.0|2800.0|6.07|6.11|3.79|
|  0.7|  Premium|    E|    VS1| 62.2| 58.0|2800.0| 5.6|5.66| 3.5|
|  0.9|Very Good|    I|    SI2| 61.3| 56.0|2800.0|6.17|6.23| 3.8|
| 0.83|    Ideal|    G|    SI1| 62.3| 57.0|2800.0|5.99|6.08|3.76|
+-----+---------+-----+-------+-----+-----+------+----+----+----+
only showing top 5 rows



In [82]:
df.filter((df['depth'] <= 65) & (df['price'] >= 2800) & (df["color"] == "G")).show(5)

+-----+---------+-----+-------+-----+-----+------+----+----+----+
|carat|      cut|color|clarity|depth|table| price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+------+----+----+----+
|  0.6|Very Good|    G|     IF| 61.6| 56.0|2800.0|5.43|5.46|3.35|
| 0.83|    Ideal|    G|    SI1| 62.3| 57.0|2800.0|5.99|6.08|3.76|
| 0.83|    Ideal|    G|    SI1| 61.8| 57.0|2800.0|6.03|6.07|3.74|
| 0.74|  Premium|    G|    VS1| 62.9| 60.0|2800.0|5.74|5.68|3.59|
| 0.61|    Ideal|    G|     IF| 62.3| 56.0|2800.0|5.43|5.45|3.39|
+-----+---------+-----+-------+-----+-----+------+----+----+----+
only showing top 5 rows



## Negating booleans with the `~` operator

In [83]:
df.filter((df['depth'] <= 65) & ~(df['price'] >= 2800) & (df["color"] == "G")).show(5)

+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|Very Good|    G|   VVS2| 60.4| 58.0|354.0|3.97|4.01|2.41|
| 0.23|    Ideal|    G|    VS1| 61.9| 54.0|404.0|3.93|3.95|2.44|
| 0.28|    Ideal|    G|   VVS2| 61.4| 56.0|553.0|4.19|4.22|2.58|
| 0.31|Very Good|    G|    SI1| 63.3| 57.0|553.0|4.33| 4.3|2.73|
| 0.31|  Premium|    G|    SI1| 61.8| 58.0|553.0|4.35|4.32|2.68|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
only showing top 5 rows



# Aggregation operations

Notice that when aggregating without specifying a column, `pyspark` recognizes only the numeric ones and computes the statistical operation for only those.

In [84]:
df.groupBy('cut').mean().show()



+---------+------------------+-----------------+------------------+------------------+-----------------+-----------------+------------------+
|      cut|        avg(carat)|       avg(depth)|        avg(table)|        avg(price)|           avg(x)|           avg(y)|            avg(z)|
+---------+------------------+-----------------+------------------+------------------+-----------------+-----------------+------------------+
|  Premium|0.8919548964733383| 61.2646726240803| 58.74609527936416|4584.5664974619285| 5.97388731781595|5.944878543977969|3.6471242114422924|
|    Ideal|0.7028369900618932|61.70940094928742|  55.9516681366512| 3457.541970210199|5.507450698343484|5.520079810681653|3.4014481926593225|
|     Good|0.8491846683696732|62.36568799655615| 58.69463921992865| 3928.864451691806|5.838785161027321|5.850743986954751|3.6395067264573817|
|     Fair|1.0461366415301465| 64.0416770484877| 59.05378881892803| 4358.757763975155|6.246894409937905| 6.18265217391304|3.9827701863353964|
|Very 



In [85]:
df.groupBy('cut').count().show()

+---------+-----+
|      cut|count|
+---------+-----+
|  Premium|13791|
|    Ideal|21551|
|     Good| 4906|
|     Fair| 1610|
|Very Good|12082|
+---------+-----+



In [89]:
df.agg(
    {'carat':'sum', 'cut':'count','table':'std','x':'max','y':'min'}
    ).show()

+------+------+----------+-----------------+------------------+
|max(x)|min(y)|count(cut)|       sum(carat)|     stddev(table)|
+------+------+----------+-----------------+------------------+
| 10.74|   0.0|     53940|43040.60989624262|2.2344905638396657|
+------+------+----------+-----------------+------------------+



# Pivoting

One point that I found intriguing in `pyspark` is that there's no simple implementationg for pivoting without aggregating, like with `R` `pivot_wider` or `pivot_longer` functions. So here I'll just show how to do it with aggregation function, mostly because at this point I just don't know how to pivot with agg.

In [111]:
df2 = df.select(*asd_cols)
df2.show(5)

+-----+-------+-----+
|color|clarity|table|
+-----+-------+-----+
| null|    SI2| 55.0|
|    E|    SI1| 61.0|
|    E|    VS1| 65.0|
|    I|    VS2| 58.0|
|    J|    SI2| 58.0|
+-----+-------+-----+
only showing top 5 rows



In [113]:
df3 = df2.dropna().groupBy('clarity').pivot('color').mean()
df3.show()

+-------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+
|clarity|                 D|                 E|                 F|                G|                 H|                 I|                 J|
+-------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+
|   VVS2|  57.1432188009914| 57.13451060228704| 56.99897435506185|56.92224531642943|56.921217115301836| 56.94657535291698| 57.72290074006292|
|    SI1| 57.55448871819491| 57.63235778997284|57.783481936228775|57.62525303836776| 57.64918681218074|  57.7197331444601|57.746800008138024|
|     IF| 57.36986301369863|56.946835433380514| 56.57896104292436|56.31791483113069|56.319397993311036| 56.49930070997118| 57.01960784313726|
|     I1|58.333333333333336| 59.25588233798158| 57.93356643356643|58.00600001017253|58.238271595519265| 58.28260869565217|             58.54|
|   VV

## Unpivoting

From [sparkbyexamples](https://sparkbyexamples.com/pyspark/pyspark-pivot-and-unpivot-dataframe/)

> PySpark SQL doesn’t have unpivot function hence will use the stack() function. 

In [117]:
from pyspark.sql.functions import expr

unpivotExpr = "stack(7, 'D', D, 'E', E, 'F', F,'G',G, 'H', H,'I',I,'J',J) as (Cut,Mean)"
unPivotDF = df3.select("clarity", expr(unpivotExpr))
unPivotDF.show()

+-------+---+------------------+
|clarity|Cut|              Mean|
+-------+---+------------------+
|   VVS2|  D|  57.1432188009914|
|   VVS2|  E| 57.13451060228704|
|   VVS2|  F| 56.99897435506185|
|   VVS2|  G| 56.92224531642943|
|   VVS2|  H|56.921217115301836|
|   VVS2|  I| 56.94657535291698|
|   VVS2|  J| 57.72290074006292|
|    SI1|  D| 57.55448871819491|
|    SI1|  E| 57.63235778997284|
|    SI1|  F|57.783481936228775|
|    SI1|  G| 57.62525303836776|
|    SI1|  H| 57.64918681218074|
|    SI1|  I|  57.7197331444601|
|    SI1|  J|57.746800008138024|
|     IF|  D| 57.36986301369863|
|     IF|  E|56.946835433380514|
|     IF|  F| 56.57896104292436|
|     IF|  G| 56.31791483113069|
|     IF|  H|56.319397993311036|
|     IF|  I| 56.49930070997118|
+-------+---+------------------+
only showing top 20 rows

