# Imports

In [None]:
import meshplot as mp
from datasets import BRepDataModule
from automate import PartFeatures
import meshplot as mp
import numpy as np
import os
from matplotlib import pyplot as plt
from zipfile import ZipFile
from pspy import Part
from pylab import cm

## Dataloaders and Model Definitions

In [2]:
import torch_geometric as tg
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
import random
import torch
import json
from tqdm import tqdm

def create_subset(data, seed, size):
    random.seed(seed)
    return random.sample(data, size)

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = tg.nn.GCNConv(in_channels, 2*out_channels)
        self.conv2 = tg.nn.GCNConv(2*out_channels, out_channels)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

class CodeGraphAutoencoder(pl.LightningModule):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = tg.nn.GAE(GCNEncoder(64,16))
    def forward(self, batch):
        return self.model.encode(batch.x, batch.edge_index)
    def training_step(self, batch, batch_idx):
        z = self(batch)
        loss = self.model.recon_loss(z, batch.edge_index)
        self.log('train_loss',loss, on_step=True,on_epoch=True,batch_size=z.shape[0])
        return loss
    def validation_step(self, batch, batch_idx):
        z = self(batch)
        loss = self.model.recon_loss(z, batch.edge_index)
        self.log('val_loss',loss,on_step=True,on_epoch=True,batch_size=z.shape[0])
    def test_step(self, batch, batch_idx):
        z = self(batch)
        loss = self.model.recon_loss(z, batch.edge_index)
        self.log('test_loss',loss,on_step=False,on_epoch=True,batch_size=z.shape[0])
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)

from automate import LinearBlock, BipartiteResMRConv
from torch.nn import ModuleList
from torchmetrics import Accuracy
from torch.nn.functional import cross_entropy
class CodePredictor(pl.LightningModule):
    def __init__(self, in_channels, out_channels, mlp_layers, mp_layers):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mlp_layers = mlp_layers
        self.mp_layers = mp_layers
        
        self.mp = ModuleList([BipartiteResMRConv(in_channels) for _ in range(mp_layers)])
        self.mlp = LinearBlock(*([in_channels]*mlp_layers), out_channels, last_linear=True)
        
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()
        
    def forward(self, data):
        x = data.x#torch.cat([data.x,data.z],dim=1)
        for mp in self.mp:
            x = mp(x,x,data.edge_index)
        x = self.mlp(x)
        return x
    
    def training_step(self, batch, batch_idx):
        scores = self(batch)
        preds = scores.argmax(dim=1)
        target = batch.y.reshape(-1)
        loss = cross_entropy(scores, target)
        self.train_acc(preds, target)
        batch_size = len(target)
        self.log('train_loss',loss,on_epoch=True,on_step=True,batch_size=batch_size)
        self.log('train_acc',self.train_acc,on_epoch=True,on_step=True,batch_size=batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        scores = self(batch)
        preds = scores.argmax(dim=1)
        target = batch.y.reshape(-1)
        loss = cross_entropy(scores, target)
        self.val_acc(preds, target)
        batch_size = len(target)
        self.log('val_loss',loss,on_epoch=True,on_step=True,batch_size=batch_size)
        self.log('val_acc',self.val_acc,on_epoch=True,on_step=True,batch_size=batch_size)
    
    def test_step(self, batch, batch_idx):
        scores = self(batch)
        preds = scores.argmax(dim=1)
        target = batch.y.reshape(-1)
        loss = cross_entropy(scores, target)
        self.test_acc(preds, target)
        batch_size = len(target)
        self.log('test_loss',loss,on_epoch=True,on_step=False,batch_size=batch_size)
        self.log('test_acc',self.test_acc,on_epoch=True,on_step=False,batch_size=batch_size)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
        

class DictDataset(torch.utils.data.Dataset):
    def __init__(self, index, data, mode='train', val_frac=.2,seed=42,undirected=True,train_size=None):
        index
        keyset = index['test']
        if mode in ['train', 'validate']:
            keyset = index['train']
            if train_size:
                keyset = create_subset(keyset, seed, train_size)
            train_keys, val_keys = train_test_split(keyset, test_size=val_frac, random_state=seed)
            keyset = train_keys if mode == 'train' else val_keys
        self.data = {i:tg.data.Data(**data[index['template'].format(*key)]) for i,key in enumerate(keyset)}
        if undirected:
            for k,v in self.data.items():
                v.edge_index = torch.cat([v.edge_index, v.edge_index[[1,0]]],dim=1)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

class DictDatamodule(pl.LightningDataModule):
    def __init__(self, index, data, val_frac=0.2, seed=42, batch_size=32,train_size=None):
        super().__init__()
        self.val_frac = val_frac
        self.seed = seed
        self.ds_train = DictDataset(index, data, 'train', val_frac, seed, True, train_size)
        self.ds_val = DictDataset(index, data, 'validate', val_frac, seed, True, train_size)
        self.ds_test = DictDataset(index, data, 'test')
        self.batch_size = batch_size
    
    def train_dataloader(self):
        return DataLoader(
            self.ds_train, 
            batch_size=min(len(self.ds_train), self.batch_size), 
            shuffle=True, 
            num_workers=8, 
            persistent_workers=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.ds_val, 
            batch_size=min(len(self.ds_val),self.batch_size), 
            shuffle=False
        )

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=1, shuffle=False)

## Load The Data Module

In [3]:
with open('/projects/grail/benjonesnb/cadlab/siggasia2022/mfcad.json','r') as f:
    index = json.load(f)
    coded_set = torch.load('/projects/grail/benjonesnb/cadlab/siggasia2022/precoded/mfcad_coded.pt')

