# Traveler DB
### Data Engineering Capstone Project

#### Project Summary
This project takes i94 data and combines with city demographic data to populate a data mart on an AWS Redshift instance

The project follows the follow steps:
* Step 1: Scope the Project and Gather Data
* Step 2: Explore and Assess the Data
* Step 3: Define the Data Model
* Step 4: Run ETL to Model the Data
* Step 5: Complete Project Write Up

In [1]:
# Do all imports and installs here
import pandas as pd
import pyspark.sql.functions as sfunc
from pyspark.sql.functions import udf, col, date_add, lit, to_date, when, sum as sql_sum, monotonically_increasing_id, upper, concat, substring
from pyspark.sql.functions import sequence, dayofmonth, weekofyear, month, year, date_format
from pyspark.sql import Row
import boto3
import psycopg2
import os
import shutil
import configparser



In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.\
config("spark.jars.packages","saurfang:spark-sas7bdat:2.0.0-s_2.11")\
.enableHiveSupport().getOrCreate()



### Step 1: Scope the Project and Gather Data

#### Scope 
This project is taken from the perspective of a hotel company that is looking to gain insight into what kind of international travelers visit the cities and states they do business in so that they can better market their hotel to those visitors.  The i94 data used in conjunction with city demographics should provide insights into how to best market to travelers in existing markets for the hotel company, as well as to help generate a business strategy for expansion into new cities / markets.

#### Describe and Gather Data 
In this project, the i94 data is used in conjunction with the us-cities-demographics.csv file as well as various csv files cobbled together from the contents of the I94_SAS_Lables_Descriptions.SAS file.

These dreived csv files include the following:
 * addr_codes.csv
 * countries.csv
 * ports.csv

### Step 2: Explore and Assess the Data
#### Explore the Data 
Identify data quality issues, like missing values, duplicate data, etc.
 * The demographic data contains a row for each different value in the column "race".  This will be flattened out to only one row per city and taking the sum of the "Count" column for each value.
 * There are many missing values for addr_code etc in the i94 data. These will be addressed by:
  * Assigning dead-end foreign key values within the fact
  * Supplying associated key values within respective dimensions

#### Cleaning Steps
Document steps necessary to clean the data

#### Clean countries data

In [4]:
#gather countries
df_nation = spark.read.load("countries.csv",
                     format="csv", sep=",", inferSchema="true", header="true")
# add increasing id
df_nation = df_nation.withColumn('idx', monotonically_increasing_id())

# set increasing id to actually be 1-increments
df_nation = df_nation.selectExpr("row_number() over (order by 'idx') as nation_key",
                                 "COUNTRY_CODE as nation_code", "COUNTRY_DESCR as nation_descr").dropDuplicates()

# generate row to satisfy null fk references in fact
nullrow_nation = df_nation.selectExpr("-999 as nation_key",
                                 "-999 as nation_code", "'UNKNOWN' as nation_descr").dropDuplicates()

# union df and null row to prepare df for parquet write
df_nation = df_nation.union(nullrow_nation)

#### Clean addr_codes

In [6]:
# gather states & territories
df_state_terr = spark.read.load("addr_codes.csv",
                               format="csv", sep=",", inferSchema="true", header="true")

# add increasing id
df_state_terr = df_state_terr.withColumn('idx', monotonically_increasing_id())

# set increasing id to actually be 1-increments
df_state_terr = df_state_terr.selectExpr("row_number() over (order by 'idx') as state_terr_key",
                                        "addr_code as state_terr_code", "addr_descr as state_terr_descr").dropDuplicates()

# generate row to satisfy null fk references in fact
nullrow_state_terr = df_state_terr.selectExpr("-999 as state_terr_key", "'ZZ' as state_terr_code", "'UNKNOWN' as state_terr_descr").dropDuplicates()

# union df and null row to prepare df for parquet write
df_state_terr = df_state_terr.union(nullrow_state_terr)

#### Clean city demographic data

In [7]:
# gather raw city demographic data
df_city_demo = spark.read.load("us-cities-demographics.csv",
                     format="csv", sep=";", inferSchema="true", header="true")

# rename the count column to something less "reserved"
df_city_demo = df_city_demo.withColumnRenamed('Count','pop_count')

