In [1]:
import os
import pandas as pd
import sqlalchemy
import networkx as nx
from sqlalchemy import create_engine, MetaData, inspect, event, text
from collections import deque
from queue import Queue
from collections import deque
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pandas as pd
# Database connection parameters
# Database connection parameters are now retrieved from environment variables
db_params = {
    'host': os.getenv('DB_HOST', 'default_host'),
    'dbname': os.getenv('DB_NAME', 'default_db_name'),
    'user': os.getenv('DB_USER', 'default_user'),
    'password': os.getenv('DB_PASSWORD', 'default_password'),
    'port': int(os.getenv('DB_PORT', '5432'))  # Default port number if not specified
}
# ANSI color and style codes
class Style:
    RED = '\033[0;31m'  # Red color, regular
    GREEN = '\033[0;32m' # Green color, regular
    YELLOW = '\033[0;33m' # Yellow color, regular
    WHITE = '\033[0;37m' # White color, regular
    RESET = '\033[0m'  # Reset to default
def get_table_columns(engine, schema_name, table_name, excluded_columns):
    """
    Retrieve column names for a given table excluding certain columns.
    """
    query = sqlalchemy.text("""
        SELECT column_name 
        FROM information_schema.columns 
        WHERE table_schema = :schema_name AND table_name = :table_name;
    """)

    with engine.connect() as conn:
        result = conn.execute(query, {'schema_name': schema_name, 'table_name': table_name})
        columns = [row[0] for row in result.fetchall() if row[0] not in excluded_columns]
    
    return columns

# Establishing the connection
db_url = f"postgresql+psycopg2://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
engine = create_engine(db_url)
# Set the schema for this session
schema_name = os.getenv('DB_SCHEMA', 'default_schema')  # Retrieve schema name from environment variable

@event.listens_for(engine, "connect")
def set_search_path(dbapi_connection, connection_record):
    cursor = dbapi_connection.cursor()
    cursor.execute(f"SET search_path TO {schema_name};")
    cursor.close()
# with conn.cursor() as cur:
#     cur.execute(f"SET search_path TO {schema_name};")

# Define a default JSON object for fields that are None
default_json = {}  # Update this with a suitable default JSON object
empty_placeholder = "\\N"  # Placeholder for empty fields in .dump file

with engine.connect() as conn:
        print(f"\t{Style.GREEN}Successfully connected to the database!{Style.RESET}")
        conn.execute(text(f"SET search_path TO {schema_name};"))


def generate_truncate_sql(export_path, related_tables):
    truncate_file_path = os.path.join(export_path, "truncate_tables.sql")
    with open(truncate_file_path, 'w', encoding='utf-8') as truncate_file:
        for table in related_tables:
            truncate_file.write(f"TRUNCATE TABLE {schema_name}.{table} restart identity cascade;\n")
    print(f"Truncate script created at {truncate_file_path}")


def get_key_value(table, column, value, connection):
    """
    Fetch the corresponding key value from a table based on a column and its value.
    """
    query = f"SELECT {column} FROM {table} WHERE {column} = %s;"
    with connection.cursor() as cur:
        cur.execute(query, [value])
        result = cur.fetchone()
        return result[0] if result else None


db_params = {
    'host': os.getenv('DB_HOST', 'default_host'),
    'dbname': os.getenv('DB_NAME', 'default_db_name'),
    'user': os.getenv('DB_USER', 'default_user'),
    'password': os.getenv('DB_PASSWORD', 'default_password'),
    'port': int(os.getenv('DB_PORT', '5432'))  # Default port number if not specified
}
def extract_relationships(excluded_tables,excluded_columns):
    relationships = []

    foreign_key_query = """
    SELECT DISTINCT
        tc.table_name AS primary_table, 
        kcu.column_name AS primary_column, 
        ccu.table_name AS foreign_table_name,
        ccu.column_name AS foreign_column
    FROM 
        information_schema.table_constraints AS tc 
        JOIN information_schema.key_column_usage AS kcu 
        ON tc.constraint_name = kcu.constraint_name
        JOIN information_schema.constraint_column_usage AS ccu 
        ON ccu.constraint_name = tc.constraint_name
    WHERE 
        tc.constraint_type = 'FOREIGN KEY';
    """

    with engine.connect() as conn:
        result = conn.execute(sqlalchemy.text(foreign_key_query))
        for row in result:
            primary_table, primary_column, foreign_table, foreign_column = row

            # Check if either the primary or foreign table is in the excluded list
            if primary_table in excluded_tables or foreign_table in excluded_tables:
                continue  # Skip this relationship
            if primary_column in excluded_columns or foreign_column in excluded_columns:
                continue
            relationships.append((primary_table, primary_column, foreign_table, foreign_column))

    return relationships

