Skip to content

Commit eb520ea

Browse files
committed
cog integration
1 parent cf284a3 commit eb520ea

File tree

5 files changed

+221
-0
lines changed

5 files changed

+221
-0
lines changed

.github/workflows/push.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: Push to Replicate
2+
3+
on:
4+
# Allows manual triggering from GitHub.com
5+
workflow_dispatch:
6+
inputs:
7+
model_name:
8+
description: "Model name to push (default: ndreca/hunyuan3d-2)"
9+
required: false
10+
default: "ndreca/hunyuan3d-2"
11+
# Uncomment the lines below to trigger on every push to main
12+
# push:
13+
# branches:
14+
# - main
15+
16+
jobs:
17+
push_to_replicate:
18+
name: Push to Replicate
19+
runs-on: ubuntu-latest
20+
21+
steps:
22+
- name: Free disk space
23+
uses: jlumbroso/free-disk-space@v1.3.1
24+
with:
25+
tool-cache: false
26+
docker-images: false
27+
28+
- name: Checkout
29+
uses: actions/checkout@v4
30+
31+
- name: Setup Cog
32+
uses: replicate/setup-cog@v2
33+
with:
34+
token: ${{ secrets.REPLICATE_API_TOKEN }}
35+
36+
- name: Push to Replicate
37+
run: |
38+
if [ -n "${{ github.event.inputs.model_name }}" ]; then
39+
cog push r8.im/${{ github.event.inputs.model_name }}
40+
else
41+
cog push
42+
fi

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,9 @@ gradio_cache/
166166
# and can be added to the global gitignore or merged into this file. For a more nuclear
167167
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
168168
#.idea/
169+
170+
__pycache__
171+
.cog
172+
checkpoints
173+
output
174+
.DS_Store

cog.yaml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
build:
2+
gpu: true
3+
cuda: "12.4"
4+
python_version: "3.10"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
- "libglu1-mesa"
9+
- "libglew2.2"
10+
python_packages:
11+
- gradio_litmodel3d
12+
- ninja
13+
- pybind11
14+
- trimesh
15+
- diffusers
16+
- tqdm
17+
- einops
18+
- opencv-python
19+
- numpy
20+
- "torch==2.4.0"
21+
- transformers
22+
- torchvision
23+
- torchaudio
24+
- ConfigArgParse
25+
- xatlas
26+
- scikit-learn
27+
- scikit-image
28+
- tritonclient
29+
- gevent
30+
- geventhttpclient
31+
- facexlib
32+
- accelerate
33+
- ipdb
34+
- omegaconf
35+
- pymeshlab
36+
- pytorch_lightning
37+
- taming-transformers-rom1504
38+
- kornia
39+
- rembg
40+
- onnxruntime
41+
- pygltflib
42+
- sentencepiece
43+
- gradio
44+
- uvicorn
45+
- "fastapi==0.112.2"
46+
- wheel
47+
48+
run:
49+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_Linux_x86_64" && chmod +x /usr/local/bin/pget
50+
- curl -o /tmp/custom_rasterizer-0.1-cp310-cp310-linux_x86_64.whl -L "https://huggingface.co/spaces/tencent/Hunyuan3D-2/resolve/main/custom_rasterizer-0.1-cp310-cp310-linux_x86_64.whl" && pip install /tmp/custom_rasterizer-0.1-cp310-cp310-linux_x86_64.whl
51+
52+
predict: "predict.py:Predictor"

demo.png

209 KB
Loading

