# Serving `FLUX.1-Dev` Model

Use `Serverless GPU Compute` for the loading and registering of model

In [0]:
dbutils.library.restartPython()

## Model Wrapper

In [0]:
%%writefile flux1dev_model.py
# sharding https://huggingface.co/docs/diffusers/en/training/distributed_inference#model-sharding
# quantized the model make it work only on one GPU

import mlflow

class StableDiffusionImgToImg(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.pipe = None

    def load_context(self, context):
        import os
        import torch
        import transformers
        from diffusers import (
            BitsAndBytesConfig as DiffusersBitsAndBytesConfig,
            FluxTransformer2DModel,
            FluxImg2ImgPipeline,
        )
        from transformers import (
            BitsAndBytesConfig as BitsAndBytesConfig,
            T5EncoderModel,
        )

        # transformers.utils.move_cache()
        os.environ["HUGGING_FACE_HUB_TOKEN"] = "" #TODO
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

        MAX_MEMORY = {i: "24GB" for i in range(torch.cuda.device_count())}

        text_encoder_8bit = T5EncoderModel.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            subfolder="text_encoder_2",
            torch_dtype=torch.float16,
            device_map="balanced",
            max_memory=MAX_MEMORY,
        )

        transformer_8bit = FluxTransformer2DModel.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            subfolder="transformer",
            torch_dtype=torch.float16,
            device_map="balanced",
            max_memory=MAX_MEMORY,
        )

        self.flush()

        self.pipeline = FluxImg2ImgPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            text_encoder_2=text_encoder_8bit,
            transformer=transformer_8bit,
            torch_dtype=torch.float16,
            device_map="balanced",
            max_memory=MAX_MEMORY,
        )

        self.flush()

    def flush(self):
        import gc
        import torch

        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.reset_peak_memory_stats()

    def image_to_base64(self, image):
        from io import BytesIO
        import base64

        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")

    def base64_to_image(self, base64_string):
        from io import BytesIO
        import base64
        import PIL
        from PIL import Image

        # Decode the base64 string
        img_data = base64.b64decode(base64_string)

        # Create a BytesIO object from the decoded data
        buffer = BytesIO(img_data)

        # Open the image using PIL
        image = Image.open(buffer)

        return image

    def predict(self, context, model_input):
        import torchvision.transforms as T
        import torch

        entry_device = torch.device(
            "cuda:0"
        )  # Or use self.pipeline.device or check hf_device_map if unsure

        prompt = model_input["prompt"][0]
        init_image = self.base64_to_image(model_input["init_image"][0])

        transform = T.Compose(
            [
                T.ToTensor(),  # (C,H,W), float in [0,1]
                T.Lambda(lambda x: x.to(torch.float16)),  # dtype match
            ]
        )
        init_image_tensor = transform(init_image).unsqueeze(
            0
        )  # Add batch dim if required
        init_image_tensor = init_image_tensor.to(entry_device)

        num_inference_steps = model_input["num_inference_steps"][0]

        strength = model_input["strength"][0]

        guidance_scale = model_input["guidance_scale"][0]

        image = self.pipeline(
            prompt=prompt,
            image=init_image_tensor,
            num_inference_steps=num_inference_steps,
            strength=strength,
            guidance_scale=guidance_scale,
        ).images[0]

        return self.image_to_base64(image)
    
mlflow.models.set_model(StableDiffusionImgToImg())

## Log & Register Model

In [0]:
# Create and set MLflow experiment
import os
import mlflow
from mlflow.exceptions import RestException

experiment_name = f"{os.getcwd()}/flux1dev-serving"

try:
    mlflow.create_experiment(name=experiment_name)
    print(f"Creating new experiment {experiment_name}")
except RestException as e:
    if "RESOURCE_ALREADY_EXISTS" in str(e):
        mlflow.set_experiment(experiment_name)
        print(f"Experiment {experiment_name} already exists.")

