## Reconstruction Gallery

In [None]:
# TODO - move any reasonable-to-precomput torch files to disk, put unreasonable ones in gitignore
#      - add params for gallery size, flag for force recompute, and cache loading for prerenders

import torch
import numpy as np
from rendering import find_best_angle_from_part, render_part, render_mesh
from tqdm.notebook import tqdm
from render_shape import preds_to_mesh
from zipfile import ZipFile
import json
from pspy import Part
from train_latent_space import BRepFaceAutoencoder
from matplotlib import pyplot as plt
import os

data_root = '/media/ben/Data/cad'
f360seg_index_path =  data_root + '/fusion360seg.json'
f360seg_zip_path = data_root + '/fusion360seg.zip'
model_checkpoint_path = data_root + '/BRepFaceAutoencoder_64_1024_4.ckpt'
computed_f360seg_codes_path = data_root + '/fusion360seg_coded.pt'
render_losses_path = 'f360_render_test_losses.pt'
figure_out_path = 'renderings.png'

rows = 6
cols = 4
size = 5

grid_density = 100

force_rendering = False





testing_losses = torch.load(render_losses_path)
avg_losses = np.array([x[0].item() for x in testing_losses])
part_sizes = np.array([x[1] for x in testing_losses])
total_losses = avg_losses * part_sizes
avg_sorted = sorted(enumerate(zip(avg_losses, part_sizes, total_losses)),key=lambda x: x[1][0])
total_sorted = sorted(enumerate(zip(avg_losses, part_sizes, total_losses)),key=lambda x: x[1][2])
total_filtered = [x for x in total_sorted if x[1][1] > 20]


recompute_renderings = True
if os.path.exists('render_results.pt') and not force_rendering:
    render_results = torch.load('render_results.pt')
    renders = render_results['renders']
    gts = render_results['gts']
    if len(renders) >= rows*cols:
        recompute_renderings = False

if recompute_renderings:
    with open(f360seg_index_path,'r') as f:
        index = json.load(f)
    data =  ZipFile(f360seg_zip_path,'r')
    parts_list = [index['template'].format(*x) for x in index['test']]

    model = BRepFaceAutoencoder(64,1024,4)
    ckpt = torch.load(model_checkpoint_path)
    model.load_state_dict(ckpt['state_dict'])

    def predict(face_codes, model, N=grid_density):
        n_faces = face_codes.shape[0]
        line = torch.linspace(-0.1,1.1,N)
        grid = torch.cartesian_prod(line, line)
        grids = grid.repeat(n_faces,1)
        indices = torch.arange(n_faces).repeat_interleave(N*N, dim=0)
        with torch.no_grad():
            indexed_codes = face_codes[indices]
            uv_codes = torch.cat([grids, indexed_codes],dim=1)
            preds = model.decoder(uv_codes)
        return preds

    codes = torch.load(computed_f360seg_codes_path)

    gts = []
    renders = []
    for k in tqdm(range(rows*cols)):
        i = total_filtered[k][0]
        N = grid_density
        preds = predict(codes[parts_list[i]]['x'], model, N)
        V, F, C = preds_to_mesh(preds, N)
        part = Part(data.open(parts_list[i]).read().decode('utf-8'))
        pose, zoom = find_best_angle_from_part(part)
        ground_truth = render_part(part, pose, zoom)
        rendering = render_mesh(V,F,C,pose,zoom)
        gts.append(ground_truth)
        renders.append(rendering)

    render_results = {'gts':gts,'renders':renders}
    torch.save(render_results, 'render_results.pt')



M = rows*cols
s = size
fig, axes = plt.subplots(int(M/cols), cols*2, figsize=(2*cols*s, s*int(M/cols)),gridspec_kw = {'wspace':0, 'hspace':0}, dpi=300)
for i in range(M):
    row = int(i / cols)
    col = i % cols
    axes[row,2*col].imshow(gts[i])
    axes[row,2*col].axis('off')
    axes[row,2*col+1].imshow(renders[i])
    axes[row,2*col+1].axis('off')
fig.savefig(figure_out_path)

## MFCAD Figure

In [None]:
gallery_examples = [646, 965, 673, 417, 888, 1090]



## Limitations Plot

In [None]:
import altair as alt
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

accs = pd.read_parquet('df_face_acc.parquet')
accs['one'] = 1
iscorr = np.vectorize(lambda x: 'correct' if x == 1 else 'incorrect')
accs['correct'] = iscorr(accs.label)
accs['logmse'] = np.log(accs.mse)

alt.data_transformers.disable_max_rows()

alt.Chart(accs[accs.mse < 1]).mark_bar().encode(
    x=alt.X('logmse',bin=alt.Bin(maxbins=50)),
    y=alt.Y('sum(one)', stack='normalize', axis=alt.Axis(format='%')),
    color=alt.Color('label:N')
).facet(row='train_size')