In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("No GPU/MPS available, falling back to CPU.")

from diffusion import init_pipe
from utils import cosine_similarity, double_plot, get_colors
from dataloaders.mesh_container import MeshContainer
from dino import init_dino
from sam2_setup import init_sam2
from pytorch3d.io import load_objs_as_meshes
from utils import compute_features
import numpy as np

In [2]:
num_views = 50
H = 512
W = 512
tolerance = 0.004
use_normal_map = True
num_images_per_prompt = 1
bq = True
use_sam = False
use_only_diffusion = False
use_diffusion = True
is_tosca = False

save_path='rendered_mesh_output.pt' # if not None, save batched_renderings, normal_batched_renderings, camera, depth to 'rendered_mesh_output.pt'

In [None]:
# !mkdir -p data/checkpoints
# !wget -P data/checkpoints https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
sam_model = init_sam2(device)
pipe = init_pipe(device)
dino_model = init_dino(device)

In [4]:
# please download the data from here into the data folder: https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link
# pip install gdown && gdown --folder https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link
source_file_path = "data/SHREC20b_lores/models/cow.obj"
source_prompt = "cow"

target_file_path = "data/SHREC20b_lores/models/camel_a.obj"
target_prompt = "camel"


# Check if MeshContainer is available
try:
    source_mesh = MeshContainer().load_from_file(source_file_path)
    target_mesh = MeshContainer().load_from_file(target_file_path)
except NameError:
    # Fallback to load_objs_as_meshes if MeshContainer is not available
    source_mesh = load_objs_as_meshes([source_file_path], device=device)
    target_mesh = load_objs_as_meshes([target_file_path], device=device)

In [None]:
sam_model = None
f_source = compute_features(device, sam_model, dino_model, pipe, source_mesh, source_prompt, num_views, H, W, tolerance, save_path)
f_target = compute_features(device, sam_model, dino_model, pipe, target_mesh, target_prompt, num_views, H, W, tolerance, save_path)

In [None]:
s = cosine_similarity(f_source.to(device),f_target.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
np.save('predicted_mapping.npy', s)

cmap_source = get_colors(source_mesh.vert); cmap_target = cmap_source[s]
double_plot(source_mesh,target_mesh,cmap_source,cmap_target)

In [None]:
from eval import evaluate_meshes

# Call the evaluation function
avg_error, accuracy, distances = evaluate_meshes(
    source_file_path = "data/SHREC20b_lores/models/cow.obj",
    target_file_path = "data/SHREC20b_lores/models/camel_a.obj",
    source_gt_path = 'data/SHREC20b_lores_gts/cow.mat',
    target_gt_path = 'data/SHREC20b_lores_gts/camel_a.mat',
    mapping_path = 'predicted_mapping.npy', 
    debug=False
)

print(f"Average correspondence error (err): {avg_error:.6f}")
print(f"Correspondence accuracy (acc, γ=1%): {accuracy:.6f}")

In [None]:
from eval_batch import run_batch_evaluation


results = run_batch_evaluation(
    pairs_file='data/SHREC20b_lores/test-sets/test-set0.txt',
    base_path="data/SHREC20b_lores",
    device=device,
    sam_model=sam_model,
    dino_model=dino_model,
    pipe=pipe,
    num_views=num_views,
    H=H,
    W=W,
    tolerance=tolerance
)