In [1]:
import subprocess
import os
import tempfile
import csv
import re
# Function to extract table names from the query log file
def extract_table_names(file_path):
    table_names = set()
    with open(file_path, 'r') as file:
        for line in file:
            line = line.strip()
            if line.startswith("Table:"):
                table_name = line.split("Table:")[1].strip()
                table_names.add(table_name)
    return table_names
class Style:
    RED = '\033[1;31m'  # Red color, bold
    GREEN = '\033[1;32m' # Green color, bold
    YELLOW = '\033[1;33m' # Yellow color, bold
    WHITE = '\033[37m'
    RESET = '\033[0m'  # Reset to default
# Function to execute PostgreSQL configuration commands
def execute_config_commands(db_params):
    postgres_config_commands = [
        "SET statement_timeout = 0;",
        "SET lock_timeout = 0;",
        "SET idle_in_transaction_session_timeout = 0;",
        "SET client_encoding = 'UTF8';",
        "SET standard_conforming_strings = on;",
        #"SELECT pg_catalog.set_config('search_path', '', false);",
        "SET check_function_bodies = false;",
        "SET xmloption = content;",
        "SET client_min_messages = warning;",
        "SET row_security = off;",
        "SET session_replication_role = 'replica';"  # This line disables foreign key checks
    ]
    config_command = " ".join(postgres_config_commands)
    command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -c \"{config_command}\""
  
    try:
        subprocess.run(command, shell=True, env={'PGPASSWORD': db_params['password']})
    except Exception as e:
        print(f"{Style.RED}Error executing configuration commands: {e}{Style.RESET}")

def get_row_count(table_name, db_params,schema_name):
    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.sql') as sql_script:
        # Write the SQL command to the temporary file
        sql_script.write(f"SELECT COUNT(*) AS row_count FROM {schema_name}.{table_name};\n")
        sql_script_path = sql_script.name

    # Command to execute the SQL script
    command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -f {sql_script_path}"

    try:
        # Execute the command
        result = subprocess.run(command, shell=True, env={'PGPASSWORD': db_params['password']}, capture_output=True, text=True)
        if result.stdout:
            # Parse the output to find the row count
            for line in result.stdout.splitlines():
                if line.strip().isdigit():
                    count = int(line.strip())
                    return count
            print(f"No row count found in output for table {schema_name}.{table_name}.")
            print(f"Result stdout: {result.stdout}")
            return 0
        else:
            print(f"Table {table_name} does not already contain data. Inserting...")
            print(f"Anything in stderr? {result.stderr}")
            return 0
    except Exception as e:
        print(f"Error getting row count for {table_name}: {e}")
        return 0
    finally:
        # Clean up temporary file
        os.remove(sql_script_path)

