# Estimating the gravity equation


### Table of Contents

* Initialize and configure Spark
* Load and prepare data
    * Distances
    * Geography
    * GDP
    * Trade flows
* Combine everything into one dataset
* Estimate gravity equation

In [1]:
# Initialization and configuration
import findspark
findspark.init()

from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession

# Spark
from pyspark.sql.types import *
import pyspark.sql.functions as F

## Initialize and configure Spark

http://localhost:4040/

In [2]:
# Configuration
conf = SparkConf().setAppName("Gravity").setMaster("local[4]")
conf.set("spark.driver.maxResultSize", "2g")
conf.set("spark.driver.memory", "2g")
conf.set("spark.executor.memory", "2g") 
conf.set("spark.executor.pyspark.memory", "2g")

# Initialization
sc = SparkContext(conf=conf)
sqlc = SQLContext(sc)
spark = SparkSession(sc)

## Load and prepare data

In [3]:
# Function for renaming several columns at once
# from https://stackoverflow.com/questions/38798567/pyspark-rename-more-than-one-column-using-withcolumnrenamed/48095327
def rename_columns(df, columns):
    for old_name, new_name in columns.items():
        df = df.withColumnRenamed(old_name, new_name)
    return df

### Distances

http://www.cepii.fr/CEPII/en/bdd_modele/presentation.asp?id=6

In [25]:
distances = spark.read.csv("data/dist_cepii.csv", sep=";", header=True) \
    .select("iso_o", "iso_d", "contig", "comlang_off", "dist")

# meaning of other distance measures? distw, distwces?
# meaning of comcol, curcol, col45, smctry, etc?
# common currency?
# add common colonizer

# iso_o: 3d iso code of "origin"
# iso_d: 3d iso code of "destination"
# contig: contiguous border
# comlang_off: common official language
# dist: distance

In [27]:
distances.count()

50176

In [26]:
distances.show(10)

+-----+-----+------+-----------+--------+
|iso_o|iso_d|contig|comlang_off|    dist|
+-----+-----+------+-----------+--------+
|  ABW|  ABW|     0|          0|5,225315|
|  ABW|  AFG|     0|          0|13257,81|
|  ABW|  AGO|     0|          0|9516,913|
|  ABW|  AIA|     0|          0|983,2682|
|  ABW|  ALB|     0|          0|9091,742|
|  ABW|  AND|     0|          1|7572,788|
|  ABW|  ANT|     0|          1|136,3848|
|  ABW|  ARE|     0|          0|12735,01|
|  ABW|  ARG|     0|          1| 5396,22|
|  ABW|  ARM|     0|          0|11107,78|
+-----+-----+------+-----------+--------+
only showing top 10 rows



### Geographies

* http://www.cepii.fr/CEPII/en/bdd_modele/presentation.asp?id=6
* .filter(geographies["maincity"] == 1)   # removes rows where there is a capital and a main city


In [28]:
geographies = spark.read.csv("data/geo_cepii.csv", sep=";", header=True) \
    .select("iso2", "iso3", "country", "landlocked", "continent", "maincity")
    
geographies = geographies \
    .filter(geographies["maincity"] == 1) \
    .drop("maincity")

In [29]:
geographies.count()

225

In [31]:
geographies.select("iso2").distinct().count()

225

In [32]:
geographies.show(10)

+----+----+--------------------+----------+---------+
|iso2|iso3|             country|landlocked|continent|
+----+----+--------------------+----------+---------+
|  AW| ABW|               Aruba|         0|  America|
|  AF| AFG|         Afghanistan|         1|     Asia|
|  AO| AGO|              Angola|         0|   Africa|
|  AI| AIA|            Anguilla|         0|  America|
|  AL| ALB|             Albania|         0|   Europe|
|  AD| AND|             Andorra|         0|   Europe|
|  AN| ANT| Netherland Antilles|         0|  America|
|  AE| ARE|United Arab Emirates|         0|     Asia|
|  AR| ARG|           Argentina|         0|  America|
|  AM| ARM|             Armenia|         1|     Asia|
+----+----+--------------------+----------+---------+
only showing top 10 rows



In [33]:
# map iso2 and iso3 codes
iso_map = spark.read.csv("data/geo_cepii.csv", sep=";", header=True) \
    .select("iso2", "iso3").distinct()

