# Getting started with PySpark

A short introduction to PySpark is provided in this notebook. It is merely a starting point for exploring core features such as PySpark dataframes.

In [None]:
from datetime import datetime, date

import pandas as pd
from pyspark.sql import SparkSession, Row

## SparkSession

The main entry point for a PySpark application is provided by `SparkSession`. Before the unification that has been introduced in Spark 2.0, `SparkContext` used to provide one of three different starting points. A session object can be initialized as it is shown in the following cell.

In [None]:
spark = SparkSession.builder \
    .appName('TestApp') \
    .getOrCreate()

## DataFrames

A PySpark `DataFrame` can be created in different ways, for example through a `pd.DataFrame` or a list of rows and and an explicit schema.

In [None]:
# create pandas dataframe
pandas_df = pd.DataFrame(
    {
        'a': [1, 2, 3],
        'b': [2., 3., 4.],
        'c': ['string1', 'string2', 'string3'],
        'd': [date(2000, 1, 1), date(2000, 2, 1), date(2000, 3, 1)],
        'e': [datetime(2000, 1, 1, 12, 0), datetime(2000, 1, 2, 12, 0), datetime(2000, 1, 3, 12, 0)]
    }
)

# create dataframe from pandas dataframe
spark_df = spark.createDataFrame(pandas_df)

# transform back to pandas
# pandas_df = spark_df.toPandas()

In [None]:
# create spark dataframe from list of rows
spark_df = spark.createDataFrame(
    [
        Row(a=1, b=2., c='string1', d=date(2000, 1, 1), e=datetime(2000, 1, 1, 12, 0)),
        Row(a=2, b=3., c='string2', d=date(2000, 2, 1), e=datetime(2000, 1, 2, 12, 0)),
        Row(a=4, b=5., c='string3', d=date(2000, 3, 1), e=datetime(2000, 1, 3, 12, 0))
    ]
)

In [None]:
# create spark dataframe with an explicit schema
spark_df = spark.createDataFrame(
    [
        (1, 2., 'string1', date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
        (2, 3., 'string2', date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
        (3, 4., 'string3', date(2000, 3, 1), datetime(2000, 1, 3, 12, 0))
    ],
    schema='a long, b double, c string, d date, e timestamp'
)

In [None]:
spark_df.show()
spark_df.printSchema()

## Accessing data

In [None]:
# select columns (note that dataframes are lazily evaluated)
one_col = spark_df.a

two_cols = spark_df.select('a', 'b')
two_cols = spark_df['a', 'b']

print(one_col)
print(two_cols)

In [None]:
# return the first rows
list_of_first_rows = spark_df.head(2) # list_of_first_rows = spark_df.take(2)
list_of_last_rows = spark_df.tail(2)

print(list_of_first_rows)
print(list_of_last_rows)

In [None]:
# collect distributed data to the driver (note that this may cause an out-of-memory error)
list_of_all_rows = spark_df.collect()

print(len(list_of_all_rows))

In [None]:
# filter rows of dataframe
filtered_df = spark_df.filter(spark_df.a == 1)

filtered_df.show()

## Grouping data

In [None]:
df = spark.createDataFrame(
    [
        ['red', 'banana', 1, 10],
        ['blue', 'banana', 2, 20],
        ['red', 'carrot', 3, 30],
        ['blue', 'grape', 4, 40],
        ['red', 'carrot', 5, 50],
        ['black', 'carrot', 6, 60],
        ['red', 'banana', 7, 70],
        ['red', 'grape', 8, 80]
    ],
    schema=['color', 'fruit', 'v1', 'v2']
)

df.show()
df.printSchema()

In [None]:
df.groupby('color').avg().show()

## SQL queries

In [None]:
# register dataframe as SQL table
df.createOrReplaceTempView('tableA')

# run SQL-style query
spark.sql('SELECT count(*) from tableA').show()

## Close session

In [None]:
# close session
spark.stop()