# ArtKid ControlNet Playground

Upload a sketch/photo, tweak model and parameters, run generation with Replicate, and preview results.

Instructions:
- Get a Replicate API token and paste it when prompted.
- Pick a model (default: Scribble) or paste a model version manually.
- Upload an image and click Generate.

Notes:
- This notebook uses Replicate Files API to upload the image, then creates a prediction and polls for output.
- Default model pinned: `jagilley/controlnet-scribble`.


In [None]:
import io
import json
import os
import time
import urllib.parse
import requests
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple

import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image

REPLICATE_API = "https://api.replicate.com/v1"
DEFAULT_MODEL_VERSION = "435061a1b5a4c1e26740464bf786efdfa9cb3a3ac488595a2de23e143fdb0117"  # jagilley/controlnet-scribble

@dataclass
class Prediction:
    id: str
    status: str
    output: Optional[Any]
    error: Optional[str]


def upload_file_to_replicate(file_bytes: bytes, filename: str, token: str) -> str:
    url = f"{REPLICATE_API}/files"
    headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/octet-stream"}
    resp = requests.post(url, data=file_bytes, headers=headers, timeout=60)
    if resp.status_code != 200:
        raise RuntimeError(f"Upload failed: {resp.status_code} {resp.text}")
    data = resp.json()
    return data["url"]


def create_prediction(version: str, input_payload: Dict[str, Any], token: str) -> Prediction:
    url = f"{REPLICATE_API}/predictions"
    headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
    resp = requests.post(url, headers=headers, json={"version": version, "input": input_payload}, timeout=60)
    if resp.status_code >= 300:
        raise RuntimeError(f"Prediction create failed: {resp.status_code} {resp.text}")
    data = resp.json()
    return Prediction(id=data["id"], status=data["status"], output=data.get("output"), error=data.get("error"))


def get_prediction(pred_id: str, token: str) -> Prediction:
    url = f"{REPLICATE_API}/predictions/{pred_id}"
    headers = {"Authorization": f"Bearer {token}"}
    resp = requests.get(url, headers=headers, timeout=60)
    if resp.status_code >= 300:
        raise RuntimeError(f"Prediction get failed: {resp.status_code} {resp.text}")
    data = resp.json()
    return Prediction(id=data["id"], status=data["status"], output=data.get("output"), error=data.get("error"))


def run_until_done(pred: Prediction, token: str, timeout_s: int = 120) -> Prediction:
    start = time.time()
    cur = pred
    while cur.status in ("starting", "processing"):
        if time.time() - start > timeout_s:
            raise TimeoutError("Prediction timeout")
        time.sleep(1.8)
        cur = get_prediction(cur.id, token)
    return cur


def fetch_image(url: str) -> Image.Image:
    r = requests.get(url, timeout=60)
    r.raise_for_status()
    return Image.open(io.BytesIO(r.content)).convert("RGB")


# Widgets
replicate_token = widgets.Password(description="Token", placeholder="Replicate API token", layout=widgets.Layout(width="450px"))
model_version = widgets.Text(value=DEFAULT_MODEL_VERSION, description="Version", layout=widgets.Layout(width="600px"))

prompt = widgets.Text(value="friendly watercolor, children illustration, soft colors", description="Prompt", layout=widgets.Layout(width="600px"))
image_resolution = widgets.Dropdown(options=["512", "768", "1024"], value="768", description="Resolution")
ddim_steps = widgets.IntSlider(value=28, min=10, max=50, step=1, description="Steps")
scale = widgets.FloatSlider(value=8.0, min=1.0, max=15.0, step=0.5, description="Scale")
seed = widgets.IntText(value=42, description="Seed")
a_prompt = widgets.Text(value="best quality, highly detailed, watercolor children illustration, soft colors", description="A+ prompt", layout=widgets.Layout(width="600px"))
n_prompt = widgets.Text(value="lowres, bad anatomy, extra fingers, cropped, worst quality, low quality, jpeg artifacts", description="Negative", layout=widgets.Layout(width="600px"))

uploader = widgets.FileUpload(accept="image/*", multiple=False)
run_btn = widgets.Button(description="Generate", button_style="primary")

status_out = widgets.Output()
input_img_out = widgets.Output()
result_out = widgets.Output()


def on_generate_click(_):
    with status_out:
        clear_output()
        if not replicate_token.value:
            print("Provide Replicate token")
            return
        if len(uploader.value) == 0:
            print("Upload an image")
            return
        # Read uploaded file
        up = list(uploader.value.values())[0]
        file_bytes = up["content"]
        filename = up["metadata"]["name"]
        print("Uploading...")
        try:
            control_url = upload_file_to_replicate(file_bytes, filename, replicate_token.value)
        except Exception as e:
            print("Upload error:", e)
            return
        print("Creating prediction...")
        payload = {
            "prompt": prompt.value,
            "image": control_url,
            "num_samples": "1",
            "image_resolution": image_resolution.value,
            "ddim_steps": int(ddim_steps.value),
            "scale": float(scale.value),
            "seed": int(seed.value),
            "a_prompt": a_prompt.value,
            "n_prompt": n_prompt.value,
        }
        try:
            pred = create_prediction(model_version.value.strip() or DEFAULT_MODEL_VERSION, payload, replicate_token.value)
            pred = run_until_done(pred, replicate_token.value)
        except Exception as e:
            print("Prediction error:", e)
            return
        if pred.status != "succeeded":
            print("Failed:", pred.error)
            return
        # Show results
        with input_img_out:
            clear_output()
            try:
                img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
                display(img)
            except Exception:
                pass
        with result_out:
            clear_output()
            try:
                if isinstance(pred.output, list) and len(pred.output) > 0:
                    out_url = pred.output[0]
                elif isinstance(pred.output, str):
                    out_url = pred.output
                else:
                    out_url = None
                if out_url:
                    display(fetch_image(out_url))
                    print(out_url)
                else:
                    print("No output image URL")
            except Exception as e:
                print("Display error:", e)


run_btn.on_click(on_generate_click)

controls = widgets.VBox([
    widgets.HBox([replicate_token]),
    widgets.HBox([model_version]),
    widgets.HBox([prompt]),
    widgets.HBox([image_resolution, ddim_steps, scale, seed]),
    widgets.HBox([a_prompt]),
    widgets.HBox([n_prompt]),
    uploader,
    run_btn,
])

display(controls)
display(widgets.HTML("<h4>Input</h4>"))
display(input_img_out)
display(widgets.HTML("<h4>Result</h4>"))
display(result_out)
display(status_out)
