1+ from cog import BasePredictor , BaseModel , Input , Path
2+ import os
3+ import time
4+ import torch
5+ import subprocess
6+ from PIL import Image
7+ import shutil
8+ from hy3dgen .shapegen import (
9+ FaceReducer , FloaterRemover , DegenerateFaceRemover ,
10+ Hunyuan3DDiTFlowMatchingPipeline
11+ )
12+ from hy3dgen .rembg import BackgroundRemover
13+ from hy3dgen .texgen import Hunyuan3DPaintPipeline
14+
15+ MODEL_REPO = "tencent/Hunyuan3D-2"
16+ MODEL_URL = "https://weights.replicate.delivery/default/tencent/Hunyuan3D-2/hunyuan3d-dit-v2-0/model.tar"
17+ DELIGHT_URL = "https://weights.replicate.delivery/default/tencent/Hunyuan3D-2/hunyuan3d-dit-v2-0/delight.tar"
18+ PAINT_URL = "https://weights.replicate.delivery/default/tencent/Hunyuan3D-2/hunyuan3d-dit-v2-0/paint.tar"
19+
20+ def download_weights (url , dest ):
21+ start = time .time ()
22+ print ("downloading url: " , url )
23+ print ("downloading to: " , dest )
24+ subprocess .check_call (["pget" , "-xf" , url , dest ], close_fds = False )
25+ print ("downloading took: " , time .time () - start )
26+
27+ class Output (BaseModel ):
28+ mesh : Path
29+
30+ class Predictor (BasePredictor ):
31+ def setup (self ) -> None :
32+ model_base_path = os .path .expanduser (os .path .join ("~/.cache/hy3dgen" , MODEL_REPO ))
33+ os .makedirs (model_base_path , exist_ok = True )
34+
35+ model_path = os .path .join (model_base_path , "hunyuan3d-dit-v2-0" )
36+ delight_path = os .path .join (model_base_path , "hunyuan3d-delight-v2-0" )
37+ paint_path = os .path .join (model_base_path , "hunyuan3d-paint-v2-0" )
38+ if not os .path .exists (model_path ):
39+ download_weights (MODEL_URL , model_path )
40+ if not os .path .exists (delight_path ):
41+ download_weights (DELIGHT_URL , delight_path )
42+ if not os .path .exists (paint_path ):
43+ download_weights (PAINT_URL , paint_path )
44+
45+ self .i23d_worker = Hunyuan3DDiTFlowMatchingPipeline .from_pretrained (MODEL_REPO )
46+ self .texgen_worker = Hunyuan3DPaintPipeline .from_pretrained (MODEL_REPO )
47+ self .floater_remove_worker = FloaterRemover ()
48+ self .degenerate_face_remove_worker = DegenerateFaceRemover ()
49+ self .face_reduce_worker = FaceReducer ()
50+ self .rmbg_worker = BackgroundRemover ()
51+
52+ def predict (
53+ self ,
54+ image : Path = Input (
55+ description = "Input image for generating 3D shape" ,
56+ default = None
57+ ),
58+ steps : int = Input (
59+ description = "Number of inference steps" ,
60+ default = 50 ,
61+ ge = 20 ,
62+ le = 50 ,
63+ ),
64+ guidance_scale : float = Input (
65+ description = "Guidance scale for generation" ,
66+ default = 5.5 ,
67+ ge = 1.0 ,
68+ le = 20.0 ,
69+ ),
70+ seed : int = Input (
71+ description = "Random seed for generation" ,
72+ default = 1234
73+ ),
74+ octree_resolution : int = Input (
75+ description = "Octree resolution for mesh generation" ,
76+ choices = [256 , 384 , 512 ],
77+ default = 256
78+ ),
79+ remove_background : bool = Input (
80+ description = "Whether to remove background from input image" ,
81+ default = True
82+ ),
83+ ) -> Output :
84+ if os .path .exists ("output" ):
85+ shutil .rmtree ("output" )
86+
87+ os .makedirs ("output" , exist_ok = True )
88+
89+ max_facenum = 40000
90+
91+ generator = torch .Generator ()
92+ generator = generator .manual_seed (seed )
93+
94+ if image is not None :
95+ input_image = Image .open (str (image ))
96+ if remove_background or input_image .mode == "RGB" :
97+ input_image = self .rmbg_worker (input_image .convert ('RGB' ))
98+ else :
99+ raise ValueError ("Image must be provided" )
100+
101+ input_image .save ("output/input.png" )
102+
103+ mesh = self .i23d_worker (
104+ image = input_image ,
105+ num_inference_steps = steps ,
106+ guidance_scale = guidance_scale ,
107+ generator = generator ,
108+ octree_resolution = octree_resolution
109+ )[0 ]
110+
111+ mesh = self .floater_remove_worker (mesh )
112+ mesh = self .degenerate_face_remove_worker (mesh )
113+ mesh = self .face_reduce_worker (mesh , max_facenum = max_facenum )
114+ mesh = self .texgen_worker (mesh , input_image )
115+ output_path = Path ("output/mesh.glb" )
116+ mesh .export (str (output_path ), include_normals = True )
117+
118+ if not Path (output_path ).exists ():
119+ raise RuntimeError (f"Failed to generate mesh file at { output_path } " )
120+
121+ return Output (mesh = output_path )
0 commit comments