In [1]:
"""
This script provides an AI-driven solution for generating sample data based on
a provided database schema. It accepts either a .sql DDL file or a structured
.xlsx file as input. The script utilizes the Google Gemini API to analyze the
schema, infer data types and relationships, and create a comprehensive data
generation plan. This plan is then executed using the Faker library to produce
realistic, referentially-intact sample data in CSV format.

The architecture is object-oriented and optimized for scalability by using a
batch-processing and streaming approach to handle large datasets without
exhausting system memory.
"""



# --- Import required libraries ---
import pandas as pd                       
from faker import Faker                   
import uuid                               
import time                              
import random                            
import datetime                          
import json                               
import os                                 
import google.generativeai as genai       
from dotenv import load_dotenv            
import inspect                           
genai.configure(api_key="AIzaSyDQ8kPzvFKt59vF63arz_e2T82Z0E2pIzg")
# --- Configuration and Initialization ---
# Load environment variables from a .env file for secure API key management.
load_dotenv()
try:
    # Authenticate with the Google Generative AI service using API key from environment.
    GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
    genai.configure(api_key=GOOGLE_API_KEY)
except (KeyError, TypeError):
    # Terminate program if the API key is not found or not configured properly.
    print("ERROR: GOOGLE_API_KEY not found. Please create a .env file or set the environment variable.")
    exit()


class SchemaReader:
    """
    Handles the reading and initial parsing of schema definition files.
    Supports .sql and .xlsx formats.
    """
    @staticmethod
    def from_file(file_path):
        """
        Reads a schema file from the provided path, determines its type,
        and returns the raw content for processing by the AIPlanner.
        
        Args:
            file_path (str): The local path to the schema file.

        Returns:
            tuple[str, str]: A tuple containing the schema content as a string
                             and the identified file type ('sql' or 'excel_json').
        """
        # Inform the user that the schema file is being read
        print(f"\nStep 1: Reading schema file: '{os.path.basename(file_path)}'...")
        
        # Ensure the provided file exists before proceeding
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found at path: '{file_path}'.")

        # Case 1: If file is a SQL file
        if file_path.lower().endswith('.sql'):
            file_type = 'sql'
            with open(file_path, 'r') as f:              # Open and read entire file
                schema_content = f.read()
            return schema_content, file_type

        # Case 2: If file is an Excel file (.xlsx or .xls)
        elif file_path.lower().endswith(('.xlsx', '.xls')):
            file_type = 'excel_json'
            df = pd.read_excel(file_path).fillna('')     # Read Excel and replace NaN with empty strings
            schema_content = df.to_json(orient='records', indent=2)  # Convert to JSON string
            return schema_content, file_type

        # Case 3: Unsupported file types
        else:
            raise ValueError("Unsupported file format. Please provide a .sql or .xlsx file.")


