# Scaling Personalized AI with Flux.1 and Stable Diffusion Image to Video: A Step-by-Step Guide

Hey there, AI enthusiasts! 👋 Ready to dive into the wild world of personalized AI models? Buckle up, because we're about to embark on an epic journey to create a system that can handle thousands of personalized Flux.1 model finetunings like it's no big deal. We'll be using the awesome power of DreamBooth and some nifty open-source tools like ZenML to make this magic happen.

By the time you're done with this notebook, you'll be slinging personalized AI models like a pro. Whether you're a seasoned ML wizard or a curious newbie, this guide will give you the superpowers you need to bring these ideas to life in your own mad scientist projects. Let's get this party started! 🎉

## Step 1: Setting Up Our Environment

First things first, let's get our environment ready for some serious AI action. We'll import all the necessary libraries and set up our configuration classes.

In [None]:
# Install requirements
!pip install -r requirements.txt  

In [None]:
# Connect to ZenML server
!zenml connect --url <your-zenml-server-url>

In [None]:
!zenml init

In [None]:
# Install ZenML integrations
!zenml stack set azure_temp_gpu
!zenml integration install kubernetes azure -y

In [10]:
# I need to do this to get the docker daemon running
import os

# Add Docker to the PATH
os.environ['PATH'] = f"{os.environ['PATH']}:/Applications/Docker.app/Contents/Resources/bin/"


## Step 2: Inspect our dataset

Now that we've got our environment set up, let's create a function to load our training data. This bad boy will help us grab all those juicy image paths we'll use to train our model.

In [11]:
from zenml.client import Client
from zenml.utils import io_utils

images_path = "az://demo-zenmlartifactstore/hamza-faces"
images_dir_path = "/tmp/hamza-faces/"
_ = Client().active_stack.artifact_store.path

io_utils.copy_dir(
    destination_dir=images_dir_path,
    source_dir=images_path,
    overwrite=True,
)

In [12]:
import os
import random

import ipywidgets as widgets
from IPython.display import display


def display_image_gallery(
    images_dir_path, sample_size=10, thumbnail_size=(200, 200)
):
    # Get all image files from the directory
    image_files = [
        f
        for f in os.listdir(images_dir_path)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ]

    # Sample the images
    sampled_files = random.sample(
        image_files, min(sample_size, len(image_files))
    )

    # Create thumbnail widgets
    thumbnails = [
        widgets.Image(
            value=open(os.path.join(images_dir_path, img), "rb").read(),
            format=img.split(".")[-1],
            layout=widgets.Layout(
                width=f"{thumbnail_size[0]}px",
                height=f"{thumbnail_size[1]}px",
                margin="5px",
            ),
        )
        for img in sampled_files
    ]

    # Create a grid of thumbnails
    thumbnail_grid = widgets.GridBox(
        thumbnails,
        layout=widgets.Layout(
            grid_template_columns="repeat(auto-fill, minmax(200px, 1fr))"
        ),
    )

    # Display the gallery
    display(
        widgets.HTML(
            f"<h3>Displaying {len(sampled_files)} of {len(image_files)} images</h3>"
        )
    )
    display(thumbnail_grid)


# Usage
display_image_gallery(images_dir_path)

HTML(value='<h3>Displaying 10 of 201 images</h3>')

GridBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\xae\x00\x00\x00\xa2\x08\x02\x0…

## Step 3: Training Our Model Like a Boss

Alright, now we're getting to the good stuff. Let's set up our model training step. This is where the magic happens, folks!

In [13]:
from rich import print
from train_dreambooth_lora_flux import main as dreambooth_main
from zenml import step
from zenml.client import Client
from zenml.integrations.huggingface.steps import run_with_accelerate
from zenml.utils import io_utils
from zenml.logger import get_logger

logger = get_logger(__name__)