# add columns for the demographic race/ethnicity elements
df_city_demo = df_city_demo.withColumn('pop_african_american', when(df_city_demo.Race=="Black or African-American", df_city_demo.pop_count).otherwise(0)) \
    .withColumn('pop_hispanic_latino', when(df_city_demo.Race=="Hispanic or Latino", df_city_demo.pop_count).otherwise(0)) \
    .withColumn('pop_native_american', when(df_city_demo.Race=="American Indian and Alaska Native", df_city_demo.pop_count).otherwise(0)) \
    .withColumn('pop_asian', when(df_city_demo.Race=="Asian", df_city_demo.pop_count).otherwise(0)) \
    .withColumn('pop_white', when(df_city_demo.Race=="White", df_city_demo.pop_count).otherwise(0)) \


# sum the populations of the various race / ethnicity demographics by city
df_city_demo = df_city_demo.groupBy("City", "State", "Median Age", "Male Population", "Female Population", "Total Population", "Number of Veterans", "Foreign-born",
                           "Average Household Size", "State Code").agg(sql_sum("pop_african_american").alias("pop_african_american"),
                                                                       sql_sum("pop_hispanic_latino").alias("pop_hispanic_latino"),
                                                                       sql_sum("pop_native_american").alias("pop_native_american"),
                                                                       sql_sum("pop_asian").alias("pop_asian"), sql_sum("pop_white").alias("pop_white"))

# set up a mapping dictionary to rename columns
mapping = dict(zip(['City', 'State', 'State Code', 'Median Age', 'Average Household Size', 'Total Population', 'Male Population', 'Female Population', 
                    'Number of Veterans', 'Foreign-born', 'pop_african_american', 'pop_hispanic_latino', 'pop_native_american', 'pop_asian', 'pop_white'], 
                   ['city', 'state', 'state_code', 'median_age', 'avg_household_size', 'pop_total', 'pop_male', 'pop_female', 
                    'pop_veteran', 'pop_foreign_born', 'pop_african_american', 'pop_hispanic_latino', 'pop_native_american', 'pop_asian', 'pop_white']))

# rename mapped columns
df_city_demo = df_city_demo.select([col(c).alias(mapping.get(c, c)) for c in df_city_demo.columns])

# inner join to df_state_terr to only get cities of us states / territories
df_city_demo = df_city_demo.join(df_state_terr, (df_city_demo.state_code == df_state_terr.state_terr_code), 'inner')


# add increasing id
df_city_demo = df_city_demo.withColumn('idx', monotonically_increasing_id())

# set increasing id to actually be 1-increments
df_city_demo = df_city_demo.selectExpr("row_number() over (order by 'idx') as city_demo_key", "state_terr_key as city_state_terr_key",
                                        "city as city_name", "state as state_name", "state_terr_code as state_code", 
                                         "cast(median_age as decimal(5,2)) as median_age", "cast(avg_household_size as decimal(5,2)) as avg_household_size",
                                         "pop_total", "pop_male", "pop_female", "pop_foreign_born", "cast(pop_african_american as integer) as pop_african_american",
                                        "cast(pop_hispanic_latino as integer) as pop_hispanic_latino", "cast(pop_native_american as integer) as pop_native_american", 
                                       "cast(pop_asian as integer) as pop_asian", "cast(pop_white as integer) as pop_white")


# generate row to satisfy null fk references in fact
nullrow_city_demo = df_city_demo.selectExpr("-999 as city_demo_key", "-999 as city_state_terr_key",
                                                 "'UNKNOWN' as city_name", "'UNKNOWN' as state_name", "'ZZ' as state_code", 
                                         "NULL as median_age", "NULL as avg_household_size",
                                         "NULL as pop_total", "NULL as pop_male", "NULL as pop_female", "NULL as pop_foreign_born", "NULL as pop_african_american",
                                        "NULL as pop_hispanic_latino", "NULL as pop_native_american", 
                                       "NULL as pop_asian", "NULL as pop_white").dropDuplicates()

# union df and null row to prepare df for parquet write
df_city_demo = df_city_demo.union(nullrow_city_demo)

#### Clean port data

In [8]:
# gather ports
df_port = spark.read.load("ports.csv",
                               format="csv", sep=",", inferSchema="true", header="true")

df_port = df_port.withColumn('port_state_terr', substring(df_port.PORT_DESCR, -2, 2))


# left join to df_state_terr to get ports of us states / territories (both on us_state_terr and the substring of the name due to false positive matches)
df_port = df_port.join(df_state_terr, (df_port.port_state_terr == df_state_terr.state_terr_code) & (df_port.PORT_US_STATE_TERR == df_state_terr.state_terr_code), 'left')

