In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("No GPU/MPS available, falling back to CPU.")

from utils import cosine_similarity, double_plot, get_colors
from diffusion import init_pipe
from dino import init_dino
from sam2_setup import init_sam2
from utils import compute_features, load_mesh
import numpy as np

In [2]:
num_views = 50
H = 512
W = 512
tolerance = 0.004
use_normal_map = True
num_images_per_prompt = 1
bq = True
use_sam = False #if False -> use DINO
use_only_diffusion = False
use_diffusion = False
is_tosca = False
tex = False

In [None]:
# !mkdir -p data/checkpoints
# !wget -P data/checkpoints https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
sam_model = init_sam2(device)
pipe = init_pipe(device)
dino_model = init_dino(device)

In [None]:
# please download the data from here into the data folder: https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link
# pip install gdown && gdown --folder https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link

source_object_name = "hippo" # file name .obj
source_file_path = f"data/SHREC20b_lores/models/{source_object_name}.obj"
source_file_path_tex = f"data/SHREC20b_lores_tex/{source_object_name}_tex/{source_object_name}_tex.obj"
source_prompt = "hippo" # prompt for diffusion (e.g. camel instead of camel_a)

target_object_name = "cow"
target_file_path = f"data/SHREC20b_lores/models/{target_object_name}.obj"
target_file_path_tex = f"data/SHREC20b_lores_tex/{target_object_name}_tex/{target_object_name}_tex.obj"
target_prompt = "cow"


source_mesh = load_mesh(source_file_path, device)
target_mesh = load_mesh(target_file_path, device)
print()
source_tex_mesh = load_mesh(source_file_path_tex, device)
target_tex_mesh = load_mesh(target_file_path_tex, device)

In [None]:
# save_path = f"data/rendered_meshes/{source_object_name}_rendered.pt" used for DINO Tracker pipeline
save_path = None
f_source = compute_features(device, sam_model, dino_model, pipe, source_mesh, source_prompt, num_views, H, W, tolerance, 
    save_path, use_normal_map, tex, source_tex_mesh, num_images_per_prompt, bq, 
    use_sam, use_only_diffusion, use_diffusion, is_tosca)
print()

# save_path = f"data/rendered_meshes/{target_object_name}_rendered.pt"
save_path = None
f_target = compute_features(device, sam_model, dino_model, pipe, target_mesh, target_prompt, num_views, H, W, tolerance, 
    save_path, use_normal_map, tex, target_tex_mesh, num_images_per_prompt, bq, 
    use_sam, use_only_diffusion, use_diffusion, is_tosca)

In [None]:
# plot and save correspondence map

s = cosine_similarity(f_source.to(device),f_target.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()

cmap_target = get_colors(target_mesh.vert); cmap_source = cmap_target[s]

save_path = f"{target_object_name}_to_{source_object_name}_plot.html"
double_plot(target_mesh, source_mesh, cmap_target, cmap_source, save_path=save_path, show=True)

In [None]:
# evaluate correspondence quantitatively

from eval import evaluate_meshes

s = cosine_similarity(f_target.to(device),f_source.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
np.save('predicted_mapping.npy', s)

avg_error, accuracy, distances = evaluate_meshes(
    source_file_path = source_file_path,
    target_file_path = target_file_path,
    source_gt_path = f'data/SHREC20b_lores_gts/{source_object_name}.mat',
    target_gt_path = f'data/SHREC20b_lores_gts/{target_object_name}.mat',
    mapping_path = 'predicted_mapping.npy', 
    debug=False
)

avg_error = f"{avg_error:.2f}".replace(".", ",")  
accuracy = f"{accuracy * 100:.2f}".replace(".", ",")

print(f"Average correspondence error (err): {avg_error}")
print(f"Correspondence accuracy (acc, γ=1%): {accuracy}%")

In [None]:
# runs evaluation on all pairs in specified test set file

from eval_batch import run_batch_evaluation

# SHREC20b_lores/test-sets/
# test-set0.txt - partial-to-full scans
# test-set1.txt - full-to-full highest isometry
# test-set2.txt - full-to-full high isometry
# test-set3.txt - full-to-full low isometry
# test-set4.txt - full-to-full lowest isometry
# test-set5.txt - combines all test sets

results = run_batch_evaluation(
        pairs_file='data/SHREC20b_lores/test-sets/test-set5.txt',
        base_path="data/SHREC20b_lores",
        device=device,
        sam_model=sam_model,
        dino_model=dino_model,
        pipe=pipe,
        source_mesh=None,
        source_prompt=None,
        num_views=50,
        H=512,
        W=512,
        tolerance=0.004,
        save_path=None,
        use_normal_map=True,
        tex=False,
        source_tex_mesh=None,
        num_images_per_prompt=1,
        bq=True,
        use_sam=False,
        use_only_diffusion=False,
        use_diffusion=True,
        is_tosca=False
    )

# Evaluate the Dino-Tracker pipeline

from eval_batch_dinotracker import run_batch_evaluation

results = run_batch_evaluation(
        pairs_file='data/SHREC20b_lores/test-sets/test-set5.txt',
        base_path="data/SHREC20b_lores",
        device=device,
        num_views=50,
        H=512,
        W=512,
        tolerance=0.004,
        use_normal_map=True,
        bq=True,
        is_tex=True,  # set true to use texturized image output from dino-tracker
        dinotracker_path="/workspace/dino-tracker/dataset",  # set to path where output for dinotracker is stored
        use_only_dino=True  # set true to use only dino features for correspondence, False to use dinotracker
    )