## JDBC I/O in Apache Spark

In this tutorial you'll learn the basics of reading and writing Apache Spark DataFrames to an SQL database,
using Apache Spark's JDBC API.

#### Tutorial Requirements
This tutorial assumes you have write privileges (including table-create privileges) to a postgresql database.
Instructions below will show you how to fill in your database connect information.

This tutorial includes the Maven coordinates for using the postgresql JDBC driver.
If you wish to connect to a different SQL database vendor, and you have access to the proper JDBC driver Maven package (or jar file), you should be able to run this demo against a non-postgresql DB.

To run this tutorial you also need to connect to an Apache Spark cluster.
You can enter the spark master hostname in the cell that creates the Spark Session.

#### Apache Spark Configurations

Spark can receive its configuration parameters from a 
[variety](https://spark.apache.org/docs/latest/configuration.html#dynamically-loading-spark-properties)
of channels.
In general, configurations set via a `SparkConf` object (as below) will override all other
configurations.
However, there are a few glitches in this rule, and the `spark.package.jars` parameter is one of them,
which is important for this tutorial.
To maximize clarity, this notebook unsets `PYSPARK_SUBMIT_ARGS` in favor of doing all
configurations using a `SparkConf` object so that it is easy to read.

In [1]:
import os
# Disable this so that configuration of 'spark.jars.packages' works correctly
if 'PYSPARK_SUBMIT_ARGS' in os.environ:
    del os.environ['PYSPARK_SUBMIT_ARGS']

#### JDBC Drivers
In order to work with Spark's JDBC API, you'll need to provide Spark with a JDBC driver.
Many drivers are available as Maven-style packages,
such as the driver for
[postgresql](https://www.postgresql.org/)
in the cell below.

Consuming JDBC drivers via Maven coordinates and `spark.jars.packages` is convenient,
since Spark will automatically download such packages and install them on Spark executors.

In some cases, database vendors may provide a "raw" jar file instead of a Maven package.
The code comments below show a Spark connection alternative that specifies the driver as a jar file
using the `spark.jars` configuration parameter.
Installing and configuring jar files directly may also be well suited for building container images
that can run on platforms like Kubernetes or OpenShift.

In [2]:
from pyspark import SparkConf

# Instantiate a spark configuration object to receive settings
spark_conf = SparkConf()

# Maven coordinates for package containing JDBC drivers
jdbc_driver_packages = 'org.postgresql:postgresql:42.2.9'

# Configure spark to see the postgresql driver package
spark_conf.set('spark.jars.packages', jdbc_driver_packages)

# Alternative method: directly list path to your JDBC driver jar (or jars)
# jdbc_driver_jars = '/path/to/postgresql-42.2.9.jar'
# spark_conf.set('spark.jars', jdbc_driver_jars)

<pyspark.conf.SparkConf at 0x7f66b88087f0>

#### Obtaining a Spark Session

Before we can begin, we need to attach to a running Apache Spark cluster.
In this cell, you'll set the hostname of the Spark master to connect to.
The `SparkConf` settings instruct my session to use just a single executor with 1 cpu core. 

In [3]:
from pyspark.sql import SparkSession

# The name of your Spark cluster hostname or ip address
spark_cluster = 'spark-cluster-eje'

# Configure some basic spark cluster sizing parameters
spark_conf.set('spark.cores.max', 1)
spark_conf.set('spark.executor.cores', '1')

spark = SparkSession.builder \
    .master('spark://{cluster}:7077'.format(cluster=spark_cluster)) \
    .appName('Spark-JDBC-Demo') \
    .config(conf = spark_conf) \
    .getOrCreate()

#### Checking JDBC Driver Configuration

As we discussed above, Spark has some subtle bugs in its normal configuration precedence orderings when it comes to `spark.package.jars` and `spark.jars`.
You can use the `getConf()` method to sanity-check the final settings that Spark is using.
In this cell, we are checking that the jar-files for our postgresql JDBC driver will actually be visible in Spark's classpath.

In [4]:
'postgresql' in spark.sparkContext.getConf().get('spark.jars')

True

#### Example Data

For the purposes of this tutorial, we'll be working with a small example data table.
The first column is some consecutive integers, and the second is the squares of the first:

In [5]:
data_raw = [(x, x*x) for x in range(1000)]
data_df = spark.createDataFrame(data_raw, ['x', 'xsq'])
data_df.show(5)

+---+---+
|  x|xsq|
+---+---+
|  0|  0|
|  1|  1|
|  2|  4|
|  3|  9|
|  4| 16|
+---+---+
only showing top 5 rows



#### Spark JDBC URL and Properties

Apache Spark JDBC calls take two data structures to specify a database connection.
The first is a string containing a JDBC connection URL.
Such a URL typically includes the following db connect info:
* vendor (here, 'postgresql')
* hostname
* port
* database name

The second structure is a property map.
In python this is a `dict` structure, containing:
* db user name
* password
* Java class name of the JDBC driver

For some vendors, other properties are expected.
A common additional property is `sslConnection`, as shown below in the comments.

Remember, it is best practice to store username and password in environment variables or other forms that can be set without explicitly typing security information in your code!

The exact syntax of the JDBC URL varies from vendor to vendor.
Refer to the vendor's JDBC driver documentation for connection specifics.
Vendors that publish JDBC drivers usable by Spark will generally include Spark example connections.

In [6]:
spark_jdbc_url = 'jdbc:postgresql://{host}:{port}/{database}'.format( \
    host     = 'postgresql', \
    port     = '5432', \
    database = 'demodb')

spark_jdbc_prop = { \
    'user':     'eje', \
    'password': 'eje12345', \
    'driver':   'org.postgresql.Driver'
    # 'sslConnection': 'false'
    # Some DB vendors expect other connection properties.
    # Setting 'sslConnection' is one common vendor-specific parameter
}

#### Writing a DataFrame with JDBC

The following Spark call uses our database connect info above to write our example data to a database table.
The two write modes are `overwrite` and `append`.
Note that in `overwrite` mode you must have both write and table create privileges on your db!

In [7]:
data_df.write.jdbc( \
    table      = 'demo', \
    mode       = 'overwrite', \
    url        = spark_jdbc_url, \
    properties = spark_jdbc_prop \
)

#### Reading a DataFrame with JDBC

The basics of reading a DataFrame from a JDBC query are (almost) as simple as writing.
The database connection information is the same.
Here, we must specify a database query, written in the vendor's supported dialect of SQL.
A query can be very simple, as in the example below, or hundreds of lines of complex SQL code!
It is generally best practice to set the query string separately, as in this example.

Note that in the `read.jdbc` call below, we have enclosed the raw query in parentheses and assigned it a temporary view `tmp`, which Spark requires.

In [8]:
query = 'select * from demo'

query_df = spark.read.jdbc( \
    table      = '({q}) tmp'.format(q=query), \
    url        = spark_jdbc_url, \
    properties = spark_jdbc_prop \
)

query_df.show(5)

+---+---+
|  x|xsq|
+---+---+
|  0|  0|
|  1|  1|
|  2|  4|
|  3|  9|
|  4| 16|
+---+---+
only showing top 5 rows



#### The Perils of Partitioning

Apache Spark's scalable compute model depends on being able to break data into multiple partitions so that it can parallelize work across each partition.
Let's look at how our dataframe got partitioned when we read it above.
The following cell prints out the number of partitions and the number of records in each partition.

As you can see from the output, Spark put all the records in our query into a single partition!
For our small example data, this is not a problem.
However, if we are working with large volumes of data, this is bad news!
With all our data in a single partition, Spark cannot process our data in parallel.
Worse, if a sufficiently large query result is pushed into a single executor,
it can easily cause an out of memory error and crash the executor!

In [9]:
print("partitions: {np}\nsizes: {sz}".format( \
    np = query_df.rdd.getNumPartitions(), \
    sz = query_df.rdd.mapPartitions(lambda itr: [len(list(itr))]).collect() \
))

partitions: 1
sizes: [1000]


#### Proper DataFrame Partitioning with JDBC

Fortunately, Spark provides a way to perform JDBC reads and correctly partition the result.

When you read data into a Spark DataFrame using a JDBC query,
Spark needs extra information about which data from the DB to put into each partition.
Spark does this by generating one query for each partition, under the hood.
To enable this, you must provide your JDBC read with a list of "partitioning predicates":
Spark will generate its DataFrame with one partition for each predicate you give it, and the data that goes into each partition is the data that is `true` for the corresponding predicate.

For example, if we wanted to partition our data into 3 queries,
we might want spark to use the following 3 queries:

```sql
select x, xsq from demo where mod(x, 3) = 0  /* query for 1st partition */
select x, xsq from demo where mod(x, 3) = 1  /* query for 2nd partition */
select x, xsq from demo where mod(x, 3) = 2  /* query for 3rd partition */
```

Spark expects us to provide a list that looks like this:
```python
[ 'mod(x, 3) = 0', 'mod(x, 3) = 1', 'mod(x, 3) = 2' ]
```

Notice that these clauses have been designed so that every record ends up in exactly one of our partitions,
and also that our partition sizes should be roughly equal, with 1 out of 3 records satisfying each.

In practice, we may very well want to create a large number of these partitioning predicates,
and so it is a good idea to generate them with a function,
such as the `qpreds` function below:

In [10]:
def qpreds(n, rowcol):
    return ["mod({rc}, {np}) = {mk}".format(mk=k, np=n, rc=rowcol) for k in range(n)]

qpreds(3, 'x')

['mod(x, 3) = 0', 'mod(x, 3) = 1', 'mod(x, 3) = 2']

#### A JDBC Read With Partitioning

With our `qpreds` function above, we can easily add the additional `predicates` parameter
to our JDBC read so that Spark can create a well partitioned DataFrame from our query.
In our example below, we configure our predicates for 5 partitions:

In [11]:
# Perform a JDBC read with proper partitioning
query = 'select * from demo'

query_df_pp = spark.read.jdbc( \
    table      = '({q}) tmp'.format(q=query), \
    url        = spark_jdbc_url, \
    properties = spark_jdbc_prop, \
    predicates = qpreds(5, 'x') \
)
query_df_pp.show(5)

+---+---+
|  x|xsq|
+---+---+
|  0|  0|
|  5| 25|
| 10|100|
| 15|225|
| 20|400|
+---+---+
only showing top 5 rows



#### Verifying Partitions
Now, when we check our paritions, we see a well partitioned DataFrame that has the 5 partitions we desired,
with the query results evenly distributed among the partitions.

In [12]:
print("partitions: {np}\nsizes: {sz}".format( \
    np = query_df_pp.rdd.getNumPartitions(), \
    sz = query_df_pp.rdd.mapPartitions(lambda itr: [len(list(itr))]).collect() \
))

partitions: 5
sizes: [200, 200, 200, 200, 200]


#### Other Partitioning Techniques

In our previous examples we took advantage of having a column `x` in our data that was both an integer and had a distribution of values (uniform) that was easy to generate equal-sized partitions from.
In real data we may not have this kind of convenient data to partition with,
but there are a couple techniques that we can use with any data.

The first technique is hashing.
For SQL dialects that support a hashing function, you can pick a column (or columns) to apply a hash to,
and then take the modulus of the resulting hash value.
A hypothetical example of such predicates might look like this:

```sql
mod(vendor_hash(my_column, 3)) = 0
mod(vendor_hash(my_column, 3)) = 1
mod(vendor_hash(my_column, 3)) = 2
```

If you use this technique, you may need to tweak your `qpreds` function to generate predicates of this form.

Not all SQL dialects have this kind of hash function, but there is almost always some variation on assigning
a unique integer to each query output row.
In postgresql, this function is `row_number()`, and we can add it to our query `select`.
In the following example, we have added a row number clause to our query.
In postgresql, this must include an `over` clause to tell it what ordering you wish the numbering to use.
When you refer to this new column, you just use `row_number` as you can see in the `qpred` call:

In [13]:
# The additional row_number clause is not necessary if you are partitioning via 
# an existing integer field, or hashing, etc.
query = 'select *, row_number() over (order by x) from demo'

query_df_pp2 = spark.read.jdbc( \
    table      = '({q}) tmp'.format(q=query), \
    url        = spark_jdbc_url, \
    properties = spark_jdbc_prop, \
    predicates = qpreds(5, 'row_number') \
)
query_df_pp2.show(5)

+---+---+----------+
|  x|xsq|row_number|
+---+---+----------+
|  4| 16|         5|
|  9| 81|        10|
| 14|196|        15|
| 19|361|        20|
| 24|576|        25|
+---+---+----------+
only showing top 5 rows



#### Row Numbering Adds a Column
You can see from the output above that using the row-numbering technique causes adds that column to your
query results.
You may want to drop this column if you are generating output to some other channel.

Lastly, we check our partitioning to see that it worked correctly:

In [14]:
print("partitions: {np}\nsizes: {sz}".format( \
    np = query_df_pp2.rdd.getNumPartitions(), \
    sz = query_df_pp2.rdd.mapPartitions(lambda itr: [len(list(itr))]).collect() \
))

partitions: 5
sizes: [200, 200, 200, 200, 200]