class AIPlanner:
    """
    Interfaces with the Google Gemini API to generate a data creation plan
    based on a given schema.
    """
    def __init__(self, model_name="gemini-1.5-flash-latest"):
        """Initializes the AI model client with specified model name."""
        self.model = genai.GenerativeModel(model_name)

    def create_generation_plan(self, schema_definition, file_type):
        """
        Constructs a prompt with the schema, sends it to the Gemini API,
        and parses the JSON response into a structured generation plan.

        Args:
            schema_definition (str): The schema content from the SchemaReader.
            file_type (str): The type of schema provided ('sql' or 'excel_json').

        Returns:
            list[dict]: A list of dictionaries, where each dictionary represents
                        a table and its complete generation plan.
        """
        print("   ->  Consulting with Gemini to create a smart generation plan...")

        # Construct the AI prompt with clear instructions
        prompt = f"""
        You are an expert data architect specializing in the Python Faker library. Your task is to create a complete data generation plan from a schema.
        Analyze the schema to determine generation order and select the most realistic Faker METHOD for each column.

        IMPORTANT: The "provider" you choose MUST be a real, callable method on a Faker() instance (e.g., `fake.name()`, `fake.address()`, `fake.email()`).
        DO NOT use provider category names like 'internet', 'commerce', or 'finance'.
        - For an email column, the correct provider is 'email', NOT 'internet'.
        - For a price column, use 'pydecimal', NOT 'finance'.
        - For a product name, use 'word' or 'bs', NOT 'commerce'.
        - For a date between two points, use 'date_between', NOT 'date_between_dates'.

        Your output MUST be a single, valid JSON array. Each element of the array must be an object representing a a table. Each table object must have:
        1.  "table_name": The name of the table.
        2.  "generation_order": An integer indicating the order of generation (e.g., 1, 2, 3...).
        3.  "columns": A list of column objects.

        Each column object must have:
        1.  "column_name": The name of the column.
        2.  "is_primary_key": A boolean (true/false).
        3.  "foreign_key": A string referencing another table's primary key (e.g., "dim_customer.customer_id") or null.
        4.  "faker": An object with "provider" (a string representing a callable Faker method name) and "params" (a JSON object).

        Schema Definition ({file_type}):
        ---
        {schema_definition}
        ---

        Generate the complete JSON plan now. Do not include any text or formatting outside of the JSON array.
        """
        
        # Implements retry mechanism (max 3 attempts) in case of API failure
        for attempt in range(3):
            try:
                # Send prompt to Gemini API
                response = self.model.generate_content(prompt)
                # Clean up response (remove markdown formatting if present)
                json_text = response.text.strip().replace("```json", "").replace("```", "").strip()
                
                # Parse AI output into JSON
                parsed_json = json.loads(json_text)
                plan = []
                if isinstance(parsed_json, list):
                    plan = parsed_json
                elif isinstance(parsed_json, dict):
                    # Handle if AI wraps list inside a dictionary
                    for key, value in parsed_json.items():
                        if isinstance(value, list):
                            plan = value
                            break
                
                # Validate plan structure
                if plan and all('table_name' in t for t in plan):
                    print("   ->  Gemini has provided a valid generation plan.")
                    return plan
                else:
                    raise ValueError("The AI's plan did not conform to the expected structure.")
            except Exception as e:
                # Retry with exponential backoff if error occurs
                print(f"      ->  Gemini API attempt {attempt+1} failed: {e}. Retrying...")
                time.sleep(2**(attempt+1))

        # Fail after multiple retries
        raise Exception("Failed to retrieve a valid generation plan from the AI after multiple attempts.")


