In [2]:


from snowflake.connector.pandas_tools import write_pandas
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import PromptTemplate
from  datetime import datetime
from langchain_text_splitters import RecursiveJsonSplitter
import snowflake.connector
import pandas as pd
import json
import csv
import os
from dotenv import load_dotenv
from pathlib import Path
from typing import Dict, List, Any
from azure.storage.filedatalake import DataLakeServiceClient
from io import StringIO

load_dotenv()


env_vars = {
        "SNOWFLAKE_USER": os.environ.get("SNOWFLAKE_USER"),
        "SNOWFLAKE_PASSWORD": os.environ.get("SNOWFLAKE_PASSWORD"),
        "SNOWFLAKE_ACCOUNT": os.environ.get("SNOWFLAKE_ACCOUNT"),
        "SNOWFLAKE_WAREHOUSE": os.environ.get("SNOWFLAKE_WAREHOUSE"),
        "SNOWFLAKE_DATABASE": os.environ.get("SNOWFLAKE_DATABASE"),
        "SNOWFLAKE_SCHEMA": os.environ.get("SNOWFLAKE_SCHEMA"),
        "AZURE_OPENAI_ENDPOINT": os.environ.get("AZURE_OPENAI_ENDPOINT"),
        "AZURE_OPENAI_4o_DEPLOYMENT_NAME": os.environ.get("AZURE_OPENAI_4o_DEPLOYMENT_NAME"),
        "AZURE_OPENAI_API_VERSION": os.environ.get("AZURE_OPENAI_API_VERSION"),
        "AZURE_OPENAI_API_KEY": os.environ.get("AZURE_OPENAI_API_KEY"),
    }

conn = snowflake.connector.connect(
        user=env_vars.get("SNOWFLAKE_USER"),
        password=env_vars.get("SNOWFLAKE_PASSWORD"),
        account=env_vars.get("SNOWFLAKE_ACCOUNT"),
        warehouse=env_vars.get("SNOWFLAKE_WAREHOUSE"),
        database=env_vars.get("SNOWFLAKE_DATABASE"),
        schema=env_vars.get("SNOWFLAKE_SCHEMA"),
    )

model = AzureChatOpenAI(
        azure_endpoint=env_vars.get("AZURE_OPENAI_ENDPOINT"),
        azure_deployment=env_vars.get("AZURE_OPENAI_4o_DEPLOYMENT_NAME"),
        openai_api_version=env_vars.get("AZURE_OPENAI_API_VERSION"),
        openai_api_key=env_vars.get("AZURE_OPENAI_API_KEY"),
    )

cursor = conn.cursor()
azure_storage_connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")

adls_client = DataLakeServiceClient.from_connection_string(azure_storage_connection_string)

cursor.execute("""
        SELECT table_name 
        FROM information_schema.tables
        WHERE table_schema = 'TEST' AND table_type = 'BASE TABLE'
    """)

tables =  cursor.fetchall()


table_names = [table[0] for table in tables]
# Initialize an empty dictionary to store data from all tables
all_data = {}
# Fetch data from all tables
for table_name in table_names:
    cursor.execute(f"SELECT * FROM {table_name}")  # Limit the rows for simplicity
    data2 = cursor.fetchall()
    df = pd.DataFrame(data2, columns=[col[0] for col in cursor.description])
    all_data[table_name] = df

    cursor.execute("""
    SELECT 
        TABLE_NAME, 
        COLUMN_NAME, 
        DATA_TYPE, 
        IS_NULLABLE, 
        COLUMN_DEFAULT 
    FROM 
        INFORMATION_SCHEMA.COLUMNS
    WHERE 
        TABLE_SCHEMA = 'TEST'
    ORDER BY table_name
""")

metadata = cursor.fetchall()
tables = {}

for table_name, column_name, data_type, is_nullable, _ in metadata:
    if table_name not in tables:
        tables[table_name] = []
    
    tables[table_name].append({
        "column_name": column_name,
        "data_type": data_type,
        "is_nullable": is_nullable
    })

json_data = json.dumps(tables, indent=4)
splitter = RecursiveJsonSplitter(max_chunk_size=300)
texts = splitter.split_text(json_data=json.loads(json_data))


syn_data = {}
prompt_template = """
    Generate 1 row of good and bad-quality data for each table based on the given metadata. 
    All data should strictly comply with constraints and data types of respective tables, 
    while bad data should simulate realistic yet invalid scenarios violating constraints like:
    1. Negative or illogical values (e.g., negative age or weight).
    2. Invalid or out-of-range dates (e.g., February 30, year > 9999).
    3. Nullability violations (e.g., null in non-nullable fields).
    4. Duplicate primary keys.
    5. Logical inconsistencies (e.g., start date after end date).
    6. missing values.
    
    
    Output format: 
    please DO NOT give invalid timestamp data.
    strictly Provide a JSON array format containing serializable data for each table.
    Provide the output in pure json JSON array format which I can parse as a json data to various platforms.
    Generate json serializable data.
    Dont provide any comments in between and any description.
    dont repeat table name inside the json data.
    Only json format data is allowed without any // comment in it.
    This is required format in which we require generated data.

    here is the input table Metadata: 
    {metadata}

    """
for i in range(10): #(len(tex)):

    # Create the prompt
    prompt = PromptTemplate(input_variables=["metadata"], template=prompt_template)
    formatted_prompt = prompt.format(metadata=texts[i])
    
    df1 = texts[i]
    metadata_dict1 = json.loads(df1)
    table_name = list(metadata_dict1.keys())[0]
    # print(formatted_prompt)
    response = model.invoke(formatted_prompt)

    synthetic_data = response.content.replace("```json", "").replace("```", "").strip()
    # Display the result
    syn_data[str(table_name)] = synthetic_data

    parsed_data = json.loads(json.dumps(syn_data))
    final_data = json.loads(parsed_data[table_name])

    syn_data[table_name] = final_data
   
try:
    data1 = json.dumps(syn_data)
except json.JSONDecodeError:
    print("Error: Failed to parse generated JSON data.")
    data1 = []
    


# Data in JSON format
data = json.loads(data1)

output_dir = "test_output1"
os.makedirs(output_dir, exist_ok=True)

# Process each table
for table_name, rows in data.items():
    if rows:  # Check if the table has data
        # Define output CSV file path
        output_csv_file = os.path.join(output_dir, f"{table_name}.csv")
        
        # Get column names from the first row
        column_names = rows[0].keys()

        # Write data to CSV file
        with open(output_csv_file, mode="w", newline="", encoding="utf-8") as file:
            writer = csv.DictWriter(file, fieldnames=column_names)
            
            # Write header and rows
            writer.writeheader()
            writer.writerows(rows)

        print(f"Table '{table_name}' saved to {output_csv_file}")
    else:
        print(f"Table '{table_name}' is empty. No file created.")


JSONDecodeError: Expecting property name enclosed in double quotes: line 33 column 51 (char 1024)