In [1]:
import subprocess
import os
import tempfile
import csv
import re
import sqlalchemy
from sqlalchemy import create_engine, MetaData, inspect, event, text
# Function to extract table names from the query log file
def extract_table_names(file_path):
    table_names = set()
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith("Table:"):
                table_name = line.split("Table:")[1].strip()
                table_names.add(table_name)
    return table_names
class Style:
    RED = '\033[1;31m'  # Red color, bold
    GREEN = '\033[1;32m' # Green color, bold
    YELLOW = '\033[1;33m' # Yellow color, bold
    WHITE = '\033[37m'
    RESET = '\033[0m'  # Reset to default



# Function to get the primary key column of a table
def get_primary_key(engine, schema_name, table_name):
    primary_key_query = sqlalchemy.text("""
        SELECT kcu.column_name
        FROM information_schema.table_constraints AS tc
        JOIN information_schema.key_column_usage AS kcu
            ON tc.constraint_name = kcu.constraint_name
            AND tc.table_schema = kcu.table_schema  -- Ensuring schema matches in both tables
        WHERE tc.table_schema = :schema_name
            AND tc.table_name = :table_name
            AND tc.constraint_type = 'PRIMARY KEY';
    """)

    with engine.connect() as conn:
        result = conn.execute(primary_key_query, {'schema_name': schema_name, 'table_name': table_name})
        primary_key_column = result.fetchone()
        return primary_key_column[0] if primary_key_column else None
# Function to execute PostgreSQL configuration commands
def execute_config_commands(db_params):
    postgres_config_commands = [
        "SET statement_timeout = 0;",
        "SET lock_timeout = 0;",
        "SET idle_in_transaction_session_timeout = 0;",
        "SET client_encoding = 'UTF8';",
        "SET standard_conforming_strings = on;",
        #"SELECT pg_catalog.set_config('search_path', '', false);",
        "SET check_function_bodies = false;",
        "SET xmloption = content;",
        "SET client_min_messages = warning;",
        "SET row_security = off;",
        "SET session_replication_role = 'replica';"  # This line disables foreign key checks
    ]
    config_command = " ".join(postgres_config_commands)
    command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -c \"{config_command}\""
  
    try:
        subprocess.run(command, shell=True, env={'PGPASSWORD': db_params['password']})
    except Exception as e:
        print(f"{Style.RED}Error executing configuration commands: {e}{Style.RESET}")

def get_row_count(table_name, db_params,schema_name):
    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.sql') as sql_script:
        # Write the SQL command to the temporary file
        sql_script.write(f"SELECT COUNT(*) AS row_count FROM {schema_name}.{table_name};\n")
        sql_script_path = sql_script.name

    # Command to execute the SQL script
    command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -f {sql_script_path}"

    try:
        # Execute the command
        result = subprocess.run(command, shell=True, env={'PGPASSWORD': db_params['password']}, capture_output=True, text=True)
        if result.stdout:
            # Parse the output to find the row count
            for line in result.stdout.splitlines():
                if line.strip().isdigit():
                    count = int(line.strip())
                    return count
            print(f"No row count found in output for table {schema_name}.{table_name}.")
            print(f"Result stdout: {result.stdout}")
            return 0
        else:
            print(f"Table {table_name} does not already contain data. Inserting...")
            print(f"Anything in stderr? {result.stderr}")
            return 0
    except Exception as e:
        print(f"Error getting row count for {table_name}: {e}")
        return 0
    finally:
        # Clean up temporary file
        os.remove(sql_script_path)
def get_foreign_keys(engine, schema, table_name):
    foreign_keys = []
    foreign_key_query = text("""
    SELECT 
        kcu.column_name
    FROM 
        information_schema.table_constraints AS tc 
        JOIN information_schema.key_column_usage AS kcu 
        ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
    WHERE 
        tc.constraint_type = 'FOREIGN KEY' 
        AND tc.table_name = :table_name 
        AND tc.table_schema = :schema;
    """)

    with engine.connect() as conn:
        result = conn.execute(foreign_key_query, {'table_name': table_name, 'schema': schema})
        foreign_keys = [row[0] for row in result]  # Accessing the first (and only) element of the row tuple

    return foreign_keys
