In [47]:
!wget https://raw.githubusercontent.com/jomo0825/MrFuTextualInversion/main/textual_inversion.py
!pip install gradio
!pip install diffusers transformers accelerate ftfy

import gradio as gr
import sys
import threading
from textual_inversion import main as train_textual_inversion  # Assuming main function is the entry point in textual_inversion.py
import time, os, logging
from os import path
import subprocess
import shlex
import queue
from PIL import Image
import shutil


output_dir = "./output"
logging_dir = "./log"
dataset_dir = "./dataset"

# Delete the 'output' folder and its contents
shutil.rmtree(output_dir, ignore_errors=True)
# Delete the 'log' folder and its contents
shutil.rmtree(logging_dir, ignore_errors=True)

def parse_lr_schedule(lr_schedule_str):
    schedule = []
    segments = lr_schedule_str.split(',')
    for segment in segments:
        if ':' in segment:
            lr, steps = segment.split(':')
            schedule.append((float(lr), int(steps)))
        else:
            schedule.append((float(segment), None))  # Final constant learning rate
    return schedule

def get_learning_rate_at_step(lr_schedule, step):
    current_step = 0
    for lr, segment_steps in lr_schedule:
        if segment_steps is None or step < current_step + segment_steps:
            return lr
        current_step += segment_steps
    return lr_schedule[-1][0]  # Return the last LR if beyond defined steps

# Callback to update the preview image in the UI
def preview_callback(image, step):
    global current_preview, current_status
    current_preview = image
    current_status = f"Preview updated at step {step}"

def run_training(dataset_path, prompt, placeholder_token, initializer_token, num_training_steps,
                 learning_rate, batch_size, preview_save_steps, preview_seed):
    global current_preview, current_status
    current_preview = None  # Reset the preview
    current_status = "Training started..."  # Initial status

    # Construct the command with all arguments
    command = [
        "python", "textual_inversion.py",
        "--pretrained_model_name_or_path", "stable-diffusion-v1-5/stable-diffusion-v1-5",
        "--train_data_dir", dataset_path,
        "--placeholder_token", placeholder_token,
        "--initializer_token", initializer_token,
        "--resolution", "512",
        "--train_batch_size", str(batch_size),
        "--gradient_accumulation_steps", "1",
        "--learning_rate", str(learning_rate),
        "--max_train_steps", str(num_training_steps),
        "--save_steps", str(preview_save_steps),
        "--validation_steps", str(preview_save_steps),
        "--output_dir", output_dir,
        "--logging_dir", logging_dir,
        "--validation_prompt", prompt,
        "--learnable_property", "object",
        "--seed", str(preview_seed)
    ]

    # Print the command for debugging
    print("Command:", " ".join(command))

    # Create the directories if they don't exist
    os.makedirs(logging_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)

    # Disable logging
    logging.getLogger("accelerate").disabled = True

    # Run the command in a separate process
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    # Wait for the process to complete
    stdout, stderr = process.communicate()

    # Print the output and errors (for debugging)
    print("Output:", stdout.decode())
    print("Errors:", stderr.decode())

    # Update status when training completes
    current_status = "Training completed!"

    yield gr.update(value=current_preview), gr.update(value=current_status)


def ui():
    with gr.Blocks() as demo:
        gr.Markdown("# Stable Diffusion Textual Inversion UI")
        gr.Markdown("Generate images using a preloaded textual inversion model.")

        with gr.Row():
            dataset_path = gr.Textbox(label="Dataset Path", value="dataset", interactive=True)

        with gr.Row():
            with gr.Column(scale=1, min_width=300):
                placeholder_token = gr.Textbox(label="Placeholder Token", placeholder="Enter placeholder token here", interactive=True)
                initializer_token = gr.Textbox(label="Initializer Token", placeholder="Enter initializer token here", interactive=True)
                prompt = gr.Textbox(label="Preview Prompt", placeholder="Enter your prompt here", interactive=True)
                num_training_steps = gr.Number(label="Number of Training Steps", value=100, interactive=True)
                learning_rate = gr.Number(label="Learning Rate", value=0.001, interactive=True)
                batch_size = gr.Number(label="Batch Size", value=1, interactive=True)
                preview_save_steps = gr.Number(label="Preview/Save Every N Steps", value=10, interactive=True)
                preview_seed = gr.Number(label="Preview Seed", value=1, interactive=True)
            with gr.Column(scale=1, min_width=300):
                output_image = gr.Image(label="Generated Image")

        generate_button = gr.Button("Start Training")

        generate_status = gr.Textbox(value="Status messages will appear here.", label="Status", interactive=False)

        generate_button.click(
            fn=run_training,
            inputs=[dataset_path, prompt, placeholder_token, initializer_token, num_training_steps,
                    learning_rate, batch_size, preview_save_steps, preview_seed],
            outputs=[output_image, generate_status],
            show_progress=True,
            queue=True
        )

    return demo


# Create the directories if they don't exist
os.makedirs(dataset_dir, exist_ok=True)

demo = ui()

demo.launch()


--2025-03-07 21:02:01--  https://raw.githubusercontent.com/jomo0825/MrFuTextualInversion/main/textual_inversion.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42572 (42K) [text/plain]
Saving to: ‘textual_inversion.py.26’


2025-03-07 21:02:01 (5.84 MB/s) - ‘textual_inversion.py.26’ saved [42572/42572]

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://6ad71a71840f2e0451.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working di



In [33]:
demo.close()

Closing server running on port: 7867
