# ✨ Coherent Data Generator

## In real life, data has meaning, relationships, etc., and this is where this tool shines.

Dependencies between fields are detected, and coherent data is generated.
Example:
When asked to generate data with **Ghana** cited as the context, fields like `name`, `food`, etc., will be Ghanaian. Fields such as phone number will have the appropriate prefix of `+233`, etc.

This is better than Faker.

## Steps
Schema -> Generate Data

Schema Sources: 
- Use the guided schema builder
- Bring your own schema from an SQL Data Definition Language (DDL)
- Prompting
- Providing a domain to an old hat to define features for a dataset

In [None]:
import json

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import pandas as pd

from pydantic import BaseModel, Field
from IPython.display import display, Markdown

In [None]:
model_id = "Qwen/Qwen3-4B-Instruct-2507"

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
print(f'Device: {device}')

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype="auto",
    device_map="auto"
)

## Schema Definitions

In [None]:
# This is for future use where errors in SQL DDL statements can be fixed if the
# specifies that from the UI
class SQLValidationResult(BaseModel):
  is_valid: bool
  is_fixable: bool
  reason: str = Field(default='', description='validation failure reason')


class FieldDescriptor(BaseModel):
    name: str = Field(..., description='Name of the field')
    data_type: str = Field(..., description='Type of the field')
    nullable: bool
    description: str = Field(..., description='Description of the field')


class Schema(BaseModel):
    name: str = Field(..., description='Name of the schema')
    fields: list[FieldDescriptor] = Field(..., description='List of fields in the schema')

## LLM Interactions

### Generate Content from LLM

In [None]:
def generate(messages: list[dict[str, str]], temperature: float = 0.1) -> any:
  text = tokenizer.apply_chat_template(
      messages,
      tokenize=False,
      add_generation_prompt=True,
  )
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

  generated_ids = model.generate(
      **model_inputs,
      max_new_tokens=16384,
      temperature=temperature
  )

  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
  content = tokenizer.decode(output_ids, skip_special_tokens=True)

  return content

### Generate Data Given A Valid Schema

In [None]:
def generate_data(schema: str, context: str = '', num_records: int = 5):
  system_prompt = f'''
  You are synthetic data generator, you generate data based on the given schema
  specific JSON structure.
  When a context is provided, intelligently use that to drive the field generation.

  Example:
  If Africa is given at the context, fields like name, first_name, last_name, etc.
  that can be derived from Africa will be generated.

  If no context is provided, generate data randomly.

  Output an array of JSON objects.
  '''

  prompt = f'''
  Generate {num_records}:

  Schema:
  {schema}

  Context:
  {context}
  '''

  messages = [
      {'role': 'system', 'content': system_prompt},
      {"role": "user", "content": prompt}
  ]

  return generate(messages)

### SQL

In [None]:
def sql_validator(ddl: str):
  system_prompt = '''
  You are an SQL validator, your task is to validate if the given SQL is valid or not.
  ONLY return a binary response of 1 and 0. Where 1=valid and 0 = not valid.
  '''
  prompt = f'Validate: {ddl}'

  messages = [
      {'role': 'system', 'content': system_prompt},
      {"role": "user", "content": prompt}
  ]

  return generate(messages)


# Future work, this will fix any errors in the SQL DDL statement provided it is
# fixable.
def sql_fixer(ddl: str):
  pass


def parse_ddl(ddl: str):
  system_prompt = f'''
  You are an SQL analyzer, your task is to extract column information to a
  specific JSON structure.

  The output must comform to the following JSON schema:
  {Schema.model_json_schema()}
  '''
  prompt = f'Generate schema for: {ddl}'

  messages = [
      {'role': 'system', 'content': system_prompt},
      {"role": "user", "content": prompt}
  ]

  return generate(messages)

### Data Scientist

Just give it a domain and you will be amazed the features will give you.

In [None]:
def create_domain_schema(domain: str):
  system_prompt = f'''
  You are an expert Data Scientist tasked to describe features for a dataset
  aspiring data scientists in a chosen domain.

  Follow these steps EXACTLY:
  **Define 6–10 features** for the given domain. Include:
   - At least 2 numerical features
   - At least 2 categorical features
   - 1 boolean or binary feature
   - 1 timestamp or date feature
   - Realistic dependencies (e.g., "if loan_amount > 50000, credit_score should be high")

  Populate your response into the JSON schema below. Strictly out **JSON**
  {Schema.model_json_schema()}
  '''
  prompt = f'Describe the data point. Domain: {domain}'

  messages = [
      {'role': 'system', 'content': system_prompt},
      {"role": "user", "content": prompt}
  ]

  return generate(messages)


# TODO: Use Gradion Examples to make it easier for the loading of different statements
sql = '''
CREATE TABLE users (
    id BIGINT PRIMARY KEY,
    name VARCHAR(100) NOT NULL,
    email TEXT,
    gender ENUM('F', 'M'),
    country VARCHAR(100),
    mobile_number VARCHAR(100),
    created_at TIMESTAMP DEFAULT NOW()
);
'''