In [34]:
iso_map.count()

225

In [35]:
iso_map.show(20)

+----+----+
|iso2|iso3|
+----+----+
|  ZA| ZAF|
|  AE| ARE|
|  JM| JAM|
|  OM| OMN|
|  LB| LBN|
|  BB| BRB|
|  CR| CRI|
|  US| USA|
|  KG| KGZ|
|  SO| SOM|
|  ZM| ZMB|
|  AR| ARG|
|  KI| KIR|
|  AL| ALB|
|  TJ| TJK|
|  GF| GUF|
|  MM| MMR|
|  MX| MEX|
|  SD| SDN|
|  LA| LAO|
+----+----+
only showing top 20 rows



### GDP

in constant 2010 us dollars

https://databank.worldbank.org/reports.aspx?source=2&series=NY.GDP.MKTP.KD&country=#

In [38]:
gdp = spark.read.csv("data/worldbank_gdp.csv", sep=",", header=True) \
    .select("Country Name", "Country Code", "Time", "Value")

# Define schema Convert: Time to int and Value to float

In [39]:
gdp.printSchema()

root
 |-- Country Name: string (nullable = true)
 |-- Country Code: string (nullable = true)
 |-- Time: string (nullable = true)
 |-- Value: string (nullable = true)



In [40]:
gdp.count()

2645

In [41]:
gdp.show(10)

+------------+------------+----+----------------+
|Country Name|Country Code|Time|           Value|
+------------+------------+----+----------------+
| Afghanistan|         AFG|2009|13865194314.7185|
| Afghanistan|         AFG|2010|15856574731.4411|
| Afghanistan|         AFG|2011|15924179997.7914|
| Afghanistan|         AFG|2012|17954877146.6564|
| Afghanistan|         AFG|2013|18960483969.8941|
| Afghanistan|         AFG|2014|19477070577.8584|
| Afghanistan|         AFG|2015|19759744157.4377|
| Afghanistan|         AFG|2016|20206376461.4103|
| Afghanistan|         AFG|2017|20744935406.0955|
| Afghanistan|         AFG|2018| 20958745169.393|
+------------+------------+----+----------------+
only showing top 10 rows



### Trade flows

In [None]:
# see below, move up

## Combine all dataframes

### combine gdp and geographies

In [58]:
gdp_geo = gdp.select("COUNTRY_ISO3", "YEAR", "GDP") \
    .join(geographies, "COUNTRY_ISO3", how="inner")

gdp_geo.show(10)

+------------+----+----------------+-----------+------------+----------+---------+-------+--------+----+--------+
|COUNTRY_ISO3|YEAR|             GDP|COUNTRY_ISO|COUNTRY_NAME|landlocked|continent|city_en|     lat| lon|maincity|
+------------+----+----------------+-----------+------------+----------+---------+-------+--------+----+--------+
|         AFG|2009|13865194314.7185|         AF| Afghanistan|         1|     Asia|  Kabul|34,51667|69,2|       1|
|         AFG|2010|15856574731.4411|         AF| Afghanistan|         1|     Asia|  Kabul|34,51667|69,2|       1|
|         AFG|2011|15924179997.7914|         AF| Afghanistan|         1|     Asia|  Kabul|34,51667|69,2|       1|
|         AFG|2012|17954877146.6564|         AF| Afghanistan|         1|     Asia|  Kabul|34,51667|69,2|       1|
|         AFG|2013|18960483969.8941|         AF| Afghanistan|         1|     Asia|  Kabul|34,51667|69,2|       1|
|         AFG|2014|19477070577.8584|         AF| Afghanistan|         1|     Asia|  Kabu

In [59]:
gdp_geo.count()

1990

### Trade flows (imports + exports)

In [None]:
# need to join with above table bc iso and iso3 difference

In [60]:
flows = spark.read.parquet("data/parquet/full2018.parquet")
flows.printSchema()

