# Scaling Personalized AI with Flux.1 and DreamBooth: A Step-by-Step Guide

Hey there, AI enthusiasts! 👋 Ready to dive into the wild world of personalized AI models? Buckle up, because we're about to embark on an epic journey to create a system that can handle thousands of personalized Flux.1 model finetunings like it's no big deal. We'll be using the awesome power of DreamBooth and some nifty open-source tools like ZenML to make this magic happen.

By the time you're done with this notebook, you'll be slinging personalized AI models like a pro. Whether you're a seasoned ML wizard or a curious newbie, this guide will give you the superpowers you need to bring these ideas to life in your own mad scientist projects. Let's get this party started! 🎉

## Step 1: Setting Up Our Environment

First things first, let's get our environment ready for some serious AI action. We'll import all the necessary libraries and set up our configuration classes.

In [None]:
import base64
import os
from typing import Annotated, List, Tuple

import torch
from accelerate.utils import write_basic_config
from diffusers import AutoPipelineForText2Image, StableVideoDiffusionPipeline
from diffusers.utils import export_to_video
from PIL import Image as PILImage
from rich import print
from train_dreambooth_lora_flux import main as dreambooth_main
from zenml import pipeline, step
from zenml.config import DockerSettings
from zenml.integrations.huggingface.steps import run_with_accelerate
from zenml.integrations.kubernetes.flavors import (
    KubernetesOrchestratorSettings,
)
from zenml.logger import get_logger
from zenml.types import HTMLString
from zenml.utils import io_utils
from zenml.client import Client

logger = get_logger(__name__)

MNT_PATH = "/mnt/data"

docker_settings = DockerSettings(
    parent_image="pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime",
    environment={
        "PJRT_DEVICE": "CUDA",
        "USE_TORCH_XLA": "false",
        "MKL_SERVICE_FORCE_INTEL": 1,
        "HF_TOKEN": os.environ["HF_TOKEN"],
        "HF_HOME": MNT_PATH,
    },
    python_package_installer="uv",
    requirements="requirements.txt",
    python_package_installer_args={
        "system": None,
    },
    apt_packages=["git", "ffmpeg", "gifsicle"],
    # prevent_build_reuse=True,
)

kubernetes_settings = KubernetesOrchestratorSettings(
    pod_settings={
        "affinity": {
            "nodeAffinity": {
                "requiredDuringSchedulingIgnoredDuringExecution": {
                    "nodeSelectorTerms": [
                        {
                            "matchExpressions": [
                                {
                                    "key": "zenml.io/gpu",
                                    "operator": "In",
                                    "values": ["yes"],
                                }
                            ]
                        }
                    ]
                }
            }
        },
        "volumes": [
            {
                "name": "data-volume",
                "persistentVolumeClaim": {"claimName": "pvc-managed-premium"},
            }
        ],
        "volume_mounts": [{"name": "data-volume", "mountPath": MNT_PATH}],
    },
)

print("Environment setup complete! 🚀")

## Step 2: Data Loading Magic

Now that we've got our environment set up, let's create a function to load our training data. This bad boy will help us grab all those juicy image paths we'll use to train our model.

In [1]:
from zenml.client import Client
from zenml.utils import io_utils

images_path = "az://demo-zenmlartifactstore/hamza-faces"
images_dir_path = "/tmp/hamza-faces/"
_ = Client().active_stack.artifact_store.path

io_utils.copy_dir(
    destination_dir=images_dir_path,
    source_dir=images_path,
    overwrite=True,
)


