In [11]:
import numpy as np
import numpy.linalg as LA
import ipywidgets as iw
from brepmatching.data import BRepMatchingDataset, load_data
import torch
import meshplot
from IPython.display import clear_output, display, HTML
from functools import partial, reduce
import operator
import pandas as ps

In [160]:
#utils
def normalize(x):
    return x/LA.norm(x)

def joinmeshes(meshes):
    """
    Join together the meshes (represented as a list of (V,F) tuples)
    """
    if len(meshes) > 1:
        F = []
        offset = 0
        for i in range(len(meshes)):
            F.append(meshes[i][1] + offset)
            offset += meshes[i][0].shape[0]
        return np.vstack([v for v,f in meshes]), np.vstack(F)
    else:
        return meshes[0]

def getUp(disp):
    up = np.zeros(3)
    up[np.argmin(np.abs(disp))] = 1
    return up

def lookAt(target):
    """
    transforms points in object space to world space, where the object transform is defined
    by the center and target points
    The point at (0, 0, 0) gets mapped to "center", and
    the point at (0, 0, 1) gets mapped to "target" (by non-uniform scaling if necessary)
    """
    f = target
    up = getUp(f)
    scale = LA.norm(f)
    f = f/scale
    s = np.cross(f, up); s = s/LA.norm(s)
    u = np.cross(s, f)

    m = np.zeros((3, 3))
    m[:, 0] = s
    m[:, 1] = u
    m[:, 2] = f * scale

    return m
    
    
def cylinder(p1, p2, radius, N=5):
    d = p2 - p1
    up = getUp(d)
    u = normalize(np.cross(d, up))
    v = normalize(np.cross(d, u))
    t = np.linspace(0, 2*np.pi, N, endpoint=False)
    circ1 = p1 + radius * (np.outer(np.cos(t), u) + np.outer(np.sin(t), v))
    circ2 = circ1 + d
    
    F = np.empty([2*N, 3], dtype=np.int64)
    inds = np.arange(N)
    inds_shifted = np.roll(inds, 1)
    inds_opposite = inds + N
    inds_shifted_opposite = inds_shifted + N
    F[:N, 0] = inds
    F[:N, 2] = inds_shifted
    F[:N, 1] = inds_shifted_opposite
    F[N:, 0] = inds
    F[N:, 2] = inds_shifted_opposite
    F[N:, 1] = inds_opposite
    return np.vstack([circ1, circ2]), F

def convert_brep_to_numpy(brep):
    brep = brep.clone()
    dic = vars(brep)['_store']
    for key in dic:
        tensor = dic[key]
        if isinstance(tensor, torch.Tensor):
            setattr(brep, key, tensor.numpy())
    return brep

def add_edges(V, F, E_to_edges, rad, N, edges_mask=None):
    """
    Returns the list of new meshes and their association with topo edges
    V: Nx3
    F: Mx3
    """
    print('starting add edge')
    if edges_mask is not None:
        E_to_edges = E_to_edges[:,edges_mask[E_to_edges[2]]]
    E = np.hstack([F[E_to_edges[0], E_to_edges[1]][:,np.newaxis],
        F[E_to_edges[0], (E_to_edges[1]-1) % 3][:,np.newaxis]])
    faces_per_edge = N*2
    num_new_faces = faces_per_edge*E.shape[0]
    num_total_faces = F.shape[0] + num_new_faces
    #F_to_edges = np.full(num_new_faces, -1)
    print('populating meshes')
    cyl_V, cyl_F = cylinder(np.zeros(3), np.array([0, 0, 1]), rad, N)
    meshes = []
    for i,(edge, topo_edge) in enumerate(zip(E, E_to_edges[2])):
        #edge_V, edge_F = cylinder(V[edge[0]], V[edge[1]], rad, N)
        edge_V = (lookAt(V[edge[1]] - V[edge[0]]) @ cyl_V.T).T + V[edge[0]]
        meshes.append((edge_V, cyl_F))
    print('joining meshes')
    print('finished add edge')
    mesh_to_E = E_to_edges[2]
    print('type of mesh_to_E:',type(mesh_to_E))
    return meshes, mesh_to_E

def F_to_topo(meshes, mesh_to_topo):
    F2topo = np.empty(sum(F.shape[0] for _, F in meshes), dtype=np.int64)
    offset = 0
    for (V_e, F_e), topo in zip(meshes, mesh_to_topo):
        nextoffset = offset + F_e.shape[0]
        F2topo[offset:nextoffset] = topo
        offset = nextoffset
    return F2topo
        

In [None]:
data_path = '/fast/jamesn8/brepmatching/ExpertDataWithBaseline.zip'
data_cache = '/fast/jamesn8/brepmatching/ExpertDataWithBaseline.pt'

cache_data = load_data(data_path, data_cache)
data = BRepMatchingDataset(cache_data, mode='train', test_size=0, val_size=0) # TODO: make sure it's using the correct set/groups

