In [3]:
from brepmatching.data import BRepMatchingDataset
from brepmatching.visualization import show_image, render_predictions

In [5]:
ds_geo = BRepMatchingDataset('../../../brepmatching/GeoWithBaseline.zip','../../../brepmatching/GeoWithBaseline.pt',mode='test')
ds_topo = BRepMatchingDataset('../../../brepmatching/TopoWithBaseline.zip','../../../brepmatching/TopoWithBaseline.pt',mode='test')
ds_both = BRepMatchingDataset('../../../brepmatching/TopoAndGeoWithBaseline.zip','../../../brepmatching/TopoAndGeoWithBaseline.pt',mode='test')

In [9]:
geo_test_graphs = [ds_geo[i] for i in range(len(ds_geo))]
topo_test_graphs = [ds_topo[i] for i in range(len(ds_topo))]
both_test_graphs = [ds_both[i] for i in range(len(ds_both))]

In [10]:
import torch

In [12]:
torch.save(geo_test_graphs, '../data/geo_test_set.pt')
torch.save(topo_test_graphs, '../data/topo_test_set.pt')
torch.save(both_test_graphs, '../data/both_test_set.pt')

In [13]:
from automate import HetData
from torch import is_tensor
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
import torch
from typing import Optional

In [14]:
TOPO_KINDS: list[tuple[str, str, str]] = [
  ("faces", "face", "f"),
  # ("loops", "loop", "l"),
  ("edges", "edge", "e"),
  ("vertices", "vertex", "v")
]

def count_batches(data: HetData) -> int:
    #Todo: is there a better way to count batches?
    return int(max(data.left_faces_batch[-1].item(), data.right_faces_batch[-1].item())) + 1


NUM_METRICS = 7
METRIC_COLS = ["true_pos", "true_neg", "missed", "incorrect", "false_pos", # relative to nr
               "precision", "recall"]

def separate_batched_matches(matches: torch.Tensor,
                             left_topo_batches: torch.Tensor,
                             right_topo_batches: torch.Tensor) -> list[torch.Tensor]:
    """
    given a 2xn tensor of matches, and the batches tensor of the left and right nodes into which the matches index,
    return a list of b matches, with local indices within each instance
    """
    match_list = []
    num_batches = int(max(left_topo_batches[-1].item(), right_topo_batches[-1].item())) + 1
    match_batches = left_topo_batches[matches[0]]
    left_batch_counts = [(left_topo_batches == b).sum() for b in range(num_batches)]
    right_batch_counts = [(right_topo_batches == b).sum() for b in range(num_batches)]
    left_batch_offsets = []
    right_batch_offsets = []
    device = matches.device
    offset = torch.tensor(0, device=device)
    for size in left_batch_counts:
        left_batch_offsets.append(offset.clone())
        offset += size
    offset = torch.tensor(0, device=device)
    for size in right_batch_counts:
        right_batch_offsets.append(offset.clone())
        offset += size
    
    for b in range(num_batches):
        filtered_matches = matches[:, match_batches == b]
        filtered_matches[0] -= left_batch_offsets[b]
        filtered_matches[1] -= right_batch_offsets[b]
        assert(filtered_matches[0].numel() == 0 or (0 <= filtered_matches[0].min() and filtered_matches[0].max() < left_batch_counts[b]))
        assert(filtered_matches[1].numel() == 0 or (0 <= filtered_matches[1].min() and filtered_matches[1].max() < right_batch_counts[b]))
        match_list.append(filtered_matches)
    return match_list

def compute_metrics_impl(matches: torch.Tensor,
                         gt_matches: torch.Tensor,
                         n_topos_left: int,
                         n_topos_right: int) -> np.ndarray:
    device = matches.device

    pred = torch.full((n_topos_right, ), -1, device=device)
    pred[matches[1]] = matches[0]

    gt = torch.full((n_topos_right, ), -1, device=device)
    gt[gt_matches[1]] = gt_matches[0]

    num_gt_matched = int((gt >= 0).sum().item())
    num_gt_unmatched = n_topos_right - num_gt_matched

    num_matched = int((pred >= 0).sum().item())

    correct_mask = (pred == gt)
    num_correct = int(correct_mask.sum().item())
    num_true_pos = int(correct_mask.logical_and(pred >= 0).sum().item())
    num_true_neg = num_correct - num_true_pos

    incorrect_mask = (pred != gt)
    num_incorrect = int(incorrect_mask.sum())
    num_false_pos = int((gt[pred >= 0] == -1).sum())
    num_missed = int((gt[pred == -1] >= 0).sum())
    num_wrong_pos = num_incorrect - num_false_pos - num_missed

    true_pos = (num_true_pos / n_topos_right) if n_topos_right > 0 else 0.0
    true_neg = (num_true_neg / n_topos_right) if n_topos_right > 0 else 1.0
    missed = (num_missed / n_topos_right) if n_topos_right > 0 else 0.0
    incorrect = (num_wrong_pos / n_topos_right) if n_topos_right > 0 else 0.0
    false_pos = (num_false_pos / n_topos_right) if n_topos_right > 0 else 0.0

    precision = (num_true_pos / num_matched) if num_matched > 0 else 1.0
    recall = (num_true_pos / num_gt_matched) if num_gt_matched > 0 else 1.0

    return np.array([true_pos, true_neg, missed, incorrect, false_pos, precision, recall])

