### Running Spark on AWS Glue

The following notebook was created in Sagemaker on a Glue dev endpoint. This contains examples of how to run pure PySpark on AWS Glue, avoiding Amazon's Glue subclass/library.

Note that AWS Sagemaker is literally exactly AWS-managed Jupyter.

In [1]:
from awsglue.context import GlueContext
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
import pyspark.sql.functions as F
import pyspark.sql.types as T

import sys
from datetime import datetime, timezone
import dateutil.tz

load_datetime = datetime.now()  # I typically append this to tables for later ref

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
41,application_1597098563813_0042,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### env

I keep a file called env.py in S3 that contains info concerning the AWS environment needed, this cell replicates that file.

In [2]:
prod = {"s3tmp": "s3://company-redshift-prod/glue-temp/",
        "s3src": "s3://company-redshift-prod/data-sources/",
        "rsconn": "company-redshift"
        }

env = prod

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### create contexts

Sagemaker requires a little different syntax for the Spark context than a Glue script as noted here. You will need to comment/uncomment as required.

In [3]:
# create necessary contexts
sc = spark.sparkContext  # Jupyter (use this in Jupyter/Sagemaker)
# sc = SparkContext()  # Glue (use this in actual Glue scripts)
glueContext = GlueContext(sc)
spark = glueContext.spark_session

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### get info from Glue connection(s)

This Glue method essentially treats a Glue connection as a credential store and allows you to pull in creds for use in connection URLs without exposing secrets. This is the only time I use the actual GlueContext(), everything else is pure Spark.

In [45]:
# get creds from redshift connection
conf = glueContext.extract_jdbc_conf(env["rsconn"])

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Spark read on S3

Note here we can now use a pure spark read on S3. Here we read a csv of public traffic data downloaded from https://catalog.data.gov/dataset?res_format=CSV.

In [4]:
df = spark.read.load(env["s3src"] + "monroe-county-crash-data2003-to-2015.csv",
                    format="com.databricks.spark.csv",
                    header="true",
                    inferSchema="true")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Examining the dataset...

In [5]:
df.head()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Row(Master Record Number=902363382, Year=2015, Month=1, Day=5, Weekend?='Weekday', Hour=0, Collision Type='2-Car', Injury Type='No injury/unknown', Primary Factor='OTHER (DRIVER) - EXPLAIN IN NARRATIVE', Reported_Location='1ST & FESS', Latitude=39.15920668, Longitude=-86.52587356)

In [6]:
df.dtypes

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[('Master Record Number', 'int'), ('Year', 'int'), ('Month', 'int'), ('Day', 'int'), ('Weekend?', 'string'), ('Hour', 'int'), ('Collision Type', 'string'), ('Injury Type', 'string'), ('Primary Factor', 'string'), ('Reported_Location', 'string'), ('Latitude', 'double'), ('Longitude', 'double')]

In [7]:
df.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+----+-----+---+--------+----+--------------+------------------+--------------------+--------------------+-----------+------------+
|Master Record Number|Year|Month|Day|Weekend?|Hour|Collision Type|       Injury Type|      Primary Factor|   Reported_Location|   Latitude|   Longitude|
+--------------------+----+-----+---+--------+----+--------------+------------------+--------------------+--------------------+-----------+------------+
|           902363382|2015|    1|  5| Weekday|   0|         2-Car| No injury/unknown|OTHER (DRIVER) - ...|          1ST & FESS|39.15920668|-86.52587356|
|           902364268|2015|    1|  6| Weekday|1500|         2-Car| No injury/unknown|FOLLOWING TOO CLO...|       2ND & COLLEGE|   39.16144|  -86.534848|
|           902364412|2015|    1|  6| Weekend|2300|         2-Car|Non-incapacitating|DISREGARD SIGNAL/...|BASSWOOD & BLOOMF...|39.14978027|-86.56889006|
|           902364551|2015|    1|  7| Weekend| 900|         2-Car|Non-incapacitati

### Spark select

We can select columns like any Spark dataframe...

In [43]:
df_select = df.select(
    F.col("year"),
    F.col("month"),
    F.col("day"),
    F.col("hour"),
    F.col("Collision Type").alias('collision_type'),
    F.col("Injury Type").alias('injury_type'),
    F.lit(load_datetime).alias("DWH_LOAD_TIMESTAMP"))