root
 |-- REPORTER: integer (nullable = true)
 |-- REPORTER_ISO: string (nullable = true)
 |-- PARTNER: integer (nullable = true)
 |-- PARTNER_ISO: string (nullable = true)
 |-- TRADE_TYPE: string (nullable = true)
 |-- PRODUCT_NC: string (nullable = true)
 |-- PRODUCT_SITC: string (nullable = true)
 |-- PRODUCT_CPA2002: string (nullable = true)
 |-- PRODUCT_CPA2008: string (nullable = true)
 |-- PRODUCT_CPA2_1: string (nullable = true)
 |-- PRODUCT_BEC: string (nullable = true)
 |-- PRODUCT_SECTION: string (nullable = true)
 |-- FLOW: integer (nullable = true)
 |-- STAT_REGIME: integer (nullable = true)
 |-- SUPP_UNIT: string (nullable = true)
 |-- PERIOD: string (nullable = true)
 |-- VALUE_IN_EUROS: long (nullable = true)
 |-- QUANTITY_IN_KG: long (nullable = true)
 |-- SUP_QUANTITY: integer (nullable = true)



In [154]:
# 1 aggregate trade flows properly

flows = spark.read.parquet("data/parquet/full2018.parquet")

flows = flows \
    .withColumn("YEAR", flows["PERIOD"].substr(1, 4).cast(IntegerType())) \
    .select("REPORTER", "REPORTER_ISO", "PARTNER", "PARTNER_ISO", "TRADE_TYPE", "PRODUCT_NC", "FLOW", "YEAR", "VALUE_IN_EUROS") \
    .filter(flows["PRODUCT_NC"] != "TOTAL") \
    .groupBy("REPORTER", "REPORTER_ISO", "PARTNER", "PARTNER_ISO", "TRADE_TYPE", "FLOW", "YEAR").agg({"VALUE_IN_EUROS": "sum"}) \
    .groupBy("REPORTER", "REPORTER_ISO", "PARTNER", "PARTNER_ISO", "TRADE_TYPE", "YEAR").pivot("FLOW", [1, 2]).sum("sum(VALUE_IN_EUROS)") \
    .fillna(0, subset=["1", "2"])

flows = flows.withColumn("TOTAL_TRADE", flows["1"] + flows["2"]) \
    .withColumnRenamed("1", "IMPORTS").withColumnRenamed("2", "EXPORTS") \
    .select("REPORTER_ISO", "PARTNER_ISO", "YEAR", "TOTAL_TRADE") \
    .cache()

flows = flows.join(broadcast(iso_map), flows["REPORTER_ISO"] == iso_map["ISO2"], how="inner") \
    .withColumnRenamed("ISO3", "REPORTER_ISO3") \
    .select("REPORTER_ISO", "REPORTER_ISO3", "PARTNER_ISO", "TOTAL_TRADE")

flows = flows.join(broadcast(iso_map), flows["PARTNER_ISO"] == iso_map["ISO2"], how="inner") \
    .withColumnRenamed("ISO3", "PARTNER_ISO3") 

flows = flows.select("REPORTER_ISO", "REPORTER_ISO3", "PARTNER_ISO", "PARTNER_ISO3", "TOTAL_TRADE") \
    .cache()    

In [155]:
flows.printSchema()

root
 |-- REPORTER_ISO: string (nullable = true)
 |-- REPORTER_ISO3: string (nullable = true)
 |-- PARTNER_ISO: string (nullable = true)
 |-- PARTNER_ISO3: string (nullable = true)
 |-- TOTAL_TRADE: long (nullable = true)



In [156]:
flows.count()

5654

In [159]:
flows.select("REPORTER_ISO", "PARTNER_ISO").distinct().count()

5654

In [157]:
flows.show()

+------------+-------------+-----------+------------+------------+
|REPORTER_ISO|REPORTER_ISO3|PARTNER_ISO|PARTNER_ISO3| TOTAL_TRADE|
+------------+-------------+-----------+------------+------------+
|          LV|          LVA|         AU|         AUS|    26891032|
|          RO|          ROM|         AI|         AIA|          37|
|          LT|          LTU|         CY|         CYP|    46406342|
|          LV|          LVA|         BF|         BFA|      776438|
|          BE|          BEL|         AD|         AND|    12203654|
|          CZ|          CZE|         DJ|         DJI|     1173300|
|          GR|          GRC|         MV|         MDV|      543677|
|          FI|          FIN|         QA|         QAT|    58097374|
|          MT|          MLT|         HU|         HUN|    26726761|
|          SE|          SWE|         DE|         DEU| 40726945397|
|          PL|          POL|         BD|         BGD|   688842891|
|          ES|          ESP|         VE|         VEN|   444385