def compute_metrics_from_matches(data: HetData, kinds: str, matches: torch.Tensor) -> np.ndarray:
    gt_matches = data[f"{kinds}_matches"]       # assume non-empty

    batch_left = data[f"left_{kinds}_batch"]
    batch_right = data[f"right_{kinds}_batch"]

    cur_matches_unbatched = separate_batched_matches(matches, batch_left, batch_right)
    gt_matches_unbatched = separate_batched_matches(gt_matches, batch_left, batch_right)

    n_batches = len(cur_matches_unbatched)

    metrics = np.zeros(NUM_METRICS)

    for b in range(n_batches):
        n_topos_left = int((batch_left == b).sum().item())
        n_topos_right = int((batch_right == b).sum().item())

        cur_matches_b = cur_matches_unbatched[b]
        gt_matches_b = gt_matches_unbatched[b]

        cur_metrics = compute_metrics_impl(
            cur_matches_b, gt_matches_b, n_topos_left, n_topos_right)
        
        metrics += cur_metrics
    
    metrics /= n_batches

    return metrics

###### PLOTTING ######

def plot_metric(metric, thresholds, name):
    fig = Figure(figsize=(8, 8))
    ax = fig.add_subplot()
    ax.plot(thresholds, metric)
    ax.set_title(name + ' vs threshold')
    ax.set_xlabel('Threshold')
    ax.set_ylabel(name)
    ax.set_ylim(-0.1, 1.1)
    ax.grid()
    return fig

def plot_the_fives(true_pos: np.ndarray,
                   true_neg: np.ndarray,
                   missed: np.ndarray,
                   incorrect: np.ndarray,
                   false_pos: np.ndarray,
                   thresholds: np.ndarray,
                   title: str) -> Figure:
    fig = Figure(figsize=(8, 8))
    ax = fig.add_subplot()
    ax.stackplot(thresholds, false_pos, incorrect, missed, true_neg, true_pos,
                 labels=["False Positive", "Incorrect", "Missed", "True Negative", "True Positive"],
                 colors=["#BA5050", "#D4756C", "#D6CFB8", "#61B5CF", "#468CB8"])
    ax.legend(loc="upper left")
    ax.set_xlabel("Threshold")
    ax.set_title(title)
    ax.set_ylim(-0.1, 1.1)
    ax.grid()
    return fig
    
    
def plot_multiple_metrics(metrics: dict[str, np.ndarray], 
                          thresholds: np.ndarray,
                          title: str):
    fig = Figure(figsize=(8, 8))
    ax = fig.add_subplot()
    for j, key in enumerate(metrics):
        if j == 0:
            color = '#0000ff'
        elif j == 1:
            color = '#e08c24'
        elif j == 2:
            color = '#ff0000'
        else:
            color = None
        ax.plot(thresholds, metrics[key], label=key, color=color)
    ax.legend()
    ax.set_xlabel('Threshold')
    ax.set_title(title)
    ax.set_ylim(-0.1, 1.1)
    ax.grid()
    return fig

def plot_tradeoff(x, y, values, indices, xname, yname, suffix=''):
    fig = Figure(figsize=(8, 8))
    ax = fig.add_subplot()
    ax.plot(x, y)

    x_filtered = [x[i] for i in indices]
    y_filtered = [y[i] for i in indices]
    v_filtered = [values[i] for i in indices]
    ax.scatter(x_filtered, y_filtered)
    for xf, yf, vf in zip(x_filtered, y_filtered, v_filtered):
        ax.annotate(str(round(vf,2)), (xf, yf))

    ax.set_title(yname + ' VS ' + xname + suffix)
    ax.set_xlabel(xname)
    ax.set_ylabel(yname)
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, 1.1)
    ax.grid()
    return fig

class Running_avg:
    def __init__(self, dim):
        self.state = np.zeros(dim)
        self.count = 0.0
    def __call__(self, value, weight):
        self.state += value * weight
        self.count += weight
    def reset(self):
        val = self.state / self.count if self.count > 0 else 0
        self.state[:] = 0
        self.count = 0.0
        return val

def logsumexp(x, keep_mask=None, add_one=True, dim=1):
    if keep_mask is not None:
        x = x.masked_fill(~keep_mask, -torch.inf)
    if add_one:
        zeros = torch.zeros(x.size(dim - 1), dtype=x.dtype, device=x.device).unsqueeze(
            dim
        )
        x = torch.cat([x, zeros], dim=dim)

    output = torch.logsumexp(x, dim=dim, keepdim=True)
    if keep_mask is not None:
        output = output.masked_fill(~torch.any(keep_mask, dim=dim, keepdim=True), 0)
    return output

