## Imports

In [67]:
import psycopg2
import numpy as np
import math

## Setting up ENV and GLOBAL variables

In [99]:

dbname = "census"
user = "nightfury"
host = "localhost"
password = ""
db_relation = "adult"
reference_dataset = db_relation
target_dataset = "target_dataset"


# measure_attributes=['age','fnlwgt', 'education_num','capital_gain','capital_loss','hours_per_week']
# groupby_attributes=['workclass','education','occupation','relationship','race','sex','native_country','salary_range']
# aggregate_functions=['sum','avg','max','min','count']

measure_attributes=['age','capital_gain','hours_per_week']
groupby_attributes=['workclass','relationship','sex']
aggregate_functions=['avg','count']

#Connecting with local db
try:
    conn = psycopg2.connect(f"dbname='{dbname}' user='{user}' host='{host}' password='{password}'")
    cur = conn.cursor()
except Exception as e:
    print (f"Unable to connect to the database. Error: {str(e)}")

#To get top k utility views
K = 5

## Util functions

In [97]:

def get_cursor(conn):
    cur = conn.cursor()
    return cur

def get_all_views(groupby_attributes, measure_attributes, aggregate_functions):
    views = np.array(np.meshgrid(groupby_attributes, measure_attributes, aggregate_functions)).T.reshape(-1,3)
    return views

def query_generator(groupby_attribute, measure_attribute, aggregate_function, target_relation, reference_relation):
    target_query = f"""SELECT {groupby_attribute}, {aggregate_function}({measure_attribute}) FROM {target_relation} GROUP BY {groupby_attribute}"""
    reference_query = f"""SELECT {groupby_attribute}, {aggregate_function}({measure_attribute}) FROM {reference_relation} GROUP BY {groupby_attribute}"""
    return target_query, reference_query

def execute_get_query(cursor, query):
    cursor.execute(query)
    rows = cursor.fetchall()
    return rows

def get_key_based_values(target_result, reference_result):
    key_set = set()
    target_result_dict = {}
    reference_result_dict = {}
    final_target_values = []
    final_reference_values = []
    
    for key, value in target_result:
        key_set.add(key)
        target_result_dict[key] = float(value)
        
    for key, value in reference_result:
        key_set.add(key)
        reference_result_dict[key] = float(value)
        
    for key in key_set:
        final_target_values.append(target_result_dict.get(key, float(10e-20)) if target_result_dict.get(key, float(10e-20)) != 0.0 else float(10e-20))
        final_reference_values.append(reference_result_dict.get(key, float(10e-20)) if reference_result_dict.get(key, float(10e-20)) != 0.0 else float(10e-20))
    
    return final_target_values, final_reference_values

def calculate_kl_divergence(vector1, vector2):
    return sum(vector1[i] * math.log2(vector1[i]/vector2[i]) for i in range(len(vector1)))

def get_top_k_utility_views(kl_divergence_view_mapping_list, k):
    top_k_views = [view for kl_divergence, view in sorted(kl_divergence_view_mapping_list)[0:k]]
    return top_k_views


## Getting user input and setting the target_db

In [38]:
#test_query = select * from adult where relationship =' Unmarried'; 
try:
    query = input("Enter a SELECT query")
    #cur = get_cursor(conn)
    cur.execute(f"""DROP table IF EXISTS {target_dataset};""")
    print(f"""Target dataset create command: create table {target_dataset} as {query};""")
    cur.execute(f"""create table {target_dataset} as {query};""")
    conn.commit()
except Exception as e:
    print(f"Error in establishing target db. Error: {str(e)}")
    
print(f"The reference dataset is {reference_dataset}")

In [94]:
def get_kl_divergence_mapping_list(views):
    kl_divergence_view_mapping_list = []
    for groupby_attribute, measure_attribute, aggregate_function in views:
        target_query, reference_query = query_generator(groupby_attribute, measure_attribute, aggregate_function, target_dataset, reference_dataset)
        target_result = execute_get_query(get_cursor(conn), target_query)
        reference_result = execute_get_query(get_cursor(conn), reference_query)
        target_values, reference_values = get_key_based_values(target_result, reference_result)        
        kl_divergence = calculate_kl_divergence(target_values, reference_values)
        kl_divergence_view_mapping_list.append((kl_divergence, (groupby_attribute, measure_attribute, aggregate_function)))
        
    return kl_divergence_view_mapping_list
        
views = get_all_views(groupby_attributes, measure_attributes, aggregate_functions)
kl_divergence_view_mapping_list = get_kl_divergence_mapping_list(views)

In [100]:
top_k_views = get_top_k_utility_views(kl_divergence_view_mapping_list,  K)
print(top_k_views)

[('workclass', 'age', 'count'), ('workclass', 'capital_gain', 'count'), ('workclass', 'hours_per_week', 'count'), ('sex', 'age', 'count'), ('sex', 'capital_gain', 'count')]
