In [None]:
from brepmatching.data import BRepMatchingDataset, load_data
from brepmatching.visualization import show_image, render_predictions
import torch
import os
from torch_geometric.data.batch import Batch
import os
from tqdm import tqdm
from brepmatching.utils import *
import numpy as np

In [None]:
topo_path = '/projects/grail/milinknb/brep-data/TopoV4.pt'
geo_path = '/projects/grail/milinknb/brep-data/GeoV4.pt'
both_path = '/projects/grail/milinknb/brep-data/BothV4.pt'

In [None]:
geo_cache = torch.load(geo_path)
topo_cache = torch.load(topo_path)
both_cache = torch.load(both_path)

In [None]:
ds_geo = BRepMatchingDataset(geo_cache, 'test')
ds_topo = BRepMatchingDataset(topo_cache, 'test')
ds_both = BRepMatchingDataset(both_cache, 'test')

In [None]:
geo_test_graphs = [ds_geo[i] for i in range(len(ds_geo))]
topo_test_graphs = [ds_topo[i] for i in range(len(ds_topo))]
both_test_graphs = [ds_both[i] for i in range(len(ds_both))]

In [None]:
def check_extra_overlaps(test_graphs):
    extra_overlaps = []
    for i in tqdm(range(len(test_graphs))):
        ex = test_graphs[i]

        overlap_matches = {tuple(x) for x in ex.bl_overlap_faces_matches.T.tolist()}
        onshape_matches = {tuple(x) for x in ex.os_bl_faces_matches.T.tolist()}

        extra_overlaps.append(len(overlap_matches - onshape_matches))
    return np.array(extra_overlaps)

In [None]:
geo_extra_overlaps = check_extra_overlaps(geo_test_graphs)
topo_extra_overlaps = check_extra_overlaps(topo_test_graphs)
both_extra_overlaps = check_extra_overlaps(both_test_graphs)

In [None]:
fig, axes = plt.subplots(1,3)
for i, overlaps in enumerate([geo_extra_overlaps, topo_extra_overlaps, both_extra_overlaps]):
    axes[i].hist(overlaps,bins=20)

In [None]:
follow_batch=['left_vertices','right_vertices','left_edges', 'right_edges','left_faces','right_faces', 'faces_matches', 'edges_matches', 'vertices_matches']
geo_batch = Batch.from_data_list(geo_test_graphs,follow_batch=follow_batch)
topo_batch = Batch.from_data_list(topo_test_graphs,follow_batch=follow_batch)
both_batch = Batch.from_data_list(both_test_graphs,follow_batch=follow_batch)

In [None]:
def compute_metrics_from_batch(batch, small_overlap_thresh=.8, large_overlap_thresh = .0):
    face_exact_metrics = compute_metrics_from_matches(batch, 'faces', batch['bl_exact_faces_matches'])
    edge_exact_metrics = compute_metrics_from_matches(batch, 'edges', batch['bl_exact_edges_matches'])
    vert_exact_metrics = compute_metrics_from_matches(batch, 'vertices', batch['bl_exact_vertices_matches'])

    face_overlap = batch.bl_overlap_faces_matches[:,(batch.bl_overlap_smaller_face_percentages >= small_overlap_thresh) & (batch.bl_overlap_larger_face_percentages >= large_overlap_thresh)]
    edge_overlap = batch.bl_overlap_edges_matches[:,(batch.bl_overlap_smaller_edge_percentages >= small_overlap_thresh) & (batch.bl_overlap_larger_edge_percentages >= large_overlap_thresh)]


    face_coincidence = torch.cat([batch['bl_exact_faces_matches'], face_overlap], dim=1)
    edge_coincidence = torch.cat([batch['bl_exact_edges_matches'], edge_overlap], dim=1)
    vert_coincidence = batch['bl_exact_vertices_matches']

    face_coincidence_metrics = compute_metrics_from_matches(batch, 'faces', face_coincidence)
    edge_coincidence_metrics = compute_metrics_from_matches(batch, 'edges', edge_coincidence)
    vert_coincidence_metrics = compute_metrics_from_matches(batch, 'vertices', vert_coincidence)

    face_onshape_metrics = compute_metrics_from_matches(batch, 'faces', batch['os_bl_faces_matches'])
    edge_onshape_metrics = compute_metrics_from_matches(batch, 'edges', batch['os_bl_edges_matches'])
    vert_onshape_metrics = compute_metrics_from_matches(batch, 'vertices', batch['os_bl_vertices_matches'])

    metrics = {
        'exact': {
            'faces': face_exact_metrics,
            'edges': edge_exact_metrics,
            'vertices': vert_exact_metrics
        },
        'coincidence': {
            'faces': face_coincidence_metrics,
            'edges': edge_coincidence_metrics,
            'vertices': vert_coincidence_metrics
        },
        'onshape': {
            'faces': face_onshape_metrics,
            'edges': edge_onshape_metrics,
            'vertices': vert_onshape_metrics
        }
    }

    return metrics

In [None]:
all_metrics = {
    'geo': compute_metrics_from_batch(geo_batch),
    'topo': compute_metrics_from_batch(topo_batch),
    'both': compute_metrics_from_batch(both_batch)
}

In [None]:
torch.save(all_metrics, 'baselinesV4.pt')

In [None]:

def generate_plots(geo_test_graphs, topo_test_graphs, both_test_graphs, outdir, match_prefix='', title_prefix=''):
    for name,test_set in tqdm((('Geo',geo_test_graphs), ('Topo', topo_test_graphs), ('Both', both_test_graphs)),'Test Sets'):
        follow_batch=['left_vertices','right_vertices','left_edges', 'right_edges','left_faces','right_faces', 'faces_matches', 'edges_matches', 'vertices_matches']
        big_batch = Batch.from_data_list(test_set,follow_batch=follow_batch)
        for topo_type in tqdm(['faces', 'edges', 'vertices'],'Match Sets', leave=False):
            metrics = compute_metrics_from_matches(big_batch, topo_type, big_batch[f'bl_exact_{topo_type}_matches'])
            plot = plot_the_fives(*np.stack([metrics]*2)[:,:-2].T, np.array([0.0,1.0]),f'Exact Matching {name} ({topo_type})')
            outpath = os.path.join(outdir, f'{name}_{topo_type}.png')
            make_containing_dir(outpath)
            plot.savefig(outpath)


In [None]:
def plot_grid(metrics, method='exact'):
    fig, axes = plt.subplots(3,3, figsize = (24,24))
    for i,ds in enumerate(['geo', 'topo','both']):
        for j,t in enumerate(['faces','edges','vertices']):
            m = metrics[ds][method][t]
            plot = plot_the_fives(*np.stack([m]*2)[:,:-2].T, np.array([0.0,1.0]),f'{method} -- {ds} ({t})', ax=axes[j,i])
    return fig
coincidence_fig = plot_grid(all_metrics, 'coincidence')
exact_fig = plot_grid(all_metrics, 'exact')
onshape_fig = plot_grid(all_metrics, 'onshape')

#coincidence_fig.savefig('coincidence_99_99.png')
#exact_fig.savefig('exact_99_99.png')
#onshape_fig.savefig('onshape_99_99.png')
