In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("No GPU/MPS available, falling back to CPU.")

from diffusion import init_pipe
from utils import cosine_similarity, double_plot, get_colors
from dino import init_dino
from sam2_setup import init_sam2
from utils import compute_features, load_mesh
import numpy as np

In [2]:
num_views = 2
H = 512
W = 512
tolerance = 0.004
use_normal_map = True
num_images_per_prompt = 1
bq = True
use_sam = False
use_only_diffusion = False
use_diffusion = True
is_tosca = False

save_path=None # if not None, save batched_renderings, normal_batched_renderings, camera, depth to 'rendered_mesh_output.pt'

In [3]:
# !mkdir -p data/checkpoints
# !wget -P data/checkpoints https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
#sam_model = init_sam2(device)
sam_model = None
pipe = init_pipe(device)
dino_model = init_dino(device)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'pipeline_controlnet_img2img.StableDiffusionControlNetImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


In [4]:
# please download the data from here into the data folder: https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link
# pip install gdown && gdown --folder https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link

source_object_name = "hippo" # file name .obj
#source_file_path = f"data/SHREC20b_lores/models/{source_object_name}.obj"
source_file_path = f"data/SHREC20b_lores_tex/{source_object_name}_tex/{source_object_name}_tex.obj"
source_prompt = "hippo" # prompt for diffusion (e.g. camel instead of camel_a)

target_object_name = "rhino"
#target_file_path = f"data/SHREC20b_lores/models/{target_object_name}.obj"
target_file_path = f"data/SHREC20b_lores_tex/{target_object_name}_tex/{target_object_name}_tex.obj"
target_prompt = "rhino"


source_mesh = load_mesh(source_file_path, device)
target_mesh = load_mesh(target_file_path, device)


Loading mesh from data/SHREC20b_lores_tex/hippo_tex/hippo_tex.obj
Detected texture references. Using load_objs_as_meshes.
Loading mesh from data/SHREC20b_lores_tex/rhino_tex/rhino_tex.obj
Detected texture references. Using load_objs_as_meshes.


In [5]:
f_source = compute_features(device, sam_model, dino_model, pipe, source_mesh, source_prompt, num_views, H, W, tolerance, save_path)
f_target = compute_features(device, sam_model, dino_model, pipe, target_mesh, target_prompt, num_views, H, W, tolerance, save_path)

Starting batch_render with num_views=2, H=512, W=512
Rendering completed successfully
Starting batch_render with num_views=2, H=512, W=512
Rendering completed successfully
Video saved to output.mp4



  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 4/4 [00:20<00:00,  5.10s/it]


Number of missing features:  7672
Copied features from nearest vertices
Time taken in mins:  0.3797622203826904
Starting batch_render with num_views=2, H=512, W=512
Rendering completed successfully
Starting batch_render with num_views=2, H=512, W=512
Rendering completed successfully
Video saved to output.mp4



100%|██████████| 4/4 [00:18<00:00,  4.73s/it]

Number of missing features:  8564
Copied features from nearest vertices
Time taken in mins:  0.3517560283342997





In [7]:
s = cosine_similarity(f_source.to(device),f_target.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
print(f"f_source.shape: {f_source.shape}")
print(f"f_target.shape: {f_target.shape}")
print(f"s.shape: {s.shape}")

cmap_source = get_colors(source_mesh); cmap_target = cmap_source[s]

double_plot(source_mesh,target_mesh,cmap_source,cmap_target)

f_source.shape: torch.Size([11765, 2048])
f_target.shape: torch.Size([12395, 2048])
s.shape: (12395,)


HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

In [10]:
from eval import evaluate_meshes

s = cosine_similarity(f_target.to(device),f_source.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
np.save('predicted_mapping.npy', s)

print(f"f_source.shape: {f_source.shape}")
print(f"f_target.shape: {f_target.shape}")
print(f"s.shape: {s.shape}")

# Call the evaluation function
avg_error, accuracy, distances = evaluate_meshes(
    # source_file_path = f"data/SHREC20b_lores/models/{source_object_name}.obj",
    source_file_path = source_file_path,
    # target_file_path = f"data/SHREC20b_lores/models/{target_object_name}.obj",
    target_file_path = target_file_path,
    source_gt_path = f'data/SHREC20b_lores_gts/{source_object_name}.mat',
    target_gt_path = f'data/SHREC20b_lores_gts/{target_object_name}.mat',
    mapping_path = 'predicted_mapping.npy', 
    debug=False
)

print(f"Average correspondence error (err): {avg_error:.6f}")
print(f"Correspondence accuracy (acc, γ=1%): {accuracy:.6f}")

f_source.shape: torch.Size([11765, 2048])
f_target.shape: torch.Size([12395, 2048])
s.shape: (11765,)
Average correspondence error (err): 0.446379
Correspondence accuracy (acc, γ=1%): 0.040816


In [None]:
from eval_batch import run_batch_evaluation

# SHREC20b_lores/test-sets/
# test-set0.txt - partial-to-full scans
# test-set1.txt - full-to-full highest isometry
# test-set2.txt - full-to-full high isometry
# test-set3.txt - full-to-full low isometry
# test-set4.txt - full-to-full lowest isometry

results = run_batch_evaluation(
    pairs_file='data/SHREC20b_lores/test-sets/test-set1.txt',
    base_path="data/SHREC20b_lores",
    device=device,
    sam_model=sam_model,
    dino_model=dino_model,
    pipe=pipe,
    num_views=num_views,
    H=H,
    W=W,
    tolerance=tolerance
)