In [None]:
print(f'{model.get_memory_footprint() / 1e9:, .2f} GB')

## Export Functions

In [None]:
from enum import StrEnum


class ExportFormat(StrEnum):
    CSV = 'CSV'
    JSON = 'JSON'
    Excel = 'Excel'
    Parquet = 'Parquet'
    TSV = 'TSV'
    HTML = 'HTML'
    Markdown = 'Markdown'
    SQL = 'SQL'


def export_data(df, format_type):
    if df is None or df.empty:
        return None

    try:
        if format_type == ExportFormat.CSV:
            output = io.StringIO()
            df.to_csv(output, index=False)
            return output.getvalue()

        elif format_type == ExportFormat.JSON:
            return df.to_json(orient='records', indent=2)

        elif format_type == ExportFormat.Excel:
            output = io.BytesIO()
            df.to_excel(output, index=False, engine='openpyxl')
            return output.getvalue()

        elif format_type == ExportFormat.Parquet:
            output = io.BytesIO()
            df.to_parquet(output, index=False)
            return output.getvalue()

        elif format_type == ExportFormat.TSV:
            output = io.StringIO()
            df.to_csv(output, sep='\t', index=False)
            return output.getvalue()

        elif format_type == ExportFormat.HTML:
            return df.to_html(index=False)

        elif format_type == ExportFormat.Markdown:
            return df.to_markdown(index=False)

        elif format_type == ExportFormat.SQL:
            from sqlalchemy import create_engine
            engine = create_engine('sqlite:///:memory:')
            table = 'users' # TODO: fix this

            df.to_sql(table, con=engine, index=False)
            connection = engine.raw_connection()
            sql_statements = list(connection.iterdump())
            sql_output_string = "\n".join(sql_statements)
            connection.close()

            return sql_output_string

    except Exception as e:
        print(f"Export error: {str(e)}")
        return None


def prepare_download(df, format_type):
  if df is None:
      return None

  content = export_data(df, format_type)
  if content is None:
      return None

  extensions = {
      ExportFormat.CSV: '.csv',
      ExportFormat.JSON: '.json',
      ExportFormat.Excel: '.xlsx',
      ExportFormat.Parquet: '.parquet',
      ExportFormat.TSV: '.tsv',
      ExportFormat.HTML: '.html',
      ExportFormat.Markdown: '.md',
      ExportFormat.SQL: '.sql',
  }

  filename = f'generated_data{extensions.get(format_type, ".txt")}'

  is_binary_format = format_type in [ExportFormat.Excel, ExportFormat.Parquet]
  mode = 'w+b' if is_binary_format else 'w'

  import tempfile
  with tempfile.NamedTemporaryFile(mode=mode, delete=False, suffix=extensions[format_type]) as tmp:
      tmp.write(content)
      tmp.flush()
      return tmp.name

## Gradio UI

In [None]:
import gradio as gr
from pydantic import BaseModel, Field
import json
import pandas as pd
import io

DATA_TYPES = ['string', 'integer', 'float', 'boolean', 'date', 'datetime', 'array', 'object']

def generate_from_sql(sql: str, context: str, num_records: int = 10):
    try:
        print(f'SQL: {sql}')
        schema = parse_ddl(sql)
        data = generate_data(schema, context, num_records)

        data = json.loads(data)
        df = pd.DataFrame(data)

        return schema, df
    except Exception as e:
        return f'Error: {str(e)}', None


def generate_from_data_scientist(domain: str, context: str, num_records: int = 10):
    try:
        print(f'Domain: {domain}')
        schema = create_domain_schema(domain)
        print(schema)
        data = generate_data(schema, context, num_records)
        data = json.loads(data)
        df = pd.DataFrame(data)

        return schema, df
    except Exception as e:
        return f'Error: {str(e)}', None


def generate_from_dynamic_fields(schema_name, context: str, num_fields, num_records: int, *field_values):
    try:
        fields = []
        for i in range(num_fields):
            idx = i * 4
            if idx + 3 < len(field_values):
                name = field_values[idx]
                dtype = field_values[idx + 1]
                nullable = field_values[idx + 2]
                desc = field_values[idx + 3]

                if name and dtype:
                    fields.append(FieldDescriptor(
                        name=name,
                        data_type=dtype,
                        nullable=nullable if nullable is not None else False,
                        description=desc if desc else ''
                    ))

        if not schema_name:
            return 'Error: Schema name is required', None

        if not fields:
            return 'Error: At least one field is required', None

        schema = Schema(name=schema_name, fields=fields)
        data = generate_data(schema.model_dump(), context , num_records)
        data = json.loads(data)
        df = pd.DataFrame(data)


        return json.dumps(schema.model_dump(), indent=2), df

    except Exception as e:
        return f'Error: {str(e)}', None



title='✨ Coherent Data Generator'

