In [1]:
import numpy as np
import ipywidgets as iw
from brepmatching.data import BRepMatchingDataset, load_data
import torch
import meshplot
from IPython.display import clear_output, display, HTML
from functools import partial, reduce
import operator
import pandas as ps

In [2]:
data_path = '/fast/jamesn8/brepmatching/ExpertDataWithBaseline.zip'
data_cache = '/fast/jamesn8/brepmatching/ExpertDataWithBaseline.pt'

cache_data = load_data(data_path, data_cache)
data = BRepMatchingDataset(cache_data, mode='train', test_size=0, val_size=0) # TODO: make sure it's using the correct set/groups

In [7]:
button_next = iw.Button(description='Next')
button_skip = iw.Button(description='Jump to next part')
button_skip_back = iw.Button(description='Jump to previous part')
button_prev = iw.Button(description='Previous')
button_correct = iw.Button(description='Correct')
button_incorrect = iw.Button(description='Incorrect')

data_ind = 0
topo_ind = 0
topo_type = 'faces'
out = iw.Output(layout={'border': '1px solid black'})

def get_choice_indices(brep, topo_type):
    #TODO: Also mask out from the choices the topologies for which predictions agree (once you have those)
    num_left_topos = getattr(brep, 'left_' + topo_type).shape[0]
    num_right_topos = getattr(brep, 'right_' + topo_type).shape[0]
    exact_matches = getattr(brep, 'bl_exact_' + topo_type + '_matches')
    right_matchless_mask = torch.ones(num_right_topos, dtype=torch.int64)
    right_matchless_mask[exact_matches[1]] = 0
    indices = right_matchless_mask.nonzero()
    return indices

@out.capture(clear_output=True, wait=True)
def display_current(brep, topo_ind, topo_type):
    num_right_topos = getattr(brep, 'right_' + topo_type).shape[0]
    num_left_topos = getattr(brep, 'left_' + topo_type).shape[0]
    choice_indices = get_choice_indices(brep, topo_type)
    display(iw.Label(f'brep index: {data_ind+1} / {len(data)}; topo index: {topo_ind+1} / {len(choice_indices)} (id {choice_indices[topo_ind].item()})'))
    exact_matches = getattr(brep, 'bl_exact_' + topo_type + '_matches')
    baseline_matches = getattr(brep, 'os_bl_'+ topo_type +'_matches')
    
    right_exact_matched_mask = torch.zeros(num_right_topos, dtype=torch.int64)
    left_exact_matched_mask = torch.zeros(num_left_topos, dtype=torch.int64)
    right_exact_matched_mask[exact_matches[1]] = 1
    left_exact_matched_mask[exact_matches[0]] = 1
    
    if topo_type == 'faces':
        topo2match = torch.full([num_right_topos], -1)
        topo2match[baseline_matches[1]] = torch.arange(baseline_matches.shape[1])
        baseline_match_ind = topo2match[choice_indices[topo_ind]]
        has_baseline_match = baseline_match_ind >= 0
        has_predicted_match = False #TODO: Get predicted matches
        
        c_l = torch.full([brep.left_F.shape[1],3], 0.25)
        c_r = torch.full([brep.right_F.shape[1],3], 0.25)

        #color all exact matches
        c_l[left_exact_matched_mask[brep.left_F_to_faces[0]] > 0] = torch.tensor([0.5, 0.5, 1.0])
        c_r[right_exact_matched_mask[brep.right_F_to_faces[0]] > 0] = torch.tensor([0.5, 0.5, 1.0])
 
        c_r[brep.right_F_to_faces[0] == choice_indices[topo_ind]] = torch.tensor([0, 1.0, 0])
        
        if has_baseline_match:
            baseline_match = baseline_matches[0, baseline_match_ind]
            c_l[brep.left_F_to_faces[0] == baseline_match] = torch.tensor([1, 0.25, 0])
            
        c_l = c_l.numpy()
        c_r = c_r.numpy()

        shading = {"flat":True, # Flat or smooth shading of triangles
               #"wireframe":True, "wire_width": 0.03, "wire_color": "black", # Wireframe rendering
               "width": 400, "height": 400, # Size of the viewer canvas
               "antialias": True, # Antialising, might not work on all GPUs
               "scale": 2.0, # Scaling of the model
               "side": "DoubleSide", # FrontSide, BackSide or DoubleSide rendering of the triangles
               "colormap": "viridis", "normalize": [None, None], # Colormap and normalization for colors
               "background": "#ffffff", # Background color of the canvas
               "line_width": 1.0, "line_color": "black", # Line properties of overlay lines
               "bbox": False, # Enable plotting of bounding box
               "point_color": "red", "point_size": 0.01 # Point properties of overlay points
              }
        

        V_l = brep.left_V.numpy()
        V_r = brep.right_V.numpy()
    else:
        return
    
    maxdim = max(brep.left_V.max(0)[0] - brep.left_V.min(0)[0]).item()
    display(iw.Label(value='Modified part'))
    display(iw.Label(value='Blue: exact matches; Green: Selected topology'))
    meshplot.plot(V_r/maxdim, brep.right_F.T.numpy(), c=c_r, shading=shading)

    display(iw.Label(value='Original part'))
    display(iw.Label(value='Blue: exact matches'))
    display(iw.Label(value=f'Predicted match: {"EXISTS (green)" if has_predicted_match else "NONE"}'))
    display(iw.Label(value=f'Baseline match: {"EXISTS (orange)" if has_baseline_match else "NONE"}'))
    p = meshplot.plot(V_l/maxdim, brep.left_F.T.numpy(), c=c_l, shading=shading, return_plot=True)

def advance(b, advance=True, skip=False):
    global data_ind
    global topo_ind
    global topo_type
    brep = data[data_ind]
    indices = get_choice_indices(brep, topo_type)
    if skip:
        data_ind += (1 if advance else -1)
        brep = data[data_ind]
        topo_ind = 0
    else:
        topo_ind += (1 if advance else -1)
        if topo_ind < 0:
            data_ind -= 1
            brep = data[data_ind]
            indices = get_choice_indices(brep, topo_type)
            topo_ind = len(indices) - 1
        elif topo_ind >= len(indices):
            data_ind += 1
            brep = data[data_ind]
            topo_ind = 0
    display_current(brep, topo_ind, topo_type)

def record_result(b, result):
    #TODO: Write to data frame and save
    pass
    advance(b)

button_next.on_click(partial(advance, advance=True))
button_prev.on_click(partial(advance, advance=False))
button_skip.on_click(partial(advance, advance=True, skip=True))
button_skip_back.on_click(partial(advance, advance=False, skip=True))
button_correct.on_click(partial(record_result, True))
button_incorrect.on_click(partial(record_result, False))
display(iw.HBox([button_prev, button_next, button_skip_back, button_skip]))
display(iw.HBox([button_correct, button_incorrect]))
display_current(data[data_ind], topo_ind, topo_type)
display(out)

HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

HBox(children=(Button(description='Correct', style=ButtonStyle()), Button(description='Incorrect', style=Butto…

Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_right='1px solid b…

In [12]:
brep.

tensor([True])