# SQL Package

Provides simple functionality to interact with a PostgreSQL server using Python classes.

**Overview of functionality:**
* Database(self, user, password, host, dbname, port)
    * properties
        * user
        * password
        * host
        * dbname
        * port
    * methods
        * create(name) x
        * connect()
        * drop(name)
* Table(self, dbname, table, schema)
    * accepts db properties
    * properties
        * connect() --> inherited
        * fetch_data(sql, con, parse_dates)
        * get_names()
        * format_names(char_dict)
        * update_names(names_dict)
        * add_columns(columns_list, type=None)
        * compare_column_order(dataframe)
        * match_columns(dataframe)
        * save_csv(data, local_path, match_column_order=True)
        * update_values(local_path, container_path)
        * update_types(types_dict)
        * close()

## Setup

In [1]:
import os
import sys
from pathlib import Path
#sys.path[0] = str(Path(__file__).resolve().parents[2]) # Set path for custom modules
import warnings
from io import StringIO

# Set path for modules
sys.path[0] = '../'

from dotenv import load_dotenv, find_dotenv
import numpy as np
import pandas as pd

# SQL libraries
import psycopg2

# Set notebook display options
pd.set_option('display.max_rows', 2000)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

# Get project root directory
root_dir = os.path.dirname(os.getcwd())

In [2]:
#sys.modules

In [3]:
class Database():
    
    # if modulename not in sys.modules: print...
    load_dotenv(find_dotenv());
    
    def __init__(self, user=None, password=None,
                 dbname=None, host=None, port=None):
        
        # Loaded from .env if not explicit
        self.user = user if user is not None else os.getenv("POSTGRES_USER")
        self.password = password if password is not None else os.getenv("POSTGRES_PASSWORD")
        self.dbname = dbname if dbname is not None else os.getenv("POSTGRES_DB")
        self.host = host if host is not None else os.getenv("DB_HOST")
        self.port = port if port is not None else os.getenv("DB_PORT")
        
        
        # Root directory
        self._root_dir = os.path.dirname(os.getcwd())
        #sys.path[0] = str(Path(__file__).resolve().parents[2])
        
    def _connect(self):

        """
        Connects to PostgreSQL database using psycopg2 driver. Same
        arguments as psycopg2.connect().

        Params
        --------
        dbname
        user
        password
        host
        port
        connect_timeout
        """

        try:
            con = psycopg2.connect(dbname=self.dbname,
                                   user=self.user,
                                   password=self.password,
                                    host=self.host, 
                                    port=self.port,
                                  connect_timeout=3)            

        except Exception as e:
            print('Error:\n', e)
            return None


        return con
    
    @property
    def _con(self):
        try:
            con = self._connect()
            print('Connected as user "{}" to database "{}" on http://{}:{}.'.format(self.user,self.dbname,
                                                               self.host,self.port))
            con.close()
        except Exception as e:
            con.rollback()
            print('Error:\n', e)
        finally:
            if con is not None:
                con.close()
        

In [4]:
db = Database()

