In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("No GPU/MPS available, falling back to CPU.")

from diffusion import init_pipe
from utils import cosine_similarity, double_plot, get_colors
from dino import init_dino
from sam2_setup import init_sam2
from utils import compute_features, load_mesh
import numpy as np

In [2]:
num_views = 50
H = 512
W = 512
tolerance = 0.004
use_normal_map = True
num_images_per_prompt = 1
bq = True
use_sam = False
use_only_diffusion = False
use_diffusion = True
is_tosca = False

In [None]:
# !mkdir -p data/checkpoints
# !wget -P data/checkpoints https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
#sam_model = init_sam2(device)
sam_model = None
pipe = init_pipe(device)
dino_model = init_dino(device)

In [None]:
# please download the data from here into the data folder: https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link
# pip install gdown && gdown --folder https://drive.google.com/drive/folders/1C6lFfCbwQqxlvUE8niVfbeIzeblCnXcx?usp=share_link

source_object_name = "cow" # file name .obj
source_file_path = f"data/SHREC20b_lores/models/{source_object_name}.obj"
#source_file_path = f"data/SHREC20b_lores_tex/{source_object_name}_tex/{source_object_name}_tex.obj"
source_prompt = "cow" # prompt for diffusion (e.g. camel instead of camel_a)

target_object_name = "camel_a"
target_file_path = f"data/SHREC20b_lores/models/{target_object_name}.obj"
#target_file_path = f"data/SHREC20b_lores_tex/{target_object_name}_tex/{target_object_name}_tex.obj"
target_prompt = "camel"


source_mesh = load_mesh(source_file_path, device)
target_mesh = load_mesh(target_file_path, device)


In [None]:
save_path = f"data/rendered_meshes/{source_object_name}_rendered.pt" # if not None, save batched_renderings, normal_batched_renderings, camera, depth to the specified path
f_source = compute_features(device, sam_model, dino_model, pipe, source_mesh, source_prompt, num_views, H, W, tolerance, save_path)

save_path = f"data/rendered_meshes/{target_object_name}_rendered.pt"
f_target = compute_features(device, sam_model, dino_model, pipe, target_mesh, target_prompt, num_views, H, W, tolerance, save_path)

In [6]:
import os

# Define the models folder path
models_folder_path = "data/SHREC20b_lores/models"

# Iterate over each object file in the models folder
for object_file in os.listdir(models_folder_path):
    if object_file.endswith(".obj"):
        object_name = object_file.split(".")[0]
        object_file_path = os.path.join(models_folder_path, object_file)
        object_prompt = object_name.split("_")[0] 
        print(f"Processing {object_name} with prompt {object_prompt}")


        # Load the mesh
        object_mesh = load_mesh(object_file_path, device)

        # Compute features and save the rendered mesh
        save_path = f"data/rendered_meshes/{object_name}_rendered.pt"
        compute_features(device, sam_model, dino_model, pipe, object_mesh, object_prompt, num_views, H, W, tolerance, save_path)



Processing hippo with prompt hippo
Processing giraffe_b with prompt giraffe
Processing cow with prompt cow
Processing giraffe_a with prompt giraffe
Processing bison with prompt bison
Processing rhino with prompt rhino
Processing elephant_a with prompt elephant
Processing dog with prompt dog
Processing elephant_b with prompt elephant
Processing leopard with prompt leopard
Processing pig with prompt pig
Processing camel_b with prompt camel
Processing bear with prompt bear
Processing camel_a with prompt camel


In [None]:
s = cosine_similarity(f_source.to(device),f_target.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
print(dir(source_mesh))
print(f"f_source.shape: {f_source.shape}")
print(f"f_target.shape: {f_target.shape}")
print(f"s.shape: {s.shape}")

cmap_source = get_colors(source_mesh.vert); cmap_target = cmap_source[s]

double_plot(source_mesh,target_mesh,cmap_source,cmap_target)

In [None]:
from eval import evaluate_meshes

s = cosine_similarity(f_target.to(device),f_source.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
np.save('predicted_mapping.npy', s)

print(f"f_source.shape: {f_source.shape}")
print(f"f_target.shape: {f_target.shape}")
print(f"s.shape: {s.shape}")

# Call the evaluation function
avg_error, accuracy, distances = evaluate_meshes(
    source_file_path = source_file_path,
    target_file_path = target_file_path,
    source_gt_path = f'data/SHREC20b_lores_gts/{source_object_name}.mat',
    target_gt_path = f'data/SHREC20b_lores_gts/{target_object_name}.mat',
    mapping_path = 'predicted_mapping.npy', 
    debug=False
)

print(f"Average correspondence error (err): {avg_error}")
print(f"Correspondence accuracy (acc, γ=1%): {accuracy}%")

In [None]:
from eval_batch import run_batch_evaluation

# SHREC20b_lores/test-sets/
# test-set0.txt - partial-to-full scans
# test-set1.txt - full-to-full highest isometry
# test-set2.txt - full-to-full high isometry
# test-set3.txt - full-to-full low isometry
# test-set4.txt - full-to-full lowest isometry

# only untextered right now
results = run_batch_evaluation(
    pairs_file='data/SHREC20b_lores/test-sets/test-set1.txt',
    base_path="data/SHREC20b_lores",
    device=device,
    sam_model=sam_model,
    dino_model=dino_model,
    pipe=pipe,
    num_views=num_views,
    H=H,
    W=W,
    tolerance=tolerance
)