In [22]:
import os
import requests
from IPython.display import Markdown, display, update_display
from openai import OpenAI
from google.colab import drive
from huggingface_hub import login
from google.colab import userdata
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
import torch

from functools import lru_cache
from diffusers import StableDiffusionPipeline
import gradio as gr

In [23]:
hf_token = userdata.get('HF_TOKEN')
login(hf_token, add_to_git_credential=True)

In [24]:
TEXT_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
IMAGE_MODEL_ID = "runwayml/stable-diffusion-v1-5"

FORMAT_RULES = {
    "JSON": "Return a JSON array containing records_count objects with consistent fields tailored to the context.",
    "CSV": "Return a CSV document with a header row and records_count data rows aligned to the context.",
    "Raw Text": "Return records_count prose entries separated by blank lines that reflect the context.",
    "Code": "Return records_count code snippets grouped in a single fenced block that models the context."
}

@lru_cache(maxsize=2)
def load_text_components(use_quant: bool):
    tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if use_quant:
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4"
        )
        model = AutoModelForCausalLM.from_pretrained(
            TEXT_MODEL_ID,
            device_map="auto",
            quantization_config=quant_config,
            trust_remote_code=True
        )
    else:
        kwargs = {"trust_remote_code": True}
        kwargs["device_map"] = "auto"
        kwargs["torch_dtype"] = torch.float16

        model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_ID, **kwargs)

    model.eval()
    return tokenizer, model

def build_text_messages(style: str, context: str, return_format: str, record_count: int):
    context_value = context.strip() if context else "general purpose scenario"
    style_value = style.strip() if style else "Balanced"
    directive = FORMAT_RULES[return_format]
    system_prompt = "You generate synthetic datasets that are high quality, diverse, and free of personally identifiable information. " + directive + " Ensure outputs are consistent in structure, imaginative in content, and avoid explanations."
    user_prompt = f"Context: {context_value}\nStyle: {style_value}\nRecords: {record_count}\nOutput style: {return_format}"
    return [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

def generate_text_data(style: str, context: str, return_format: str, quantize: bool, record_count: int):
    tokenizer, model = load_text_components(bool(quantize))
    messages = build_text_messages(style, context, return_format, int(record_count))
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
    inputs = inputs.to("cuda")
    attention_mask = torch.ones_like(inputs)
    with torch.inference_mode():
        generated = model.generate(
            input_ids=inputs,
            attention_mask=attention_mask,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.05,
            do_sample=True
        )

    output_ids = generated[:, inputs.shape[-1]:]
    text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    return text.strip()

@lru_cache(maxsize=1)
def load_image_pipeline():
    pipeline = StableDiffusionPipeline.from_pretrained(IMAGE_MODEL_ID, torch_dtype=torch.float16)
    pipeline = pipeline.to("cuda")
    return pipeline

def generate_image_data(style: str, context: str, image_prompt: str, image_count: int):
    pipeline = load_image_pipeline()
    parts = []
    if image_prompt:
        parts.append(image_prompt.strip())

    if context:
        parts.append(context.strip())

    base = ", ".join([p for p in parts if p])
    if not base:
        base = "Synthetic data concept visualization"

    prompt = f"{base}, {style.lower()} style"
    images = pipeline(prompt, num_images_per_prompt=int(image_count), guidance_scale=7.0, num_inference_steps=30).images
    return images

def run_generation(data_type: str, style: str, context: str, return_format: str, quantize: bool, image_prompt: str, record_count: int, image_count: int):
    if data_type == "Text":
        text = generate_text_data(style, context, return_format, quantize, record_count)
        return gr.update(value=text, visible=True), gr.update(value=[], visible=False)

    images = generate_image_data(style, context, image_prompt, image_count)
    return gr.update(value="", visible=False), gr.update(value=images, visible=True)

def toggle_inputs(data_type: str):
    if data_type == "Text":
        return (
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(value="", visible=True),
            gr.update(value=[], visible=False)
        )
    return (
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=True),
        gr.update(visible=False),
        gr.update(visible=True),
        gr.update(value="", visible=False),
        gr.update(value=[], visible=True)
)

In [25]:
with gr.Blocks(title="Synthetic Data Generator") as demo:
    gr.Markdown("## Synthetic Data Generator")
    with gr.Row():
        data_type = gr.Radio(["Text", "Image"], label="Type", value="Text")
        style = gr.Dropdown(["Concise", "Detailed", "Narrative", "Technical", "Tabular"], label="Style", value="Detailed")

    context_input = gr.Textbox(label="Context", lines=4, placeholder="Describe the entities, attributes, and purpose of the dataset.")
    return_format = gr.Dropdown(["JSON", "CSV", "Raw Text", "Code"], label="Return Format", value="JSON")
    quantize = gr.Checkbox(label="Quantize", value=False)
    record_count = gr.Slider(1, 20, value=5, step=1, label="Records")
    image_prompt = gr.Textbox(label="Image Prompt", lines=2, visible=False, placeholder="Detail the visual you want to synthesize.")
    image_count = gr.Slider(1, 4, value=1, step=1, label="Images", visible=False)
    generate_button = gr.Button("Generate")
    text_output = gr.Textbox(label="Text Output", lines=12)
    image_output = gr.Gallery(label="Generated Images", visible=False, columns=2, rows=1)
    data_type.change(
        toggle_inputs,
        inputs=data_type,
        outputs=[return_format, quantize, image_prompt, record_count, image_count, text_output, image_output]
    )
    generate_button.click(
        run_generation,
        inputs=[data_type, style, context_input, return_format, quantize, image_prompt, record_count, image_count],
        outputs=[text_output, image_output]
    )


demo.launch(debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://b5fd391afd63f4968c.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


  0%|          | 0/30 [00:00<?, ?it/s]

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7862 <> https://b5fd391afd63f4968c.gradio.live