In [0]:
# Register model
mlflow.set_registry_uri("databricks-uc")

import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec, TensorSpec
import pandas as pd
from PIL import Image
from io import BytesIO
import base64

test_image_path = f"{os.getcwd()}/test_image.jpg"
catalog_name = "users"  # TODO
schema_name = "david_huang"  # TODO
model_name = "flux-auto-img2img-model"  # TODO
vol_name = "flux1dev_artifact"  # TODO
model_file_name = "flux1dev_model.py"  # TODO


def load_image_from_volume(volume_path):
    import PIL
    from PIL import Image

    with Image.open(volume_path) as img:
        return img.convert("RGB")


def image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


input_schema = Schema(
    [
        ColSpec(DataType.string, "prompt"),
        ColSpec(DataType.integer, "num_inference_steps"),
        ColSpec(DataType.string, "init_image"),
        ColSpec(DataType.float, "strength"),
        ColSpec(DataType.float, "guidance_scale"),
    ]
)

output_schema = Schema([ColSpec(DataType.string, "image")])

signature = ModelSignature(inputs=input_schema, outputs=output_schema)


image = image_to_base64(load_image_from_volume(test_image_path))


# Define input example
input_example = pd.DataFrame(
    {
        "prompt": ["a photo of cat dressed in Gandalf from Lord of the Ring"],
        "num_inference_steps": [8],
        "init_image": [image],
        "strength": [0.8],
        "guidance_scale": [12.0],
    }
)

input_example["num_inference_steps"] = input_example["num_inference_steps"].astype(
    "int32"
)
input_example["strength"] = input_example["strength"].astype("float32")
input_example["guidance_scale"] = input_example["guidance_scale"].astype("float32")

In [0]:
# Log the model with its details such as artifacts, pip requirements and input example
with mlflow.start_run() as run:
    mlflow.pyfunc.log_model(
        name="model",
        python_model=os.path.join(os.getcwd(), "flux1dev_model.py"),
        input_example=input_example,
        signature=signature,
        registered_model_name=f"{catalog_name}.{schema_name}.{model_name}",
        artifacts={"repository": f"/Volumes/{catalog_name}/{schema_name}/{vol_name}"},
        pip_requirements=[
            "transformers==4.48.0",
            "torch==2.5.1",
            "torchvision==0.20.1",
            "accelerate",
            "diffusers==0.32.2",
            "huggingface_hub==0.27.1",
            "invisible-watermark>=0.2.0",
            "bitsandbytes==0.45.4",
            "sentencepiece==0.2.0",
        ],
    )

## Serve on Mondel Serving

Steps to serve the model on Model Serving:
1. Use the following config when deploying on Model Serving UI:
    * GPU Medium (A10G x 8)
    * Small 0-4 Concurrency
    * Scale to zero
2. Point to the registered model
3. (Optional) Enable scale to zero during development / testing
4. Wait 20-30 minutes

## Test Deployment

In [None]:
import mlflow.deployments
import pandas as pd
import base64
from PIL import Image
from io import BytesIO

# Set the model serving endpoint name
endpoint_name = "dhuang-flux1dev"  # TODO: Replace with your actual endpoint name

# Create deployment client
client = mlflow.deployments.get_deploy_client("databricks")

In [None]:
# Prepare test input
test_input = pd.DataFrame(
    {
        "prompt": ["a photo of a cat dressed as Don Corleone in the Godfather"],
        "num_inference_steps": [8],
        "init_image": [image],  # Using the same test image from registration
        "strength": [0.8],
        "guidance_scale": [12.0],
    }
)

# Make prediction
response = client.predict(
    endpoint=endpoint_name, inputs={"inputs": test_input.to_dict(orient="records")}
)

predicted_image_base64 = response["predictions"]

image_data = base64.b64decode(predicted_image_base64)
result_image = Image.open(BytesIO(image_data))
display(result_image)

In [None]:
# Save result image
result_image.save("result_image.jpg", format="JPEG")