In [None]:
# default_exp core

In [None]:
#hide
%load_ext autoreload
%autoreload 2

# Core

In [None]:
#export
import os
import random
from io import StringIO
from typing import Collection, Dict, Optional

import pandas as pd
import psycopg2
from dotenv import load_dotenv

Connection = psycopg2.extensions.connection

In [None]:
#exporti
def _get_connection_arguments(server: str,
                              dotenv_path: str = '.env',
                              **conn_kwargs) -> Dict[str,str]:
    "Get connection arguments for server from .env file and conn_kwargs."
    load_dotenv(dotenv_path)
    variables = ('host', 'port', 'database', 'user', 'password')
    params = {variable: os.getenv(f'postgres_{variable}_{server}'.upper())
              for variable in variables}
    params.update(conn_kwargs)
    return params

In [None]:
#export
def create_connection(server: str,
                      dotenv_path: str = '.env',
                      **conn_kwargs) -> Connection:
    "Create psycopg2 connection to server from .env file and conn_kwargs."
    conn_args = _get_connection_arguments(server, dotenv_path, **conn_kwargs)
    return psycopg2.connect(**conn_args)

The first step is to create a connection to the database. This can be accomplished by setting up a .env file with the following convention: POSTGRES_{VARIABLE}_{SERVER_NAME}, such as the following:

In [None]:
!cat ../.env_sample

POSTGRES_HOST_LOCAL=localhost
POSTGRES_PORT_LOCAL=5432
POSTGRES_DATABASE_LOCAL=test
POSTGRES_USER_LOCAL=jose
POSTGRES_PASSWORD_LOCAL=mypwd


We can then create a connection to the database using `create_connection` by specifying the server name and the path to the .env file.

In [None]:
connection = create_connection('local', '../.env_sample')
cursor = connection.cursor()
cursor.execute('select now()')
print(cursor.fetchall())
cursor.close()

[(datetime.datetime(2021, 1, 21, 16, 57, 37, 19648, tzinfo=psycopg2.tz.FixedOffsetTimezone(offset=-360, name=None)),)]


We can use this connection to query a table using pandas, for example.

In [None]:
import pandas as pd

def print_test_table_contents():
    connection = create_connection('local', '../.env_sample')
    print(pd.read_sql_query('SELECT * FROM test_table', connection))
    connection.close()
    
print_test_table_contents()

Empty DataFrame
Columns: [x, y]
Index: []


In [None]:
#export
def append_df_to_table(conn: Connection,
                       table: str,
                       df: pd.DataFrame) -> None:
    "Append a dataframe to an existing postgresql table."
    with conn.cursor() as cur:
        output = StringIO()
        df.to_csv(output, sep='\t', header=False, index=False)
        output.seek(0)
        cur.copy_from(output, table, null='')

If we have an already existing table and want to append a pandas dataframe to it we can use `append_df_to_table`.

In [None]:
df = pd.DataFrame({'x': [1, 2, 3], 'y': ['ABC', 'DEF', 'GHI']})
df

Unnamed: 0,x,y
0,1,ABC
1,2,DEF
2,3,GHI


In [None]:
append_df_to_table(connection, 'test_table', df)

We can check that the data has been inserted.

In [None]:
pd.read_sql_query('select * from test_table', connection)

Unnamed: 0,x,y
0,1,ABC
1,2,DEF
2,3,GHI


However we have to commit these changes or else the transaction will be discarded when the connection is closed.

In [None]:
connection.close()

In [None]:
print_test_table_contents()

Empty DataFrame
Columns: [x, y]
Index: []


Another option is to use the connection as a context manager, which will commit the changes once it exits.

In [None]:
with create_connection('local', '../.env_sample') as conn:
    append_df_to_table(conn, 'test_table', df)
conn.close()

In [None]:
print_test_table_contents()

   x    y
0  1  ABC
1  2  DEF
2  3  GHI


If the append fails, a rollback is performed, else it is commited

In [None]:
valid_df = pd.DataFrame({'x': [4], 'y': ['JKL']})
invalid_df = pd.DataFrame({'x': ['a'], 'y': [1]})
with create_connection('local', '../.env_sample') as conn:
    try:
        append_df_to_table(conn, 'test_table', valid_df)
        print('Successfully inserted valid_df')
        append_df_to_table(conn, 'test_table', invalid_df)
        print('Successfully inserted invalid_df')
    except Exception as e:
        print('Transaction failed')
conn.close()

Successfully inserted valid_df
Transaction failed


In [None]:
print_test_table_contents()

   x    y
0  1  ABC
1  2  DEF
2  3  GHI


In [None]:
#export
def update_table_from_df(conn: Connection,
                         table: str,
                         df: pd.DataFrame,
                         join_cols: Collection[str],
                         update_cols: Optional[Collection[str]] = None,
                         extra_where: Optional[str] = None):
    """Updates a postgresql table using the contents of a dataframe.

    If update_cols is None (the default) join_cols are used in the where
    statement of the update and the remaining columns are updated,
    otherwise only update_cols are updated.

    extra_where is placed at the start of the where statement.
    """
    if update_cols is None:
        update_cols = df.columns.drop(join_cols)
    temp_table = f'temp_replacements_{random.randint(1, 1000):04}'
    create_query = f"""
        CREATE TEMP TABLE {temp_table}
        AS
        SELECT {', '.join(df.columns)}
        FROM {table}
        LIMIT 0
    """
    def _create_equals_statements(cols, join_str, left_prefix='old_table.'):
        statements = [f'{left_prefix}{col} = new_table.{col}' for col in cols]
        return join_str.join(statements)
    set_statement = _create_equals_statements(update_cols, join_str=',\n\t', left_prefix='')
    where_statement = _create_equals_statements(join_cols, join_str='\n\tAND ')
    if extra_where is not None:
        where_statement = extra_where + '\n\tAND ' + where_statement
    update_query = f"""
        UPDATE {table} AS old_table
        SET {set_statement}
        FROM {temp_table} new_table
        WHERE {where_statement}
    """
    with conn.cursor() as cur:
        cur.execute(create_query)
        append_df_to_table(conn, temp_table, df)
        cur.execute(update_query)

We can also perform updates on existing tables using pandas dataframes.

In [None]:
new_values = pd.DataFrame({'x': [2], 'y': ['MNO']})
new_values

Unnamed: 0,x,y
0,2,MNO


In [None]:
with create_connection('local', '../.env_sample') as conn:
    update_table_from_df(conn, 'test_table', new_values, join_cols='x')
conn.close()

In [None]:
print_test_table_contents()

   x    y
0  1  ABC
1  3  GHI
2  2  MNO