predict.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from cog import BasePredictor, BaseModel, Input, Path
2+
import os
3+
import time
4+
import torch
5+
import subprocess
6+
from PIL import Image
7+
import shutil
8+
from hy3dgen.shapegen import (
9+
FaceReducer, FloaterRemover, DegenerateFaceRemover,
10+
Hunyuan3DDiTFlowMatchingPipeline
11+
)
12+
from hy3dgen.rembg import BackgroundRemover
13+
from hy3dgen.texgen import Hunyuan3DPaintPipeline
14+
15+
MODEL_REPO = "tencent/Hunyuan3D-2"
16+
MODEL_URL = "https://weights.replicate.delivery/default/tencent/Hunyuan3D-2/hunyuan3d-dit-v2-0/model.tar"
17+
DELIGHT_URL = "https://weights.replicate.delivery/default/tencent/Hunyuan3D-2/hunyuan3d-dit-v2-0/delight.tar"
18+
PAINT_URL = "https://weights.replicate.delivery/default/tencent/Hunyuan3D-2/hunyuan3d-dit-v2-0/paint.tar"
19+
20+
def download_weights(url, dest):
21+
start = time.time()
22+
print("downloading url: ", url)
23+
print("downloading to: ", dest)
24+
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
25+
print("downloading took: ", time.time() - start)
26+
27+
class Output(BaseModel):
28+
mesh: Path
29+
30+
class Predictor(BasePredictor):
31+
def setup(self) -> None:
32+
model_base_path = os.path.expanduser(os.path.join("~/.cache/hy3dgen", MODEL_REPO))
33+
os.makedirs(model_base_path, exist_ok=True)
34+
35+
model_path = os.path.join(model_base_path, "hunyuan3d-dit-v2-0")
36+
delight_path = os.path.join(model_base_path, "hunyuan3d-delight-v2-0")
37+
paint_path = os.path.join(model_base_path, "hunyuan3d-paint-v2-0")
38+
if not os.path.exists(model_path):
39+
download_weights(MODEL_URL, model_path)
40+
if not os.path.exists(delight_path):
41+
download_weights(DELIGHT_URL, delight_path)
42+
if not os.path.exists(paint_path):
43+
download_weights(PAINT_URL, paint_path)
44+
45+
self.i23d_worker = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(MODEL_REPO)
46+
self.texgen_worker = Hunyuan3DPaintPipeline.from_pretrained(MODEL_REPO)
47+
self.floater_remove_worker = FloaterRemover()
48+
self.degenerate_face_remove_worker = DegenerateFaceRemover()
49+
self.face_reduce_worker = FaceReducer()
50+
self.rmbg_worker = BackgroundRemover()
51+
52+
def predict(
53+
self,
54+
image: Path = Input(
55+
description="Input image for generating 3D shape",
56+
default=None
57+
),
58+
steps: int = Input(
59+
description="Number of inference steps",
60+
default=50,
61+
ge=20,
62+
le=50,
63+
),
64+
guidance_scale: float = Input(
65+
description="Guidance scale for generation",
66+
default=5.5,
67+
ge=1.0,
68+
le=20.0,
69+
),
70+
seed: int = Input(
71+
description="Random seed for generation",
72+
default=1234
73+
),
74+
octree_resolution: int = Input(
75+
description="Octree resolution for mesh generation",
76+
choices=[256, 384, 512],
77+
default=256
78+
),
79+
remove_background: bool = Input(
80+
description="Whether to remove background from input image",
81+
default=True
82+
),
83+
) -> Output:
84+
if os.path.exists("output"):
85+
shutil.rmtree("output")
86+
87+
os.makedirs("output", exist_ok=True)
88+
89+
max_facenum = 40000
90+
91+
generator = torch.Generator()
92+
generator = generator.manual_seed(seed)
93+
94+
if image is not None:
95+
input_image = Image.open(str(image))
96+
if remove_background or input_image.mode == "RGB":
97+
input_image = self.rmbg_worker(input_image.convert('RGB'))
98+
else:
99+
raise ValueError("Image must be provided")
100+
101+
input_image.save("output/input.png")
102+
103+
mesh = self.i23d_worker(
104+
image=input_image,
105+
num_inference_steps=steps,
106+
guidance_scale=guidance_scale,
107+
generator=generator,
108+
octree_resolution=octree_resolution
109+
)[0]
110+
111+
mesh = self.floater_remove_worker(mesh)
112+
mesh = self.degenerate_face_remove_worker(mesh)
113+
mesh = self.face_reduce_worker(mesh, max_facenum=max_facenum)
114+
mesh = self.texgen_worker(mesh, input_image)
115+
output_path = Path("output/mesh.glb")
116+
mesh.export(str(output_path), include_normals=True)
117+
118+
if not Path(output_path).exists():
119+
raise RuntimeError(f"Failed to generate mesh file at {output_path}")
120+
121+
return Output(mesh=output_path)

0 commit comments

Comments
 (0)