# generate port_state_terr_key column based on associated port_key or lack thereof
df_port = df_port.withColumn('port_state_terr_key', when(df_port.state_terr_key.isNull()==1, -999).otherwise(df_port.state_terr_key)) \

# left join to df_city_demo to get cities with demographic information where applicable
df_port = df_port.join(df_city_demo, (df_port.PORT_DESCR == concat(upper(df_city_demo.city_name), lit(', '), upper(df_city_demo.state_code))), 'left_outer')

# generate port_city_demo_key column based on associated port_key or lack thereof
df_port = df_port.withColumn('port_city_demo_key', when(df_port.city_demo_key.isNull()==1, -999).otherwise(df_port.city_demo_key)) \

# add increasing id
df_port = df_port.withColumn('idx', monotonically_increasing_id())

# set increasing id to actually be 1-increments
df_port = df_port.selectExpr("row_number() over (order by 'idx') as port_key", "port_code as port_code", "port_descr as port_descr",
                                        "state_terr_code as port_state_terr", "port_state_terr_key", "port_city_demo_key").dropDuplicates()

# generate row to satisfy null fk references in fact
nullrow_port = df_port.selectExpr("-999 as port_key", "'ZZZ' as port_code", "'UNKNOWN' as port_descr", 
                                  "'ZZ' as port_state_terr", "-999 as port_state_terr_key", "-999 as port_city_demo_key").dropDuplicates()

# union df and null row to prepare df for parquet write
df_port = df_port.union(nullrow_port)

#### Generate date dimension data

In [9]:
def genrows(i):
    """ generates a set number of rows """
    df_return = spark.createDataFrame([Row(date_num = 0)])
    x = 1
    while(x<i):
        df_nextnum = spark.createDataFrame([Row(date_num = x)])
        df_return = df_return.union(df_nextnum)
        x=x+1
    return df_return

df_datenums = genrows(366) # here we generate 366 rows, one for each day in the year 2016


df_bdate = spark.createDataFrame([Row(date_text = '2016-01-01')])
df_bdate = df_bdate.withColumn('b_date', to_date(df_bdate.date_text))
df_bdate = df_bdate.drop('date_text')

df_date = df_datenums.crossJoin(df_bdate)
df_date = df_date.selectExpr("*", "date_add(b_date, date_num) as date_key")

df_date = df_date.drop('date_num')
df_date = df_date.drop('b_date')

df_date = df_date.withColumn('date_day', dayofmonth(df_date.date_key))\
    .withColumn('date_week', weekofyear(df_date.date_key))\
    .withColumn('date_month', month(df_date.date_key))\
    .withColumn('date_year', year(df_date.date_key))\
    .withColumn('date_weekday', date_format(df_date.date_key, 'u'))


### Step 3: Define the Data Model
#### 3.1 Conceptual Data Model
The data model is a modified star schema data warehouse model with an outrigger dimension for city demographics that does not tie to the fact table.

The fact table is an aggregation of travelers by the intersection of the foreign keys, the travel date, and the traveler characteristics. Most characteristic columns are quasi-boolean smallint values such that they may be aggregated to get distinct counts of that characteristic by day.  The fact table is distributed by distkey of arrival_date, assuming that while there will be some days with more travelers than others, it will still be a pretty decent distribution candidate especially considering the expectation that most queries will be time-based.

The dimensional tables are distributed to all nodes as they do not contain that much data and will not demand too much storage overhead.  their distribution to all nodes should help with query performance on joins.

This model was chosen based on a use case of a hotel company's interest in how to market to various types of international travelers and how that may vary by city and/or state.

<img src="travelerdb_diagram.PNG">

#### 3.2 Mapping Out Data Pipelines
List the steps necessary to pipeline the data into the chosen data model

### Step 4: Run Pipelines to Model the Data 
#### 4.1 Create the data model
Build the data pipelines to create the data model.

#### Create the destination tables

In [26]:
configredshift = configparser.ConfigParser()
configredshift.read_file(open('awscfg.cfg'))

rshost = configredshift.get('CLUSTER','HOST')
rsdbname = configredshift.get('CLUSTER','DB_NAME')
rsuser = configredshift.get('CLUSTER','DB_USER')
rspassword = configredshift.get('CLUSTER','DB_PASSWORD')
rsport = configredshift.get('CLUSTER','DB_PORT')

conn = psycopg2.connect(host=rshost, dbname=rsdbname, user=rsuser, password=rspassword, port=rsport)
cur = conn.cursor()

