# Estimating the gravity equation


### Table of Contents

* Initialize and configure Spark
* Load and prepare data
    * Distances
    * Geography
    * GDP
    * Trade flows
* Combine everything into one dataset
* Estimate gravity equation

In [1]:
# Initialization and configuration
import findspark
findspark.init()

from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession

# Spark
from pyspark.sql.types import *
import pyspark.sql.functions as F

## Initialize and configure Spark

http://localhost:4040/

In [2]:
# Configuration
conf = SparkConf().setAppName("Gravity").setMaster("local[4]")
conf.set("spark.driver.maxResultSize", "2g")
conf.set("spark.driver.memory", "2g")
conf.set("spark.executor.memory", "2g") 
conf.set("spark.executor.pyspark.memory", "2g")

# Initialization
sc = SparkContext(conf=conf)
sqlc = SQLContext(sc)
spark = SparkSession(sc)

## Load and prepare data

In [3]:
# Function for renaming several columns at once
# from https://stackoverflow.com/questions/38798567/pyspark-rename-more-than-one-column-using-withcolumnrenamed/48095327
def rename_columns(df, columns):
    for old_name, new_name in columns.items():
        df = df.withColumnRenamed(old_name, new_name)
    return df

### Flows

In [113]:
flows = spark.read.parquet("data/parquet/full2018.parquet")

flows = flows.withColumn("YEAR", flows["PERIOD"].substr(1, 4).cast(IntegerType()))
flows = flows \
    .select("REPORTER_ISO", "PARTNER_ISO", "PRODUCT_NC", "FLOW", "YEAR", "VALUE_IN_EUROS") \
    .filter(flows["YEAR"] == 2018) \
    .filter(flows["PRODUCT_NC"] != "TOTAL") \
    .groupBy("REPORTER_ISO", "PARTNER_ISO", "FLOW", "YEAR").agg({"VALUE_IN_EUROS": "sum"}) \
    .groupBy("REPORTER_ISO", "PARTNER_ISO", "YEAR").pivot("FLOW", [1, 2]).sum("sum(VALUE_IN_EUROS)") \
    .withColumnRenamed("1", "IMPORTS").withColumnRenamed("2", "EXPORTS")

flows = flows \
    .fillna({"IMPORTS": 0, "EXPORTS": 0})
flows = flows \
    .withColumn("TRADE", flows["IMPORTS"] + flows["EXPORTS"]) \
    .cache()

In [115]:
flows.printSchema()

root
 |-- REPORTER_ISO: string (nullable = true)
 |-- PARTNER_ISO: string (nullable = true)
 |-- YEAR: integer (nullable = true)
 |-- IMPORTS: long (nullable = false)
 |-- EXPORTS: long (nullable = false)
 |-- TRADE: long (nullable = false)



In [118]:
flows.count()

6274

In [116]:
flows.show(5)

+------------+-----------+----+----------+----------+----------+
|REPORTER_ISO|PARTNER_ISO|YEAR|   IMPORTS|   EXPORTS|     TRADE|
+------------+-----------+----+----------+----------+----------+
|          NL|         QW|2018|3185020237|1535657158|4720677395|
|          PT|         XK|2018|      3839|   1164953|   1168792|
|          ES|         CR|2018| 207058853| 212704580| 419763433|
|          IT|         BH|2018| 282682377| 267854643| 550537020|
|          GR|         DJ|2018|      8321|    889926|    898247|
+------------+-----------+----+----------+----------+----------+
only showing top 5 rows



### Distances

http://www.cepii.fr/CEPII/en/bdd_modele/presentation.asp?id=6


* meaning of other distance measures? distw, distwces?
* meaning of comcol, curcol, col45, smctry, etc?
* common currency?

* iso_o: 3d iso code of "origin"
* iso_d: 3d iso code of "destination"
* contig: contiguous border
* comlang_off: common official language
* dist: distance

In [119]:
distances = spark.read.csv("data/dist_cepii.csv", sep=";", header=True, inferSchema=True) \
    .select("iso_o", "iso_d", "contig", "comlang_off", "dist")

In [120]:
distances.printSchema()

root
 |-- iso_o: string (nullable = true)
 |-- iso_d: string (nullable = true)
 |-- contig: integer (nullable = true)
 |-- comlang_off: integer (nullable = true)
 |-- dist: double (nullable = true)



In [121]:
distances.count()

50176

In [122]:
distances.show(5)

+-----+-----+------+-----------+--------+
|iso_o|iso_d|contig|comlang_off|    dist|
+-----+-----+------+-----------+--------+
|  ABW|  ABW|     0|          0|5.225315|
|  ABW|  AFG|     0|          0|13257.81|
|  ABW|  AGO|     0|          0|9516.913|
|  ABW|  AIA|     0|          0|983.2682|
|  ABW|  ALB|     0|          0|9091.742|
+-----+-----+------+-----------+--------+
only showing top 5 rows



### Geographies

* http://www.cepii.fr/CEPII/en/bdd_modele/presentation.asp?id=6
* .filter(geographies["maincity"] == 1)   # removes rows where there is a capital and a main city


In [123]:
geographies = spark.read.csv("data/geo_cepii.csv", sep=";", header=True, inferSchema=True) \
    .select("iso2", "iso3", "country", "landlocked", "maincity")
    
geographies = geographies \
    .filter(geographies["maincity"] == 1) \
    .drop("maincity")