class PlanExecutor:
    """
    Executes the AI-generated plan to produce sample data. Manages data
    generation, referential integrity, and file output using a scalable
    batch-processing approach.
    """
    def __init__(self, plan):
        """Initializes the executor with the AI-generated plan."""
        self.plan = sorted(plan, key=lambda x: x['generation_order'])  # Ensure tables are processed in order
        self.primary_key_store = {}  # Stores generated PK values for enforcing uniqueness and FK mapping
        self.faker_instances = {}    # Cache of Faker instances per locale

    def _get_faker_instance(self, locale='en_GB'):
        """
        Retrieves or creates a cached Faker instance for a specific locale
        to optimize performance.
        """
        if locale not in self.faker_instances:
            self.faker_instances[locale] = Faker(locale)  # Create new instance if not already cached
        return self.faker_instances[locale]

    def _generate_one_value(self, provider_name, params, fake_instance):
        """
        Generates a single data value using the specified Faker provider and
        parameters. Includes pre-processing for special values and parameter validation.
        """
        try:
            # Special case: generate UUID
            if provider_name == 'uuid4':
                return str(uuid.uuid4())
            
            # Replace "today" string with today's date
            processed_params = params.copy()
            for key, value in processed_params.items():
                if isinstance(value, str) and value.lower() == 'today':
                    processed_params[key] = datetime.date.today()
            
            # Get Faker method dynamically
            provider_method = getattr(fake_instance, provider_name)
            
            # Ensure only valid parameters are passed to Faker method
            sig = inspect.signature(provider_method)
            valid_params = {k: v for k, v in processed_params.items() if k in sig.parameters}

            return provider_method(**valid_params)  # Generate value
        except Exception as e:
            # Log error and fallback value
            print(f"   -> Error: Data generation failed for provider '{provider_name}'. Details: {e}")
            return "generation_error"

    def execute(self, row_counts, batch_size=10000):
        """
        Iterates through the generation plan, generating data in batches to
        conserve memory. It maintains referential integrity and streams the
        output directly to CSV files.
        """
        print(f"\nStep 2: Executing generation plan (in batches of {batch_size})...")
        
        # Loop over each table rule in the generation plan
        for table_rule in self.plan:
            table_name = table_rule['table_name']
            num_rows = row_counts.get(table_name, 0)

            # Skip tables with zero rows requested
            if num_rows == 0:
                print(f"\nSkipping table '{table_name}' as 0 rows were requested.")
                continue

            print(f"\nProcessing table: '{table_name}' for {num_rows} rows...")
            
            output_filename = f"{table_name}_sample_data.csv"  # Output CSV file name
            is_first_batch = True  # Flag to include header only in first batch
            
            # Process data in smaller batches to prevent memory exhaustion
            for i in range(0, num_rows, batch_size):
                batch_data = []  # Temporary storage for current batch
                current_batch_size = min(batch_size, num_rows - i)  # Adjust final batch size
                
                # Generate rows in current batch
                for _ in range(current_batch_size):
                    row_data = {}
                    # Loop over columns in current table
                    for col_rule in table_rule['columns']:
                        col_name = col_rule['column_name']
                        fake = self._get_faker_instance(col_rule.get('faker', {}).get('locale', 'en_GB'))
                        
                        value = None
                        fk_def = col_rule.get('foreign_key')

                        # Case 1: Foreign key column - pick value from parent PK store
                        if fk_def and isinstance(fk_def, str) and '.' in fk_def:
                            ref_key = ".".join(fk_def.split('.'))
                            if ref_key in self.primary_key_store and self.primary_key_store[ref_key]:
                                value = random.choice(self.primary_key_store[ref_key])
                            else:
                                value = "FK_REFERENCE_NOT_FOUND"
                        else:
                            # Case 2: Normal faker column
                            provider = col_rule['faker']['provider']
                            params = col_rule['faker']['params']

                            # Handle invalid faker providers by defaulting to "word"
                            if not (hasattr(fake, provider) and callable(getattr(fake, provider))):
                                print(f"   ->  AI suggested an invalid provider '{provider}' for '{col_name}'. Using a safe default.")
                                provider = 'word'
                                params = {}

                            # Generate actual value
                            value = self._generate_one_value(provider, params, fake)

                        # If column is Primary Key, ensure uniqueness and store for FK references
                        if col_rule['is_primary_key']:
                            pk_key = f"{table_name}.{col_name}"
                            if pk_key not in self.primary_key_store:
                                self.primary_key_store[pk_key] = []
                            
                            # Guarantee uniqueness by regenerating if duplicate found
                            while value in self.primary_key_store[pk_key]:
                                 value = self._generate_one_value(provider, params, fake)
                            self.primary_key_store[pk_key].append(value)
                        
                        # Assign generated value to row
                        row_data[col_name] = value
                    # Append completed row to batch
                    batch_data.append(row_data)

                # Convert batch into DataFrame and write/append to CSV file
                output_df = pd.DataFrame(batch_data)
                if is_first_batch:
                    output_df.to_csv(output_filename, index=False, mode='w')  # Write with header
                    is_first_batch = False
                else:
                    output_df.to_csv(output_filename, index=False, mode='a', header=False)  # Append without header

                # Show progress update to user
                print(f"   -> Progress: {min(i + batch_size, num_rows)} / {num_rows} rows generated.")
            
            print(f"   -> Success. Data for '{table_name}' saved to '{output_filename}'.")