In [607]:
class Table(Database):
    def __init__(self, user=None, password=None, dbname=None, host=None, port=None, table=None):
        super().__init__(user, password, dbname, host, port)
        
        self.table = table
        
        # Loaded from .env if not explicit
        self.user = user if user is not None else os.getenv("POSTGRES_USER")
        self.password = password if password is not None else os.getenv("POSTGRES_PASSWORD")
        self.dbname = dbname if dbname is not None else os.getenv("POSTGRES_DB")
        self.host = host if host is not None else os.getenv("DB_HOST")
        self.port = port if port is not None else os.getenv("DB_PORT")
    
    # Connect to database
    def __connect(self):
        return super(Table, self)._connect()
    
    # Check info on connection
    def __con(self):
        return super(Table, self)._con
    
    # Fetch data from sql query
    def fetch_data(self, sql, coerce_float=False, parse_dates=None):
        
        con = self.__connect()
        
        # Fetch fresh data
        data = pd.read_sql_query(sql=sql, con=con, coerce_float=coerce_float, parse_dates=parse_dates)

        # Replace None with np.nan
        data.fillna(np.nan, inplace=True)
        
        # Close db connection
        con.close()

        return data
    
    # Get names of column
    def get_names(self):
        
        # Specific query to retrieve table names
        sql = "SELECT * FROM information_schema.columns WHERE table_name = N'{}'".format(self.table)
        
        # Run query and extract
        con = self.__connect()
        data = pd.read_sql_query(sql, con)
        column_series = data['column_name']
        con.close()
    
        return column_series
    
    # Get types of columns, returns dict
    def get_types(self):
        
        # Specific query to retrieve table names
        sql = '''SELECT column_name, 
        CASE 
            WHEN domain_name is not null then domain_name
            WHEN data_type='character varying' THEN 'varchar('||character_maximum_length||')'
            WHEN data_type='character' THEN 'char('||character_maximum_length||')'
            WHEN data_type='numeric' THEN 'numeric'
            ELSE data_type
        END AS type
        FROM information_schema.columns WHERE table_name = 'permits_raw';
        '''
        
        # Run query and extract
        con = self.__connect()
        data = pd.read_sql_query(sql, con)
        con.close()
        
        types_dict = dict(zip(data['column_name'], data['type'].str.upper()))
        
        return types_dict

    # Standardize column names using dictionary of character replacements
    def reformat_names(self, replace_map):
        
        series = self.get_names()
        
        def replace_chars(text):
            for oldchar, newchar in replace_map.items():
                text = text.replace(oldchar, newchar).lower()
            return text
        
        return series.apply(replace_chars)
    
    # Update column names in db table
    def update_names(self, replace_map):
        
        # Extract current columns in table
        old_columns = self.get_names()
        
        # Create list of reformatted columns to replace old columns 
        new_columns = self.reformat_names(replace_map)
    
        # SQL query string to change column names
        sql = 'ALTER TABLE {} '.format(self.table) + 'RENAME "{old_name}" to {new_name};'

        sql_query = []

        # Iterate through old column names and replace each with reformatted name 
        for idx, name in old_columns.iteritems():
            sql_query.append(sql.format(old_name=name, new_name=new_columns[idx]))
            
        # Join list to string
        sql_query = '\n'.join(sql_query)
        
        # Execute query against database
        con = self.__connect()
        try:
            cur = con.cursor()
            cur.execute(sql_query)
            con.commit()
            cur.close()
            print('Updated table "{}".'.format(self.table))
        except Exception as e:
            con.rollback()
            print('Error:\n', e)
        finally:
            if con is not None:
                con.close()
                
    # Add new columns to database
    def add_columns(self, data):

        # Get names of current columns in PostgreSQL table
        current_names = self.get_names()

        # Get names of updated table not in current table
        updated_names = data.columns.tolist()
        new_names = list(set(updated_names) - set(current_names))

        # Check names list is not empty
        if not new_names:
            print("Table columns are already up to date.")
            return

        # Format strings for query
        alter_table_sql = "ALTER TABLE {db_table}\n"
        add_column_sql = "\tADD COLUMN {column} TEXT,\n"

        # Create a list and append ADD column statements
        sql_query = [alter_table_sql.format(db_table=self.table)]
        for name in new_names:
            sql_query.append(add_column_sql.format(column=name))

        # Join into one string
        sql_query = ''.join(sql_query)[:-2] + ";"

        # Execute query against database
        con = self.__connect()
        try:
            cur = con.cursor()
            cur.execute(sql_query)
            con.commit()
            cur.close()
            print('Updated table "{}".'.format(self.table))
        except Exception as e:
            con.rollback()
            print('Error:\n', e)
        finally:
            if con is not None:
                con.close()

    # Compare order of columns in dataframe against order of columns in database                
    def compare_column_order(self, data):
        
        # Get columns from database as list
        db_columns = self.get_names().tolist()
        
        # Select columns from dataframe as list
        data_columns = data.columns.tolist()
        
        if set(data_columns) == set(db_columns):
            
            str1 = 'Dataframe columns match table "{}" '.format(self.table)
            
            if data_columns == db_columns:
                print(str1 + "and are in identical order.")
                return True
            else:
                print("but are not in identical order.")                
                return False            
        else:
            if len(data_columns) > len(db_columns):
                print('Dataframe has more columns than table "{}".'.format(self.table))
                return False
            else:
                print('Dataframe has less columns than table "{}".'.format(self.table))
                return False
    
    # Rearrange the order of columns in dataframe to match order in table
    def match_column_order(self, data):
        
        # Get columns from database as list
        db_columns = self.get_names().tolist()

        # Select columns from dataframe as list
        data_columns = data.columns.tolist()
        
        if set(data_columns) == set(db_columns):
            if data_columns != db_columns:
                print('Rearranged dataframe columns to match table "{}".'.format(self.table))
                return data[db_columns]
            else:
                print('Dataframe columns already match table "{}".'.format(self.table))
                return data
        else:
            if len(data_columns) > len(db_columns):
                print('Dataframe has more columns than table "{}".'.format(self.table))
                return data
            else:
                print('Dataframe has less columns than table "{}".'.format(self.table))
                return data
            
    # Builds a query to update postgres from a csv file
    def update_values(self, data, id_col, columns=None, sep=','):
        
        # Fetch data types
        types_dict = self.get_types()
        
        # Append id_col to selected columns
        columns = None if not columns else [id_col] + columns
        
        # CREATE TABLE query
        tmp_table = "tmp_" + self.table

        column_names = self.get_names().tolist() if not columns else columns
        
        # Subsets types_dict by columns argument and formats into string
        types_dict = types_dict if not columns else {key:value for key, value in types_dict.items() if key in set(columns)}
        names = ',\n\t'.join(['{key} {val}'.format(key=key, val=val) for key, val in types_dict.items()])
        
        # Build queries
        sql_create_tmp_table = 'DROP TABLE IF EXISTS {};\n\n'.format(tmp_table)
        sql_create_tmp_table = sql_create_tmp_table + 'CREATE TABLE {tmp_table} (\n\t{names}\n);\n\n' \
                                .format(tmp_table=tmp_table, names=names)
           
        sql_update_query = 'UPDATE {db_table}\n'.format(db_table=self.table)
        
        sql_set = ["SET "]
        
        for name in column_names:
            set_sql = "{name} = {tmp_name},\n\t".format(name=name, tmp_name=tmp_table + '.' + name)
            sql_set.append(set_sql)
            
        sql_set = ''.join(sql_set)
        sql_set = sql_set[:-3] + "\n"
        
        sql_from = "FROM {tmp_table}\nWHERE {db_table}.{id_col} = {tmp_table}.{id_col};\n\n" \
                            .format(tmp_table=tmp_table, db_table=self.table, id_col=id_col)
        sql_drop = 'DROP TABLE {};\n'.format(tmp_table)
        
        sql_query_1 = sql_create_tmp_table
        sql_query_2 = sql_update_query + sql_set + sql_from + sql_drop

        # Preview sql query to debug
        #print(sql_query_1 + "# Copy into temp_table\ncur.copy_from(...)\n\n"+ sql_query_2)
        
        # Run update query
        data_buffer = StringIO(data.to_csv(header=False, index=False))
        con = self.__connect()
        try:
            cur = con.cursor()
            
            # Create tmp_table
            cur.execute(sql_query_1)

            # Copy into temp_table
            data_buffer.read()
            cur.copy_from(file=data_buffer, table=tmp_table, columns=columns, sep=sep)
            data_buffer.close()
            
            # Update from temp_table into table and delete temp
            cur.execute(sql_query_2)
            con.commit()
            cur.close()
            print('Updated table "{}".'.format(self.table))
        except Exception as e:
            con.rollback()
            print('Error:\n', e)
        finally:
            if con is not None:
                con.close()