date_table_drop = "drop table if exists date_dim"
nation_table_drop = "drop table if exists nation_dim;"
state_terr_table_drop = "drop table if exists state_terr_dim;"
port_table_drop = "drop table if exists port_dim;"
city_demo_table_drop = "drop table if exists city_demo_dim;"
traveler_table_drop = "drop table if exists traveler_fact;"

drop_table_queries = [traveler_table_drop, date_table_drop, port_table_drop, city_demo_table_drop, state_terr_table_drop, nation_table_drop]

date_table_create = """create table date_dim(date_key date, date_day integer, date_week integer, date_month integer, date_year integer, date_weekday varchar,
    primary key(date_key)) diststyle all;"""

nation_table_create = """create table nation_dim(nation_key integer, nation_code integer, nation_descr varchar,
    primary key(nation_key)) diststyle all;"""

state_terr_table_create = """create table state_terr_dim(state_terr_key integer, state_terr_code varchar, state_terr_descr varchar,
    primary key(state_terr_key)) diststyle all;"""

city_demo_table_create = """create table city_demo_dim(city_demo_key integer, city_state_terr_key integer, city_name varchar,
    state_name varchar, state_code varchar, median_age decimal(5, 2), avg_household_size decimal(5,2), pop_total integer, pop_male integer, pop_female integer, pop_foreign_born integer,
    pop_african_american integer, pop_hispanic_latino integer, pop_native_american integer, pop_asian integer, pop_white integer, primary key(city_demo_key),
    foreign key(city_state_terr_key) references state_terr_dim(state_terr_key)) diststyle all;"""

port_table_create = """create table port_dim(port_key integer, port_code varchar, port_descr varchar, port_state_terr varchar, port_state_terr_key integer, port_city_demo_key integer,
    primary key(port_key), foreign key(port_state_terr_key) references state_terr_dim(state_terr_key),
    foreign key(port_city_demo_key) references city_demo_dim(city_demo_key)) diststyle all;"""

traveler_table_create = """create table traveler_fact(arrival_date date, citizen_nation_key integer, residence_nation_key integer, traveler_port_key integer,
    traveler_port_state_terr_key integer, visiting_state_terr_key integer, gender_reported smallint, gender_male smallint, gender_female smallint, gender_other smallint, age_reported smallint,
    age_0_11 smallint, age_12_17 smallint, age_18_24 smallint, age_25_34 smallint, age_35_44 smallint, age_45_54 smallint, age_55_64 smallint, age_65_74 smallint, age_75p smallint,
    arrival_type_reported smallint, air_arrival smallint, sea_arrival smallint, land_arrival smallint, visa_type_reported smallint, business_visa smallint, pleasure_visa smallint,
    student_visa smallint, traveler_count integer,
    foreign key(citizen_nation_key) references nation_dim(nation_key), foreign key(residence_nation_key) references nation_dim(nation_key),
    foreign key(traveler_port_key) references port_dim(port_key), foreign key(traveler_port_state_terr_key) references state_terr_dim(state_terr_key),
    foreign key(visiting_state_terr_key) references state_terr_dim(state_terr_key),
    foreign key(arrival_date) references date_dim(date_key)) distkey(arrival_date) ;"""


create_table_queries = [date_table_create, nation_table_create, state_terr_table_create, city_demo_table_create, port_table_create, traveler_table_create]



def drop_tables(cur, conn):
    for query in drop_table_queries:
        cur.execute(query)
        conn.commit()

def create_tables(cur, conn):
    for query in create_table_queries:
        cur.execute(query)
        conn.commit()
        
drop_tables(cur, conn)

create_tables(cur, conn)

conn.close()


#### Write dimensional dataframes to local parquet folders / files

In [12]:
def conditionally_remove_dir(path):
    """ removes a directory and its contents """
    if os.path.isfile(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)

dframes = ["df_date", "df_nation", "df_state_terr", "df_city_demo", "df_port"]

for dframe in dframes:
    conditionally_remove_dir(dframe)
    
df_date.write.parquet("df_date")
df_nation.write.parquet("df_nation")
df_state_terr.write.parquet("df_state_terr")
df_city_demo.write.parquet("df_city_demo")
df_port.write.parquet("df_port")  

#### Create functions to upload parquet files to S3 and ingest to Redshift

In [36]:

