# Dataframe Operations

## Import Libraries

In [1]:
from pyspark.sql import SparkSession, Row
from pyspark.sql import functions as F
from pyspark.sql.types import StructField, StructType, StringType, LongType

## Spark Session

In [2]:
spark = SparkSession\
    .builder\
    .master("local[*]")\
    .appName("Test App")\
    .getOrCreate()
spark.catalog.clearCache()

In [3]:
spark

## Load Data

In [4]:
manual_schema = StructType([
    StructField("some", StringType(), True),
    StructField("column", StringType(), True),
    StructField("names", LongType(), True)
])

In [5]:
my_row = Row("Hello", "World", 42)

In [6]:
my_df = spark.createDataFrame([my_row], schema=manual_schema)

In [7]:
my_df.show()

+-----+------+-----+
| some|column|names|
+-----+------+-----+
|Hello| World|   42|
+-----+------+-----+



In [8]:
df = spark.read\
    .format("csv")\
    .option("inferSchema", True)\
    .option("header", True)\
    .load("../../data/flight-data/csv/*.csv")

## Operations

### Select

In [9]:
df.select("DEST_COUNTRY_NAME").show(2)

+-----------------+
|DEST_COUNTRY_NAME|
+-----------------+
|    United States|
|    United States|
+-----------------+
only showing top 2 rows



In [10]:
df.select("DEST_COUNTRY_NAME", "ORIGIN_COUNTRY_NAME").show(3)

+-----------------+-------------------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|
+-----------------+-------------------+
|    United States|            Romania|
|    United States|            Ireland|
|    United States|              India|
+-----------------+-------------------+
only showing top 3 rows



In [11]:
df.select(F.expr("DEST_COUNTRY_NAME AS DESTINATION"), F.expr("ORIGIN_COUNTRY_NAME AS ORIGIN")).show(3)

+-------------+-------+
|  DESTINATION| ORIGIN|
+-------------+-------+
|United States|Romania|
|United States|Ireland|
|United States|  India|
+-------------+-------+
only showing top 3 rows



In [12]:
df.selectExpr("*", "(DEST_COUNTRY_NAME = ORIGIN_COUNTRY_NAME) AS LOCAL_FLIGHT", "(count * 10) AS MULTI_FLIGHT" )

DataFrame[DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string, count: int, LOCAL_FLIGHT: boolean, MULTI_FLIGHT: int]

In [13]:
df.selectExpr("AVG(count)", "MAX(count)", "MIN(count)", "COUNT(DISTINCT ORIGIN_COUNTRY_NAME)").show()

+------------------+----------+----------+-----------------------------------+
|        avg(count)|max(count)|min(count)|count(DISTINCT ORIGIN_COUNTRY_NAME)|
+------------------+----------+----------+-----------------------------------+
|1718.3189081225032|    370002|         1|                                154|
+------------------+----------+----------+-----------------------------------+



#### Renaming Colums

In [14]:
df.selectExpr("DEST_COUNTRY_NAME AS DEST", "ORIGIN_COUNTRY_NAME AS ORIGIN", "count AS FLIGHTS").show(3)

df\
    .withColumnRenamed("DEST_COUNTRY_NAME", "DEST")\
    .withColumnRenamed("ORIGIN_COUNTRY_NAME", "ORIGIN")\
    .withColumnRenamed("count", "FLIGHTS")\
    .show(5)

+-------------+-------+-------+
|         DEST| ORIGIN|FLIGHTS|
+-------------+-------+-------+
|United States|Romania|      1|
|United States|Ireland|    264|
|United States|  India|     69|
+-------------+-------+-------+
only showing top 3 rows

+-----------------+-------------+-------+
|             DEST|       ORIGIN|FLIGHTS|
+-----------------+-------------+-------+
|    United States|      Romania|      1|
|    United States|      Ireland|    264|
|    United States|        India|     69|
|            Egypt|United States|     24|
|Equatorial Guinea|United States|      1|
+-----------------+-------------+-------+
only showing top 5 rows