def generate_relationships_dict(relationships,excluded_columns,starting_table):
    relationships_dict = {}

    #  # Ensure 'tenant' is included in the dictionary with an empty list
    # relationships_dict[starting_table] = []
    for primary_table, primary_column, foreign_table, foreign_column in relationships:
        if primary_column in excluded_columns or foreign_column in excluded_columns:
            continue  # Skip this relationship
        if primary_table not in relationships_dict:
            relationships_dict[primary_table] = []
        relationships_dict[primary_table].append((foreign_table, primary_column, foreign_column))

    return relationships_dict

def get_primary_key(engine, schema_name, table_name):
    """
    Retrieve the primary key column for a given table.
    """
    primary_key_query = sqlalchemy.text("""
        SELECT kcu.column_name
        FROM information_schema.table_constraints AS tc 
        JOIN information_schema.key_column_usage AS kcu
        ON tc.constraint_name = kcu.constraint_name
        WHERE tc.table_schema = :schema_name
        AND tc.table_name = :table_name
        AND tc.constraint_type = 'PRIMARY KEY';
    """)

    with engine.connect() as conn:
        result = conn.execute(primary_key_query, {'schema_name': schema_name, 'table_name': table_name})
        primary_key_column = result.fetchone()
        return primary_key_column[0] if primary_key_column else None


def create_graph(relationships_dict):
    G = nx.DiGraph()
    for primary_table, relations in relationships_dict.items():
        for foreign_table, primary_column, foreign_column in relations:
            # Add edges as originally defined
            G.add_edge(primary_table, foreign_table, primary_key=primary_column, foreign_key=foreign_column)
    return G

# Assuming engine is your SQLAlchemy database engine
def extract_relationships_last(engine):
    relationships = []
    foreign_key_query = text("""
    SELECT DISTINCT
        tc.table_name AS primary_table, 
        kcu.column_name AS primary_column, 
        ccu.table_name AS foreign_table_name,
        ccu.column_name AS foreign_column
    FROM 
        information_schema.table_constraints AS tc 
        JOIN information_schema.key_column_usage AS kcu 
        ON tc.constraint_name = kcu.constraint_name
        JOIN information_schema.constraint_column_usage AS ccu 
        ON ccu.constraint_name = tc.constraint_name
    WHERE 
        tc.constraint_type = 'FOREIGN KEY';
    """)

    with engine.connect() as conn:
        result = conn.execute(foreign_key_query)
        for row in result:
            relationships.append((row.primary_table, row.primary_column, row.foreign_table_name, row.foreign_column))
    return relationships

def create_graph_last(relationships):
    G = nx.DiGraph()
    for primary_table, primary_column, foreign_table, foreign_column in relationships:
        G.add_edge(primary_table, foreign_table, primary_key=primary_column, foreign_key=foreign_column)
    return G



def find_all_paths_from_node_last(graph, start_node):
    all_paths = {}
    for target_node in graph.nodes():
        if start_node != target_node:
            paths = list(nx.all_simple_paths(graph, target_node, start_node))
            all_paths[target_node] = paths
    return all_paths
union_counts = 0
def generate_sql_queries_restructured(all_paths, schema_name, target_node, key_value):
    query_dict = {}
    union_count_dict = {}  # Hashtable for storing number of UNIONs per table

    for table, paths in all_paths.items():
        primary_key_column = get_primary_key(engine, schema_name, table)  # dynamically get the primary key column
        if not primary_key_column:  # Skip if no primary key is found
            print(f"NO PRIMARY KEY COLUMN FOR {table}")
            continue
        path_queries = []
        for path in paths:
            query_parts = []
            for i in range(len(path) - 1, 0, -1):
                current_table = path[i]
                prev_table = path[i - 1]
                primary_column = G[prev_table][current_table]['primary_key']
                foreign_column = G[prev_table][current_table]['foreign_key']
                join_condition = f'"{schema_name}"."{prev_table}"."{primary_column}" = "{schema_name}"."{current_table}"."{foreign_column}"'
                query_parts.insert(0, f'LEFT JOIN "{schema_name}"."{current_table}" ON {join_condition}')

            select_clause = f'SELECT DISTINCT "{schema_name}"."{table}".*'
            from_clause = f'FROM "{schema_name}"."{path[0]}"'
            where_clause = f'WHERE "{schema_name}"."{target_node}"."id" = \'{key_value}\' AND "{schema_name}"."{table}"."{primary_key_column}" IS NOT NULL'
            path_query = f'{select_clause} {from_clause} {" ".join(query_parts)} {where_clause}'
            path_queries.append(path_query)

        # Combine all path queries for the current table using UNION if there's more than one path
        if len(paths) > 1:
            query_dict[table] = ' UNION '.join(path_queries)
            union_count_dict[table] = len(paths) - 1
        elif len(paths) == 1:
            print(f"Found one path for {Style.YELLOW}{table}{Style.RESET}")
            query_dict[table] = path_queries[0] 
        else:
            print(f"No paths found for {Style.RED}{table}{Style.RESET}")
    return query_dict, union_count_dict 