def upload_files(path):
    """ uploads parquet files contained within a specified path to s3 """
    configupload = configparser.ConfigParser()
    configupload.read_file(open('awscfg.cfg'))
    
    keyid = configupload.get('KEY_CREDS','KEY_ID')
    keysecret = configupload.get('KEY_CREDS','KEY_SECRET')
    s3region = configupload.get('S3','REGION')
    s3bucket = configupload.get('S3','BUCKET')
    
    session = boto3.Session(
        aws_access_key_id=keyid,
        aws_secret_access_key=keysecret,
        region_name=s3region
    )
    s3 = session.resource('s3')
    bucket = s3.Bucket(s3bucket)
 
    for subdir, dirs, files in os.walk(path):
        for file in files:
            full_path = os.path.join(subdir, file)
            with open(full_path, 'rb') as data:
                if full_path.endswith(".parquet"):
                    bucket.put_object(Key=full_path[len(path)+1:], Body=data)
                                   

def get_parquet_files(path) :
    """ generates a list of parquet file names for a given path """
    filelist = []
    for subdir, dirs, files in os.walk(path):
        for file in files:
            full_path = os.path.join(subdir, file)
            if full_path.endswith(".parquet"):
                filelist.append(file)
    return filelist


def insert_to_redshift(parquetname, targetname):
    """ copies parquet files of a specified parquet from s3 to specified redshift table """
    configredshift = configparser.ConfigParser()
    configredshift.read_file(open('awscfg.cfg'))
    
    rshost = configredshift.get('CLUSTER','HOST')
    rsdbname = configredshift.get('CLUSTER','DB_NAME')
    rsuser = configredshift.get('CLUSTER','DB_USER')
    rspassword = configredshift.get('CLUSTER','DB_PASSWORD')
    rsport = configredshift.get('CLUSTER','DB_PORT')
    bucketsource = configredshift.get('S3','BUCKET')
    bucketiam = configredshift.get('IAM_ROLE','ARN')

    conn = psycopg2.connect(host=rshost, dbname=rsdbname, user=rsuser, password=rspassword, port=rsport)
    cur = conn.cursor()
    
    filestoget = get_parquet_files(parquetname)
    
    for parquetfile in filestoget:

        copycommand = (""" COPY {}
        FROM 's3://{}/{}'
        IAM_ROLE '{}'
        FORMAT AS PARQUET;
        """).format(targetname, bucketsource, parquetfile, bucketiam)
        
        cur.execute(copycommand)
        conn.commit()
        

def delete_parquets(path):
    """ deletes parquet files from s3 """
    configupload = configparser.ConfigParser()
    configupload.read_file(open('awscfg.cfg'))
    
    delkeyid = configupload.get('KEY_CREDS','KEY_ID')
    delkeysecret = configupload.get('KEY_CREDS','KEY_SECRET')
    dels3region = configupload.get('S3','REGION')
    dels3bucket = configupload.get('S3','BUCKET')
    
    session = boto3.Session(
        aws_access_key_id=delkeyid,
        aws_secret_access_key=delkeysecret,
        region_name=dels3region
    )
    s3 = session.resource('s3')
    bucket = s3.Bucket(dels3bucket)
    parquetstodel = get_parquet_files(path)
    
    for parquet in parquetstodel:
        obj = s3.Object(dels3bucket, parquet)
        obj.delete()
        
def get_row_count(tablename):
    """ retrieves the row count of a specified table """
    configredshift = configparser.ConfigParser()
    configredshift.read_file(open('awscfg.cfg'))

    rshost = configredshift.get('CLUSTER','HOST')
    rsdbname = configredshift.get('CLUSTER','DB_NAME')
    rsuser = configredshift.get('CLUSTER','DB_USER')
    rspassword = configredshift.get('CLUSTER','DB_PASSWORD')
    rsport = configredshift.get('CLUSTER','DB_PORT')

    conn = psycopg2.connect(host=rshost, dbname=rsdbname, user=rsuser, password=rspassword, port=rsport)
    cur = conn.cursor()
    
    rowcountsql = ("SELECT COUNT(*) FROM {};").format(tablename)
    
    cur.execute(rowcountsql)
    
    sqlresult = cur.fetchone()
    tablecount = sqlresult[0]
    
    return tablecount


def get_parquet_count(dfname):
    """ retrieves the row count of a specified parquet file """
    parquetdf = spark.read.parquet(dfname)
    parquetcount = parquetdf.count()
    return parquetcount

def verify_counts(dfname, tablename, initct):
    """ compares the row counts of the target table to the sum of the row counts of the parquet and the row count of the target prior to load """
    rowct = get_row_count(tablename)
    pqct = get_parquet_count(dfname)
    
    if rowct != (pqct+initct):
        raise ValueError("row count in table {} does not match previous table row count plus row count in parquet folder {}.".format(tablename, dfname))