If a column name has spaces or reserverd keywords use backticks \`colum name-with reserved+characters\`

In [15]:
df.withColumnRenamed("count", "Number of -flights-")\
    .selectExpr("Number of -flights-") # This will fail

ParseException: 
Possibly unquoted identifier of-flights detected. Please consider quoting it with back-quotes as `of-flights`(line 1, pos 10)

== SQL ==
Number of -flights-
----------^^^


In [16]:
df.withColumnRenamed("count", "Number of -flights-")\
    .selectExpr("`Number of -flights-`") # This won't fail because backtics

DataFrame[Number of -flights-: int]

In [17]:
df.select("dest_country_name").show()  # Case insensitive unless specified with 

+--------------------+
|   dest_country_name|
+--------------------+
|       United States|
|       United States|
|       United States|
|               Egypt|
|   Equatorial Guinea|
|       United States|
|       United States|
|          Costa Rica|
|             Senegal|
|       United States|
|              Guyana|
|       United States|
|               Malta|
|             Bolivia|
|            Anguilla|
|Turks and Caicos ...|
|       United States|
|Saint Vincent and...|
|               Italy|
|       United States|
+--------------------+
only showing top 20 rows



#### Removing Columns

In [18]:
df

DataFrame[DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string, count: int]

In [19]:
df.select("DEST_COUNTRY_NAME", "count").show(1)

+-----------------+-----+
|DEST_COUNTRY_NAME|count|
+-----------------+-----+
|    United States|    1|
+-----------------+-----+
only showing top 1 row



In [20]:
# Same as:
df.drop("ORIGIN_COUNTRY_NAME").show(1)

+-----------------+-----+
|DEST_COUNTRY_NAME|count|
+-----------------+-----+
|    United States|    1|
+-----------------+-----+
only showing top 1 row



### Casting Columns

In [21]:
df.withColumn("count(float)", F.col("count").cast("float")).show(5)

+-----------------+-------------------+-----+------------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|count(float)|
+-----------------+-------------------+-----+------------+
|    United States|            Romania|    1|         1.0|
|    United States|            Ireland|  264|       264.0|
|    United States|              India|   69|        69.0|
|            Egypt|      United States|   24|        24.0|
|Equatorial Guinea|      United States|    1|         1.0|
+-----------------+-------------------+-----+------------+
only showing top 5 rows



### Filtering

.where() and .filter() are the same 

In [22]:
df.filter(F.col("count") < 2).show(5)

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|       United States|            Romania|    1|
|   Equatorial Guinea|      United States|    1|
|               Malta|      United States|    1|
|Saint Vincent and...|      United States|    1|
|            Slovakia|      United States|    1|
+--------------------+-------------------+-----+
only showing top 5 rows



Because of spark optimizations, it's better to chain filters instead of using ANDs inside 1 filter 

In [23]:
df.filter((F.col("count") < 2) & (F.col("ORIGIN_COUNTRY_NAME") != "Romania")).show(5)  # This works but its better to let Spark optimize

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|   Equatorial Guinea|      United States|    1|
|               Malta|      United States|    1|
|Saint Vincent and...|      United States|    1|
|            Slovakia|      United States|    1|
|       United States|             Cyprus|    1|
+--------------------+-------------------+-----+
only showing top 5 rows



In [24]:
df.filter(F.col("count") < 2)\
    .filter(F.col("ORIGIN_COUNTRY_NAME") != "Romania")\
    .show(5)

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|   Equatorial Guinea|      United States|    1|
|               Malta|      United States|    1|
|Saint Vincent and...|      United States|    1|
|            Slovakia|      United States|    1|
|       United States|             Cyprus|    1|
+--------------------+-------------------+-----+
only showing top 5 rows



#### Distinct

In [25]:
df.select("ORIGIN_COUNTRY_NAME").distinct().show(5)

+-------------------+
|ORIGIN_COUNTRY_NAME|
+-------------------+
|             Russia|
|           Anguilla|
|           Paraguay|
|            Senegal|
|             Sweden|
+-------------------+
only showing top 5 rows



In [26]:
# Or a combination by 
df.select("ORIGIN_COUNTRY_NAME", "DEST_COUNTRY_NAME").distinct().show(5)

+-------------------+-----------------+
|ORIGIN_COUNTRY_NAME|DEST_COUNTRY_NAME|
+-------------------+-----------------+
|       Saint Martin|    United States|
|             Guinea|    United States|
|            Romania|    United States|
|            Croatia|    United States|
|            Ireland|    United States|
+-------------------+-----------------+
only showing top 5 rows



### Random Split of Dataframes

#### Random Sampling

In [27]:
seed = 42
withReplacement = False
fraction = 0.1
df.sample(withReplacement, fraction, seed).show()

+-----------------+--------------------+-----+
|DEST_COUNTRY_NAME| ORIGIN_COUNTRY_NAME|count|
+-----------------+--------------------+-----+
|       Costa Rica|       United States|  477|
|    United States|         Afghanistan|    2|
|            Italy|       United States|  390|
|         Colombia|       United States|  785|
|    United States|               Palau|   30|
|    United States|             Finland|   20|
|    United States|              Greece|   61|
|    United States|           Hong Kong|  293|
|    United States|               Egypt|   25|
|    United States|              Turkey|   87|
|    United States|             Estonia|    1|
|    United States|            Thailand|   13|
|    United States|Turks and Caicos ...|  147|
|           Cyprus|       United States|    2|
|            Qatar|       United States|   41|
|    United States|Saint Vincent and...|   16|
|            Aruba|       United States|  359|
|        Singapore|       United States|   25|
|    United S

In [28]:
split_dataframes = df.randomSplit([0.25, 0.75], seed=42)

In [29]:
print(f"First DF Count: {split_dataframes[0].count()}")
print(f"Second DF Count: {split_dataframes[1].count()}")

First DF Count: 368
Second DF Count: 1134


Train test split for Machine Learning

In [30]:
train, test = df.randomSplit([0.75, 0.25], seed=101)

In [31]:
train.count()

1097

In [32]:
test.count()

405

### Concat and Append

Unions are based on location and not schema, so be careful

In [33]:
from pyspark.sql import Row

In [34]:
schema = df.schema

In [35]:
newRows = [
    Row("New Country", "Other Country", 5),
    Row("New Country 2", "Other Country 3", 3)
]

In [36]:
parallelizedRows = spark.sparkContext.parallelize(newRows)

In [37]:
newDF = spark.createDataFrame(parallelizedRows, schema=schema)

In [38]:
newDF.show()

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|      New Country|      Other Country|    5|
|    New Country 2|    Other Country 3|    3|
+-----------------+-------------------+-----+



In [39]:
df.union(newDF)\
    .filter("count = 5")\
    .filter("DEST_COUNTRY_NAME != 'United States'")\
    .show()

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    French Guiana|      United States|    5|
|           Guinea|      United States|    5|
|      Afghanistan|      United States|    5|
|      New Country|      Other Country|    5|
+-----------------+-------------------+-----+



### Sorting Columns

In [40]:
df.sort("count").show(5)

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|       United States|               Mali|    1|
|Saint Helena, Asc...|      United States|    1|
|       United States|           Malaysia|    1|
|       United States|              Niger|    1|
|             Burundi|      United States|    1|
+--------------------+-------------------+-----+
only showing top 5 rows



In [41]:
df.sort("count", "DEST_COUNTRY_NAME").show(5)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|       Azerbaijan|      United States|    1|
|          Belarus|      United States|    1|
|          Belarus|      United States|    1|
|           Brunei|      United States|    1|
|         Bulgaria|      United States|    1|
+-----------------+-------------------+-----+
only showing top 5 rows



Ascending or Descending Order

3 Different ways to do it!  
F.desc(), F.col("x").desc() and F.expr("col desc")

In [42]:
df.orderBy(F.desc("count"), F.col("DEST_COUNTRY_NAME").asc(), F.expr("ORIGIN_COUNTRY_NAME desc")).show()

+-----------------+-------------------+------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|
+-----------------+-------------------+------+
|    United States|      United States|370002|
|    United States|      United States|358354|
|    United States|      United States|352742|
|    United States|      United States|348113|
|    United States|      United States|347452|
|    United States|      United States|343132|
|    United States|             Canada|  8650|
|           Canada|      United States|  8514|
|    United States|             Canada|  8483|
|           Canada|      United States|  8399|
|    United States|             Canada|  8305|
|           Canada|      United States|  8271|
|    United States|             Canada|  8177|
|    United States|             Canada|  8097|
|           Canada|      United States|  8034|
|    United States|             Canada|  7983|
|           Canada|      United States|  7974|
|           Canada|      United States|  7860|
|    United S

For optimization purposes, sometimes it's better to sort within Partitions before doing more transofrmations

In [44]:
df.sortWithinPartitions(F.col("count").desc_nulls_first()).show(5)

+-----------------+-------------------+------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|
+-----------------+-------------------+------+
|    United States|      United States|370002|
|    United States|      United States|352742|
|    United States|      United States|348113|
|    United States|             Canada|  8650|
|           Canada|      United States|  8514|
+-----------------+-------------------+------+
only showing top 5 rows



### Limit

In [55]:
df.orderBy(F.expr("count DESC")).limit(10).show()

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|            Malaysia|      United States|    1|
|       United States|            Algeria|    1|
|       United States|            Romania|    1|
|          Azerbaijan|      United States|    1|
|               Malta|      United States|    1|
|             Liberia|      United States|    1|
|Saint Vincent and...|      United States|    1|
|       United States|            Vietnam|    1|
|            Slovakia|      United States|    1|
|       United States|            Estonia|    1|
+--------------------+-------------------+-----+



### Repartition and Coalesce

Partition calls a full shuffle, so do it only if partitions needs to increase ex:

In [56]:
df.rdd.getNumPartitions()

2

In [57]:
part_df = df.repartition(5)

In [58]:
part_df.rdd.getNumPartitions()

5

You can also repartition based on a specific column if you're going to filter by it a lot:

In [59]:
dest_df = df.repartition(F.col("DEST_COUNTRY_NAME"))

In [60]:
dest_df.rdd.getNumPartitions()

200

Can also repartition by a column and set a number of partitions

In [62]:
small_dest_df = df.repartition(10, F.col("DEST_COUNTRY_NAME"))

In [63]:
small_dest_df.rdd.getNumPartitions()

10

## Collect to Driver

When you colllect/take data, its passed to the driver (make sure driver is big enough for the data collected)

In [70]:
collectDF = df\
    .filter("DEST_COUNTRY_NAME != 'United States'")\
    .filter("count > 1")\
    .filter("count < 1000")\
    .orderBy(F.col("count").desc_nulls_last())\
    .limit(10)

In [71]:
collectDF.collect()

[Row(DEST_COUNTRY_NAME='Brazil', ORIGIN_COUNTRY_NAME='United States', count=995),
 Row(DEST_COUNTRY_NAME='Brazil', ORIGIN_COUNTRY_NAME='United States', count=979),
 Row(DEST_COUNTRY_NAME='The Bahamas', ORIGIN_COUNTRY_NAME='United States', count=975),
 Row(DEST_COUNTRY_NAME='Brazil', ORIGIN_COUNTRY_NAME='United States', count=969),
 Row(DEST_COUNTRY_NAME='South Korea', ORIGIN_COUNTRY_NAME='United States', count=968),
 Row(DEST_COUNTRY_NAME='France', ORIGIN_COUNTRY_NAME='United States', count=966),
 Row(DEST_COUNTRY_NAME='The Bahamas', ORIGIN_COUNTRY_NAME='United States', count=955),
 Row(DEST_COUNTRY_NAME='The Bahamas', ORIGIN_COUNTRY_NAME='United States', count=950),
 Row(DEST_COUNTRY_NAME='France', ORIGIN_COUNTRY_NAME='United States', count=935),
 Row(DEST_COUNTRY_NAME='Brazil', ORIGIN_COUNTRY_NAME='United States', count=927)]

* There is another way of collecting by partition using .toLocalIterator() 

In [81]:
my_local = df.repartition(F.col("DEST_COUNTRY_NAME")).limit(10).toLocalIterator()

In [82]:
my_local  # Is a generator

<generator object _local_iterator_from_socket.<locals>.PyLocalIterable.__iter__ at 0x7fc0108ffb30>

In [83]:
for item in my_test:
    print(item)

Row(DEST_COUNTRY_NAME='Chad', ORIGIN_COUNTRY_NAME='United States', count=1)
Row(DEST_COUNTRY_NAME='Anguilla', ORIGIN_COUNTRY_NAME='United States', count=21)
Row(DEST_COUNTRY_NAME='Russia', ORIGIN_COUNTRY_NAME='United States', count=152)
Row(DEST_COUNTRY_NAME='Paraguay', ORIGIN_COUNTRY_NAME='United States', count=90)
Row(DEST_COUNTRY_NAME='Anguilla', ORIGIN_COUNTRY_NAME='United States', count=41)
Row(DEST_COUNTRY_NAME='Russia', ORIGIN_COUNTRY_NAME='United States', count=176)
Row(DEST_COUNTRY_NAME='Paraguay', ORIGIN_COUNTRY_NAME='United States', count=60)
Row(DEST_COUNTRY_NAME='Anguilla', ORIGIN_COUNTRY_NAME='United States', count=21)
Row(DEST_COUNTRY_NAME='Russia', ORIGIN_COUNTRY_NAME='United States', count=199)
Row(DEST_COUNTRY_NAME='Paraguay', ORIGIN_COUNTRY_NAME='United States', count=85)