In [15]:
data = ds_geo[0]

In [19]:
from torch_geometric.data.batch import Batch

In [22]:
follow_batch=['left_vertices','right_vertices','left_edges', 'right_edges','left_faces','right_faces', 'faces_matches', 'edges_matches', 'vertices_matches']
big_batch = Batch.from_data_list(geo_test_graphs,follow_batch=follow_batch)

In [25]:
geo_face_metrics = compute_metrics_from_matches(big_batch, 'faces', big_batch.os_bl_faces_matches)

In [50]:
from matplotlib import pyplot as plt
from tqdm import tqdm

In [58]:
plots = []
for name,test_set in tqdm((('Geo',geo_test_graphs), ('Topo', topo_test_graphs), ('Both', both_test_graphs)),'Test Sets'):
    big_batch = Batch.from_data_list(test_set,follow_batch=follow_batch)
    ds_plots = []
    for topo_type in tqdm(['faces', 'edges', 'vertices'],'Match Sets', leave=False):
        metrics = compute_metrics_from_matches(big_batch, topo_type, big_batch[f'os_bl_{topo_type}_matches'])
        plot = plot_the_fives(*np.stack([metrics]*2)[:,:-2].T, np.array([0.0,1.0]),f'Onshape Baseline {name} ({topo_type})')
        ds_plots.append(plot)
        writer.add_figure(f'Onshape Baseline {name} ({topo_type})', plot)
    plots.append(ds_plots)

Test Sets: 100%|██████████| 3/3 [00:20<00:00,  6.67s/it]


In [62]:
writer.flush()

In [60]:
data

HetData(left_F=[3, 164], right_F=[3, 160], left_edge_export_ids=[48], right_edge_export_ids=[42], left_faces=[18, 62], right_faces=[16, 62], left_F_to_faces=[1, 164], right_F_to_faces=[1, 160], left_mcfs=[390, 6], right_mcfs=[344, 6], left_face_to_loop=[2, 18], right_face_to_loop=[2, 16], left_mcf_refs=[3, 390], right_mcf_refs=[3, 344], left_vertex_to_flat_topos=[2, 32], right_vertex_to_flat_topos=[2, 28], left_edge_samples=[48, 7, 10], right_edge_samples=[42, 7, 10], left_edge_to_flat_topos=[2, 48], right_edge_to_flat_topos=[2, 42], left_loop_export_ids=[18], right_loop_export_ids=[16], left_face_to_flat_topos=[2, 18], right_face_to_flat_topos=[2, 16], left_vertex_export_ids=[32], right_vertex_export_ids=[28], left_loops=[18, 38], right_loops=[16, 38], left_graph_idx=[1, 1], right_graph_idx=[1, 1], left_flat_topos=[116, 0], right_flat_topos=[102, 0], left_edges=[48, 72], right_edges=[42, 72], left_face_samples=[18, 9, 10, 10], right_face_samples=[16, 9, 10, 10], left_V_to_vertices=[2,

In [54]:
from torch.utils.tensorboard import SummaryWriter

In [55]:
writer = SummaryWriter('dummy_log/plots')

In [None]:
writer.add_figure(plots[0][0],)

In [61]:
for name,test_set in tqdm((('Geo',geo_test_graphs), ('Topo', topo_test_graphs), ('Both', both_test_graphs)),'Test Sets'):
    big_batch = Batch.from_data_list(test_set,follow_batch=follow_batch)
    for topo_type in tqdm(['faces', 'edges', 'vertices'],'Match Sets', leave=False):
        metrics = compute_metrics_from_matches(big_batch, topo_type, big_batch[f'{topo_type}_matches'])
        plot = plot_the_fives(*np.stack([metrics]*2)[:,:-2].T, np.array([0.0,1.0]),f'Ground Truth {name} ({topo_type})')
        writer.add_figure(f'Ground Truth {name} ({topo_type})', plot)

Test Sets: 100%|██████████| 3/3 [00:18<00:00,  6.19s/it]


In [65]:
import os
os.makedirs('./metrics/exact/',exist_ok=True)
for name,test_set in tqdm((('Geo',geo_test_graphs), ('Topo', topo_test_graphs), ('Both', both_test_graphs)),'Test Sets'):
    big_batch = Batch.from_data_list(test_set,follow_batch=follow_batch)
    for topo_type in tqdm(['faces', 'edges', 'vertices'],'Match Sets', leave=False):
        metrics = compute_metrics_from_matches(big_batch, topo_type, big_batch[f'bl_exact_{topo_type}_matches'])
        plot = plot_the_fives(*np.stack([metrics]*2)[:,:-2].T, np.array([0.0,1.0]),f'Exact Matching {name} ({topo_type})')
        writer.add_figure(f'Exact Matching {name} ({topo_type})', plot)
        plot.savefig(f'./metrics/exact/{name}_{topo_type}.png')

writer.flush()

Test Sets: 100%|██████████| 3/3 [00:19<00:00,  6.45s/it]