In [4]:
datamodule = DictDatamodule(
                    index, 
                    coded_set,  
                    seed=0, 
                    batch_size=512
                )

In [10]:
test_ds = datamodule.ds_test

## Find and load trained model checkpoints

In [6]:
def load_model(mp_layers, tb_dir):
    checkpoint_dir = os.path.join(tb_dir, 'version_0', 'checkpoints')
    checkpoints = [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt') and 'val_loss' in f]
    model = CodePredictor(64, 16, 2, mp_layers)
    sd = torch.load(checkpoints[0],map_location='cpu')['state_dict']
    model.load_state_dict(sd)
    return model

def load_models(root='/projects/grail/benjonesnb/cadlab/siggasia2022/tensorboard/f360seg/', seed = 0):
    model_dirs = [(d.split('_')[0], int(d.split('_')[1]), os.path.join(root,d)) for d in os.listdir(root) if '_' in d and os.path.isdir(os.path.join(root,d)) and int(d.split('_')[-1]) == seed]
    m_dir_dict = dict()
    for model, train_size, path in model_dirs:
        ms = m_dir_dict.get(model, [])
        ms.append((train_size, path))
        m_dir_dict[model] = ms
    for k,v in m_dir_dict.items():
        m_dir_dict[k] = sorted(v, key=lambda x: x[0])
    for k,v in m_dir_dict.items():
        for i,(s, p) in enumerate(v):
            v[i] = (s, load_model(int(k[-1]), p))
    return m_dir_dict
all_models = load_models('/projects/grail/benjonesnb/cadlab/siggasia2022/tensorboard/mfcad/')
our_models = all_models['mp2']

In [7]:
[m[0] for m in our_models]

[10, 100, 1000, 10000, 13940]

## Compute a "difficulty" scale for test set examples

In [12]:
accs = np.zeros((len(test_ds), len(our_models)))
for i in tqdm(range(len(test_ds))):
    data = test_ds[i]
    for j, (ts, m) in enumerate(our_models):
        with torch.no_grad():
            preds = m(data).argmax(dim=1).numpy()
        targets = data.y.numpy()
        acc = (preds.flatten() == targets.flatten()).sum() / len(targets.flatten())
        accs[i,j] = acc

100%|██████████████████████████████████████████████████████████████| 1548/1548 [00:11<00:00, 135.90it/s]


In [13]:
local_auc = accs.sum(axis = 1) - accs[:,0] / 2 - accs[:,-1] / 2

In [14]:
complexity = np.array([len(test_ds[i].y) for i in range(len(test_ds))])

In [17]:
sorted_by_auc = sorted(enumerate(zip(local_auc, complexity)), key=lambda x:x[1][0])

In [19]:
filtered_examples = [x for x in sorted_by_auc if x[1][1] > 20]

## Setup Plotting Code

In [36]:
num_labels = max([i for j in index['test_labels'] for i in j]) + 1
cmap = cm.get_cmap('tab20', num_labels)
zf = ZipFile('/projects/grail/benjonesnb/cadlab/siggasia2022/mfcad.zip')
test_keys = [index['template'].format(*x) for x in index['test']]
def plot_part(V, F, E2T, FC=None):
    plot = mp.plot(V, F, c = FC)
    E = np.concatenate([
        F[E2T[:,0]>=0][:,[2,0]], 
        F[E2T[:,1]>=0][:,[0,1]],
        F[E2T[:,2]>=0][:,[1,2]]
    ], axis=0)
    plot.add_edges(V, E, shading={'line_width':0.5})
def plot_preds(i, test_ds, test_keys=test_keys, zf=zf, models=our_models, colormap=cmap, just_gt=False):
    data = test_ds[i]
    with zf.open(test_keys[i],'r') as f:
        p = Part(f.read().decode('utf-8'))
    target = data.y.numpy()
    target_c = colormap(target[p.mesh_topology.face_to_topology])[:,:3]
    plot_part(p.mesh.V, p.mesh.F, p.mesh_topology.edge_to_topology,  target_c)
    if just_gt:
        return
    model_preds = []
    for ts, model in models:
        with torch.no_grad():
            preds = model(data).argmax(dim=1).numpy()
            model_preds.append(preds)   
    for preds in model_preds:
        pred_c = colormap(preds[p.mesh_topology.face_to_topology])[:,:3]
        plot_part(p.mesh.V, p.mesh.F, p.mesh_topology.edge_to_topology, pred_c)


## Explore Examples

In [40]:
len(filtered_examples)

1019

In [62]:
gallery_examples = [646, 965, 673, 417, 888, 1090]

In [63]:
k = -1
print(filtered_examples[k])
#plot_preds(filtered_examples[k][0], test_ds, test_keys, zf, our_models, cmap)

(1090, (3.9523809523809526, 21))


In [67]:
for row in accs[gallery_examples]:
    print('new row')
    for col in row:
        print(f'{col*100:.1f}%')
                

new row
26.9%
34.6%
84.6%
73.1%
73.1%
new row
40.9%
50.0%
90.9%
95.5%
100.0%
new row
36.4%
59.1%
90.9%
100.0%
100.0%
new row
43.5%
60.9%
95.7%
100.0%
100.0%
new row
44.0%
68.0%
100.0%
100.0%
100.0%
new row
90.5%
100.0%
100.0%
100.0%
100.0%


In [277]:
for i in gallery_examples:
    plot_preds(i, test_ds, test_keys, zf, our_models, color_pallet, just_gt=True)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.004967…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…