def import_csv_to_db(table_name, db_params, schema):
    csv_file_path = f"/app/outputs/output_{table_name}.csv"  # Adjusted path
    print(f"\{Style.YELLOW}Importing!! : {csv_file_path}{Style.RESET}")
    if not os.path.exists(csv_file_path):
        print(f"\t{Style.RED}CSV file for {table_name} not found. Skipping...{Style.RESET}")
        return 0

    # Get initial row count
    initial_count = get_row_count(table_name, db_params, schema)
    print(f"\t{Style.GREEN}initial_count for {table_name} is {initial_count}{Style.RESET}")
    with open(csv_file_path, mode='r', encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        columns = next(reader)

    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.sql') as sql_script:
        sql_script.write(f"SET search_path TO {schema};\n")
        sql_script.write(f"SELECT * FROM aiven_extras.session_replication_role('replica');\n")
        #sql_script.write(f"SET session_replication_role = 'replica';\n")
        sql_script.write(f"CREATE TEMP TABLE tmp_{table_name} (LIKE \"{table_name}\" INCLUDING DEFAULTS);\n")
        sql_script.write(f"\\copy tmp_{table_name} FROM '{csv_file_path}' WITH CSV HEADER;\n")
        sql_script.write(f"INSERT INTO \"{table_name}\" SELECT * FROM tmp_{table_name};\n")
        sql_script.write(f"DROP TABLE tmp_{table_name};\n")
        sql_script_path = sql_script.name

    command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -f {sql_script_path}"

    try:
        print(f"\tExecuting : {command}")
        result = subprocess.run(command, shell=True, env={'PGPASSWORD': db_params['password']}, capture_output=True, text=True)

        # Log any stdout or stderr for debugging
        print(f"\tStd out: {result.stdout}  <----")
        print(f"\tStd err: {result.stderr}  <---")

    except Exception as e:
        print(f"\t{Style.RED}Error importing CSV for {table_name}: error is  {e}{Style.RESET}")
        return 0
    finally:
        # Clean up temporary file
        os.remove(sql_script_path)

    # Get new row count and calculate the number of rows inserted
    new_count = get_row_count(table_name, db_params, schema)
    rows_inserted = new_count - initial_count
    print(f"\tFor table {Style.YELLOW}{table_name} inserted {Style.RESET}{Style.GREEN}{rows_inserted}{Style.RESET}")
    return rows_inserted,initial_count
    
def run_psql_command(command, db_params):
    try:
        # Pass the password through environment variables for security
        env_vars = {'PGPASSWORD': db_params['password']}
        completed_process = subprocess.run(command, shell=True, env=env_vars, text=True, capture_output=True, check=True)
        output_lines = completed_process.stdout.splitlines()

        # Debugging: print all output lines
        print("\tOutput from psql command:")
        for line in output_lines:
            print(f"{Style.GREEN}{line}{Style.RESET}")

        # Assuming the actual value is in the third line of the output
        if len(output_lines) >= 3:
            return output_lines[2].strip(), None
        else:
            return "Expected output not found", None

    except subprocess.CalledProcessError as e:
        return None, e.stderr.strip()
# Function to set and then immediately check the session_replication_role
def set_and_check_replication_role(db_params):
    set_role_command = f"psql -h {db_params['host']} -p {db_params['port']} -U {db_params['user']} -d {db_params['dbname']} -c \"SELECT * FROM aiven_extras.session_replication_role('replica');\""
    role_after_setting, error = run_psql_command(set_role_command, db_params)
    if role_after_setting is not None:
        print(f"{Style.YELLOW}Session_replication_role after setting to replica: {Style.RESET}{Style.GREEN}{role_after_setting}{Style.RESET}")
    else:
        print(f"{Style.RED}Error: {error}{Style.RESET}")

# Calling the function

# Main function to execute the import process
def main():
    table_log_file = '/app/outputs/tables_log.txt'  # Adjusted path
    db_params = {
        'host': os.environ.get('DB_HOST', 'localhost'),
        'dbname': os.environ.get('DB_NAME', 'defaultdb'),
        'user': os.environ.get('DB_USER', 'default'),
        'password': os.environ.get('DB_PASSWORD', 'password'),
        'port': os.environ.get('DB_PORT', 'default_port')
    }
    schema = os.getenv('DB_SCHEMA', 'public')
    # Execute PostgreSQL configuration commands
    # Check the current session_replication_role
    set_and_check_replication_role(db_params)    

    # Dictionary to hold the count of rows imported for each table
    imported_rows_count = {}
    export_filename = f"/app/outputs/imported_rows_per_table_{db_params['dbname']}_{db_params['user']}_{schema}.txt"
    table_names = extract_table_names(table_log_file)
    for table_name in table_names:
        print(f"Executing import for table {Style.YELLOW}{table_name}{Style.RESET}")
        rows_inserted, initial_count  = import_csv_to_db(table_name, db_params, schema)
        imported_rows_count[table_name] = (initial_count, rows_inserted)
    # Read the file and populate table_dict
    # Initialize an empty dictionary to hold table names and counts from the file
    table_dict = {}
    outputfile = "/app/outputs/table_line_counts.txt"
    with open(outputfile, 'r') as file:
        for line in file:
            table_name, count = line.strip().split(': ')
            table_dict[table_name] = int(count)
    # Write the row counts to a new file
    with open(export_filename, 'w') as f:
        for table, counts in imported_rows_count.items():
            initial, inserted = counts
            f.write(f"Table {table}\n Initial count {initial} rows before import\nInserted {inserted} rows\nData READ to migrate {table_dict[table]}\n\n")
# Execute the main function
if __name__ == "__main__":
    main()

CSV file for psp.app_deposits not found. Skipping...
CSV file for psp.app_brands not found. Skipping...
CSV file for psp.app_availability_rules not found. Skipping...
