In [0]:
from pyspark.sql.functions import col, countDistinct, mean, sum, count
from pyspark.sql import functions as F

In [0]:
# Point file path
path = '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv'

# Load Data
df = spark.read.csv(path, header=True)

# Update columns data types
df = df.select( col('carat').cast('float'),
                'cut', 'color', 'clarity',
                col('depth').cast('float'),
                col('table').cast('float'),
                col('price').cast('float'),
                col('x').cast('float'),
                col('y').cast('float'),
                col('z').cast('float')    )

In [0]:
display(
    df.sample(withReplacement=None, fraction=0.01, seed=123)
)

In [0]:
print(f'Rows: {df.count():,} and Cols: {len(df.columns)}' )

In [0]:
vals = [row.cut for row in df.select('cut').distinct().collect()]
vals

In [0]:
(df
 .limit(3)
 .select(F.greatest('x', 'y', 'z').alias("greatest"))
 .collect()
 )

In [0]:
display(df
        .withColumn('ceiling', F.ceiling('x'))
        .withColumn('floor', F.floor('x'))
        .select('x', 'floor', 'ceiling')
        )

In [0]:
display(df
        .limit(100)
        .groupBy('cut')
        .agg(F.collect_list('carat') )
 )

In [0]:
dbutils.fs.ls(path)

In [0]:
df2 = df.groupBy("cut", 'color').agg(mean('price').alias('avg_price'))
df2.count()

In [0]:
df3 = df2.where(col("color") == 'E')
df3.count()

In [0]:
df2 = df.groupBy("cut", 'color').agg(mean('price').alias('avg_price')).cache()
df2.count()

In [0]:
df3 = df2.where(col("color") == 'E')
df3.count()

Notice that cache might not be the best option for all the cases. Sometimes, the dataset is small and fits well in the temporary memory, being processed faster than the cached version. 

The cached version, give it's parallelized, can become larger than the original data, in that case, leading to a worse performance. 
For large data, though, it's always a good option to consider.

In [0]:
%sql
clear cache

In [0]:
display(
    df
    .agg(F.percentile('price', .25).alias('25th'),
         F.percentile('price', .5).alias('median_price'),
         F.percentile('price', .75).alias('75th'),
         F.percentile('price', .95).alias('95th'))
)

25th,median_price,75th,95th
950.0,2401.0,5324.25,13107.099999999991
