In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("cpu")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("No GPU/MPS available, falling back to CPU.")

In [None]:
""" from mesh_video_generator import MeshVideoGenerator

hw = 32
num_views = 2
data_dirs = ["meshes", "SHREC19_MH_dataset", "SHREC20b_hires/models", "MPI-FAUST/training/scans"]

generator = MeshVideoGenerator(
    output_dir=data_dirs[0],
    hw=hw,
    num_views=num_views,
    use_normal_map=True,
    device=device
) """
#generator.process_folder("meshes", display_frames=False)
#generator.process_single_mesh("meshes", index=5)

In [None]:
#import torch
from diff3f import get_features_per_vertex
from diffusion import init_pipe

from utils import convert_mesh_container_to_torch_mesh, cosine_similarity, double_plot, get_colors
from dataloaders.mesh_container import MeshContainer
from dino import init_dino
from sam2_setup import init_sam2
#from functional_map import compute_surface_map

In [4]:
num_views = 2
H = 32
W = 32
num_images_per_prompt = 1
tolerance = 0.004
random_seed = 42
use_normal_map = True
is_tosca = False

bq = True
use_sam = True
use_only_diffusion = False
use_diffusion = False

In [5]:
def compute_features(device, sam_model, dino_model, pipe, m, prompt, num_views, H, W, tolerance):
    mesh = convert_mesh_container_to_torch_mesh(m, device=device, is_tosca=False)
    mesh_vertices = mesh.verts_list()[0]

    features = get_features_per_vertex(
        device=device,
        sam_model=sam_model,
        pipe=pipe,
        dino_model=dino_model,
        mesh=mesh,
        prompt=prompt,
        num_views=num_views,
        H=H,
        W=W,
        tolerance=tolerance,
        use_normal_map= use_normal_map,
        num_images_per_prompt=num_images_per_prompt,
        mesh_vertices=mesh_vertices,
        bq=bq,
        use_sam = use_sam,
        use_only_diffusion = use_only_diffusion,
        use_diffusion = use_diffusion,
    )
    return features.cpu()

In [6]:
sam_model = init_sam2(device)
pipe = init_pipe(device)
dino_model = init_dino(device)

In [7]:
source_file_path = "meshes/cow.obj"
target_file_path = "meshes/camel.obj"
source_mesh = MeshContainer().load_from_file(source_file_path)
target_mesh = MeshContainer().load_from_file(target_file_path)

In [None]:
f_source = compute_features(device, sam_model, dino_model, pipe, source_mesh, "cow", num_views, H, W, tolerance)
f_target = compute_features(device, sam_model, dino_model, pipe, target_mesh, "camel", num_views, H, W, tolerance)

In [None]:
s = cosine_similarity(f_source.to(device),f_target.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
cmap_source = get_colors(source_mesh.vert); cmap_target = cmap_source[s]
double_plot(source_mesh,target_mesh,cmap_source,cmap_target)