# Stable Fast 3D Mesh Reconstruction and OpenVINO

<div class="alert alert-block alert-danger"> <b>Important note:</b> This notebook requires python >= 3.9. Please make sure that your environment fulfill to this requirement before running it </div>

[Stable Fast 3D (SF3D)](https://huggingface.co/stabilityai/stable-fast-3d) is a large reconstruction model based on [TripoSR](https://huggingface.co/spaces/stabilityai/TripoSR), which takes in a single image of an object and generates a textured UV-unwrapped 3D mesh asset.

You can find [the source code on GitHub](https://github.com/Stability-AI/stable-fast-3d) and read the paper [SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement](https://arxiv.org/abs/2408.00653).

![Teaser Video](https://github.com/Stability-AI/stable-fast-3d/blob/main/demo_files/teaser.gif?raw=true)

#### Table of contents:

- [Prerequisites](#Prerequisites)
- [Get the original model](#Get-the-original-model)
- [Convert the model to OpenVINO IR](#Convert-the-model-to-OpenVINO-IR)
- [Compiling models and prepare pipeline](#Compiling-models-and-prepare-pipeline)
- [Interactive inference](#Interactive-inference)

### Installation Instructions

This is a self-contained example that relies solely on its own code.

We recommend  running the notebook in a virtual environment. You only need a Jupyter server to start.
For details, please refer to [Installation Guide](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/README.md#-installation-guide).

<img referrerpolicy="no-referrer-when-downgrade" src="https://static.scarf.sh/a.png?x-pxid=5b5a4db0-7875-4bfb-bdbd-01698b5b1a77&file=notebooks/stable-fast-3d/stable-fast-3d.ipynb" />

## Prerequisites
[back to top ⬆️](#Table-of-contents:)

In [None]:
import requests

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
open("notebook_utils.py", "w").write(r.text)

r = requests.get(
    url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/pip_helper.py",
)
open("pip_helper.py", "w").write(r.text)

from pip_helper import pip_install


pip_install("-q", "gradio>=4.19", "openvino>=2024.3.0", "wheel", "gradio-litmodel3d==0.0.1")

pip_install(
    "-q",
    "torch>=2.2.2",
    "torchvision",
    "transformers>=4.42.3",
    "rembg==2.0.57",
    "trimesh==4.4.1",
    "einops==0.7.0",
    "omegaconf>=2.3.0",
    "jaxtyping==0.2.31",
    "gpytoolbox==0.2.0",
    "open_clip_torch==2.24.0",
    "git+https://github.com/vork/PyNanoInstantMeshes.git",
    "--extra-index-url",
    "https://download.pytorch.org/whl/cpu",
)

In [None]:
import sys
from pathlib import Path

if not Path("stable-fast-3d").exists():
    !git clone https://github.com/Stability-AI/stable-fast-3d

sys.path.append("stable-fast-3d")
pip_install("-q", "stable-fast-3d/texture_baker/")

## Get the original model

In [None]:
from sf3d.system import SF3D


model = SF3D.from_pretrained(
    "stabilityai/stable-fast-3d",
    config_name="config.yaml",
    weight_name="model.safetensors",
)

### Convert the model to OpenVINO IR
[back to top ⬆️](#Table-of-contents:)

Define the conversion function for PyTorch modules. We use `ov.convert_model` function to obtain OpenVINO Intermediate Representation object and `ov.save_model` function to save it as XML file.

In [None]:
import torch

import openvino as ov


def convert(model: torch.nn.Module, xml_path: str, example_input):
    xml_path = Path(xml_path)
    if not xml_path.exists():
        xml_path.parent.mkdir(parents=True, exist_ok=True)
        with torch.no_grad():
            converted_model = ov.convert_model(model, example_input=example_input)
        ov.save_model(converted_model, xml_path, compress_to_fp16=False)

        # cleanup memory
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
        torch.jit._state._clear_class_state()

The original model is a pipeline of several models. There are `image_tokenizer`, `tokenizer`, `backbone`, `post_processor`, `camera_embedder`, `decoder`, `image_estimator` and `global_estimator`. Convert all internal models one by one. 

`image_tokenizer` contains `Dinov2Embeddings` that call `nn.functional.interpolate` in its method `interpolate_pos_encoding`. This method accepts a tuple of floats as `scale_factor`, but during conversion a tuple of floats converts to a tuple of tensors due to conversion specific. It raises an error. So, we need to patch it by converting in float.

In [None]:
import math
import types

from torch import nn

from sf3d.models.tokenizers.dinov2 import Dinov2Embeddings


class Dinov2EmbeddingsPatched(Dinov2Embeddings):
    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        num_patches = embeddings.shape[1] - 1
        num_positions = self.position_embeddings.shape[1] - 1
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        class_pos_embed = self.position_embeddings[:, 0]
        patch_pos_embed = self.position_embeddings[:, 1:]
        dim = embeddings.shape[-1]
        height = height // self.config.patch_size
        width = width // self.config.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

        scale_factor = (
            (
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
        )
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                float(height / math.sqrt(num_positions)),
                float(width / math.sqrt(num_positions)),
            ),
            mode="bicubic",
            align_corners=False,
        )
        if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
            raise ValueError("Width or height does not match with the interpolated position embeddings")
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)


model.image_tokenizer.model.embeddings.interpolate_pos_encoding = types.MethodType(
    Dinov2EmbeddingsPatched.interpolate_pos_encoding, model.image_tokenizer.model.embeddings
)

In [None]:
example_input = {
    "images": torch.rand([1, 1, 3, 512, 512], dtype=torch.float32),
    "modulation_cond": torch.rand([1, 1, 768], dtype=torch.float32),
}

IMAGE_TOKENIZER_OV_PATH = Path("models/image_tokenizer_ir.xml")
convert(model.image_tokenizer, IMAGE_TOKENIZER_OV_PATH, example_input)

In [None]:
TOKENIZER_OV_PATH = Path("models/tokenizer_ir.xml")
convert(model.tokenizer, TOKENIZER_OV_PATH, torch.tensor(1))

In [None]:
example_input = {
    "hidden_states": torch.rand([1, 1024, 27648], dtype=torch.float32),
    "encoder_hidden_states": torch.rand([1, 1297, 1024], dtype=torch.float32),
}

BACKBONE_OV_PATH = Path("models/backbone_ir.xml")
convert(model.backbone, BACKBONE_OV_PATH, example_input)

In [None]:
POST_PROCESSOR_OV_PATH = Path("models/post_processor_ir.xml")
convert(
    model.post_processor,
    POST_PROCESSOR_OV_PATH,
    torch.rand([1, 3, 1024, 32, 32], dtype=torch.float32),
)

In [None]:
CAMERA_EMBEDDER_OV_PATH = Path("models/camera_embedder_ir.xml")


class CameraEmbedderWrapper(torch.nn.Module):
    def __init__(self, camera_embedder):
        super().__init__()
        self.camera_embedder = camera_embedder

    def forward(
        self,
        rgb_cond=None,
        mask_cond=None,
        c2w_cond=None,
        intrinsic_cond=None,
        intrinsic_normed_cond=None,
    ):
        kwargs = {
            "rgb_cond": rgb_cond,
            "mask_cond": mask_cond,
            "c2w_cond": c2w_cond,
            "intrinsic_cond": intrinsic_cond,
            "intrinsic_normed_cond": intrinsic_normed_cond,
        }
        embedding = self.camera_embedder(**kwargs)

        return embedding


example_input = {
    "rgb_cond": torch.rand([1, 1, 512, 512, 3], dtype=torch.float32),
    "mask_cond": torch.rand([1, 1, 512, 512, 1], dtype=torch.float32),
    "c2w_cond": torch.rand([1, 1, 1, 4, 4], dtype=torch.float32),
    "intrinsic_cond": torch.rand([1, 1, 1, 3, 3], dtype=torch.float32),
    "intrinsic_normed_cond": torch.rand([1, 1, 1, 3, 3], dtype=torch.float32),
}
convert(
    CameraEmbedderWrapper(model.camera_embedder),
    CAMERA_EMBEDDER_OV_PATH,
    example_input,
)

In [None]:
class ImageEstimatorWrapper(torch.nn.Module):
    def __init__(self, image_estimator):
        super().__init__()
        self.image_estimator = image_estimator

    def forward(self, cond_image):
        outputs = self.image_estimator(cond_image)
        filtered_ouptuts = {}
        for k, v in outputs.items():
            if k.startswith("decoder_"):
                filtered_ouptuts[k] = v
        return filtered_ouptuts


IMAGE_ESTIMATOR_OV_PATH = Path("models/image_estimator_ir.xml")
example_input = {
    "cond_image": torch.rand([1, 1, 512, 512, 3], dtype=torch.float32),
}
convert(
    ImageEstimatorWrapper(model.image_estimator),
    IMAGE_ESTIMATOR_OV_PATH,
    torch.rand([1, 1, 512, 512, 3], dtype=torch.float32),
)

The decoder accepts lists of include or exclude heads in forward method and uses them to choose a part of heads. We can't accept a list of strings in ir-model, but we can build 2 decoders with required structures.

In [None]:
include_cfg_decoder = [h for h in model.decoder.cfg.heads if h.name in ["vertex_offset", "density"]]
exclude_cfg_decoder = [h for h in model.decoder.cfg.heads if h.name not in ["density", "vertex_offset"]]


INCLUDE_DECODER_OV_PATH = Path("models/include_decoder_ir.xml")
EXCLUDE_DECODER_OV_PATH = Path("models/exclude_decoder_ir.xml")


model.decoder.cfg_heads = include_cfg_decoder
convert(
    model.decoder,
    INCLUDE_DECODER_OV_PATH,
    torch.rand([1, 535882, 120], dtype=torch.float32),
)


model.decoder.cfg_heads = exclude_cfg_decoder
convert(
    model.decoder,
    EXCLUDE_DECODER_OV_PATH,
    torch.rand([263302, 120], dtype=torch.float32),
)

## Compiling models and prepare pipeline
[back to top ⬆️](#Table-of-contents:)

Select device from dropdown list for running inference using OpenVINO.

In [None]:
from notebook_utils import device_widget

device = device_widget()

device

In [None]:
core = ov.Core()

compiled_image_tokenizer = core.compile_model(IMAGE_TOKENIZER_OV_PATH, device.value)
compiled_tokenizer = core.compile_model(TOKENIZER_OV_PATH, device.value)
compiled_backbone = core.compile_model(BACKBONE_OV_PATH, device.value)
compiled_post_processor = core.compile_model(POST_PROCESSOR_OV_PATH, device.value)
compiled_camera_embedder = core.compile_model(CAMERA_EMBEDDER_OV_PATH, device.value)
compiled_image_estimator = core.compile_model(IMAGE_ESTIMATOR_OV_PATH, device.value)
compiled_include_decoder = core.compile_model(INCLUDE_DECODER_OV_PATH, device.value)
compiled_exclude_decoder = core.compile_model(EXCLUDE_DECODER_OV_PATH, device.value)

Let's create callable wrapper classes for compiled models to allow interaction with original `SF3D` class. Note that all of wrapper classes return `torch.Tensor`s instead of `np.array`s.

In [None]:
from collections import namedtuple
from typing import List, Optional


class ImageTokenizerWrapper(torch.nn.Module):
    def __init__(self, image_tokenizer):
        super().__init__()
        self.image_tokenizer = image_tokenizer

    def forward(self, images, modulation_cond):
        inputs = {
            "images": images,
            "modulation_cond": modulation_cond,
        }
        outs = self.image_tokenizer(inputs)[0]

        return torch.from_numpy(outs)


class TokenizerWrapper(torch.nn.Module):
    def __init__(self, tokenizer, model):
        super().__init__()
        self.tokenizer = tokenizer
        self.detokenize = model.detokenize

    def forward(self, batch_size):
        outs = self.tokenizer(batch_size)[0]

        return torch.from_numpy(outs)


class BackboneWrapper(torch.nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

    def forward(self, hidden_states, encoder_hidden_states, **kwargs):
        inputs = {
            "hidden_states": hidden_states,
            "encoder_hidden_states": encoder_hidden_states.detach().numpy(),
        }

        outs = self.backbone(inputs)[0]

        return torch.from_numpy(outs)


class PostProcessorWrapper(torch.nn.Module):
    def __init__(self, post_processor):
        super().__init__()
        self.post_processor = post_processor

    def forward(self, triplanes):
        outs = self.post_processor(triplanes)[0]

        return torch.from_numpy(outs)


class CameraEmbedderWrapper(torch.nn.Module):
    def __init__(self, camera_embedder):
        super().__init__()
        self.camera_embedder = camera_embedder

    def forward(self, **kwargs):
        outs = self.camera_embedder(kwargs)[0]

        return torch.from_numpy(outs)


class ImageEstimatorWrapper(torch.nn.Module):
    def __init__(self, image_estimator):
        super().__init__()
        self.image_estimator = image_estimator

    def forward(self, cond_image):
        outs = self.image_estimator(cond_image)

        results = {}
        for k, v in outs.to_dict().items():
            results[k.names.pop()] = torch.from_numpy(v)
        return results


class DecoderWrapper(torch.nn.Module):
    def __init__(self, include_decoder, exclude_decoder):
        super().__init__()
        self.include_decoder = include_decoder
        self.exclude_decoder = exclude_decoder

    def forward(self, x, include: Optional[List] = None, exclude: Optional[List] = None):
        if include is not None:
            outs = self.include_decoder(x)
        else:
            outs = self.exclude_decoder(x)
        results = {}
        for k, v in outs.to_dict().items():
            results[k.names.pop()] = torch.from_numpy(v)
        return results

Replace all models in the original model by wrappers instances:

In [None]:
model.image_tokenizer = ImageTokenizerWrapper(compiled_image_tokenizer)
model.tokenizer = TokenizerWrapper(compiled_tokenizer, model.tokenizer)
model.backbone = BackboneWrapper(compiled_backbone)
model.post_processor = PostProcessorWrapper(compiled_post_processor)
model.camera_embedder = CameraEmbedderWrapper(compiled_camera_embedder)
model.image_estimator = ImageEstimatorWrapper(compiled_image_estimator)
model.decoder = DecoderWrapper(compiled_include_decoder, compiled_exclude_decoder)

## Interactive inference
[back to top ⬆️](#Table-of-contents:)
It's taken from the original `gradio_app.py`, but the model is replaced with the one defined above. 

In [None]:
import os
import tempfile
import time
from contextlib import nullcontext
from functools import lru_cache
from typing import Any

import gradio as gr
import numpy as np
import rembg
import torch
from gradio_litmodel3d import LitModel3D
from PIL import Image

import sf3d.utils as sf3d_utils
from sf3d.system import SF3D

os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")

rembg_session = rembg.new_session()

COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 1.6
COND_FOVY_DEG = 40
BACKGROUND_COLOR = [0.5, 0.5, 0.5]

# Cached. Doesn't change
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH)

generated_files = []

example_files = [os.path.join("stable-fast-3d/demo_files/examples", f) for f in os.listdir("stable-fast-3d/demo_files/examples")]


def run_model(input_image, remesh_option, vertex_count, texture_size):
    start = time.time()
    with torch.no_grad():
        with nullcontext():
            model_batch = create_batch(input_image)
            model_batch = {k: v.to("cpu") for k, v in model_batch.items()}
            print(f"{model_batch.keys()=}")
            print(f"{texture_size=}")
            print(f"{remesh_option=}")
            print(f"{vertex_count=}")
            trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, texture_size, remesh_option, vertex_count)
            trimesh_mesh = trimesh_mesh[0]

    # Create new tmp file
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")

    trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
    generated_files.append(tmp_file.name)

    print("Generation took:", time.time() - start, "s")

    return tmp_file.name


def create_batch(input_image: Image) -> dict[str, Any]:
    img_cond = torch.from_numpy(np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0).float().clip(0, 1)
    mask_cond = img_cond[:, :, -1:]
    rgb_cond = torch.lerp(torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond)

    batch_elem = {
        "rgb_cond": rgb_cond,
        "mask_cond": mask_cond,
        "c2w_cond": c2w_cond.unsqueeze(0),
        "intrinsic_cond": intrinsic.unsqueeze(0),
        "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
    }
    # Add batch dim
    batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
    return batched


@lru_cache
def checkerboard(squares: int, size: int, min_value: float = 0.5):
    base = np.zeros((squares, squares)) + min_value
    base[1::2, ::2] = 1
    base[::2, 1::2] = 1

    repeat_mult = size // squares
    return base.repeat(repeat_mult, axis=0).repeat(repeat_mult, axis=1)[:, :, None].repeat(3, axis=-1)


def remove_background(input_image: Image) -> Image:
    return rembg.remove(input_image, session=rembg_session)


def resize_foreground(
    image: Image,
    ratio: float,
) -> Image:
    image = np.array(image)
    assert image.shape[-1] == 4
    alpha = np.where(image[..., 3] > 0)
    y1, y2, x1, x2 = (
        alpha[0].min(),
        alpha[0].max(),
        alpha[1].min(),
        alpha[1].max(),
    )
    # crop the foreground
    fg = image[y1:y2, x1:x2]
    # pad to square
    size = max(fg.shape[0], fg.shape[1])
    ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
    ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
    new_image = np.pad(
        fg,
        ((ph0, ph1), (pw0, pw1), (0, 0)),
        mode="constant",
        constant_values=((0, 0), (0, 0), (0, 0)),
    )

    # compute padding according to the ratio
    new_size = int(new_image.shape[0] / ratio)
    # pad to size, double side
    ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
    ph1, pw1 = new_size - size - ph0, new_size - size - pw0
    new_image = np.pad(
        new_image,
        ((ph0, ph1), (pw0, pw1), (0, 0)),
        mode="constant",
        constant_values=((0, 0), (0, 0), (0, 0)),
    )
    new_image = Image.fromarray(new_image, mode="RGBA").resize((COND_WIDTH, COND_HEIGHT))
    return new_image


def square_crop(input_image: Image) -> Image:
    # Perform a center square crop
    min_size = min(input_image.size)
    left = (input_image.size[0] - min_size) // 2
    top = (input_image.size[1] - min_size) // 2
    right = (input_image.size[0] + min_size) // 2
    bottom = (input_image.size[1] + min_size) // 2
    return input_image.crop((left, top, right, bottom)).resize((COND_WIDTH, COND_HEIGHT))


def show_mask_img(input_image: Image) -> Image:
    img_numpy = np.array(input_image)
    alpha = img_numpy[:, :, 3] / 255.0
    chkb = checkerboard(32, 512) * 255
    new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
    return Image.fromarray(new_img.astype(np.uint8), mode="RGB")


def run_button(
    run_btn,
    input_image,
    background_state,
    foreground_ratio,
    remesh_option,
    vertex_count,
    texture_size,
):
    if run_btn == "Run":
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
        glb_file: str = run_model(background_state, remesh_option.lower(), vertex_count, texture_size)
        if torch.cuda.is_available():
            print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
        elif torch.backends.mps.is_available():
            print("Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB")

        return (
            gr.update(),
            gr.update(),
            gr.update(),
            gr.update(),
            gr.update(value=glb_file, visible=True),
            gr.update(visible=True),
        )
    elif run_btn == "Remove Background":
        rem_removed = remove_background(input_image)

        sqr_crop = square_crop(rem_removed)
        fr_res = resize_foreground(sqr_crop, foreground_ratio)

        return (
            gr.update(value="Run", visible=True),
            sqr_crop,
            fr_res,
            gr.update(value=show_mask_img(fr_res), visible=True),
            gr.update(value=None, visible=False),
            gr.update(visible=False),
        )


def requires_bg_remove(image, fr):
    if image is None:
        return (
            gr.update(visible=False, value="Run"),
            None,
            None,
            gr.update(value=None, visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    alpha_channel = np.array(image.getchannel("A"))
    min_alpha = alpha_channel.min()

    if min_alpha == 0:
        print("Already has alpha")
        sqr_crop = square_crop(image)
        fr_res = resize_foreground(sqr_crop, fr)
        return (
            gr.update(value="Run", visible=True),
            sqr_crop,
            fr_res,
            gr.update(value=show_mask_img(fr_res), visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    return (
        gr.update(value="Remove Background", visible=True),
        None,
        None,
        gr.update(value=None, visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
    )


def update_foreground_ratio(img_proc, fr):
    foreground_res = resize_foreground(img_proc, fr)
    return (
        foreground_res,
        gr.update(value=show_mask_img(foreground_res)),
    )


with gr.Blocks() as demo:
    img_proc_state = gr.State()
    background_remove_state = gr.State()
    gr.Markdown(
        """
    # SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement

    **SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
    This demo allows you to upload an image and generate a 3D mesh model from it.

    **Tips**
    1. If the image already has an alpha channel, you can skip the background removal step.
    2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
    3. You can select the remeshing option to control the mesh topology. This can introduce artifacts in the mesh on thin surfaces and should be turned off in such cases.
    4. You can upload your own HDR environment map to light the 3D model.
    """
    )
    with gr.Row(variant="panel"):
        with gr.Column():
            with gr.Row():
                input_img = gr.Image(type="pil", label="Input Image", sources="upload", image_mode="RGBA")
                preview_removal = gr.Image(
                    label="Preview Background Removal",
                    type="pil",
                    image_mode="RGB",
                    interactive=False,
                    visible=False,
                )

            foreground_ratio = gr.Slider(
                label="Foreground Ratio",
                minimum=0.5,
                maximum=1.0,
                value=0.85,
                step=0.05,
            )

            foreground_ratio.change(
                update_foreground_ratio,
                inputs=[img_proc_state, foreground_ratio],
                outputs=[background_remove_state, preview_removal],
            )

            remesh_option = gr.Radio(
                choices=["None", "Triangle", "Quad"],
                label="Remeshing",
                value="None",
                visible=True,
            )

            vertex_count_slider = gr.Slider(
                label="Target Vertex Count",
                minimum=1000,
                maximum=20000,
                value=10000,
                step=1000,
                visible=True,
            )

            texture_size = gr.Slider(
                label="Texture Size",
                minimum=512,
                maximum=2048,
                value=1024,
                step=256,
                visible=True,
            )

            run_btn = gr.Button("Run", variant="primary", visible=False)

        with gr.Column():
            output_3d = LitModel3D(
                label="3D Model",
                visible=False,
                clear_color=[0.0, 0.0, 0.0, 0.0],
                tonemapping="aces",
                contrast=1.0,
                scale=1.0,
            )
            with gr.Column(visible=False, scale=1.0) as hdr_row:
                gr.Markdown(
                    """## HDR Environment Map

                Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
                """
                )

                with gr.Row():
                    hdr_illumination_file = gr.File(label="HDR Env Map", file_types=[".hdr"], file_count="single")
                    example_hdris = [os.path.join("stable-fast-3d/demo_files/hdri", f) for f in os.listdir("stable-fast-3d/demo_files/hdri")]
                    hdr_illumination_example = gr.Examples(
                        examples=example_hdris,
                        inputs=hdr_illumination_file,
                    )

                    hdr_illumination_file.change(
                        lambda x: gr.update(env_map=x.name if x is not None else None),
                        inputs=hdr_illumination_file,
                        outputs=[output_3d],
                    )

    examples = gr.Examples(
        examples=example_files,
        inputs=input_img,
    )

    input_img.change(
        requires_bg_remove,
        inputs=[input_img, foreground_ratio],
        outputs=[
            run_btn,
            img_proc_state,
            background_remove_state,
            preview_removal,
            output_3d,
            hdr_row,
        ],
    )

    run_btn.click(
        run_button,
        inputs=[
            run_btn,
            input_img,
            background_remove_state,
            foreground_ratio,
            remesh_option,
            vertex_count_slider,
            texture_size,
        ],
        outputs=[
            run_btn,
            img_proc_state,
            background_remove_state,
            preview_removal,
            output_3d,
            hdr_row,
        ],
    )


try:
    demo.queue().launch(debug=True)
except Exception:
    demo.queue().launch(debug=True, share=True)