# Libraries import

In [1]:
import psycopg2
from psycopg2 import sql
import pandas as pd
import numpy as np
from datetime import datetime
import os
from dotenv import load_dotenv
import json
from tqdm import tqdm
from tabulate import tabulate
from collections import defaultdict
from mylogger import getmylogger

# Initialization and database connection

In [2]:
load_dotenv()

db_connection_dict = {
    'dbname': str(os.getenv('DB_NAME')),
    'user': str(os.getenv('DB_USER')),
    'password': str(os.getenv('DB_PASSWORD')),
    'host': str(os.getenv('DB_HOST')),
    'port': str(os.getenv('DB_PORT')),
    'options': """-c search_path="colombia" """
}

In [3]:
class database:

    def __init__(self,param_dict):
        self.conn = self.connect_bd(param_dict)

    def connect_bd(self, param_dict):
        conn = None
        try:
            conn = psycopg2.connect(**param_dict)
            conn.set_client_encoding('UTF8')
            print("Connection successful")
        except (Exception, psycopg2.DatabaseError) as error:
            print(error)
            conn = None

        return conn

In [4]:
logger = getmylogger(__name__)
conn = database(db_connection_dict).conn
cursor = conn.cursor()

Connection successful


# Queries

In [5]:
with open('query_parameters.json') as json_file:
    query_parameters = json.load(json_file)

In [6]:
def create_join_clause(parameters_dict,column_name):
    
    joins = parameters_dict.get("Joins").get(column_name)
    join_clause = sql.SQL("")

    if joins: # check if dictionary is not empty
        for key, value in joins.items():
            if "tables_to_join" in value: # check if there are tables to join
                tables_list = value.get("tables_to_join")
                if tables_list: # check if list is not empty
                    for table in tables_list:
                        join_type = table.get("join_type").strip()
                        table_name = table.get("name").strip()
                        primary_key = value.get("primary_key").strip()
                        foreign_key = joins.get(table_name).get("foreign_keys").get(key).strip()
                        join_query =  sql.SQL(join_type + " join {} on {} = {}\n").format(sql.Identifier(table_name),
                                                                      sql.Identifier(key, primary_key),
                                                                      sql.Identifier(table_name, foreign_key) )
                        join_clause = sql.Composed([join_clause, join_query])

    return join_clause

In [7]:
q = create_join_clause(query_parameters,"Planted_area_ha")
print(q.as_string(conn))

inner join "parcelwaves" on "parcels"."id" = "parcelwaves"."parcelid"



In [8]:
excel_file = os.path.join('missing_data', 'missing_data.xlsx')
missing_data_df = pd.read_excel(excel_file)

In [9]:
def execute_query(cursor,query):
    try:
        cursor.execute(query)
    except Exception as e:
        print(e)
        conn.rollback()
    else:
        conn.commit()

In [10]:
def create_where_condition(cursor,where_condition,table_and_field_to_check,value_to_check):

    if not where_condition.as_string(cursor): # if the where condition is empty (first iteration of for loop)
        where_check = sql.SQL("WHERE {field_name}={value}").format(field_name=sql.Identifier(*table_and_field_to_check),
                                                                value = sql.Literal(value_to_check))
        where_condition = sql.Composed([where_condition, where_check])
    else:
        where_check = sql.SQL(" AND {field_name}={value}").format(field_name=sql.Identifier(*table_and_field_to_check),
                                                                    value = sql.Literal(value_to_check))
        where_condition = sql.Composed([where_condition, where_check])

    return where_condition

In [11]:
def update_request_by_id(cursor, table, field, value, select_query):
    update_query = sql.SQL("UPDATE {table} SET {field} = {value} WHERE id in").format(
                                                table=sql.Identifier(table),
                                                field=sql.Identifier(field),
                                                value=sql.Literal(value)
                                                )
    
    query = sql.Composed([update_query, sql.SQL("("), select_query, sql.SQL(")")])
    
    logger.info(query.as_string(cursor).replace('\n', ' '))

    return execute_query(cursor,query)
    

In [12]:
def pretty_table_from_query_result(cursor, result):
    columns = [desc[0] for desc in cursor.description]
    dict_list = []
    for row in result:
        dict_list.append(dict(zip(columns, row)))
    dd = defaultdict(list)
    for d in dict_list:
        for key, value in d.items():
            dd[key].append(value)
    return tabulate(dd, headers="keys")

In [21]:
def create_select_request(cursor, parameters_dict, dataframe):
    columns_to_update = parameters_dict.get("Columns_to_update")

    for row in tqdm(dataframe.itertuples(), total=dataframe.shape[0]):
        for key1, value1 in columns_to_update.items():
            
            table_to_update = value1.get("table_name")
            field_to_update = value1.get("field_name")
            value_to_update = getattr(row, key1)
            table_and_field_to_select = (table_to_update,"id")

            join_clause = create_join_clause(parameters_dict,key1)
            identifying_columns = parameters_dict.get("Identifying_columns").get(key1)

            values_to_check = []
            count = 0
            
            where_condition = sql.SQL("")

            for key2, value2 in identifying_columns.items():

                table_to_check = value2.get("table_name")
                field_to_check = value2.get("field_name")
                table_and_field_to_check = (table_to_check,field_to_check)

                value_to_check = getattr(row, key2)
                if isinstance(value_to_check, str):
                    value_to_check = value_to_check.strip().lower()
                values_to_check.append(value_to_check)

                where_condition = create_where_condition(cursor,where_condition,table_and_field_to_check,value_to_check)

                select_fields = sql.SQL("SELECT DISTINCT {fields}\nFROM {table}\n").format(
                                                fields=sql.Identifier(*table_and_field_to_select),
                                                table=sql.Identifier(table_to_update))
                
                query = sql.Composed([select_fields, join_clause, where_condition])

                execute_query(cursor,query)
                result = cursor.fetchall()
                result_table = pretty_table_from_query_result(cursor, result)
                
                count += 1
                if count == 1:
                    logger.debug("-----------------------------------------------")

                if not result:
                    logger.debug("No result has been returned by query:")
                    logger.debug(query.as_string(cursor).replace('\n', ' '))
                elif len(result) == 1:
                    if value_to_update and pd.notna(value_to_update):
                        logger.debug("Select Query:")
                        logger.debug(query.as_string(cursor).replace('\n', ' '))
                        logger.debug("Query results:\n" + result_table)
                        logger.debug("One result has been returned. Making the update.")
                        logger.debug("Update Query:")
                        update_request_by_id(cursor, table_to_update, field_to_update , value_to_update, query)
                        break
                    else:
                        logger.debug("No value to update for {} {}.".format(', '.join(values_to_check) , 
                                                                     'value' if len(values_to_check) <= 1 else 'values'))
                else:
                    logger.debug("Select Query:")
                    logger.debug(query.as_string(cursor).replace('\n', ' '))
                    logger.debug("Multiple results have been returned by query:")
                    logger.debug("Query results:\n" + result_table)

    

In [18]:
create_select_request(cursor,query_parameters, missing_data_df)

100%|██████████| 191/191 [00:06<00:00, 31.05it/s]