def export_table_data(all_paths_from_tenant, queries,starting_table,starting_key_column,key_value,export_path='/app/outputs', log_file_name='query_log.txt',tables_file='tables_log.txt',all_paths='all_paths.txt',include_Starting = True):

    log_file_path = os.path.join(export_path, log_file_name)
    tables_file_path = os.path.join(export_path, tables_file)
    all_paths = os.path.join(export_path,all_paths)
    with open(log_file_path, 'w', encoding='utf-8') as log_file:
        # Write the modified query to the log file
        
        #write the first table
        if include_Starting: 
            initial_query_str = f'SELECT DISTINCT * FROM "{schema_name}"."{starting_table}"  WHERE "{schema_name}"."{starting_table}"."{starting_key_column}" = :key_value AND "{schema_name}"."{starting_table}"."id" IS NOT NULL'
            include_Starting = False
            query_str = str(initial_query_str)

            # Manually replace the parameter placeholder with its value
            # Use repr() to handle string parameters correctly by adding quotes
            param_value = repr(key_value) if isinstance(key_value, str) else str(key_value)
            query_str_with_params = query_str.replace(":key_value", param_value)
            log_file.write(f"Table: {starting_table}\n{query_str_with_params}\n")
        for table, query in queries.items():
            log_file.write(f"Table: {table}\n{query}\n")

    with open(tables_file_path, 'w', encoding='utf-8') as tables_file:
        # Write the modified query to the log file
        #log_file.write(f"{query_str_with_params}\n")
        for table, query in queries.items():
            tables_file.write(f"Table: {table}\n")

    with open(all_paths, "w") as file: 
        for table, path in all_paths_from_tenant.items():
            file.write(f"Table:{table}\n Paths:{path}\n")

starting_table = os.getenv('STARTING_TABLE', 'default_value_for_starting_table')
starting_key_column = os.getenv('STARTING_KEY_COLUMN', 'default_value_for_starting_key_column')

# For 'key_value', handle different data types (string/int)
key_value_str = os.getenv('KEY_VALUE', 'default_value_for_key_value')
try:
    key_value = int(key_value_str)
except ValueError:
    key_value = key_value_str  # Use the string value if it's not an integer

# Split the excluded tables into a list if they are provided as a comma-separated string
excluded_tables_str = os.getenv('EXCLUDED_TABLES', '')
excluded_tables = excluded_tables_str.split(',') if excluded_tables_str else []
excluded_columns = {'modified_by_id', 'created_by_id','offer_symbol_group'}

relationships = extract_relationships(excluded_tables,excluded_columns)
print(f"\t{Style.YELLOW}Relationships have been extracted from schema!{Style.RESET}")
dictrel = generate_relationships_dict(relationships,excluded_columns,starting_table)
print(f"\t{Style.YELLOW}Total tables found in relationship with {Style.RESET}{Style.GREEN}{starting_table}{Style.RESET} = {Style.GREEN}{len(dictrel)}{Style.RESET}")
G = create_graph_last(relationships)
print(f"\t{Style.YELLOW}Graph representation has been created succesfully !{Style.RESET}")
all_paths_from_tenant = find_all_paths_from_node_last(G, starting_table)
print(f"\t{Style.YELLOW}Total paths found from {Style.RESET}{Style.GREEN}{starting_table}{Style.RESET}{Style.WHITE} to any other table = {Style.RESET}{Style.GREEN}{len(all_paths_from_tenant)}{Style.RESET}\n")

sql_queries,union_count_dict  = generate_sql_queries_restructured(all_paths_from_tenant, schema_name,starting_table,key_value)
print(f"\t{Style.YELLOW}Sql queries composed Total = {Style.RESET}{Style.GREEN}{len(sql_queries)}{Style.RESET}\n")
print(f"\t{Style.YELLOW}Total Queries that constructed using Union :{Style.RESET}")


for table, union_count in union_count_dict.items():
    print(f"\tTable: {Style.GREEN}{table}{Style.RESET}, Number of UNIONs: {Style.GREEN}{union_count}{Style.RESET}")
print(f"\t{Style.YELLOW}Data exporting..{Style.RESET} \n")
export_table_data(all_paths_from_tenant,sql_queries,starting_table,starting_key_column,key_value)
print(f"\t{Style.GREEN}Log files created succesfully..in{Style.RESET} {Style.RED}/app/outputs{Style.RESET} named: {Style.YELLOW}query_log.txt{Style.RESET}----{Style.YELLOW}tables_log.txt{Style.RESET}----{Style.YELLOW}all_paths.txt{Style.RESET} \n")




OperationalError: (psycopg2.OperationalError) could not translate host name "default_host" to address: No such host is known. 

(Background on this error at: https://sqlalche.me/e/20/e3q8)