@run_with_accelerate(
    num_processes=1, multi_gpu=False, mixed_precision="bf16"
)  # Adjust num_processes as needed
@step
def train_model(
    images_path: str,
    instance_name: str,
    class_name: str,
    model_name: str,
    hf_repo_suffix: str,
    prefix: str,
    resolution: int,
    train_batch_size: int,
    rank: int,
    gradient_accumulation_steps: int,
    learning_rate: float,
    lr_scheduler: str,
    lr_warmup_steps: int,
    max_train_steps: int,
    push_to_hub: bool,
    checkpointing_steps: int,
    seed: int,
) -> None:

    images_dir_path = "/tmp/hamza-faces/"
    _ = Client().active_stack.artifact_store.path

    io_utils.copy_dir(
        destination_dir=images_dir_path,
        source_dir=images_path,
        overwrite=True,
    )

    instance_phrase = f"{instance_name} the {class_name}"
    instance_prompt = f"{prefix} {instance_phrase}".strip()

    # Create an ArgumentParser-like object to mimic the args in the original script
    class Args:
        def __init__(self, **kwargs):
            self.mixed_precision = kwargs.get("mixed_precision", "bf16")
            self.pretrained_model_name_or_path = kwargs.get(
                "pretrained_model_name_or_path"
            )
            self.revision = kwargs.get("revision", None)
            self.variant = kwargs.get("variant", None)
            self.dataset_name = kwargs.get("dataset_name", None)
            self.dataset_config_name = kwargs.get("dataset_config_name", None)
            self.instance_data_dir = kwargs.get("instance_data_dir")
            self.cache_dir = kwargs.get("cache_dir", None)
            self.image_column = kwargs.get("image_column", "image")
            self.caption_column = kwargs.get("caption_column", None)
            self.repeats = kwargs.get("repeats", 1)
            self.class_data_dir = kwargs.get("class_data_dir", None)
            self.output_dir = kwargs.get("output_dir")
            self.instance_prompt = kwargs.get("instance_prompt")
            self.class_prompt = kwargs.get("class_prompt", None)
            self.max_sequence_length = kwargs.get("max_sequence_length", 512)
            self.validation_prompt = kwargs.get("validation_prompt", None)
            self.num_validation_images = kwargs.get("num_validation_images", 4)
            self.validation_epochs = kwargs.get("validation_epochs", 50)
            self.rank = kwargs.get("rank", 4)
            self.with_prior_preservation = kwargs.get(
                "with_prior_preservation", False
            )
            self.prior_loss_weight = kwargs.get("prior_loss_weight", 1.0)
            self.num_class_images = kwargs.get("num_class_images", 100)
            self.seed = kwargs.get("seed", None)
            self.resolution = kwargs.get("resolution", 512)
            self.center_crop = kwargs.get("center_crop", False)
            self.random_flip = kwargs.get("random_flip", False)
            self.train_text_encoder = kwargs.get("train_text_encoder", False)
            self.train_batch_size = kwargs.get("train_batch_size", 4)
            self.sample_batch_size = kwargs.get("sample_batch_size", 4)
            self.num_train_epochs = kwargs.get("num_train_epochs", 1)
            self.max_train_steps = kwargs.get("max_train_steps", None)
            self.checkpointing_steps = kwargs.get("checkpointing_steps", 500)
            self.checkpoints_total_limit = kwargs.get(
                "checkpoints_total_limit", None
            )
            self.resume_from_checkpoint = kwargs.get(
                "resume_from_checkpoint", None
            )
            self.gradient_accumulation_steps = kwargs.get(
                "gradient_accumulation_steps", 1
            )
            self.gradient_checkpointing = kwargs.get(
                "gradient_checkpointing", False
            )
            self.learning_rate = kwargs.get("learning_rate", 1e-4)
            self.guidance_scale = kwargs.get("guidance_scale", 3.5)
            self.text_encoder_lr = kwargs.get("text_encoder_lr", 5e-6)
            self.scale_lr = kwargs.get("scale_lr", False)
            self.lr_scheduler = kwargs.get("lr_scheduler", "constant")
            self.lr_warmup_steps = kwargs.get("lr_warmup_steps", 500)
            self.lr_num_cycles = kwargs.get("lr_num_cycles", 1)
            self.lr_power = kwargs.get("lr_power", 1.0)
            self.dataloader_num_workers = kwargs.get(
                "dataloader_num_workers", 0
            )
            self.weighting_scheme = kwargs.get("weighting_scheme", "none")
            self.logit_mean = kwargs.get("logit_mean", 0.0)
            self.logit_std = kwargs.get("logit_std", 1.0)
            self.mode_scale = kwargs.get("mode_scale", 1.29)
            self.optimizer = kwargs.get("optimizer", "AdamW")
            self.use_8bit_adam = kwargs.get("use_8bit_adam", False)
            self.adam_beta1 = kwargs.get("adam_beta1", 0.9)
            self.adam_beta2 = kwargs.get("adam_beta2", 0.999)
            self.prodigy_beta3 = kwargs.get("prodigy_beta3", None)
            self.prodigy_decouple = kwargs.get("prodigy_decouple", True)
            self.adam_weight_decay = kwargs.get("adam_weight_decay", 1e-04)
            self.adam_weight_decay_text_encoder = kwargs.get(
                "adam_weight_decay_text_encoder", 1e-03
            )
            self.adam_epsilon = kwargs.get("adam_epsilon", 1e-08)
            self.prodigy_use_bias_correction = kwargs.get(
                "prodigy_use_bias_correction", True
            )
            self.prodigy_safeguard_warmup = kwargs.get(
                "prodigy_safeguard_warmup", True
            )
            self.max_grad_norm = kwargs.get("max_grad_norm", 1.0)
            self.push_to_hub = kwargs.get("push_to_hub", False)
            self.hub_token = kwargs.get("hub_token", None)
            self.hub_model_id = kwargs.get("hub_model_id", None)
            self.logging_dir = kwargs.get("logging_dir", "logs")
            self.allow_tf32 = kwargs.get("allow_tf32", False)
            self.report_to = kwargs.get("report_to", "tensorboard")
            self.local_rank = kwargs.get("local_rank", -1)
            self.prior_generation_precision = kwargs.get(
                "prior_generation_precision", None
            )

    # Usage example:
    args = Args(
        mixed_precision="bf16",
        pretrained_model_name_or_path=model_name,
        instance_data_dir=images_dir_path,
        output_dir=hf_repo_suffix,
        instance_prompt=instance_prompt,
        resolution=resolution,
        train_batch_size=train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        rank=rank,
        lr_scheduler=lr_scheduler,
        lr_warmup_steps=lr_warmup_steps,
        max_train_steps=max_train_steps,
        checkpointing_steps=checkpointing_steps,
        seed=seed,
        push_to_hub=push_to_hub if push_to_hub else False,
    )

    # Run the main function with the created args
    print("Launching dreambooth training script")
    dreambooth_main(args)