df_select.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+-----+---+----+--------------+------------------+--------------------+
|year|month|day|hour|collision_type|       injury_type|  DWH_LOAD_TIMESTAMP|
+----+-----+---+----+--------------+------------------+--------------------+
|2015|    1|  5|   0|         2-Car| No injury/unknown|2020-10-12 02:54:...|
|2015|    1|  6|1500|         2-Car| No injury/unknown|2020-10-12 02:54:...|
|2015|    1|  6|2300|         2-Car|Non-incapacitating|2020-10-12 02:54:...|
|2015|    1|  7| 900|         2-Car|Non-incapacitating|2020-10-12 02:54:...|
|2015|    1|  7|1100|         2-Car| No injury/unknown|2020-10-12 02:54:...|
+----+-----+---+----+--------------+------------------+--------------------+
only showing top 5 rows

### Spark SQL

We can do the same with spark.sql() as shown below. Note that in Sagemaker/Jupyter the temp view must be created in the same cell as the spark.sql() that references it or it will not work.

Note also we convert our date/time fields to a true timestamp.

In [42]:
df.createOrReplaceTempView("crash_data")
df_sql = spark.sql("""select
                    to_timestamp(concat(cast(year as string), '-', 
                        cast(month as string), '-', cast(day as string), ' ', 
                            lpad(cast(hour as string),4,'0')), "yyyy-MM-dd HHmm") as datetime
                    ,`Collision Type` as collision_type
                    ,`Injury Type` as injury_type
                    from crash_data""") \
        .withColumn("DWH_LOAD_TIMESTAMP", F.lit(load_datetime))

df_sql.show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+--------------+------------------+--------------------+
|           datetime|collision_type|       injury_type|  DWH_LOAD_TIMESTAMP|
+-------------------+--------------+------------------+--------------------+
|2015-01-05 00:00:00|         2-Car| No injury/unknown|2020-10-12 02:54:...|
|2015-01-06 15:00:00|         2-Car| No injury/unknown|2020-10-12 02:54:...|
|2015-01-06 23:00:00|         2-Car|Non-incapacitating|2020-10-12 02:54:...|
|2015-01-07 09:00:00|         2-Car|Non-incapacitating|2020-10-12 02:54:...|
|2015-01-07 11:00:00|         2-Car| No injury/unknown|2020-10-12 02:54:...|
+-------------------+--------------+------------------+--------------------+
only showing top 5 rows

### Spark read on a jdbc connection

It is also extremely possible to do a Spark read from something like Amazon RDS or Redshift. Here are dummy examples showing syntax for both using a direct query, note that .option("dbtable", table) would be used for a full table read.

In [None]:
# for RDS connection, note the weird syntax required to use a query

conf1 = glueContext.extract_jdbc_conf("name_of_rds_connection")

qy = "select column1, column2, column3 from rds_table"

df = spark.read \
    .format("jdbc") \
    .option("url", conf1["url"] + "/database_name") \
    .option("driver", "com.mysql.jdbc.Driver") \
    .option("dbtable",
            f"({qy}) a") \
    .option("user", conf1["user"]) \
    .option("password", conf1["password"]) \
    .load()

In [50]:
# for Redshift, note there is a "query" option

url = f"{conf['url']}/dwh_test?user={conf['user']}&password={conf['password']}"

qy = "select column1, column2, column3 from redshift_table"

df = spark.read \
    .format("com.databricks.spark.redshift") \
    .option("url", url) \
    .option("query", qy) \
    .option("forward_spark_s3_credentials", 'true') \
    .option("tempdir", env["s3tmp"]) \
    .load()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### write to Redshift

Now we can Spark write to Redshift (you know you want to). Note here we reference our connection creds from above.

In [46]:
url = f"{conf['url']}/dwh_test?user={conf['user']}&password={conf['password']}"
table = 'staging.crash_data'

df_sql.write \
    .format("com.databricks.spark.redshift") \
    .option("url", url) \
    .option("dbtable", table) \
    .option("tempdir", env["s3tmp"]) \
    .option("forward_spark_s3_credentials", "true") \
    .mode("append") \
.save()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### postactions and preactions

This example does not contain an instance of a preaction or postaction, but these are invaluable options on Spark write and worth looking into. These allow you to run SQL before and/or after a write. Below is a dummy version to show the syntax.

In [None]:
table = "my_table"

# using delete instead of truncate will empty a table as a TRANSACTION and will rollback on Spark write fail (!)
preactions = "delete from my_table where 1=1;" 

# you can call an sp or do other cleanup afterwards
postactions = "call public.some_stored_procedure();"

df.write \
    .format("com.databricks.spark.redshift") \
    .option("url", url) \
    .option("dbtable", table) \
    .option("tempdir", env["s3tmp"]) \
    .option("preactions", preactions) \
    .option("postactions", postactions) \
    .option("forward_spark_s3_credentials", "true") \
    .mode("append") \
.save()