def remove_dir(path):
    """ removes files and the directory of the specified path """
    if os.path.isfile(path):
        os.remove(path)  # remove the file
    elif os.path.isdir(path):
        shutil.rmtree(path)  # remove dir and all contents
    else:
        raise ValueError("file {} is not a file or dir.".format(path))
            

            
def upload_dataframe(dfname, tablename):
    
    prevcount = get_row_count(tablename)
    upload_files(dfname)
    insert_to_redshift(dfname, tablename)
    delete_parquets(dfname)
    verify_counts(dfname, tablename, prevcount)
    remove_dir(dfname)
            
 
            

#### Upload dimension data to S3 and Redshift

In [28]:
def process_dims():
    df_table_map = dict(zip(["df_date", "df_nation", "df_state_terr", "df_city_demo", "df_port"],["date_dim", "nation_dim", "state_terr_dim", "city_demo_dim", "port_dim"]))
    for (key, value) in df_table_map.items():
        upload_dataframe(key, value)
        

process_dims()



#### Create functions to clean I94 Data and write to parquet 

In [39]:


def get_month(monthno):
    """ associates a 3 char string representing a month with that month's number value """
    month_map = dict(zip([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec']))
    for (key, value) in month_map.items():
        if key==monthno:
            return value
        
def remove_fact_dir(path):
    """ removes files and a directory for a given path """
    if os.path.isfile(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)


def clean_fact(runmonth):
    """ cleans the i94 data to fit to the target table structure """

    curmonth_char = get_month(runmonth)

    location_prefix = '../../data/18-83510-I94-Data-2016/i94_'
    location_suffix = '16_sub.sas7bdat'

    data_location = location_prefix + curmonth_char + location_suffix

    df_nation_cit = df_nation.selectExpr("nation_key as nation_key_cit", "nation_descr as nation_descr_cit", "nation_code as nation_code_cit")
    df_nation_res = df_nation.selectExpr("nation_key as nation_key_res", "nation_descr as nation_descr_res", "nation_code as nation_code_res")


    df_fact = spark.read.format('com.github.saurfang.sas.spark').load(data_location)

    df_fact = df_fact.withColumn('arrdate_int', df_fact.arrdate.cast('integer')) \
        .withColumn('diff_date', to_date(lit('1960-01-01'))) \
        .withColumn('gender_reported', when(df_fact.gender.isNotNull()==1, 1).otherwise(0)) \
        .withColumn('gender_male', when(df_fact.gender=="M", 1).otherwise(0)) \
        .withColumn('gender_female', when(df_fact.gender=="F", 1).otherwise(0)) \
        .withColumn('gender_other', when(df_fact.gender=="M", 0).when(df_fact.gender=="F", 0).when(df_fact.gender.isNull()==1, 0).otherwise(1)) \
        .withColumn('age_reported', when(df_fact.i94bir.isNotNull()==1, 1).otherwise(0)) \
        .withColumn('age_0_11', when(df_fact.i94bir <= 11, 1).otherwise(0)) \
        .withColumn('age_12_17', when(df_fact.i94bir.between(12, 17)==1, 1).otherwise(0)) \
        .withColumn('age_18_24', when(df_fact.i94bir.between(18, 24)==1, 1).otherwise(0)) \
        .withColumn('age_25_34', when(df_fact.i94bir.between(25, 34)==1, 1).otherwise(0)) \
        .withColumn('age_35_44', when(df_fact.i94bir.between(35, 44)==1, 1).otherwise(0)) \
        .withColumn('age_45_54', when(df_fact.i94bir.between(45, 54)==1, 1).otherwise(0)) \
        .withColumn('age_55_64', when(df_fact.i94bir.between(55, 64)==1, 1).otherwise(0)) \
        .withColumn('age_65_74', when(df_fact.i94bir.between(65, 74)==1, 1).otherwise(0)) \
        .withColumn('age_75p', when(df_fact.i94bir >= 75, 1).otherwise(0)) \
        .withColumn('arrival_type_reported', when(df_fact.i94mode.isNull()==1, 0).when(df_fact.i94mode==9, 0).otherwise(1)) \
        .withColumn('air_arrival', when(df_fact.i94mode ==1, 1).otherwise(0)) \
        .withColumn('sea_arrival', when(df_fact.i94mode ==2, 1).otherwise(0)) \
        .withColumn('land_arrival', when(df_fact.i94mode ==3, 1).otherwise(0)) \
        .withColumn('visa_type_reported', when(df_fact.i94visa ==1, 1).when(df_fact.i94visa ==2, 1).when(df_fact.i94visa ==3, 1).otherwise(0)) \
        .withColumn('business_visa', when(df_fact.i94visa ==1, 1).otherwise(0)) \
        .withColumn('pleasure_visa', when(df_fact.i94visa ==2, 1).otherwise(0)) \
        .withColumn('student_visa', when(df_fact.i94visa ==3, 1).otherwise(0)) \

    # add column arrival_date as diff_date plus arrdate_int 
    df_fact = df_fact.selectExpr("*", "date_add(diff_date, arrdate_int) as arrival_date", "cast(count as int) as traveler_count", "cast(i94cit as int) as i94_cit",
                              "cast(i94res as int) as i94_res", "i94port as i94_port", "i94addr as i94_addr")

    df_fact = df_fact.select([df_fact.arrival_date, df_fact.i94_cit, df_fact.i94_res, df_fact.i94_port, df_fact.gender_reported, df_fact.gender_male, 
                            df_fact.gender_female, df_fact.gender_other, df_fact.age_reported, df_fact.age_0_11, df_fact.age_12_17, df_fact.age_18_24,
                            df_fact.age_25_34, df_fact.age_35_44, df_fact.age_45_54, df_fact.age_55_64, df_fact.age_65_74, df_fact.age_75p,
                            df_fact.arrival_type_reported, df_fact.air_arrival, df_fact.sea_arrival, df_fact.land_arrival, 
                            df_fact.visa_type_reported, df_fact.business_visa, df_fact.pleasure_visa, df_fact.student_visa, 
                            df_fact.i94_addr, df_fact.traveler_count ])

    # left join to df_nation_cit to get nation_key of i94cit
    df_fact = df_fact.join(df_nation_cit, (df_fact.i94_cit == df_nation_cit.nation_code_cit), 'left_outer')

    # left join to df_nation_res to get nation_key of i94res
    df_fact = df_fact.join(df_nation_res, (df_fact.i94_res == df_nation_res.nation_code_res), 'left_outer')

    # left join to df_port to get port_key and port_state_terr_key of i94port
    df_fact = df_fact.join(df_port, (df_fact.i94_port == df_port.port_code), 'left_outer')

    # left join to df_state_terr to get visiting_state_terr_key of i94addr
    df_fact = df_fact.join(df_state_terr, (df_fact.i94_addr == df_state_terr.state_terr_code), 'left_outer')

    # Set foreign key values
    df_fact = df_fact.withColumn('citizen_nation_key', when(df_fact.nation_key_cit.isNull()==1, -999).otherwise(df_fact.nation_key_cit)) \
        .withColumn('residence_nation_key', when(df_fact.nation_key_res.isNull()==1, -999).otherwise(df_fact.nation_key_res)) \
        .withColumn('traveler_port_key', when(df_fact.port_key.isNull()==1, -999).otherwise(df_fact.port_key)) \
        .withColumn('traveler_port_state_terr_key', when(df_fact.port_key.isNull()==1, -999).otherwise(df_fact.port_state_terr_key)) \
        .withColumn('visiting_state_terr_key', when(df_fact.state_terr_key.isNull()==1, -999).otherwise(df_fact.state_terr_key)) \

    # Group by date, foreign keys and traveler characteristics
    df_fact = df_fact.groupBy("arrival_date", "citizen_nation_key", "residence_nation_key", "traveler_port_key", "traveler_port_state_terr_key", 
                                "gender_reported", "gender_male", "gender_female", "gender_other",
                               "age_reported", "age_0_11", "age_12_17", "age_18_24", "age_25_34","age_35_44", "age_45_54", "age_55_64","age_65_74", "age_75p",
                               "arrival_type_reported", "air_arrival", "sea_arrival", "land_arrival", "visa_type_reported",
                               "business_visa", "pleasure_visa","student_visa", "visiting_state_terr_key").agg(sql_sum("traveler_count").alias("traveler_count"))

    # set columns for parquet write
    df_fact = df_fact.selectExpr("arrival_date", "citizen_nation_key", "residence_nation_key", "traveler_port_key", "traveler_port_state_terr_key", 
                               "visiting_state_terr_key", "cast(gender_reported as smallint) as gender_reported", "cast(gender_male as smallint) as gender_male", 
                                 "cast(gender_female as smallint) as gender_female", "cast(gender_other as smallint) as gender_other",
                               "cast(age_reported as smallint) as age_reported", "cast(age_0_11 as smallint) as age_0_11", "cast(age_12_17 as smallint) as age_12_17", 
                                 "cast(age_18_24 as smallint) as age_18_24", "cast(age_25_34 as smallint) as age_25_34","cast(age_35_44 as smallint) as age_35_44", 
                                 "cast(age_45_54 as smallint) as age_45_54", "cast(age_55_64 as smallint) as age_55_64","cast(age_65_74 as smallint) as age_65_74", 
                                 "cast(age_75p as smallint) as age_75p", "cast(arrival_type_reported as smallint) as arrival_type_reported", 
                                 "cast(air_arrival as smallint) as air_arrival", "cast(sea_arrival as smallint) as sea_arrival", "cast(land_arrival as smallint) as land_arrival", 
                                 "cast(visa_type_reported as smallint) as visa_type_reported", "cast(business_visa as smallint) as business_visa", 
                                 "cast(pleasure_visa as smallint) as pleasure_visa",  "cast(student_visa as smallint) as student_visa","cast(traveler_count as int) as traveler_count")
    
    remove_fact_dir("df_fact")
    df_fact.write.parquet("df_fact")
    

def process_fact(runmonth):
    clean_fact(runmonth)
    upload_dataframe("df_fact", "traveler_fact")




#### Upload fact data to S3 and Redshift

In [40]:
start_month = 1
end_month = 2

months_to_run = list(range(start_month, end_month+1))

for month in months_to_run:
    process_fact(month)



#### 4.2 Data Quality Checks
The project includes two methods for maintaining data quality:
 * Implementation of multiple foreign key constraints
 * Row count checks on each load
 
The implementaion of the primary key / foreign key relationships is illustrated within the image in section 3.1

Row count checks are implemented within the <b>upload_dataframe</b> function.  This is because the fact table load is built to run in monthly batches due to the amount of data and the way the data is retrieved from its source.  The data quality check must be embedded within the load mechanism due to the incremental nature of the load.

 * First, a row count is retrieved from the table prior to the table being loaded.
 * Once the table is loaded, a new row count is retrieved from the target table.
 * Once that count is retrieved, a count of rows from the local parquet files is retrieved.
 * Finally, the row count of the target table is compared to the sum of the count from the local parquet file and the count of rows in the table prior to the load.

#### 4.3 Data dictionary 
Create a data dictionary for your data model. For each field, provide a brief description of what the data is and where it came from. You can include the data dictionary in the notebook or in a separate file.

 *  The data dictionary is available in the project workspace and can be found in the file <b>travelerdb_data_dictionary.txt</b>

#### Step 5: Complete Project Write Up
##### Choice of Technologies / Tools
 * Spark and Redshift were chosen as the key tools for the project. This is due to the following:
  * Spark offers an ideal environment for processing the i94 data considering its size
  * The data lends itself well to a Redshift as it can easily be broken out into Fact and Dimension elements
  * The scale and shape of the data does not exceed the limitations of a parallel-processed data warehousing technology

##### Propose how often the data should be updated and why.
##### Data Update Cadence
 * At its most frequent, this data could be updated on a monthly basis. This is due to the source data being grouped by month.
 * In reality, it appears the i94 data is really only made available in yearly chunks. One could realisitically not expect new data to arrive except once per year.

##### Write a description of how you would approach the problem differently under the following scenarios:
 * The data was increased by 100x:
  * I believe Spark + Redshift could still make for a good choice, though it would require either going to 16 nodes or upgrading to a ds2.xlarge instance.
  * There would be a significant increase in s3 costs that would have to be considered.
 * The data populates a dashboard that must be updated on a daily basis by 7am every day.
  * Assuming the hypothetical that new data arrived daily, it would probably be best to implement an airflow solution that would govern the implementation of the spark jobs.
  * One would need to gain understanding of when upstream data was made available each day, as well as get a baseline for how long the DAG takes to complete.
  * Given the knowledge of the above elements as well as the SLA, one could determine to run with either an availability-driven or deadline-driven schedule
 * The database needed to be accessed by 100+ people.
  * I believe Redshift would still be fine here, but it may be necesary to configure WLM queues to manage the query volume  or enable concurrency scaling

In [104]:
def remove_dir(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path):
        os.remove(path)  # remove the file
    elif os.path.isdir(path):
        shutil.rmtree(path)  # remove dir and all contains
    else:
        raise ValueError("file {} is not a file or dir.".format(path))
        
remove_dir('df_fact')