In [169]:
button_next = iw.Button(description='Next')
button_skip = iw.Button(description='Jump to next part')
button_skip_back = iw.Button(description='Jump to previous part')
button_prev = iw.Button(description='Previous')
button_correct = iw.Button(description='Correct')
button_incorrect = iw.Button(description='Incorrect')

data_ind = 0
topo_ind = 0
topo_type = 'edges'
out = iw.Output(layout={'border': '1px solid black'})

cache_data = {'left_edge_meshes': {}, 'right_edge_meshes': {}}

def get_choice_indices(brep, topo_type):
    num_left_topos = getattr(brep, 'left_' + topo_type).shape[0]
    num_right_topos = getattr(brep, 'right_' + topo_type).shape[0]
    exact_matches = getattr(brep, 'bl_exact_' + topo_type + '_matches')
    right_matchless_mask = np.ones(num_right_topos, dtype=np.int64)
    right_matchless_mask[exact_matches[1]] = 0
    indices = right_matchless_mask.nonzero()[0]
    return indices

@out.capture(clear_output=True, wait=True)
def display_current(brep, topo_ind, topo_type, cache, edge_radius=0.001, edge_resolution=3):
    choice_indices = get_choice_indices(brep, topo_type)
    num_right_topos = getattr(brep, 'right_' + topo_type).shape[0]
    num_left_topos = getattr(brep, 'left_' + topo_type).shape[0]
        
    exact_matches = getattr(brep, 'bl_exact_' + topo_type + '_matches')
    overlap_matches = getattr(brep, 'bl_overlap_' + topo_type + '_matches')
    #exact_matches = np.hstack([exact_matches, overlap_matches])

    baseline_matches = getattr(brep, 'os_bl_'+ topo_type +'_matches') #TODO: This will just be the predicted match
    
    right_exact_matched_mask = np.zeros(num_right_topos, dtype=np.int64)
    left_exact_matched_mask = np.zeros(num_left_topos, dtype=np.int64)
    right_exact_matched_mask[exact_matches[1]] = 1
    left_exact_matched_mask[exact_matches[0]] = 1
    right_exact_matched_mask[overlap_matches[1]] = 2
    left_exact_matched_mask[overlap_matches[0]] = 2
    
    topo2match = np.full([num_right_topos], -1)
    topo2match[baseline_matches[1]] = np.arange(baseline_matches.shape[1])
    baseline_match_ind = topo2match[choice_indices[topo_ind]]
    has_baseline_match = baseline_match_ind >= 0
    
    exact_match_color = np.array([0.5, 0.5, 1.0]) #normal map blue
    overlap_match_color = np.array([0.5, 0.9, 0.9]) #teal
    selected_color = np.array([0, 1.0, 0])
    match_color = selected_color
    
    V_l = brep.left_V
    V_r = brep.right_V
    
    F_l = brep.left_F.T
    F_r = brep.right_F.T
    
    if topo_type == 'faces':
        c_l = np.full([F_l.shape[0],3], 0.25)
        c_r = np.full([F_r.shape[0],3], 0.25)

        #color all exact matches
        c_l[left_exact_matched_mask[brep.left_F_to_faces[0]] == 1] = exact_match_color
        c_r[right_exact_matched_mask[brep.right_F_to_faces[0]] == 1] = exact_match_color
        c_l[left_exact_matched_mask[brep.left_F_to_faces[0]] == 2] = overlap_match_color
        c_r[right_exact_matched_mask[brep.right_F_to_faces[0]] == 2] = overlap_match_color
 
        c_r[brep.right_F_to_faces[0] == choice_indices[topo_ind]] = selected_color
        
        if has_baseline_match:
            baseline_match = baseline_matches[0, baseline_match_ind]
            c_l[brep.left_F_to_faces[0] == baseline_match] = match_color
            
    elif topo_type == 'edges':
        #get all edges we want to visualize: exact/overlap, selected, matched
        #create new mesh and corresponding F_to_edges for each part
        if data_ind not in cache_data['right_edge_meshes']:
            cache_data['right_edge_meshes'][data_ind] = add_edges(V_r, F_r, brep.right_E_to_edges, edge_radius, edge_resolution)
        if data_ind not in cache_data['left_edge_meshes']:
            cache_data['left_edge_meshes'][data_ind] = add_edges(V_l, F_l, brep.left_E_to_edges, edge_radius, edge_resolution)
        
        #select the meshes from the cache that we want to display
        left_edge_mask = left_exact_matched_mask > 0
        right_edge_mask = right_exact_matched_mask > 0
        left_meshes, left_mesh_to_e = cache_data['left_edge_meshes'][data_ind]
        left_meshes, left_mesh_to_e = [mesh for j,mesh in enumerate(left_meshes) if left_edge_mask[left_mesh_to_e[j]]], left_mesh_to_e[left_edge_mask[left_mesh_to_e]]
        right_meshes, right_mesh_to_e = cache_data['right_edge_meshes'][data_ind]
        right_meshes, right_mesh_to_e = [mesh for j,mesh in enumerate(right_meshes) if right_edge_mask[right_mesh_to_e[j]]], right_mesh_to_e[right_edge_mask[right_mesh_to_e]]

        #convert to single meshes with face-to-topo maps
        left_F_to_e = F_to_topo(left_meshes, left_mesh_to_e)
        V_l, F_l = joinmeshes(left_meshes + [(V_l, F_l)])
        right_F_to_e = F_to_topo(right_meshes, right_mesh_to_e)
        V_r, F_r = joinmeshes(right_meshes + [(V_r, F_r)])
        left_num_new_faces = len(left_F_to_e)
        right_num_new_faces = len(right_F_to_e)
        
        #color as before
        c_l = np.full([F_l.shape[0],3], 0.25)
        c_r = np.full([F_r.shape[0],3], 0.25)
        
        #color all exact matches
        c_r[:right_num_new_faces][right_exact_matched_mask[right_F_to_e] == 1] = exact_match_color
        c_r[:right_num_new_faces][right_exact_matched_mask[right_F_to_e] == 2] = overlap_match_color
        c_l[:left_num_new_faces][left_exact_matched_mask[left_F_to_e] == 1] = exact_match_color
        c_l[:left_num_new_faces][left_exact_matched_mask[left_F_to_e] == 2] = overlap_match_color
        
        c_r[:right_num_new_faces][right_F_to_e == choice_indices[topo_ind]] = selected_color
        
        if has_baseline_match:
            baseline_match = baseline_matches[0, baseline_match_ind]
            c_l[:left_num_new_faces][left_F_to_e == baseline_match] = match_color

            
    shading = {"flat":True, # Flat or smooth shading of triangles
               #"wireframe":True, "wire_width": 0.03, "wire_color": "black", # Wireframe rendering
               "width": 400, "height": 400, # Size of the viewer canvas
               "antialias": True, # Antialising, might not work on all GPUs
               "scale": 2.0, # Scaling of the model
               "side": "DoubleSide", # FrontSide, BackSide or DoubleSide rendering of the triangles
               "colormap": "viridis", "normalize": [None, None], # Colormap and normalization for colors
               "background": "#ffffff", # Background color of the canvas
               "line_width": 1.0, "line_color": "black", # Line properties of overlay lines
               "bbox": False, # Enable plotting of bounding box
               "point_color": "red", "point_size": 0.01 # Point properties of overlay points
              }
    maxdim = max(V_l.max(0) - V_l.min(0))
    
    display(iw.Label(f'brep index: {data_ind+1} / {len(data)}; topo index: {topo_ind+1} / {len(choice_indices)} (id {choice_indices[topo_ind]})'))

    display(iw.Label(value='Modified part'))
    display(iw.Label(value='Blue: exact matches; Teal: Overlap matches; Green: Selected topology'))
    meshplot.plot(V_r/maxdim, F_r, c=c_r, shading=shading)

    display(iw.Label(value='Original part'))
    display(iw.Label(value='Blue: exact matches; Teal: Overlap matches'))
    display(iw.Label(value=f'Predicted match: {"EXISTS (green)" if has_baseline_match else "NONE"}'))
    meshplot.plot(V_l/maxdim, F_l, c=c_l, shading=shading)

    