In [124]:
geographies.printSchema()

root
 |-- iso2: string (nullable = true)
 |-- iso3: string (nullable = true)
 |-- country: string (nullable = true)
 |-- landlocked: integer (nullable = true)



In [125]:
geographies.count()

225

In [126]:
geographies.show(5)

+----+----+-----------+----------+
|iso2|iso3|    country|landlocked|
+----+----+-----------+----------+
|  AW| ABW|      Aruba|         0|
|  AF| AFG|Afghanistan|         1|
|  AO| AGO|     Angola|         0|
|  AI| AIA|   Anguilla|         0|
|  AL| ALB|    Albania|         0|
+----+----+-----------+----------+
only showing top 5 rows



### GDP

* in constant 2010 us dollars
* https://databank.worldbank.org/reports.aspx?source=2&series=NY.GDP.MKTP.KD&country=#

In [127]:
gdp = spark.read.csv("data/worldbank_gdp.csv", sep=",", header=True, inferSchema=True) \
    .select("Country Name", "Country Code", "Time", "Value")

In [128]:
gdp.printSchema()

root
 |-- Country Name: string (nullable = true)
 |-- Country Code: string (nullable = true)
 |-- Time: integer (nullable = true)
 |-- Value: double (nullable = true)



In [129]:
gdp.count()

2645

In [130]:
gdp.show(5)

+------------+------------+----+-------------------+
|Country Name|Country Code|Time|              Value|
+------------+------------+----+-------------------+
| Afghanistan|         AFG|2009|1.38651943147185E10|
| Afghanistan|         AFG|2010|1.58565747314411E10|
| Afghanistan|         AFG|2011|1.59241799977914E10|
| Afghanistan|         AFG|2012|1.79548771466564E10|
| Afghanistan|         AFG|2013|1.89604839698941E10|
+------------+------------+----+-------------------+
only showing top 5 rows



### Combine dataframes

In [131]:
# Flows: REPORTER_ISO, PARTNER_ISO - 2d
# Distances: iso_o, iso_d - 3d
# gdp: Country Code - 3d
# geographies: iso2, iso3

In [132]:
# flows
df = flows \
    .join(F.broadcast(geographies), flows["REPORTER_ISO"] == geographies["iso2"], how="inner") \
    .withColumnRenamed("iso3", "REPORTER_ISO_3") \
    .drop("iso2", "country", "landlocked") \
    .join(F.broadcast(geographies), flows["PARTNER_ISO"] == geographies["iso2"], how="inner") \
    .withColumnRenamed("iso3", "PARTNER_ISO_3") \
    .drop("iso2", "country", "landlocked") \

# gdp
df = df \
    .join(F.broadcast(gdp), (df["REPORTER_ISO_3"] == gdp["Country Code"]) & (df["YEAR"] == gdp["Time"]), how="inner") \
    .drop("Country Name", "Country Code", "Time") \
    .withColumnRenamed("Value", "REPORTER_GDP") \
    .join(F.broadcast(gdp), (df["PARTNER_ISO_3"] == gdp["Country Code"]) & (df["YEAR"] == gdp["Time"]), how="inner") \
    .drop("Country Name", "Country Code", "Time") \
    .withColumnRenamed("Value", "PARTNER_GDP")

# distance
df = df \
    .join(F.broadcast(distances), (df["REPORTER_ISO_3"] == distances["iso_o"]) & (df["PARTNER_ISO_3"] == distances["iso_d"]), how="inner") \
    .drop("iso_o", "iso_d") \
    .withColumnRenamed("dist", "DISTANCE")

In [133]:
df.show(5)

+------------+-----------+----+----------+----------+-----------+--------------+-------------+-------------------+-------------------+------+-----------+--------+
|REPORTER_ISO|PARTNER_ISO|YEAR|   IMPORTS|   EXPORTS|      TRADE|REPORTER_ISO_3|PARTNER_ISO_3|       REPORTER_GDP|        PARTNER_GDP|contig|comlang_off|DISTANCE|
+------------+-----------+----+----------+----------+-----------+--------------+-------------+-------------------+-------------------+------+-----------+--------+
|          ES|         CR|2018| 207058853| 212704580|  419763433|           ESP|          CRI|1.54872368448831E12| 4.9457675613131E10|     0|          1|8483.913|
|          IT|         BH|2018| 282682377| 267854643|  550537020|           ITA|          BHR|2.13876616980778E12|3.36465609433779E10|     0|          0|3883.804|
|          GR|         DJ|2018|      8321|    889926|     898247|           GRC|          DJI|2.52723303555798E11|               null|     0|          0|3516.045|
|          SE|        

In [134]:
df.printSchema()

root
 |-- REPORTER_ISO: string (nullable = true)
 |-- PARTNER_ISO: string (nullable = true)
 |-- YEAR: integer (nullable = true)
 |-- IMPORTS: long (nullable = false)
 |-- EXPORTS: long (nullable = false)
 |-- TRADE: long (nullable = false)
 |-- REPORTER_ISO_3: string (nullable = true)
 |-- PARTNER_ISO_3: string (nullable = true)
 |-- REPORTER_GDP: double (nullable = true)
 |-- PARTNER_GDP: double (nullable = true)
 |-- contig: integer (nullable = true)
 |-- comlang_off: integer (nullable = true)
 |-- DISTANCE: double (nullable = true)



In [135]:
df.count()

5162