[33mCould not import AWS service connector: No module named 'boto3'.[0m
[33mCould not import GCP service connector: cannot import name 'artifactregistry_v1' from 'google.cloud' (unknown location).[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token succeeded[0m
[1;35mClientSecretCredential.get_token 

In [3]:
import os
from IPython.display import display
import ipywidgets as widgets

def display_image_gallery(images_dir_path, thumbnail_size=(200, 200)):
    # Get all image files from the directory
    image_files = [f for f in os.listdir(images_dir_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    # Create thumbnail widgets
    thumbnails = [widgets.Image(
        value=open(os.path.join(images_dir_path, img), 'rb').read(),
        format=img.split('.')[-1],
        layout=widgets.Layout(width=f'{thumbnail_size[0]}px', height=f'{thumbnail_size[1]}px', margin='5px')
    ) for img in image_files]
    
    # Create buttons for thumbnails
    buttons = [widgets.Button(layout=widgets.Layout(width=f'{thumbnail_size[0]}px', height=f'{thumbnail_size[1]}px', padding='0')) for _ in thumbnails]
    
    # Add thumbnails as button icons
    for thumbnail, button in zip(thumbnails, buttons):
        button.icon = thumbnail
    
    # Create a grid of buttons
    thumbnail_grid = widgets.GridBox(
        buttons,
        layout=widgets.Layout(grid_template_columns="repeat(auto-fill, minmax(200px, 1fr))")
    )
    
    # Create a larger image widget for the selected image
    selected_image = widgets.Image(
        layout=widgets.Layout(max_width='100%', height='auto', margin='10px 0')
    )
    
    # Function to update the selected image
    def on_button_click(button):
        index = buttons.index(button)
        selected_image.value = thumbnails[index].value
        selected_image.format = thumbnails[index].format
    
    # Add click event to buttons
    for button in buttons:
        button.on_click(on_button_click)
    
    # Display the gallery
    display(thumbnail_grid, selected_image)


# Usage
display_image_gallery(images_dir_path)

AttributeError: 'Image' object has no attribute 'on_click'

In [None]:
def load_image_paths(image_dir: Path) -> List[Path]:
    logger.info(f"Loading images from {image_dir}")
    return (
        list(image_dir.glob("**/*.png"))
        + list(image_dir.glob("**/*.jpg"))
        + list(image_dir.glob("**/*.jpeg"))
    )

@step(
    settings={"orchestrator.kubernetes": kubernetes_settings},
    enable_cache=False,
)
def load_data(instance_example_dir: str) -> List[PILImage.Image]:
    instance_example_paths = load_image_paths(Path(instance_example_dir))
    logger.info(f"Loaded {len(instance_example_paths)} images")
    return [PILImage.open(path) for path in instance_example_paths]

print("Data loading function ready to roll! 📸")

## Step 3: Training Our Model Like a Boss

Alright, now we're getting to the good stuff. Let's set up our model training step. This is where the magic happens, folks!

In [None]:
@run_with_accelerate(num_processes=1, multi_gpu=True)
@step(
    settings={"orchestrator.kubernetes": kubernetes_settings},
    enable_cache=False,
)
def train_model(
    instance_example_images: List[PILImage.Image],
    instance_name: str,
    class_name: str,
    model_name: str,
    hf_repo_suffix: str,
    prefix: str,
    resolution: int,
    train_batch_size: int,
    rank: int,
    gradient_accumulation_steps: int,
    learning_rate: float,
    lr_scheduler: str,
    lr_warmup_steps: int,
    max_train_steps: int,
    push_to_hub: bool,
    checkpointing_steps: int,
    seed: int,
) -> None:
    # Set up a temporary directory for our images
    image_dir = Path(tempfile.mkdtemp(prefix="instance_images_"))
    for i, image in enumerate(instance_example_images):
        image.save(image_dir / f"image_{i}.png")

    logger.info(f"Saved images to {image_dir}")
    images_dir_path = str(image_dir)

    # Configure accelerate for some speedy training
    write_basic_config(mixed_precision="bf16")

    instance_phrase = f"{instance_name} the {class_name}"
    instance_prompt = f"{prefix} {instance_phrase}".strip()

    # Set up our training arguments
    class Args:
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)

    args = Args(
        mixed_precision="bf16",
        pretrained_model_name_or_path=model_name,
        instance_data_dir=images_dir_path,
        output_dir=hf_repo_suffix,
        instance_prompt=instance_prompt,
        resolution=resolution,
        train_batch_size=train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        rank=rank,
        lr_scheduler=lr_scheduler,
        lr_warmup_steps=lr_warmup_steps,
        max_train_steps=max_train_steps,
        checkpointing_steps=checkpointing_steps,
        seed=seed,
        push_to_hub=push_to_hub if push_to_hub else "",
    )

    # Fire up that training engine!
    print("Launching dreambooth training script")
    dreambooth_main(args)

print("Model training step locked and loaded! 💪")

## Step 4: Batch Inference - Let's See What We've Created!

Now that we've trained our model, it's time to put it to the test. Let's set up a batch inference step to generate some cool images!

In [None]:
@step(settings={"orchestrator.kubernetes": kubernetes_settings})
def batch_inference(
    hf_username: str,
    hf_repo_suffix: str,
    instance_name: str,
    class_name: str,
) -> PILImage.Image:
    model_path = f"{hf_username}/{hf_repo_suffix}"
    pipe = AutoPipelineForText2Image.from_pretrained(
        "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
    ).to("cuda")
    pipe.load_lora_weights(
        model_path, weight_name="pytorch_lora_weights.safetensors"
    )

    instance_phrase = f"{instance_name} the {class_name}"
    prompts = [
        f"A photo of {instance_phrase} wearing a beret in front of the Eiffel Tower",
        f"A photo of {instance_phrase} on a busy Paris street",
        f"A photo of {instance_phrase} sitting at a Parisian cafe",
        # ... (add more prompts as desired)
    ]

    images = pipe(
        prompt=prompts,
        num_inference_steps=50,
        guidance_scale=7.5,
        height=512,
        width=512,
    ).images

    # Create a cool gallery image with all our generated pics
    width, height = images[0].size
    rows, cols = 3, 5
    gallery_img = PILImage.new("RGB", (width * cols, height * rows))

    for i, image in enumerate(images):
        gallery_img.paste(image, ((i % cols) * width, (i // cols) * height))

    return gallery_img

print("Batch inference step ready to generate some masterpieces! 🎨")

## Step 5: From Still to Motion - Let's Make Some Video Magic!

Why stop at images when we can create videos? Let's add a step to turn our generated image into a short video clip!

In [None]:
def get_optimal_size(
    image: PILImage.Image, max_size: int = 1024
) -> Tuple[int, int]:
    width, height = image.size
    aspect_ratio = width / height
    if width > height:
        new_width = min(width, max_size)
        new_height = int(new_width / aspect_ratio)
    else:
        new_height = min(height, max_size)
        new_width = int(new_height * aspect_ratio)
    return (new_width, new_height)

@step(
    settings={"orchestrator.kubernetes": kubernetes_settings},
    enable_cache=False,
)
def image_to_video(
    hf_username: str,
    hf_repo_suffix: str,
    instance_name: str,
) -> Tuple[
    Annotated[PILImage.Image, "generated_image"],
    Annotated[bytes, "video_data"],
    Annotated[HTMLString, "video_html"],
]:
    model_path = f"{hf_username}/{hf_repo_suffix}"
    pipe = AutoPipelineForText2Image.from_pretrained(
        "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
    ).to("cuda")
    pipe.load_lora_weights(
        model_path, weight_name="pytorch_lora_weights.safetensors"
    )

    image = pipe(
        prompt=f"A photo of {instance_name} on a busy Paris street",
        num_inference_steps=70,
        guidance_scale=7.5,
        height=512,
        width=512,
    ).images[0]

    video_pipeline = StableVideoDiffusionPipeline.from_pretrained(
        "stabilityai/stable-video-diffusion-img2vid-xt",
        torch_dtype=torch.float16,
        variant="fp16",
    )
    video_pipeline.enable_model_cpu_offload()

    optimal_size = get_optimal_size(image)
    image = image.resize(optimal_size)
    optimal_width, optimal_height = optimal_size

    frames = video_pipeline(
        image,
        num_inference_steps=50,
        decode_chunk_size=8,
        generator=torch.manual_seed(42),
        height=optimal_height,
        width=optimal_width,
    ).frames[0]

    output_file = "generated_video.mp4"
    export_to_video(frames, output_file, fps=7)

    with open(output_file, "rb") as file:
        video_data = file.read()

    html_visualization_str = f"""
    <html>
    <body>
        <video width="{optimal_width}" height="{optimal_height}" controls>
            <source src="data:video/mp4;base64,{base64.b64encode(video_data).decode()}" type="video/mp4">
            Your browser does not support the video tag.
        </video>
    </body>
    </html>
    """

    return (image, video_data, HTMLString(html_visualization_str))

print("Video generation step ready to bring your images to life! 🎬")

## Step 6: Putting It All Together - Our Dreambooth Pipeline

Now for the grand finale - let's string all these awesome steps together into one epic pipeline!

In [None]:
@pipeline(settings={"docker": docker_settings})
def dreambooth_pipeline(
    instance_example_dir: str = "data/hamza-instance-images",
    instance_name: str = "htahir1",
    class_name: str = "Pakistani man",
    model_name: str = "black-forest-labs/FLUX.1-dev",
    hf_username: str = "htahir1",
    hf_repo_suffix: str = "flux-dreambooth-hamza",
    prefix: str = "A photo of",
    resolution: int = 512,
    train_batch_size: int = 1,
    rank: int = 16,
    gradient_accumulation_steps: int = 1,
    learning_rate: float = 0.0004,
    lr_scheduler: str = "constant",
    lr_warmup_steps: int = 0,
    max_train_steps: int = 1600,
    push_to_hub: bool = True,
    checkpointing_steps: int = 1000,
    seed: int = 117,
):
    data = load_data(instance_example_dir)
    train_model(
        data,
        instance_name=instance_name,
        class_name=class_name,
        model_name=model_name,
        hf_repo_suffix=hf_repo_suffix,
        prefix=prefix,
        resolution=resolution,
        train_batch_size=train_batch_size,
        rank=rank,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        lr_scheduler=lr_scheduler,
        lr_warmup_steps=lr_warmup_steps,
        max_train_steps=max_train_steps,
        push_to_hub=push_to_hub,
        checkpointing_steps=checkpointing_steps,
        seed=seed,
    )
    batch_inference(
        hf_username,
        hf_repo_suffix,
        instance_name,
        class_name,
        after="train_model",
    )
    image_to_video(
        hf_username, hf_repo_suffix, instance_name, after="batch_inference"
    )

print("Dreambooth pipeline assembled and ready for action! 🚀")

## Step 7: Launch the Pipeline and Watch the Magic Happen!

Alright, folks, this is it - the moment of truth! Let's fire up our pipeline and see this baby in action!

In [None]:
if __name__ == "__main__":
    dreambooth_pipeline()

print("Pipeline launched! Sit back, relax, and prepare to be amazed! 🍿")

And there you have it, folks! You've just built and launched a kickass pipeline for personalized AI model generation. From loading data to training models, from generating images to creating videos - you've done it all!

Remember, this is just the beginning. Feel free to tweak, adjust, and experiment with the parameters to see what kind of magic you can create. The AI world is your oyster, and you've got the tools to make some serious pearls!

Happy coding, and may your models be ever accurate and your latency low! 🚀🎉