In [None]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("cpu")
    print("Using MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Using GPU")
else:
    device = torch.device("cpu")
    print("No GPU/MPS available, falling back to CPU.")

In [None]:
""" from mesh_video_generator import MeshVideoGenerator

hw = 512
num_views = 10

generator = MeshVideoGenerator(
    output_dir="outputs",
    hw=hw,
    num_views=num_views,
    use_normal_map=True,
    device=device
) 
generator.process_folder("camel_tex", display_frames=False)
#generator.process_single_mesh("meshes", index=5) """

In [None]:
#import torch
from diff3f import get_features_per_vertex
from diffusion import init_pipe

from utils import convert_mesh_container_to_torch_mesh, cosine_similarity, double_plot, get_colors
from dataloaders.mesh_container import MeshContainer
from dino import init_dino
from sam2_setup import init_sam2
#from functional_map import compute_surface_map

In [11]:
num_views = 50
H = 512
W = 512
num_images_per_prompt = 1
tolerance = 0.004
random_seed = 42
use_normal_map = True
is_tosca = False

bq = True
use_sam = False
use_only_diffusion = False
use_diffusion = True

In [5]:
def compute_features(device, sam_model, dino_model, pipe, m, prompt, num_views, H, W, tolerance):
    # Check if input is already a PyTorch3D Meshes object
    if not hasattr(m, 'vert'):  # PyTorch3D Meshes object
        mesh = m
    else:  # MeshContainer object
        mesh = convert_mesh_container_to_torch_mesh(m, device=device, is_tosca=False)
    
    mesh_vertices = mesh.verts_list()[0]


    features = get_features_per_vertex(
        device=device,
        sam_model=sam_model,
        pipe=pipe,
        dino_model=dino_model,
        mesh=mesh,
        prompt=prompt,
        num_views=num_views,
        H=H,
        W=W,
        tolerance=tolerance,
        use_normal_map= use_normal_map,
        num_images_per_prompt=num_images_per_prompt,
        mesh_vertices=mesh_vertices,
        bq=bq,
        use_sam = use_sam,
        use_only_diffusion = use_only_diffusion,
        use_diffusion = use_diffusion,
    )
    return features.cpu()

In [None]:
sam_model = init_sam2(device)
pipe = init_pipe(device)
dino_model = init_dino(device)

In [7]:
from pytorch3d.io import load_objs_as_meshes

source_file_path = "cow_tex/cow_tex.obj"
target_file_path = "camel_tex/camel_tex.obj"

source_file_path = "meshes/cow.obj"
target_file_path = "meshes/camel.obj"

source_file_path = "SHREC20b_lores/models/cow.obj"
target_file_path = "SHREC20b_lores/models/camel_a.obj"

# source_mesh = load_objs_as_meshes([source_file_path], device=device)
# target_mesh = load_objs_as_meshes([target_file_path], device=device)

source_mesh = MeshContainer().load_from_file(source_file_path)
target_mesh = MeshContainer().load_from_file(target_file_path)

In [None]:
f_source = compute_features(device, sam_model, dino_model, pipe, source_mesh, "cow", num_views, H, W, tolerance)
f_target = compute_features(device, sam_model, dino_model, pipe, target_mesh, "camel", num_views, H, W, tolerance)

In [None]:
import numpy as np
import meshplot as mp

def get_colorss(vertices):
    """Get colors for vertices using their normalized positions as RGB values"""
    # If vertices is a Meshes object, get the vertices tensor and convert to numpy
    if hasattr(vertices, 'verts_list'):
        vertices = vertices.verts_list()[0].cpu().numpy()
    elif torch.is_tensor(vertices):
        vertices = vertices.cpu().numpy()
    
    min_coord, max_coord = np.min(vertices, axis=0, keepdims=True), np.max(vertices, axis=0, keepdims=True)
    cmap = (vertices - min_coord)/(max_coord - min_coord)
    return cmap

def double_plot(myMesh1, myMesh2, cmap1=None, cmap2=None):
    # Get vertices and faces from PyTorch3D Meshes if needed
    if hasattr(myMesh1, 'verts_list'):
        verts1 = myMesh1.verts_list()[0].cpu().numpy()
        faces1 = myMesh1.faces_list()[0].cpu().numpy()
    else:
        verts1 = myMesh1.vert
        faces1 = myMesh1.face
        
    if hasattr(myMesh2, 'verts_list'):
        verts2 = myMesh2.verts_list()[0].cpu().numpy()
        faces2 = myMesh2.faces_list()[0].cpu().numpy()
    else:
        verts2 = myMesh2.vert
        faces2 = myMesh2.face
    
    d = mp.subplot(verts1, faces1, c=cmap1, s=[2, 2, 0])
    mp.subplot(verts2, faces2, c=cmap2, s=[2, 2, 1], data=d)


s = cosine_similarity(f_source.to(device),f_target.to(device))
s = torch.argmax(s, dim=0).cpu().numpy()
np.savetxt('s.csv', s, delimiter=',', fmt='%d')

cmap_source = get_colors(source_mesh.vert); cmap_target = cmap_source[s]
double_plot(source_mesh,target_mesh,cmap_source,cmap_target)