In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("..")

In [None]:
import matplotlib.pyplot as plt
import torch
from hy3dgen.rembg import BackgroundRemover
from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline
from hy3dgen.texgen import Hunyuan3DPaintPipeline
from hy3dgen.text2image import HunyuanDiTPipeline

from hy3dgen.shapegen.postprocessors import (
    FloaterRemover,
    DegenerateFaceRemover,
    FaceReducer,
    mesh_normalize,
)
import src.hooked_model.scheduler
from src.hooked_model.hooked_model import HookedDiffusionModel
from src.hooked_model.hooks import AblateHook
from src.hooked_model.utils import get_timesteps

In [None]:
text2image = HunyuanDiTPipeline(
    model_path="Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers-Distilled"
)

In [None]:
prompt = "A photo of an astronaut riding a horse on mars"

### How to register ablation hook and use it during the inference

In [None]:
image = text2image(prompt)
plt.imshow(image)


In [None]:
del text2image
torch.cuda.empty_cache()

rembg = BackgroundRemover()
image = rembg(image)
plt.imshow(image)

In [None]:
del rembg
torch.cuda.empty_cache()


pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
    "tencent/Hunyuan3D-2",
    subfolder="hunyuan3d-dit-v2-0-turbo",
    use_safetensors=True,
)
pipeline.enable_flashvdm()

In [None]:
mesh = pipeline(
    image,
    num_inference_steps=5,
    octree_resolution=380,
    num_chunks=200000,
    generator=torch.manual_seed(12345),
    output_type="trimesh",
)[0]
mesh


In [None]:
try:
    del pipeline
    torch.cuda.empty_cache()
except NameError:
    pass


pipeline_texgen = Hunyuan3DPaintPipeline.from_pretrained("tencent/Hunyuan3D-2")
mesh = pipeline_texgen(mesh, image=image)

In [None]:
mesh.export("output.glb")

In [None]:
postprocessing_steps = [
    DegenerateFaceRemover(),
    FloaterRemover(),
    FaceReducer(),
    mesh_normalize,
]


for step in postprocessing_steps:
    mesh = step(mesh)
mesh.export("output_processed.glb")

In [None]:
mesh = pipeline_texgen(mesh, image=image)
mesh.export("output_processed_textured.glb")

In [None]:
# scheduler = src.hooked_model.scheduler.DDIMScheduler.from_config(
#     pipe.scheduler.config
# )


In [None]:
# hooked_model = HookedDiffusionModel(
#     model=model,
#     scheduler=scheduler,
#     encode_prompt=pipe.encode_prompt,
#     get_timesteps=get_timesteps,
#     vae=pipe.vae,
# )


In [None]:
# import re

# hookpoints = []
# pattern = re.compile(r".*\.attentions\.(\d+)$")
# for n, m in pipe.unet.named_modules():
#     match = pattern.match(n)
#     if match:
#         hookpoints.append(n)
#         print(n)


In [None]:
# prompts = ["A photo of an astronaut in Van Gogh style" for _ in range(4)]

In [None]:
# all_images = []

# for i, hookpoint in enumerate(hookpoints):
#     image = hooked_model.run_with_hooks(
#         {hookpoint: AblateHook()},
#         prompt=prompts,
#         num_inference_steps=50,
#         guidance_scale=7.5,
#         generator=torch.Generator(device="cuda").manual_seed(1),
#     )

#     all_images.append(image)

In [None]:
# def display_images(all_images, hookpoints, images_per_row=4):
#     rows = len(all_images)
#     fig, axes = plt.subplots(
#         rows, images_per_row, figsize=(images_per_row * 3, rows * 3)
#     )
#     fig.subplots_adjust(
#         hspace=0.5, wspace=0.5
#     )  # Adjust space between rows and columns

#     for i, row_images in enumerate(
#         all_images[:rows]
#     ):  # Limit to the first `rows`
#         for j, image in enumerate(
#             row_images[:images_per_row]
#         ):  # Limit to `images_per_row`
#             ax = axes[i, j] if rows > 1 else axes[j]  # Handle single row case
#             ax.imshow(image)
#             ax.axis("off")  # Turn off axes for a cleaner look
#             if j == 0:
#                 ax.set_title(hookpoints[i])
#     plt.tight_layout()
#     plt.show()


# display_images(all_images, hookpoints)