In [608]:
permits = Table(table="permits_raw")
data = permits.fetch_data(sql="SELECT * FROM permits_raw;")

In [609]:
permits.update_values(data, id_col="pcis_permit_no", columns=['assessor_book', 'latitude'])

Updated table "permits_raw".


In [414]:
### Rewrite to save table as csv, not dataframe
# Save csv with option to match order of columns in postgres
def save_csv(data, path, index=False):

    # Check unique columns
    if data.columns.tolist() != data.columns.unique().tolist():
        raise IndexError("Dataframe has duplicate columns.")

    if index:
        warnings.warn('Setting "index=True" may cause problems when importing from csv file.')


    # Write to csv
    data.to_csv(path, index=False)

In [400]:
permits._con

Connected as user "postgres" to database "permits" on http://localhost:5432.


In [401]:
permits.get_names()[:3]

0      assessor_book
1      assessor_page
2    assessor_parcel
Name: column_name, dtype: object

In [260]:
# Map of character replacements
replace_map = {' ': '_', '-': '_', '#': 'No', '/': '_', 
               '.': '', '(': '', ')': '', "'": ''}

permits.reformat_names(replace_map)[:3]

0      assessor_book
1      assessor_page
2    assessor_parcel
Name: column_name, dtype: object

In [261]:
permits.update_names(replace_map)