In [None]:
#    .join(broadcast(iso_map), flows["PARTNER_ISO"] == iso_map["ISO2"], how="inner") \
#    .withColumnRenamed("ISO3", "PARTNER_ISO3") \


In [134]:
flows.show(20)

+------------+-------------+-----------+-----------+
|REPORTER_ISO|REPORTER_ISO3|PARTNER_ISO|TOTAL_TRADE|
+------------+-------------+-----------+-----------+
|          LV|          LVA|         AU|   26891032|
|          RO|          ROM|         AI|         37|
|          LT|          LTU|         CY|   46406342|
|          LV|          LVA|         BF|     776438|
|          GB|          GBR|         HM|      39927|
|          BE|          BEL|         AD|   12203654|
|          CZ|          CZE|         DJ|    1173300|
|          GR|          GRC|         MV|     543677|
|          FI|          FIN|         QA|   58097374|
|          MT|          MLT|         HU|   26726761|
|          SE|          SWE|         DE|40726945397|
|          PL|          POL|         BD|  688842891|
|          ES|          ESP|         VE|  444385782|
|          PL|          POL|         QZ|   23727166|
|          DE|          DEU|         XS| 4484932716|
|          DE|          DEU|         XS| 44849

In [123]:
iso_map.show(20)

+----+----+
|ISO2|ISO3|
+----+----+
|  AW| ABW|
|  AF| AFG|
|  AO| AGO|
|  AI| AIA|
|  AL| ALB|
|  AD| AND|
|  AN| ANT|
|  AE| ARE|
|  AR| ARG|
|  AM| ARM|
|  TF| ATF|
|  AG| ATG|
|  AU| AUS|
|  AU| AUS|
|  AT| AUT|
|  AZ| AZE|
|  BI| BDI|
|  BE| BEL|
|  BJ| BEN|
|  BJ| BEN|
+----+----+
only showing top 20 rows



In [114]:
flows.join(broadcast(gdp_geo), 
           (flows["REPORTER_ISO"] == gdp_geo["COUNTRY_ISO"]) & (flows["YEAR"] == gdp_geo["YEAR"])) \
    .select("REPORTER_ISO", "PARTNER_ISO", flows["YEAR"], "TOTAL_TRADE", "GDP", "landlocked") \
    .show(20)

# how to model landlocked? 1 if any country is landlocked? yes - need to model that separately

+------------+-----------+----+-----------+----------------+----------+
|REPORTER_ISO|PARTNER_ISO|YEAR|TOTAL_TRADE|             GDP|landlocked|
+------------+-----------+----+-----------+----------------+----------+
|          LV|         AU|2018|   26891032|31606334082.6379|         0|
|          LT|         CY|2018|   46406342|49290096837.1762|         0|
|          LV|         BF|2018|     776438|31606334082.6379|         0|
|          GB|         HM|2018|      39927| 2858097953047.3|         0|
|          BE|         AD|2018|   12203654|533218198385.386|         0|
|          CZ|         DJ|2018|    1173300|248204259552.577|         1|
|          GR|         MV|2018|     543677|252723303555.798|         0|
|          FI|         QA|2018|   58097374|267320507829.469|         0|
|          MT|         HU|2018|   26726761|13725030463.5762|         0|
|          SE|         DE|2018|40726945397| 582804342166.04|         0|
|          PL|         BD|2018|  688842891|631952707856.598|    

In [107]:
gdp_geo.select("COUNTRY_ISO", "COUNTRY_ISO3", "GDP", "landlocked").show(10)

+-----------+------------+----------------+----------+
|COUNTRY_ISO|COUNTRY_ISO3|             GDP|landlocked|
+-----------+------------+----------------+----------+
|         AF|         AFG|13865194314.7185|         1|
|         AF|         AFG|15856574731.4411|         1|
|         AF|         AFG|15924179997.7914|         1|
|         AF|         AFG|17954877146.6564|         1|
|         AF|         AFG|18960483969.8941|         1|
|         AF|         AFG|19477070577.8584|         1|
|         AF|         AFG|19759744157.4377|         1|
|         AF|         AFG|20206376461.4103|         1|
|         AF|         AFG|20744935406.0955|         1|
|         AF|         AFG| 20958745169.393|         1|
+-----------+------------+----------------+----------+
only showing top 10 rows