def advance(b, advance=True, skip=False):
    global data_ind
    global topo_ind
    global topo_type
    global cache_data
    brep = convert_brep_to_numpy(data[data_ind])
    indices = get_choice_indices(brep, topo_type)
    if skip:
        data_ind += (1 if advance else -1)
        brep = convert_brep_to_numpy(data[data_ind])
        topo_ind = 0
    else:
        topo_ind += (1 if advance else -1)
        if topo_ind < 0:
            data_ind -= 1
            brep = convert_brep_to_numpy(data[data_ind])
            indices = get_choice_indices(brep, topo_type)
            topo_ind = len(indices) - 1
        elif topo_ind >= len(indices):
            data_ind += 1
            brep = convert_brep_to_numpy(data[data_ind])
            topo_ind = 0
    display_current(brep, topo_ind, topo_type, cache_data)

def record_result(b, result):
    #TODO: Write to data frame and save
    #TODO: Make sure datafame has separate entries for PART, TOPO, TOPO_TYPE and TEST IT
    pass
    advance(b)

button_next.on_click(partial(advance, advance=True))
button_prev.on_click(partial(advance, advance=False))
button_skip.on_click(partial(advance, advance=True, skip=True))
button_skip_back.on_click(partial(advance, advance=False, skip=True))
button_correct.on_click(partial(record_result, True))
button_incorrect.on_click(partial(record_result, False))
display(iw.HBox([button_prev, button_next, button_skip_back, button_skip]))
display(iw.HBox([button_correct, button_incorrect]))
display_current(convert_brep_to_numpy(data[data_ind]), topo_ind, topo_type, cache_data)
display(out)

HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

HBox(children=(Button(description='Correct', style=ButtonStyle()), Button(description='Incorrect', style=Butto…

Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_right='1px solid b…

In [None]:
#Too slow to render: part 3, 16 (sort of)