In [None]:
import os
import json
import pandas as pd
import random
import re
import subprocess
import pyarrow as pa
from typing import List
import openai
import anthropic
from dotenv import load_dotenv
import gradio as gr

In [None]:
# load API
load_dotenv(override=True)

In [None]:
# --- Schema Definition ---
SCHEMA = [
    ("Team", "TEXT", '"Toronto Raptors"'),
    ("NAME", "TEXT", '"Otto Porter Jr."'),
    ("Jersey", "TEXT", '"10", or "NA" if null'),
    ("POS", "TEXT", 'One of ["PF","SF","G","C","SG","F","PG"]'),
    ("AGE", "INT", 'integer age in years, e.g., 22'),
    ("HT", "TEXT", '`6\' 7"` or `6\' 10"`'),
    ("WT", "TEXT", '"232 lbs"'),
    ("COLLEGE", "TEXT", '"Michigan", or "--" if null'),
    ("SALARY", "TEXT", '"$9,945,830", or "--" if null')
]

In [None]:
# Default schema text for the textbox
DEFAULT_SCHEMA_TEXT = "\n".join([f"{i+1}. {col[0]} ({col[1]}) Example: {col[2]}" for i, col in enumerate(SCHEMA)])

In [None]:
# Available models
MODELS = [
    "gpt-4o",
    "claude-3-5-haiku-20241022", 
    "ollama:llama3.2:latest"
]

In [None]:
# Available file formats
FILE_FORMATS = [".csv", ".tsv", ".jsonl", ".parquet", ".arrow"]

In [None]:
def get_prompt(n: int, schema_text: str, system_prompt: str) -> str:
    prompt = f"""
{system_prompt}

Generate {n} rows of realistic basketball player data in JSONL format, each line a JSON object with the following fields:

{schema_text}

Do NOT repeat column values from one row to another.

Only output valid JSONL.
"""
    return prompt.strip()

In [None]:
# --- LLM Interface ---
def query_model(prompt: str, model: str = "gpt-4o") -> List[dict]:
    """Call OpenAI, Claude, or Ollama"""
    try:
        if model.lower().startswith("gpt"):
            client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
            response = client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.7
            )
            content = response.choices[0].message.content

        elif model.lower().startswith("claude"):
            client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
            response = client.messages.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=4000,
                temperature=0.7
            )
            content = response.content[0].text

        elif model.lower().startswith("ollama:"):
            ollama_model = model.split(":")[1]
            result = subprocess.run(
                ["ollama", "run", ollama_model],
                input=prompt,
                text=True,
                capture_output=True
            )
            if result.returncode != 0:
                raise Exception(f"Ollama error: {result.stderr}")
            content = result.stdout
        else:
            raise ValueError("Unsupported model. Use 'gpt-4.1-mini', 'claude-3-5-haiku-20241022', or 'ollama:llama3.2:latest'")

        # Parse JSONL output
        lines = [line.strip() for line in content.strip().splitlines() if line.strip().startswith("{")]
        return [json.loads(line) for line in lines]
    
    except Exception as e:
        raise Exception(f"Model query failed: {str(e)}")

In [None]:
# --- Output Formatter ---
def save_dataset(records: List[dict], file_format: str, filename: str):
    df = pd.DataFrame(records)
    if file_format == ".csv":
        df.to_csv(filename, index=False)
    elif file_format == ".tsv":
        df.to_csv(filename, sep="\t", index=False)
    elif file_format == ".jsonl":
        with open(filename, "w") as f:
            for record in records:
                f.write(json.dumps(record) + "\n")
    elif file_format == ".parquet":
        df.to_parquet(filename, engine="pyarrow", index=False)
    elif file_format == ".arrow":
        table = pa.Table.from_pandas(df)
        with pa.OSFile(filename, "wb") as sink:
            with pa.ipc.new_file(sink, table.schema) as writer:
                writer.write(table)
    else:
        raise ValueError("Unsupported file format")

