# ETL pipeline

這部分會從頭到尾完成一個 ETL pipeline，以 GDP 資料為例，整個 notebook 的編排如下:
1. Cell 1 會連接到 SQLite 資料庫，worldbank.db，在其中創造一個 table 來存放 gdp 資料。
2. Cell 2 為一個名為 extract_line() 的函式，它是一個　[Python generator](https://wiki.python.org/moin/Generators)，用來一行一行輸入資料。
3. Cell 3 為名為 transform_indicator_data() 的函式，輸入為一行資料，輸出為轉換後的資料。
4. Cell 4 為名為 load_indicator_data() 的函式，會把轉換後的資料存進 worldbank.db。
5. Cell 5 執行整個 pipeline。
6. Cell 6 會確認程式是否正常執行。

## 1. Create database
其 schema 如下

gdp
- countryname text
- countrycode text
- year integer
- gdp real

(countrycode, year) is the primary key

In [1]:
import pandas as pd
import numpy as np
import sqlite3

# connect to the database
# sqlite3 will create this database file if it does not exist already
conn = sqlite3.connect('worldbank.db')

# get a cursor
cur = conn.cursor()

# drop the gdp table in case it already exists
cur.execute('DROP TABLE IF EXISTS gdp')

# create the table gdp
sql_query = 'CREATE TABLE gdp (countryname TEXT, countrycode TEXT, year INTEGER, gdp REAL, PRIMARY KEY (countrycode, year));'
cur.execute(sql_query)

# commit changes and close database
conn.commit()
conn.close()

## 2. Generator

In [2]:
# Generator for reading in one line at a time
def extract_lines(file):
    while True:
        line = file.readline()
        if not line:
            break
        yield line

## 3. Transform

In [3]:
def transform_indicator_data(data, colnames):
    ''' This function transform the indicator data for load into database.
    
    Input:
        data: a row of data from the gdp csv file
        colnames: a list of column names from the gdp csv file
    Output:
        result: a list of [countryname, countrycode, year, gdp] values
    '''
    # get rid of quote marks
    for i, datum in enumerate(data):
        data[i] = datum.replace('"','')
    
    # Extract the countryname from the list and put the result in the country variable
    country = data[0]
    
    # these are "countryname" values that are not actually countries
    non_countries = ['World',
     'High income',
     'OECD members',
     'Post-demographic dividend',
     'IDA & IBRD total',
     'Low & middle income',
     'Middle income',
     'IBRD only',
     'East Asia & Pacific',
     'Europe & Central Asia',
     'North America',
     'Upper middle income',
     'Late-demographic dividend',
     'European Union',
     'East Asia & Pacific (excluding high income)',
     'East Asia & Pacific (IDA & IBRD countries)',
     'Euro area',
     'Early-demographic dividend',
     'Lower middle income',
     'Latin America & Caribbean',
     'Latin America & the Caribbean (IDA & IBRD countries)',
     'Latin America & Caribbean (excluding high income)',
     'Europe & Central Asia (IDA & IBRD countries)',
     'Middle East & North Africa',
     'Europe & Central Asia (excluding high income)',
     'South Asia (IDA & IBRD)',
     'South Asia',
     'Arab World',
     'IDA total',
     'Sub-Saharan Africa',
     'Sub-Saharan Africa (IDA & IBRD countries)',
     'Sub-Saharan Africa (excluding high income)',
     'Middle East & North Africa (excluding high income)',
     'Middle East & North Africa (IDA & IBRD countries)',
     'Central Europe and the Baltics',
     'Pre-demographic dividend',
     'IDA only',
     'Least developed countries: UN classification',
     'IDA blend',
     'Fragile and conflict affected situations',
     'Heavily indebted poor countries (HIPC)',
     'Low income',
     'Small states',
     'Other small states',
     'Not classified',
     'Caribbean small states',
     'Pacific island small states']
    
    # filter out country name values that are in the above list
    if country not in non_countries:
        # convert the data variable into a numpy array
        # Use the ndmin=2 option
        data_array = np.array(data, ndmin=2)
        
        # reshape the data_array so that it is one row and 63 columns
        data_array = np.reshape(data_array, (1, 63))
        
        # convert the data_array variable into a pandas dataframe
        # Also, replace all empty strings in the dataframe with nan
        df = pd.DataFrame(data_array, columns=colnames).replace('', np.nan)
        
        # Drop the 'Indicator Name' and 'Indicator Code' columns
        df.drop(['\n', 'Indicator Name', 'Indicator Code'],axis=1, inplace=True)

        # Reshape the data sets so that they are in long format
        df_melt = pd.melt(df, id_vars=['Country Name', 'Country Code'], var_name='year', value_name='gdp')
        
        # Iterate through the rows in df_melt
        # For each row, extract the country, countrycode, year, and gdp values into a list like this:
        #     [country, countrycode, year, gdp]
        # If the gdp value is not null, append the row (in the form of a list) to the results variable
        # Finally, return the results list after iterating through the df_melt data
        results = []
        for i, value in df_melt.iterrows():
            country_name, country_code, year, gdp = value
            if str(gdp) != 'nan':
                results.append([country_name, country_code, year, gdp])
        return results

### 4. Load

In [4]:
def load_indicator_data(results):
    '''This function iterates through the input and inserts each value into the gdp data set.
    
    Input:
        results: a list of data outputted from the transformation step
    Output:
        None
    '''
    # connect to the worldbank.db database using the sqlite3 library
    conn = sqlite3.connect('worldbank.db')
    
    # create a cursor object
    cur = conn.cursor()
    
    if results:
        # iterate through the results variable and insert each result into the gdp table
        for result in results:
            
            # extract the countryname, countrycode, year, and gdp from each iteration
            countryname, countrycode, year, gdp = result

            # prepare a query to insert a countryname, countrycode, year, gdp value
            sql_string = f'INSERT INTO gdp (countryname, countrycode, year, gdp) VALUES ("{countryname}", "{countrycode}", {year}, {gdp});'

            # connect to database and execute query
            try:
                cur.execute(sql_string)
            # print out any errors (like if the primary key constraint is violated)
            except Exception as e:
                print('error occurred:', e, result)
    
    # commit changes and close the connection
    conn.commit()
    conn.close()
    
    return None

## 5. ETL pipeline

In [5]:
# open the data file
with open('./data/gdp_data.csv', encoding='utf-8') as f:
    # execute the generator to read in the file line by line
    for line in extract_lines(f):
        # split the comma separated values
        data = line.split(',')
        # check the length of the line because the first four lines of the csv file are not data
        if len(data) == 63:
            # check if the line represents column names
            if data[0] == '"Country Name"':
                colnames = []
                # get rid of quote marks in the results to make the data easier to work with
                for i, datum in enumerate(data):
                    colnames.append(datum.replace('"',''))
            else:
                # transform and load the line of indicator data
                results = transform_indicator_data(data, colnames)
                load_indicator_data(results)

## Test

In [6]:
# connect to the database
conn = sqlite3.connect('worldbank.db')

# get a cursor
cur = conn.cursor()

# create the test table including project_id as a primary key
df = pd.read_sql("SELECT * FROM gdp", con=conn)

conn.commit()
conn.close()

df

Unnamed: 0,countryname,countrycode,year,gdp
0,Aruba,ABW,1994,1.330168e+09
1,Aruba,ABW,1995,1.320670e+09
2,Aruba,ABW,1996,1.379888e+09
3,Aruba,ABW,1997,1.531844e+09
4,Aruba,ABW,1998,1.665363e+09
5,Aruba,ABW,1999,1.722799e+09
6,Aruba,ABW,2000,1.873453e+09
7,Aruba,ABW,2001,1.920263e+09
8,Aruba,ABW,2002,1.941095e+09
9,Aruba,ABW,2003,2.021302e+09