Connecting...
Executing query on table "permits_raw"...
Error:
 column "assessor_book" of relation "permits_raw" already exists



In [262]:
permits.add_columns(data)

Table columns are already up to date.


In [263]:
permits.get_names()[:3]

0      assessor_book
1      assessor_page
2    assessor_parcel
Name: column_name, dtype: object

In [264]:
permits.compare_column_order(data)

Dataframe columns match table "permits_raw" and are in identical order.


True

In [349]:
permits.match_column_order(data).head();

Dataframe columns already match table "permits_raw".


In [266]:
path = root_dir + "/data/interim/test.csv"

save_csv(data, path, index=True)

  if __name__ == '__main__':


In [348]:
types_dict = {'status':'VARCHAR(50)', 'permit_type':'VARCHAR(50)', 'permit_sub_type':'VARCHAR(50)', 
                'permit_category':'VARCHAR(50)', 'initiating_office':'VARCHAR(50)', 
                'license_type':'VARCHAR(50)', 'zone':'VARCHAR(50)', 'census_tract':'VARCHAR(50)', 
                'applicant_relationship':'VARCHAR(50)', 'block':'VARCHAR(50)', 'lot':'VARCHAR(50)', 
                'reference_no_old_permit_no':'VARCHAR(50)','pcis_permit_no':'VARCHAR(50)', 
               'address_fraction_start': 'CHAR(3)', 'address_fraction_end': 'CHAR(3)', 
                'street_direction': 'CHAR(1)', 'street_name': 'VARCHAR(50)', 'street_suffix': 'VARCHAR(10)',
               'suffix_direction': 'VARCHAR(10)', 'unit_range_start': 'VARCHAR(50)', 'unit_range_end': 'VARCHAR(50)',
               'work_description': 'TEXT', 'floor_area_la_zoning_code_definition': 'VARCHAR(10)', 
               'contractors_business_name': 'VARCHAR(100)', 'contractor_address': 'VARCHAR(100)',
               'contractor_city': 'VARCHAR(50)', 'contractor_state': 'CHAR(2)', 'license_type': 'VARCHAR(10)', 
               'principal_first_name': 'VARCHAR(50)', 'principal_middle_name': 'VARCHAR(50)', 
                'principal_last_name': 'VARCHAR(50)', 'applicant_first_name': 'VARCHAR(50)', 
                'applicant_last_name': 'VARCHAR(50)', 'applicant_business_name': 'VARCHAR(100)',
               'applicant_address_1': 'VARCHAR(50)', 'applicant_address_2': 'VARCHAR(50)', 
                'applicant_address_3': 'VARCHAR(50)', 'occupancy': 'VARCHAR(50)', 
                'floor_area_la_building_code_definition': 'VARCHAR(10)', 'census_tract': 'VARCHAR(10)',
                'latitude_longitude': 'VARCHAR(50)', 'assessor_parcel': 'CHAR(3)', 'tract': 'VARCHAR(200)',
              'assessor_book': 'SMALLINT', 'assessor_page': 'SMALLINT', 'council_district': 'SMALLINT', 
                'project_number': 'SMALLINT', 'address_start': 'INTEGER', 
                'address_end': 'INTEGER', 'no_of_residential_dwelling_units': 'SMALLINT', 
                'no_of_accessory_dwelling_units': 'SMALLINT', 'no_of_stories': 'SMALLINT', 
                'license_no': 'INTEGER', 'zip_code': 'INTEGER', 'existing_code': 'SMALLINT', 
                'proposed_code': 'SMALLINT', 'valuation':'NUMERIC', 'latitude':'NUMERIC', 'longitude':'NUMERIC',
             'status_date': 'DATE', 'issue_date': 'DATE', 'license_expiration_date': 'DATE', 'event_code':'VARCHAR(50)',
             'full_address':'VARCHAR(100)'}

In [416]:
sql = permits.update_values(data, id_col='pcis_permit_no')
#print(sql)

#for key, value in types_dict.items():
    #print(''.join([key, ' ', value, ',']))


AttributeError: 'DataFrame' object has no attribute 'tolist'

In [189]:
permits_fetch_

62