# Code to import census CSVs to SQL

### Settings

In [None]:
inputpath = 'to_import'
outputpath = 'sql'
flavour = 'mysql'
insertdata = False
ssh_address = 'your SSH address'
ssh_port = 'your SSH port'
# replace these lines with your login details
ssh_username = 'your SSH username'
ssh_password = 'your SSH password'
mysql_username = 'your MySQL username'
mysql_password = 'your MySQL password'
mysql_database = '2016_census'

## Get the CSV Files and make .SQL files from them

#### Function to get csv files to parse

In [None]:
import csvkit
import os
from os import listdir
from os.path import isfile, join

In [None]:
def get_files(inputpath):
    try:
        files = [f for f in listdir(inputpath) if isfile(join(inputpath, f))]
        return(files)
    except:
        print('Couldn\'t get files for some reason')

In [None]:
# Test get_files()
inputpath = 'to_import'
files = get_files(inputpath)
print(files)

#### Functions to actually make the schemas

In [None]:
def make_table_schema(file, inputpath, outputpath, flavour):
    try:
        # Windows only - get first 10 rows - saves having to infer types from entire file which can be slow
        command = 'powershell -command "& {Get-Content ' + inputpath + '/' + file + ' -TotalCount 10}" |'
        #pick the dialect
        command += ' csvsql --dialect ' + flavour + ' --table ' + file[:-4]
        # feed output path
        command += ' > ' + outputpath + '/' + file[:-4] + '.sql'
        
        return command
    
    except:
        print('Couldn\'t make command to build table schema for file: ' + file)

def make_table_schemas(files, inputpath = 'to_import', outputpath = 'sql', flavour = 'mysql'):
    for file in files:
        try:
            command = make_table_schema(file, inputpath, outputpath, flavour)
            os.system(command)
            print('Made schema for: ' + file)
        except:
            print('Fell over making schema for: ' + file)

#### Go! Make the schemas

In [None]:
make_table_schemas(files)

## Execute the .SQL Files

#### Function to get the .SQL file names

In [None]:
def get_sql_files(path = 'sql'):
    try:
        files = [f for f in listdir(path) if isfile(join(path, f))]
        return(files)
    except:
        print('Couldn\'t get the sql files for some reason')

#### Function to get the .SQL file paths

In [None]:
def get_sql_file_contents(file, filepath=False):
    if filepath:
        file = filepath + '/' + file
    
    try:
        fd = open(file, 'r')
        sqlFile = fd.read()
        fd.close()
        return sqlFile
    except:
        print('Couldn\'t get the sql file contents for some reason: ' + file)

#### Function to monkey patch the average columns - explained in readme

In [None]:
def monkey_patch_averages(contents):
    contents = contents.replace('`Average_num_psns_per_bedroom` DECIMAL NOT NULL', '`Average_num_psns_per_bedroom` DECIMAL (4,2) NOT NULL')
    contents = contents.replace('`Average_household_size` DECIMAL NOT NULL', '`Average_household_size` DECIMAL (4,2) NOT NULL')
    return contents

### Actually run it all - Execute the .SQL Files

In [None]:
import pymysql
import sshtunnel
import time
sql_files = get_sql_files()

with sshtunnel.SSHTunnelForwarder(
        (ssh_address, ssh_port),
        ssh_username=ssh_username,
        ssh_password=ssh_password,
        remote_bind_address=("127.0.0.1", 3306)
) as tunnel:
# sleep to give the tunnel a chance to get established
    time.sleep(1)
    connection = pymysql.connect(host="127.0.0.1",
                                 port=tunnel.local_bind_port,
                                 user=mysql_username,
                                 password=mysql_password,
                                 db=mysql_database,
                                 charset='utf8mb4')
# loop through sql files
    for sql_file in sql_files:
      try:
        # read them
        query = get_sql_file_contents(sql_file, filepath='sql')
        query = monkey_patch_averages(query)
        cur = connection.cursor()
        cur.execute(query)
        connection.commit()
        print(sql_file + ' done')
      except Exception as e:
        print(e)
    connection.close()

## Write data from the .CSVs into the tables

#### Function to create a mysql connection string to use with Pandas

In [None]:
def create_mysql_engine_string(user,password,host,db,port=3306):
    enginestr = 'mysql://'
    enginestr += user
    enginestr += ':'
    enginestr += password
    enginestr += '@'
    enginestr += host
    enginestr += ':'
    enginestr += str(port)
    enginestr += '/'
    enginestr += db
    
    return enginestr

#### Function to take a filename, read it into a Pandas dataframe, and write that dataframe to a mysql table

In [None]:
def insert_into_mysql(file, connection, path=False, flavor = 'mysql'):
    
    try:
        tablename = file[:-4] # strip '.csv'

        if path:
            file = path + '/' + file
        
        #header=0 makes it treat the first row as headers
        df = pandas.read_csv(file, header=0, sep=',')
        
        #if_exists = append means insert into
        #index=False means don't try to write the Pandas index as a column
        df.to_sql(con=connection, name=tablename, if_exists='append', index=False, flavor = flavor)   
    except Exception as e:
        print(e)

## Actually run it - Insert data into the tables

In [None]:
from sqlalchemy import create_engine
import pandas
inputpath = 'to_import'
files = get_files(inputpath)


with sshtunnel.SSHTunnelForwarder(
        (ssh_address, ssh_port),
        ssh_username = ssh_username,
        ssh_password = ssh_password,
        remote_bind_address=("127.0.0.1", 3306)
) as tunnel:
# sleep to give the tunnel a chance to get established
    time.sleep(1)
    engine_string = create_mysql_engine_string(mysql_username,mysql_password,"127.0.0.1",mysql_database,tunnel.local_bind_port)
    engine = create_engine(engine_string)
    connection = engine.raw_connection()
    # for each .csv, write
    for file in files[1:]:
        insert_into_mysql(file, connection, path=inputpath)
        print('Inserted data for: ' + file)

# Add Indexes

In [None]:
def add_index(file):
    df = pandas.read_csv(inputpath + '/' + file, header=0, sep=',')
    col_name = df.columns[0]
    query = "ALTER TABLE " + file[:-4] + " ADD PRIMARY KEY (" + col_name + ")"
    return query

In [None]:
with sshtunnel.SSHTunnelForwarder(
        (ssh_address, ssh_port),
        ssh_username = ssh_username,
        ssh_password = ssh_password,
        remote_bind_address=("127.0.0.1", 3306)
) as tunnel:
# sleep to give the tunnel a chance to get established
    time.sleep(1)
    connection = pymysql.connect(host="127.0.0.1",
                                 port=tunnel.local_bind_port,
                                 user=mysql_username,
                                 password=mysql_password,
                                 db=mysql_database,
                                 charset='utf8mb4')
# loop through tables
    for file in files:
      try:
        query = add_index(file)
        cur = connection.cursor()
        cur.execute(query)
        connection.commit()
        print(file + " done")
      except Exception as e:
        print(e)
    connection.close()