class DataGenerator:
    """
    Main class that orchestrates the data generation workflow by coordinating
    the SchemaReader, AIPlanner, and PlanExecutor.
    """
    def run(self):
        """
        Initializes and runs the end-to-end data generation process based on user input.
        """
        print("--- AI-Driven Data Generator Initialized ---")
        
        try:
            # Step 1: Ask user for schema file input path
            file_input = input("Enter the path to the .sql or .xlsx schema file: ").strip()
            
            # Step 2: Read schema using SchemaReader
            reader = SchemaReader()
            schema_content, file_type = reader.from_file(file_input)

            # Step 3: Generate data creation plan using AIPlanner
            planner = AIPlanner()
            ai_plan = planner.create_generation_plan(schema_content, file_type)

            # Step 4: Ask user for number of rows per table
            row_counts = {}
            print("\nPlease specify the number of rows for each table.")
            sorted_plan = sorted(ai_plan, key=lambda x: x['generation_order'])
            for table_rule in sorted_plan:
                table_name = table_rule['table_name']
                while True:
                    try:
                        rows = int(input(f" - Enter the total number of rows for table '{table_name}': "))
                        row_counts[table_name] = rows
                        break
                    except ValueError:
                        print("   Please enter a valid integer.")
            
            # Step 5: Ask user for batch size
            batch_size = int(input("\nEnter the batch size for processing (e.g., 10000 for large datasets): "))

            # Step 6: Execute plan with PlanExecutor
            executor = PlanExecutor(ai_plan)
            executor.execute(row_counts, batch_size)

        except (FileNotFoundError, ValueError, Exception) as e:
            # Handle and display errors gracefully
            print(f"\nAn error occurred during execution: {e}")
        
        print("\n--- Data Generation Process Complete ---")


# --- Main Entry Point ---
# This block executes when the script is run directly.
if __name__ == "__main__":
    generator = DataGenerator()   # Create an instance of DataGenerator
    generator.run()               # Run the main process


ERROR: GOOGLE_API_KEY not found. Please create a .env file or set the environment variable.
--- AI-Driven Data Generator Initialized ---


  from .autonotebook import tqdm as notebook_tqdm


Enter the path to the .sql or .xlsx schema file:  C:\\Users\\Jayavarshini.v\\Downloads\\new_schema.sql



Step 1: Reading schema file: 'new_schema.sql'...
   ->  Consulting with Gemini to create a smart generation plan...
   ->  Gemini has provided a valid generation plan.

Please specify the number of rows for each table.


 - Enter the total number of rows for table 'dim_offices':  20
 - Enter the total number of rows for table 'dim_employees':  30
 - Enter the total number of rows for table 'dim_products':  40
 - Enter the total number of rows for table 'dim_customers':  35
 - Enter the total number of rows for table 'dim_projects':  35
 - Enter the total number of rows for table 'junc_project_assignments':  50
 - Enter the total number of rows for table 'fact_sales':  50

Enter the batch size for processing (e.g., 10000 for large datasets):  30



Step 2: Executing generation plan (in batches of 30)...

Processing table: 'dim_offices' for 20 rows...
   -> Progress: 20 / 20 rows generated.
   -> Success. Data for 'dim_offices' saved to 'dim_offices_sample_data.csv'.

Processing table: 'dim_employees' for 30 rows...
   -> Progress: 30 / 30 rows generated.
   -> Success. Data for 'dim_employees' saved to 'dim_employees_sample_data.csv'.

Processing table: 'dim_products' for 40 rows...
   -> Progress: 30 / 40 rows generated.
   -> Progress: 40 / 40 rows generated.
   -> Success. Data for 'dim_products' saved to 'dim_products_sample_data.csv'.

Processing table: 'dim_customers' for 35 rows...
   -> Progress: 30 / 35 rows generated.
   -> Progress: 35 / 35 rows generated.
   -> Success. Data for 'dim_customers' saved to 'dim_customers_sample_data.csv'.

Processing table: 'dim_projects' for 35 rows...
   -> Progress: 30 / 35 rows generated.
   -> Progress: 35 / 35 rows generated.
   -> Success. Data for 'dim_projects' saved to 'dim_pr