In [None]:
# --- Main Generation Function ---
def generate_dataset(schema_text, system_prompt, model, nr_records, file_format, save_as):
    try:
        # Validation
        if nr_records <= 10:
            return "❌ Error: Nr_records must be greater than 10.", None
        
        if file_format not in FILE_FORMATS:
            return "❌ Error: Invalid file format specified.", None
        
        if not save_as or save_as.strip() == "":
            save_as = f"basketball_dataset{file_format}"
        elif not save_as.endswith(file_format):
            save_as = save_as + file_format
        
        # Generate prompt
        prompt = get_prompt(nr_records, schema_text, system_prompt)
        
        # Query model
        records = query_model(prompt, model=model)
        
        if not records:
            return "❌ Error: No valid records generated from the model.", None
        
        # Save dataset
        save_dataset(records, file_format, save_as)
        
        # Create preview
        df = pd.DataFrame(records)
        preview = df.head(10)  # Show first 10 rows
        
        success_message = f"✅ Dataset generated successfully!\n📁 Saved to: {save_as}\n📊 Generated {len(records)} records"
        
        return success_message, preview
    
    except Exception as e:
        return f"❌ Error: {str(e)}", None

In [None]:
# --- Gradio Interface ---
def create_interface():
    with gr.Blocks(title="Dataset Generator", theme=gr.themes.Soft()) as interface:
        gr.Markdown("# Dataset Generator")
        gr.Markdown("Generate realistic datasets using AI models")
        
        with gr.Row():
            with gr.Column(scale=2):
                schema_input = gr.Textbox(
                    label="Schema",
                    value=DEFAULT_SCHEMA_TEXT,
                    lines=15,
                    placeholder="Define your dataset schema here..."
                )
                
                system_prompt_input = gr.Textbox(
                    label="Prompt",
                    value="You are a helpful assistant that generates realistic basketball player data.",
                    lines=1,
                    placeholder="Enter system prompt for the model..."
                )
                
                with gr.Row():
                    model_dropdown = gr.Dropdown(
                        label="Model",
                        choices=MODELS,
                        value=MODELS[1],  # Default to Claude
                        interactive=True
                    )
                    
                    nr_records_input = gr.Number(
                        label="Nr. records",
                        value=25,
                        minimum=11,
                        maximum=1000,
                        step=1
                    )
                
                with gr.Row():
                    file_format_dropdown = gr.Dropdown(
                        label="File format",
                        choices=FILE_FORMATS,
                        value=".csv",
                        interactive=True
                    )
                    
                    save_as_input = gr.Textbox(
                        label="Save as",
                        value="basketball_dataset",
                        placeholder="Enter filename (extension will be added automatically)"
                    )
                
                generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
            
            with gr.Column(scale=1):
                output_status = gr.Textbox(
                    label="Status",
                    lines=4,
                    interactive=False
                )
                
                output_preview = gr.Dataframe(
                    label="Preview (First 10 rows)",
                    interactive=False,
                    wrap=True
                )
        
        # Connect the generate button
        generate_btn.click(
            fn=generate_dataset,
            inputs=[
                schema_input,
                system_prompt_input, 
                model_dropdown,
                nr_records_input,
                file_format_dropdown,
                save_as_input
            ],
            outputs=[output_status, output_preview]
        )
        
        gr.Markdown("""
        ### 📝 Instructions:
        1. **Schema**: Define the structure of your dataset (pre-filled with basketball player schema)
        2. **Prompt**: System prompt to guide the AI model
        3. **Model**: Choose between GPT, Claude, or Ollama models
        4. **Nr. records**: Number of records to generate (minimum 11)
        5. **File format**: Choose output format (.csv, .tsv, .jsonl, .parquet, .arrow)
        6. **Save as**: Filename (extension added automatically)
        7. Click **Generate** to create your dataset
        
        ### 🔧 Requirements:
        - Set up your API keys in `.env` file (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`)
        - For Ollama models, ensure Ollama is installed and running locally
        """)
    
    return interface

In [None]:
interface = create_interface()
interface.launch(inbrowser=True)