## Step 4: Batch Inference - Let's See What We've Created!

Now that we've trained our model, it's time to put it to the test. Let's set up a batch inference step to generate some cool images!

In [14]:
import torch
from diffusers import AutoPipelineForText2Image
from PIL import Image as PILImage
from zenml import step
from zenml.logger import get_logger

logger = get_logger(__name__)

@step
def batch_inference(
    hf_username: str,
    hf_repo_suffix: str,
    instance_name: str,
    class_name: str,
) -> PILImage.Image:
    model_path = f"{hf_username}/{hf_repo_suffix}"
    pipe = AutoPipelineForText2Image.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
    ).to("cuda")
    pipe.load_lora_weights(
        model_path, weight_name="pytorch_lora_weights.safetensors"
    )

    instance_phrase = f"{instance_name} the {class_name}"
    prompts = [
        f"A portrait photo of {instance_phrase} in a Superman pose",
        f"A portrait photo of {instance_phrase} flying like Superman",
        f"A portrait photo of {instance_phrase} standing like Superman",
        f"A portrait photo of {instance_phrase} as a football player in an action pose",
        f"A portrait photo of {instance_phrase} as a firefighter in a heroic stance",
        f"A portrait photo of {instance_phrase} in a spacesuit in space",
        f"A portrait photo of {instance_phrase} on the Moon",
        f"A portrait photo of {instance_phrase} as an astronaut working on a satellite",
        f"A portrait photo of {instance_phrase} as an astronaut looking out a spacecraft window",
        f"A portrait photo of {instance_phrase} as an astronaut on a spacewalk",
        f"A portrait photo of {instance_phrase} in a heroic Superman pose",
        f"A portrait photo of {instance_phrase} as an astronaut on Mars",
        f"A portrait photo of {instance_phrase} flying like Superman",
        f"A portrait photo of {instance_phrase} as an astronaut floating in zero gravity",
        f"A portrait photo of {instance_phrase} as a superhero in a powerful stance",
    ]

    images = pipe(
        prompt=prompts,
        num_inference_steps=35,
        guidance_scale=8.5,
        height=256,
        width=256,
    ).images

    width, height = images[0].size
    rows, cols = 3, 5
    gallery_img = PILImage.new("RGB", (width * cols, height * rows))

    for i, image in enumerate(images):
        gallery_img.paste(image, ((i % cols) * width, (i // cols) * height))

    return gallery_img


## Step 5: From Still to Motion - Let's Make Some Video Magic!

Why stop at images when we can create videos? Let's add a step to turn our generated image into a short video clip!

In [15]:
import base64
from typing import Annotated, Tuple, List

import torch
from diffusers import AutoPipelineForText2Image, StableVideoDiffusionPipeline
from diffusers.utils import export_to_video
from PIL import Image as PILImage
from zenml import step
from zenml.types import HTMLString
from zenml.logger import get_logger

logger = get_logger(__name__)

def get_optimal_size(
    image: PILImage.Image, max_size: int = 1024
) -> Tuple[int, int]:
    width, height = image.size
    aspect_ratio = width / height
    if width > height:
        new_width = min(width, max_size)
        new_height = int(new_width / aspect_ratio)
    else:
        new_height = min(height, max_size)
        new_width = int(new_height * aspect_ratio)
    return (new_width, new_height)


@step
def image_to_video(
    hf_username: str,
    hf_repo_suffix: str,
    instance_name: str,
) -> Tuple[
    Annotated[List[PILImage.Image], "generated_images"],
    Annotated[List[bytes], "video_data_list"],
    Annotated[HTMLString, "video_html"],
]:

    model_path = f"{hf_username}/{hf_repo_suffix}"
    pipe = AutoPipelineForText2Image.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
    ).to("cuda")
    pipe.load_lora_weights(
        model_path, weight_name="pytorch_lora_weights.safetensors"
    )

    instance_phrase = f"{instance_name} the man"
    prompts = [
        f"A portrait photo of {instance_phrase} in a Superman pose",
        f"A portrait photo of {instance_phrase} flying like Superman",
        f"A portrait photo of {instance_phrase} standing like Superman",
        f"A portrait photo of {instance_phrase} as a football player in an action pose",
        f"A portrait photo of {instance_phrase} as a firefighter in a heroic stance",
        f"A portrait photo of {instance_phrase} in a spacesuit in space",
        f"A portrait photo of {instance_phrase} on the Moon",
    ]

    images = pipe(
        prompt=prompts,
        num_inference_steps=40,
        guidance_scale=8.5,
        height=512,
        width=512,
    ).images

    video_pipeline = StableVideoDiffusionPipeline.from_pretrained(
        "stabilityai/stable-video-diffusion-img2vid-xt",
        torch_dtype=torch.float16,
        variant="fp16",
    )
    video_pipeline.enable_model_cpu_offload()

    video_data_list = []
    for i, image in enumerate(images):
        frames = video_pipeline(
            image,
            num_inference_steps=80,
            generator=torch.manual_seed(77),
            height=512,
            width=512,
        ).frames[0]

        output_file = f"generated_video_{i}.mp4"
        export_to_video(frames, output_file, fps=5)

        with open(output_file, "rb") as file:
            video_data = file.read()
            video_data_list.append(video_data)

    html_visualization_str = """
    <html>
    <head>
    </head>
    <body>
        <div style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: 100vh; margin: 0; padding: 0;">
    """
    for i, video_data in enumerate(video_data_list):
        html_visualization_str += f"""
            <div style="margin-bottom: 20px;">
                <video width="512" height="512" controls autoplay loop>
                    <source src="data:video/mp4;base64,{base64.b64encode(video_data).decode()}" type="video/mp4">
                    Your browser does not support the video tag.
                </video>
            </div>
        """
    html_visualization_str += """
        </div>
    </body>
    </html>
    """
    
    return (images, video_data_list, HTMLString(html_visualization_str))

## Step 6: Putting It All Together - Our Dreambooth Pipeline

Now for the grand finale - let's string all these awesome steps together into one epic pipeline!

In [23]:
from rich import print
from zenml import pipeline


@pipeline
def dreambooth_pipeline(
    instance_example_dir: str = "az://demo-zenmlartifactstore/hamza-faces",
    instance_name: str = "sks htahir1",
    class_name: str = "man",
    model_name: str = "black-forest-labs/FLUX.1-dev",
    hf_username: str = "strickvl",
    hf_repo_suffix: str = "flux-dreambooth-hamza",
    prefix: str = "A portrait photo of",
    resolution: int = 512,
    train_batch_size: int = 1,
    rank: int = 32,
    gradient_accumulation_steps: int = 1,
    learning_rate: float = 0.0002,
    lr_scheduler: str = "constant",
    lr_warmup_steps: int = 0,
    max_train_steps: int = 1300,
    push_to_hub: bool = True,
    checkpointing_steps: int = 1000,
    seed: int = 117,
):
    train_model(
        instance_example_dir,
        instance_name=instance_name,
        class_name=class_name,
        model_name=model_name,
        hf_repo_suffix=hf_repo_suffix,
        prefix=prefix,
        resolution=resolution,
        train_batch_size=train_batch_size,
        rank=rank,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        lr_scheduler=lr_scheduler,
        lr_warmup_steps=lr_warmup_steps,
        max_train_steps=max_train_steps,
        push_to_hub=push_to_hub,
        checkpointing_steps=checkpointing_steps,
        seed=seed,
    )
    batch_inference(
        hf_username,
        hf_repo_suffix,
        instance_name,
        class_name,
        after="train_model",
    )
    image_to_video(
        hf_username, hf_repo_suffix, instance_name, after="batch_inference"
    )


print("Dreambooth pipeline assembled and ready for action! 🚀")

## Step 7: Launch the Pipeline and Watch the Magic Happen!

Alright, folks, this is it - the moment of truth! Let's fire up our pipeline and see this baby in action!

In [24]:
dreambooth_pipeline.with_options(config_path="configs/k8s_run_refactored_multi_video.yaml")()

print("Pipeline launched! Sit back, relax, and prepare to be amazed! 🍿")

[1;35mInitiating a new run for the pipeline: [0m[1;36mdreambooth_pipeline[1;35m.[0m
[1;35mUploading notebook code...[0m
[1;35mUpload finished.[0m
[1;35mReusing existing build [0m[1;36me100d938-3463-400f-a4a2-07ef891964b6[1;35m for stack [0m[1;36mazure_temp_gpu[1;35m.[0m
[1;35mArchiving pipeline code...[0m
[1;35mUploading code to [0m[1;36maz://demo-zenmlartifactstore/code_uploads/e84ae2456e57b3d805d357562797d9d10c097702.tar.gz[1;35m (Size: 1.87 MiB).[0m
[1;35mCode upload finished.[0m
[1;35mExecuting a new run.[0m
[1;35mUsing a build:[0m
[1;35m Image(s): demozenmlcontainerregistry.azurecr.io/zenml@sha256:c996d64b7caac50cd9215f7ce068cf3b308d8e764ec67ce99ef6d0bc14561926, demozenmlcontainerregistry.azurecr.io/zenml@sha256:c996d64b7caac50cd9215f7ce068cf3b308d8e764ec67ce99ef6d0bc14561926, demozenmlcontainerregistry.azurecr.io/zenml@sha256:c996d64b7caac50cd9215f7ce068cf3b308d8e764ec67ce99ef6d0bc14561926[0m
[1;35mUsing user: [0m[1;36mhamza@zenml.io[1;35m[0m


And there you have it, folks! You've just built and launched a kickass pipeline for personalized AI model generation. From loading data to training models, from generating images to creating videos - you've done it all!

Remember, this is just the beginning. Feel free to tweak, adjust, and experiment with the parameters to see what kind of magic you can create. The AI world is your oyster, and you've got the tools to make some serious pearls!

Happy coding, and may your models be ever accurate and your latency low! 🚀🎉