def delete_csv_data_from_db(table_name, db_params, schema):
    csv_file_path = f"/app/outputs/output_{table_name}.csv"  # Adjusted path
    print(f"Deleting data from : {csv_file_path}")
    if not os.path.exists(csv_file_path):
        print(f"CSV file for {table_name} not found. Skipping...")
        return 0

    # SQLAlchemy engine setup
    db_url = f"postgresql+psycopg2://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['dbname']}"
    engine = create_engine(db_url)
    
    # Get the primary key column
    primary_key_column = get_primary_key(engine, schema, table_name)
     # Get initial row count
    initial_count = get_row_count(table_name, db_params, schema)
    print(f"\t{Style.GREEN}Initial row count for {table_name} is {initial_count}{Style.RESET}")
    # If no primary key is found, attempt to use foreign keys
    if primary_key_column is None:
        print(f"No primary key found for {table_name}. Attempting to delete using foreign keys.")
        foreign_keys = get_foreign_keys(engine, schema, table_name)
        if not foreign_keys:
            print(f"No foreign keys found for {table_name}. Cannot perform deletion without a unique identifier.")
            return 0

        matching_condition = ' AND '.join(f"{table_name}.{key} = temp.{key}" for key in foreign_keys)
        delete_statement = f"DELETE FROM \"{schema}\".\"{table_name}\" USING temp WHERE {matching_condition};"
    else:
        # Check if the primary key column exists in the CSV file
        with open(csv_file_path, mode='r', encoding='utf-8') as csvfile:
            reader = csv.reader(csvfile)
            headers = next(reader)
            if primary_key_column not in headers:
                print(f"Primary key column '{primary_key_column}' not found in CSV headers for {table_name}. Skipping deletion.")
                return 0

        delete_statement = f"DELETE FROM \"{schema}\".\"{table_name}\" WHERE EXISTS (SELECT 1 FROM temp WHERE temp.{primary_key_column} = \"{table_name}\".{primary_key_column});"

    # Perform deletion
    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.sql') as sql_script:
        sql_script.write(f"SET search_path TO {schema};\n")
        sql_script.write(f"SELECT * FROM aiven_extras.session_replication_role('replica');\n")
        sql_script.write(f"CREATE TEMP TABLE temp (LIKE \"{schema}\".\"{table_name}\" INCLUDING DEFAULTS);\n")
        sql_script.write(f"\\copy temp FROM '{csv_file_path}' WITH CSV HEADER;\n")
        sql_script.write(delete_statement)
        sql_script.write(f"DROP TABLE temp;\n")
        sql_script_path = sql_script.name

    # Execute the script using the psql command-line interface
    command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -f {sql_script_path}"

    try:
        print(f"Executing: {command}")
        result = subprocess.run(command, shell=True, env={'PGPASSWORD': db_params['password']}, capture_output=True, text=True)

        # Log any stdout or stderr for debugging
        print(f"Std out: {result.stdout}")
        print(f"Std err: {result.stderr}")

    except Exception as e:
        print(f"Error in deleting CSV data for {table_name}: {e}")
        return 0
    finally:
        # Clean up temporary file
        os.remove(sql_script_path)

    # Get new row count and calculate the number of rows deleted
    new_count = get_row_count(table_name, db_params, schema)
    rows_deleted = initial_count - new_count
    print(f"\tFor table {Style.YELLOW}{table_name} deleted {Style.RESET}{Style.GREEN}{rows_deleted}{Style.RESET} rows")
    return rows_deleted,initial_count
    
def run_psql_command(command, db_params):
    try:
        # Pass the password through environment variables for security
        env_vars = {'PGPASSWORD': db_params['password']}
        completed_process = subprocess.run(command, shell=True, env=env_vars, text=True, capture_output=True, check=True)
        output_lines = completed_process.stdout.splitlines()

        # Debugging: print all output lines
        print("\tOutput from psql command:")
        for line in output_lines:
            print(f"{Style.GREEN}{line}{Style.RESET}")

        # Assuming the actual value is in the third line of the output
        if len(output_lines) >= 3:
            return output_lines[2].strip(), None
        else:
            return "Expected output not found", None

    except subprocess.CalledProcessError as e:
        return None, e.stderr.strip()
# Function to set and then immediately check the session_replication_role
def set_and_check_replication_role(db_params):
    set_role_command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -c \"SELECT * FROM aiven_extras.session_replication_role('replica');\""
    role_after_setting, error = run_psql_command(set_role_command, db_params)
    if role_after_setting is not None:
        print(f"{Style.YELLOW}Session_replication_role after setting to replica: {Style.RESET}{Style.GREEN}{role_after_setting}{Style.RESET}")
    else:
        print(f"{Style.RED}Error: {error}{Style.RESET}")

# Calling the function

# Main function to execute the import process
def main():
    table_log_file = '/app/outputs/tables_log.txt'  # Adjusted path
    db_params = {
        'host': os.environ.get('DB_HOST', 'localhost'),
        'dbname': os.environ.get('DB_NAME', 'defaultdb'),
        'user': os.environ.get('DB_USER', 'default'),
        'password': os.environ.get('DB_PASSWORD', 'password'),
        'port': os.environ.get('DB_PORT', 'default_port')
    }
    schema = os.getenv('DB_SCHEMA', 'public')
    # Execute PostgreSQL configuration commands
    # Check the current session_replication_role
    set_and_check_replication_role(db_params)    

    # Dictionary to hold the count of rows imported for each table
    deleted_rows_count = {}
    export_filename = f"/app/outputs/deleted_rows_per_table_{db_params['dbname']}_{db_params['user']}_{schema}.txt"
    table_names = extract_table_names(table_log_file)
    for table_name in table_names:
        print(f"\tExecuting delete for table {Style.YELLOW}{table_name}{Style.RESET}")
        rows_deleted, initial_count  = delete_csv_data_from_db(table_name, db_params, schema)
        deleted_rows_count[table_name] = (initial_count, rows_deleted)

    # Write the row counts to a new file
    with open(export_filename, 'w') as f:
        for table, counts in deleted_rows_count.items():
            initial, inserted = counts
            f.write(f"{table}: Initial count {initial} rows, DELETED {inserted} rows\n")
# Execute the main function
if __name__ == "__main__":
    main()

CSV file for psp.app_deposits not found. Skipping...
CSV file for psp.app_brands not found. Skipping...
CSV file for psp.app_availability_rules not found. Skipping...
