Create a virtual environment

In [None]:
python3 -m venv <your-venv-name>

Activate the environment

In [None]:
source ~/<your-venv-name>/bin/activate

In [None]:
pip install git+https://github.com/huggingface/diffusers.git

Clone the HF repo and install jax and flax

In [None]:
git clone https://github.com/huggingface/community-events
cd community-events/jax-controlnet-sprint/training_scripts
pip install -U -r requirements_flax.txt

Verify that jax was installed properly

In [None]:
import jax
jax.device_count()

Install logging 

In [None]:
pip install wandb

Run training

In [None]:
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png

Log in to Hugging Face

In [None]:
huggingface-cli login

Set the environmental variables

In [None]:
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
export HUB_MODEL_ID="controlnet-fill-circle"

Start training

In [None]:
python3 train_controlnet_flax.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --dataset_name=fusing/fill50k \
 --resolution=512 \
 --learning_rate=1e-5 \
 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
 --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
 --validation_steps=1000 \
 --train_batch_size=2 \
 --revision="non-ema" \
 --from_pt \
 --report_to="wandb" \
 --tracker_project_name=$HUB_MODEL_ID \
 --num_train_epochs=11 \
 --push_to_hub \
 --hub_model_id=$HUB_MODEL_ID

Save checkpoints

In [None]:
 --checkpointing_steps=500

You can start training from checkpoint

In [None]:
 --controlnet_model_name_or_path="./control_out/500" 

It is recommended to set --snr_gamma to 5.0

Profile the code

In [None]:
 --profile_steps==5

To inspect profile trace, run following code

In [None]:
pip install tensorflow tensorboard-plugin-profile
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/

Killing the TPU

In [None]:
kill -9 PID

Pushing model to HUB

In [None]:
model.save_pretrained("path_to_your_model_repository")

In [None]:
from huggingface_hub import create_repo, upload_folder

create_repo("username/my-awesome-model")
upload_folder(
    folder_path="path_to_your_model_repository",
    repo_id="username/my-awesome-model"
)

Creating the space

In [None]:
import gradio as gr

# inference function takes prompt, negative prompt and image
def infer(prompt, negative_prompt, image):
    # implement your inference function here
    return output_image

# you need to pass inputs and outputs according to inference function
gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image").launch()

Customize the space

In [None]:
title = "Animal Pose controlnet"
description = "This is a demo on ControlNet based on canny filter."
# you need to pass your examples according to your inputs
# each inner list is one example, each element in the list corresponding to a component in the `inputs`.
examples = [["a cat with cake texture", "low quality", "cat_image.png"]]
gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image",
            title = title, description = description, examples = examples, theme='gradio/soft').launch()

Using Blocks

In [None]:
import gradio as gr

def infer_segmentation(prompt, negative_prompt, image):
    # your inference function for segmentation control 
    return im

def infer_canny(prompt, negative_prompt, image):
    # your inference function for canny control 
    return im

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown("## Stable Diffusion with Different Controls")
    gr.Markdown("In this app, you can find different ControlNets with different filters. ")


    with gr.Tab("ControlNet on Canny Filter "):
        prompt_input_canny = gr.Textbox(label="Prompt")
        negative_prompt_canny = gr.Textbox(label="Negative Prompt")
        canny_input = gr.Image(label="Input Image")
        canny_output = gr.Image(label="Output Image")
        submit_btn = gr.Button(value = "Submit")
        canny_inputs = [prompt_input_canny, negative_prompt_canny, canny_input]
        submit_btn.click(fn=infer_canny, inputs=canny_inputs, outputs=[canny_output])
        
    with gr.Tab("ControlNet with Semantic Segmentation"):
        prompt_input_seg = gr.Textbox(label="Prompt")
        negative_prompt_seg = gr.Textbox(label="Negative Prompt")
        seg_input = gr.Image(label="Image")
        seg_output = gr.Image(label="Output Image")
        submit_btn = gr.Button(value = "Submit")
        seg_inputs = [prompt_input_seg, negative_prompt_seg, seg_input]
        submit_btn.click(fn=infer_segmentation, inputs=seg_inputs, outputs=[seg_output])

demo.launch()