with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as ui:
    gr.Markdown(f'# {title}')
    gr.Markdown('Embrass the Coherent Data wins 🏆!')

    df_state = gr.State(value=None)

    with gr.Row():
        num_records_input = gr.Number(
            label='Number of Records to Generate',
            value=10,
            minimum=1,
            maximum=10000,
            step=1,
            precision=0
        )

        context_input = gr.Textbox(
            label='Context',
            placeholder='70% Ghana and 30% Nigeria data. Start ID generation from 200',
            lines=1
        )

    with gr.Tabs() as tabs:
        with gr.Tab('Manual Entry', id=0):
            schema_name_input = gr.Textbox(label='Schema Name', placeholder='Enter schema name')

            gr.Markdown('### Fields')

            num_fields_state = gr.State(3)

            with gr.Row():
                num_fields_slider = gr.Slider(
                    minimum=1,
                    maximum=20,
                    value=3,
                    step=1,
                    label='Number of Fields',
                    interactive=True
                )

            gr.HTML('''
            <div style="display: flex; gap: 8px; margin-bottom: 8px; font-weight: bold;">
                <div style="flex: 2;">Field Name</div>
                <div style="flex: 2;">Data Type</div>
                <div style="flex: 1;">Nullable</div>
                <div style="flex: 3;">Description</div>
            </div>
            ''')

            field_components = []
            row_components = []

            for i in range(20):
                with gr.Row(visible=(i < 3)) as row:
                    field_name = gr.Textbox(label='', container=False, scale=2)
                    data_type = gr.Dropdown(choices=DATA_TYPES, value='string', label='', container=False, scale=2)
                    nullable = gr.Checkbox(label='', container=False, scale=1)
                    description = gr.Textbox(label='', container=False, scale=3)

                    row_components.append(row)
                    field_components.extend([field_name, data_type, nullable, description])

            submit_btn = gr.Button('Generate', variant='primary')

            num_fields_slider.change(
                fn=lambda x: [gr.update(visible=(i < x)) for i in range(20)],
                inputs=[num_fields_slider],
                outputs=row_components
            )


        with gr.Tab('SQL', id=1):
            gr.Markdown('### Parse SQL DDL')
            ddl_input = gr.Code(
                value=sql,
                label='SQL DDL Statement',
                language='sql',
                lines=10
            )
            ddl_btn = gr.Button('Generate', variant='primary')


        with gr.Tab('>_ Prompt', id=2):
            gr.Markdown('### You are on your own here, so be creative 💡')
            prompt_input = gr.Textbox(
                label='Prompt',
                placeholder='Type your prompt',
                lines=10
            )
            prompt_btn = gr.Button('Generate', variant='primary')

        with gr.Tab('Data Scientist 🎩', id=3):
            gr.Markdown('### You are on your own here, so be creative 💡')
            domain_input = gr.Dropdown(
                label='Domain',
                choices=['E-commerce Customers', 'Hospital Patients', 'Loan Applications'],
                allow_custom_value=True
            )

            data_scientist_generate_btn = gr.Button('Generate', variant='primary')


    with gr.Accordion('Generated Schema', open=False):
      output = gr.Code(label='Schema (JSON)', language='json')

    gr.Markdown('## Generated Data')
    dataframe_output = gr.Dataframe(
        label='',
        interactive=False,
        wrap=True
    )

    gr.Markdown('### Export Data')
    with gr.Row():
        format_dropdown = gr.Dropdown(
            choices=[format.value for format in ExportFormat],
            value=ExportFormat.CSV,
            label='Export Format',
            scale=2
        )
        download_btn = gr.Button('Download', variant='secondary', scale=1)

    download_file = gr.File(label='Download File', visible=True)


    def _handle_result(result):
        if isinstance(result, tuple) and len(result) == 2:
            return result[0], result[1], result[1]
        return result[0], result[1], None


    def update_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values):
        result = generate_from_dynamic_fields(schema_name, context, num_fields, num_records, *field_values)
        return _handle_result(result)


    def update_from_sql(sql: str, context, num_records: int):
        result = generate_from_sql(sql, context, num_records)
        return _handle_result(result)


    def update_from_data_scientist(domain: str, context, num_records: int):
        result = generate_from_data_scientist(domain, context, num_records)
        return _handle_result(result)


    submit_btn.click(
        fn=update_from_dynamic_fields,
        inputs=[schema_name_input, context_input, num_fields_slider, num_records_input] + field_components,
        outputs=[output, dataframe_output, df_state]
    )

    ddl_btn.click(
        fn=update_from_sql,
        inputs=[ddl_input,  context_input, num_records_input],
        outputs=[output, dataframe_output, df_state]
    )

    data_scientist_generate_btn.click(
        fn=update_from_data_scientist,
        inputs=[domain_input,  context_input, num_records_input],
        outputs=[output, dataframe_output, df_state]
    )


    download_btn.click(
        fn=prepare_download,
        inputs=[df_state, format_dropdown],
        outputs=[download_file]
    )


ui.launch(debug=True)
