In [5]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.curdir, '..')))
print(sys.path)
import trimesh
import k3d
import numpy as np
from util.visualization.utils_mesh import get_watertight_mesh_for_latent
import torch
import numpy as np
import k3d
import trimesh
from tqdm import tqdm
from utils import get_model, get_model_path_via_wandb_id_from_fs, get_stateless_net_with_partials, set_else_default

['/system/user/radler/repos/objgen/notebooks', '/system/apps/userenv/radler/ginn/lib/python312.zip', '/system/apps/userenv/radler/ginn/lib/python3.12', '/system/apps/userenv/radler/ginn/lib/python3.12/lib-dynload', '', '/system/apps/userenv/radler/ginn/lib/python3.12/site-packages', '/system/user/radler/repos/objgen']


In [11]:
## Parameters
device = 'cuda:1'
cfg = {
    'layers': [128, 128, 128],
    'w0': 1.0,
    'w0_initial': 18,
    'wire_scale': 6,
    'model': 'cond_wire',
    'nx': 3,
    'ny': 1,
    'nz': 1,
    'problem': 'simjeb',
    'simjeb_root_dir': '../GINN/simJEB/data',
    'envelope_sample_from': 'exterior',
    'n_points_domain': 2048,  # used for eikonal loss
    'n_points_envelope': 16384,
    'n_points_interfaces': 4096,
    'n_points_normals': 4096,
    'mc_resolution': 128,
    'bounds': torch.from_numpy(np.load('../GINN/simJEB/data/bounds.npy')).float().to(device),
    'device': device,
    'ginn_bsize': 2,
    'surf_pts_recompute_every_n_epochs': 1,
    'surf_pts_nof_points': 8192, # 32768  ## nof points for initializing the flow to surface
    'surf_pts_lr': 0.01, ## learning rate for non-Newton optimizer
    'surf_pts_n_iter': 10, # iterations of surface flow
    'surf_pts_prec_eps': 1.0e-3,  ## precision threshold for early stopping surface flow and filtering the points 
    'surf_pts_converged_interval': 1, ## how often to check the convergence
    'surf_pts_use_newton': True, ## whether to use Newton iteration or Adam
    'surf_pts_newton_clip': 0.15, ## magnitude for clipping the Newton update
    'surf_pts_inflate_bounds_amount': 0.05, ## inflate the (otherwise tight) bounding box by this fraction
    'surf_pts_uniform_n_iter': 10, ## nof iterations for repelling the points
    'surf_pts_uniform_nof_neighbours': 16, ## nof neighbors for knn
    'surf_pts_uniform_stepsize': 0.75, ## step size for the repelling update
    'surf_pts_uniform_n_iter_reproj': 5, ## nof Newton-iterations for reprojecting the points
    'surf_pts_uniform_prec_eps': 1.0e-3, ## precision for reprojection (similar to above)
    'surf_pts_uniform_min_count': 1000, ## minimum number of points to redistribute. Less than this is meaningless
    'surf_pts_surpress_tqdm': True,
    'surf_pts_uniform_reproject_surpress_tqdm': True,
    'fig_show': False,
    'fig_save': False,
    'fig_wandb': False,
}
torch.set_default_device(device)

In [7]:
########## MODELS ##########

## No smoothing
key = '7bnersho'
model_path = get_model_path_via_wandb_id_from_fs(key)
z = torch.tensor([[0.0], [0.1]], device=device)

########## END MODELS ##########

## MODEL
# activation = get_activation(config.get('activation', None))
model = get_model(cfg).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
nep = get_stateless_net_with_partials(model, cfg['nz'])
params, f, vf_x, vf_xx = nep.params_, nep.f_, nep.vf_x_, nep.vf_xx_

## visualize shapes for a range of z
meshes = []
for z_ in tqdm(z): ## do marching cubes for every z
    meshes.append(get_watertight_mesh_for_latent(f, params, z_, cfg['bounds'], cfg['mc_resolution'], device, surpress_watertight=True))

Found model at /system/user/publicwork/radler/ginndata/saved_models/cond_wire/2024_09_04__14_45_35-7bnersho/2024_09_04__14_45_35-7bnersho-model.pt


100%|██████████| 2/2 [00:00<00:00,  2.77it/s]


In [12]:
from GINN import problem_sampler
from GINN.visualize.plotter_dummy import DummyPlotHelper
from GINN.helpers.timer_helper import TimerHelper
from GINN.helpers.mp_manager import MPManager


mpm = MPManager(cfg)
t_helper = TimerHelper(cfg, lock=mpm.get_lock())
mpm.set_timer_helper(t_helper)  # avoid circular dependencies
p_sampler = problem_sampler.ProblemSampler(cfg)

In [25]:
import torch
import torch.nn.functional as F
from tqdm import trange
from models.point_wrapper import PointWrapper
from train.losses import get_gauss_curvature, get_mean_curvature_normalized

import logging
import time
from util.misc import set_and_true, set_else_default
from util.sample_utils import inflate_bounds, precompute_sample_grid

class ShapeBoundaryHelper:
    
    def __init__(self, config, netp, mp_manager: MPManager, plot_helper, timer_helper: TimerHelper, x_interface, device):
        self.config = config
        self.netp = netp.detach()
        self.mpm = mp_manager
        self.logger = logging.getLogger('surf_pts_helper')
        self.plotter = plot_helper
        self.timer_helper = timer_helper
        self.bounds = self.config['bounds'].to(device)
        self.bounds = inflate_bounds(self.bounds, amount=set_else_default('surf_pts_inflate_bounds_amount', self.config, 0.05))
        self.grid_find_surface, self.grid_dist_find_surface = precompute_sample_grid(self.config['surf_pts_nof_points'], self.bounds)
        
        self.record_time = timer_helper.record        
        self.p_surface = None
        self.x_interface = x_interface
        self.knn_k = set_else_default('surf_pts_uniform_nof_neighbours', config, 16)
        # NOTE: more neighbors pushes the points more to edges, so might be favourable for smoothness 
    
    def get_surface_pts(self, z):
        success, p_surface = self._get_and_plot_surface_flow(z)
        if not success:
            return None, None

        if len(p_surface)>set_else_default('surf_pts_uniform_min_count', self.config, 1000): 
            ## Stop redistributing if there are not enough points.
            ## Better return failure so that the integrals don't have a high variance
            success, p_surface = self.resample(p_surface, z, num_iters=set_else_default('surf_pts_uniform_n_iter', self.config, 10))

        if not success:
            return None, None

        weights_surf_pts = torch.ones(len(p_surface)) / p_surface.data.shape[0]
        dist = torch.min(torch.norm(p_surface.data[:, None, :] - self.x_interface[None, :, :], dim=2), dim=1)[0]
        if set_and_true('reweigh_surface_pts_close_to_interface', self.config):
            dist = torch.min(torch.norm(p_surface.data[:, None, :] - self.x_interface[None, :, :], dim=2), dim=1)[0]            
            assert False, 'break here'
            dist = torch.clamp(dist, max=self.config['reweigh_surface_pts_close_to_interface_cutoff'])
            weights_surf_pts = torch.pow(dist, self.config['reweigh_surface_pts_close_to_interface_power'])
            weights_surf_pts = weights_surf_pts / weights_surf_pts.sum()  ## normalize to sum to 1

        if set_and_true('plot_surface_points', self.config):
            y_x_surf = self.netp.vf_x(p_surface.data, p_surface.z_in(z)).squeeze(1)
            y_xx_surf = self.netp.vf_xx(p_surface.data, p_surface.z_in(z)).squeeze(1)
            
            mean_curvatures = get_mean_curvature_normalized(y_x_surf, y_xx_surf)
            gauss_curvatures = get_gauss_curvature(y_x_surf, y_xx_surf)
            E_strain = (2*mean_curvatures)**2 - 2*gauss_curvatures
            E_strain = torch.log(E_strain + 1e-3)  ## log to make the values more interpretable

            self.mpm.plot(self.plotter.plot_shape_and_points, 'plot_surface_points', 
                                arg_list=[p_surface.detach().cpu().numpy(), 'Surface points', E_strain.detach().cpu().numpy()])
        return p_surface, weights_surf_pts
    
    def _get_and_plot_surface_flow(self, z):
        
        with self.record_time('cp_helper: flow_to_surface_points'):
            p = self.get_grid_starting_pts(self.grid_find_surface, self.grid_dist_find_surface)
            success, tup = self.flow_to_surface_pts(p, z, 
                lr=self.config['surf_pts_lr'],
                n_iter=self.config['surf_pts_n_iter'],
                plot_descent=self.plotter.do_plot('plot_surface_descent'),
                use_newton=self.config['surf_pts_use_newton'],
                surpress_tqdm=set_else_default('surf_pts_surpress_tqdm', self.config, False),
                )

        if not success:
            self.logger.debug(f'No surface points found')
            return False, None        
        p_surface, x_path_over_iters = tup
        if self.plotter.do_plot('plot_surface_descent'):
            self.mpm.plot(self.plotter.plot_descent_trajectories, 'plot_surface_descent', [p.detach().cpu().numpy(), x_path_over_iters.cpu().numpy()])

        return True, p_surface
    
    def flow_to_surface_pts(self, p, z, lr, n_iter, plot_descent, filter_thr=None, newton_clip=None, min_count=None, use_sgd=False, use_newton=True, surpress_tqdm=False):
        """
        A simple optimization loop to let starting points p flow to zero.
        NOTE: Adam/SGD is kept for historic reasons, but going forward we might want to
        either split or remove it as the current code is a bit unreadable.
        The main difference between Adam and Newton:
        Adam requires to register the variables, which stay fixed size: filtering just selects a subset for evaluation and updating.
        Newton update is manual, so we can always throw away points.
        """

        ## Filter far away from surface so we get a more uniform distribution and need less iterations
        # y = self.netp.f(p.data, p.z_in(z)).squeeze(1)
        # init_mask = torch.abs(y) < 5e-2
        # p = p.select_w_mask(incl_mask=init_mask)

        ## Initialize parameters
        if filter_thr is None:
            filter_thr = set_else_default('surf_pts_prec_eps', self.config, 1e-3)
        if newton_clip is None:
            newton_clip = set_else_default('surf_pts_newton_clip', self.config, 0.15)
        if min_count is None:
            min_count = set_else_default('surf_pts_uniform_min_count', self.config, 100)

        ## Initialize plotting
        x_path_over_iters = None
        if plot_descent:
            x_path_over_iters = torch.full([n_iter + 1, len(p), self.config['nx']], torch.nan)
            idxs_in_orig = torch.arange(0, len(p))

        ## Initialize points and optimizer
        if use_newton:
            p_in = p
        else:
            p.data.requires_grad = True
            opt = torch.optim.Adam([p.data], lr=lr)
            if use_sgd:
                opt = torch.optim.SGD([p.data], lr=lr)
    
        ## Iterate
        for i in (pbar := trange(n_iter, disable=surpress_tqdm)):

            ## Mask
            if use_newton:
                out_mask = get_is_out_mask(p_in.data, self.bounds)
                p_in = p_in.select_w_mask(incl_mask=~out_mask)
                if plot_descent:
                    idxs_in_orig = idxs_in_orig[~out_mask]
                    x_path_over_iters[i][idxs_in_orig] = p_in.data.detach()
            else:
                opt.zero_grad()
                out_mask = get_is_out_mask(p.data, self.bounds)
                p_in = p.select_w_mask(incl_mask=~out_mask)
                if plot_descent:
                    x_path_over_iters[i] = p.data.detach()

            if len(p_in) == 0:
                self.logger.debug(f'No surf_pts_n_iter points found in the domain')
                return False, None

            ## Main update
            if use_newton:
                with torch.no_grad():
                    z_ = p_in.z_in(z)
                    y = self.netp.f(p_in.data, z_).squeeze(1)
                    y_x = self.netp.vf_x(p_in.data, z_).squeeze(1)
                    update = y_x * (torch.clip(y, -newton_clip, newton_clip)/y_x.norm(dim=1))[:,None]
                    p_in.data = p_in.data - update

                    ## For compatibility with remaining code
                    y_in = y

                    ## Logging
                    if not surpress_tqdm:
                        loss = y_in.square().mean()
                        pbar.set_description(f"Flow to surface points: {len(p_in)}/{len(p)}; {loss.item():.2e}")
            else:
                y_in = self.netp.f(p_in.data, p_in.z_in(z)).squeeze(1)  ## [bx]
                
                # L2 loss works better than L1 loss
                loss = y_in.square().mean()
                if torch.isnan(loss):
                    self.logger.debug(f'Early stop "Finding surface points" at it {i} due to nan loss')

                loss.backward()
                opt.step()
                if not surpress_tqdm:
                    pbar.set_description(f"Flow to surface points: {len(p_in)}/{len(p)}; {loss.item():.2e}")
            
            ## Early stopping
            if i % self.config['surf_pts_converged_interval'] == 0:
                # stop if |points| < thresh
                if (torch.abs(y_in) < self.config['surf_pts_prec_eps']).all():
                    self.logger.debug(f'Early stop "Finding surface points" at it {i}')
                    break
                
        
        ## Filter non-converged points
        converged_mask = torch.abs(y_in) < filter_thr
        p_in = p_in.select_w_mask(incl_mask=converged_mask)

        ## Exit early if no points are left
        if len(p_in)<min_count:
            self.logger.debug(f'Only {len(p_in)} surface points found, not continuing')
            return False, None

        ## Handle the last iteration of plotting
        if plot_descent:
            if use_newton:
                idxs_in_orig = idxs_in_orig[converged_mask]
                x_path_over_iters[i+1][idxs_in_orig] = p_in.data.detach()
            else:
                x_path_over_iters[i+1] = p.data.detach()
            x_path_over_iters = x_path_over_iters[:i+2] ## remove the unfilled part due to early stopping
        
        ## Disable gradient tracking for Adam
        if not use_newton:
            p_in = p_in.detach()
            p_in.data.requires_grad = False

        return True, (p_in, x_path_over_iters)

    def get_normals(self, p, z, invert=False):
        f_x = self.netp.vf_x(p.data, p.z_in(z)).squeeze(1)  ## [bx nx]
        if invert:
            f_x = -f_x
        p_normals = PointWrapper(f_x, map=p.get_map())
        p_normals.data = F.normalize(p_normals.data, dim=-1)
        return p_normals
        
    def get_nn_idcs(self, x, k):
        dist = torch.cdist(x, x, compute_mode='use_mm_for_euclid_dist')
        # dist = (x.unsqueeze(1) - x.unsqueeze(0)).norm(dim=-1)
        idcs = dist.argsort(dim=-1)[:, 1:k+1]
        return idcs
    
    def resample(self, points_init, z, num_iters=0, debug=True):
        """
        """

        ## Initialize parameters
        n_iter_reproj = set_else_default('surf_pts_uniform_n_iter_reproj', self.config, 5)
        filter_thr_reproj = set_else_default('surf_pts_uniform_filter_thr_reproj', self.config, 1e-3) ## lower thr requires more n_iter
        stepsize = set_else_default('surf_pts_uniform_stepsize', self.config, 0.75) ## .75 worked well with 8 and 16 nns

        for i_iter in range(num_iters):
            if debug:
                if i_iter>0:
                    self.logger.debug(f'iter: {i_iter} \t density: {density_w.mean().item():.3f} \t nof pts: {len(points_init)}')

            for i_shape in range(len(points_init.get_map())):

                start_t = time.time()
                points = points_init.pts_of_shape(i_shape)

                normals_init = self.get_normals(points_init, z)
                normals = normals_init.pts_of_shape(i_shape)
                num_points = points.shape[0]
                
                ## NOTE: not sure if this should be recomputed every iteration if the nof points doesn't change much
                diag = (points.view(-1, 3).max(dim=0).values - points.view(-1, 3).min(0).values).norm().item()
                if diag < 1e-6: ## Fail if the diagonal is too small
                    return False, None
                inv_sigma_spatial = num_points / diag

                knn_indices = self.get_nn_idcs(points, self.knn_k) # [n_points, k]
                knn_nn = points[knn_indices] # [n_points, k, 3]                
                knn_diff = points.unsqueeze(1) - knn_nn  # [n_points, k, 3]
                knn_dists_sq = torch.sum(knn_diff**2, dim=-1)  # [n_points, k]
                spatial_w = torch.exp(-knn_dists_sq * inv_sigma_spatial)  # [n_points, k]
                move = torch.sum(spatial_w[..., None] * knn_diff, dim=-2)

                if debug:
                    ## Store the previous points for debugging
                    density_w = torch.sum(spatial_w, dim=-1, keepdim=True)  # [n_points, 1] ## NOTE: can change sum to mean to make invariant to number of neighbors 

                ## Project the move onto the tangential plane
                move -= (move * normals) * normals
                ## Scale the update
                move *= stepsize ## the update size is a hyperparameter. Larger steps needs better reprojection

                ## Update the points
                points += move 
                points_init.set_pts_of_shape(i_shape, points)
                
                print(f'move done in {time.time() - start_t:.3f} s')
            
            start_t = time.time()
            ## Reproject
            ## NOTE: the majoriy of time is spent here
            success, ret = self.flow_to_surface_pts(
                points_init,
                z,
                lr=None,
                n_iter=n_iter_reproj,
                plot_descent=False, 
                use_newton=True, 
                surpress_tqdm=set_else_default('surf_pts_uniform_reproject_surpress_tqdm', self.config, True),
                filter_thr=filter_thr_reproj,
                )
            if success:
                points_init, _ = ret
            else:
                self.logger.debug("No points left after reprojection. Try decreasning the update size, increasing the number of reprojection iterations or decreasing the filtering threshold")
                return False, None 
            if len(points_init) < set_else_default('surf_pts_uniform_min_count', self.config, 1000):
                ## Stop redistributing if there are not enough points.
                ## Better return failure so that the integrals don't have a high variance
                False, None
                
            print(f'reproject done in {time.time() - start_t:.3f} s')
        
        return True, points_init


    def get_grid_starting_pts(self, x_grid, grid_dist):
        '''
        Create grid once at the beginning.
        Translate the grid by a random offset.
        '''
        ## Translate the grid by a random offset
        xc_offset = torch.rand((self.config['ginn_bsize'], self.config['nx'])) * grid_dist  # bz nx

        # x_grid: [n_points nx]
        x = x_grid.unsqueeze(0) + xc_offset.unsqueeze(1)  # bz n_points nx

        ## Translate each point by a random offset
        x += torch.randn(x_grid.shape) * grid_dist / 3

        return PointWrapper.create_from_equal_bx(x)


t0 = time.time()
boundary_helper = ShapeBoundaryHelper(cfg, nep, mp_manager=mpm, plot_helper=DummyPlotHelper(), timer_helper=t_helper, x_interface=p_sampler.sample_from_interface()[0], device=device)
p_surface, _ = boundary_helper.get_surface_pts(z)
print(f"found pts in {(time.time() - t0):.3f} s"); t0 = time.time()
p_orig = p_surface.data.clone().detach()
# p_resampled, move, knn_diff, density_w, knn_dists_sq, knn_indices, points_prev = boundary_helper.resample(p_surface, z, num_iters=10, debug=True)
success, p_resampled = boundary_helper.resample(p_surface, z, num_iters=10, debug=True)
print(f"redistributed pts in {(time.time() - t0):.3f} s"); t0 = time.time()

move done in 0.009 s
move done in 0.022 s
reproject done in 0.104 s
move done in 0.009 s
move done in 0.025 s
reproject done in 0.103 s
move done in 0.008 s
move done in 0.023 s
reproject done in 0.101 s
move done in 0.009 s
move done in 0.022 s
reproject done in 0.101 s
move done in 0.009 s
move done in 0.022 s
reproject done in 0.103 s
move done in 0.009 s
move done in 0.023 s
reproject done in 0.092 s
move done in 0.008 s
move done in 0.020 s
reproject done in 0.071 s
move done in 0.007 s
move done in 0.020 s
reproject done in 0.065 s
move done in 0.010 s
move done in 0.019 s
reproject done in 0.066 s
move done in 0.007 s
move done in 0.020 s
reproject done in 0.066 s
found pts in 1.369 s
move done in 0.009 s
move done in 0.017 s
reproject done in 0.073 s
move done in 0.008 s
move done in 0.018 s
reproject done in 0.069 s
move done in 0.008 s
move done in 0.020 s
reproject done in 0.066 s
move done in 0.009 s
move done in 0.019 s
reproject done in 0.061 s
move done in 0.008 s
move d

In [26]:
verts, faces = meshes[0]
color = 0xbbbbbb


fig = k3d.plot(height=1000) #, camera_fov=1.0)
n_cols = 2

bounds = cfg['bounds'].cpu().numpy()
shape_grid = 1.5*(bounds[:,1] - bounds[:,0]) ## distance between the shape grid for plotting

for i_shape in range(len(meshes)):
    i_col = (i_shape  % n_cols)
    i_row = (i_shape // n_cols)
    verts, faces = meshes[i_shape]
    group = f'Shape {i_shape}'
    fig += k3d.mesh(verts, faces, group=group, color=color, side='double', flat_shading=False, opacity=1.0, name=f"Shape_{i_shape}", translation=[0, shape_grid[1]*i_col, shape_grid[2]*i_row])
    fig += k3d.points(p_resampled.pts_of_shape(i_shape).cpu().numpy(), point_size=0.003, name=f'pts_{i_shape}', translation=[0, shape_grid[1]*i_col, shape_grid[2]*i_row])
    

# fig += k3d.mesh(verts, faces, color=color, side='double', flat_shading=False, opacity=1.0, name="Shape")

## Debug nearest-neighbor
# knn_edges = torch.stack((
#     torch.arange(knn_indices.shape[0]).repeat(knn_indices.shape[1],1).T, 
#     knn_indices)).flatten(start_dim=1).T.cpu().numpy()
# fig += k3d.lines(points_prev.cpu().numpy(), knn_edges, indices_type='segment', shader="simple", width=0.01, color=0x0000ff)

## Debug density
# fig += k3d.points(points_prev.cpu().numpy(), point_size=0.003, attribute=density_w.detach().cpu().numpy(), name='pts') ## can also plotmean_knn_dists

## Debug directions
# fig += k3d.vectors(points_prev.cpu().numpy(), move.detach().cpu().numpy(), color=0x0000ff, line_width=0.0001, head_size=0.01, name='move')

## Debug resampling
# fig += k3d.points(p_orig.detach().cpu().numpy(), point_size=0.003, color=0x00ff00, name='pts')
# fig += k3d.points(p_resampled.data.detach().cpu().numpy(), point_size=0.003, color=0xff0000, name='resampled')
# fig += k3d.points(p_surface.data.detach().cpu().numpy(), point_size=0.003, color=0x0000ff, name='surface')
# fig += k3d.points(p_reproj.data.detach().cpu().numpy(), point_size=0.003, color=0x0000ff, name='reproj')
# fig += k3d.vectors(p_resampled.detach().cpu().numpy(), (p_resampled.data - p_reproj.data).detach().cpu().numpy(), color=0x0000ff, line_width=0.0001, head_size=0.01, name='move')

## Plot trajectories
def trajectories_to_verts_inds(trajectories):
    vertices = trajectories.reshape(-1,3).cpu().numpy()
    # K steps, P points, 3 coordinates
    K, P, _ = trajectories.shape
    start_indices = np.arange((K - 1) * P)
    end_indices = start_indices + P
    # Stack them to form the pairs of indices for the lines
    indices = np.vstack([start_indices, end_indices]).T
    return vertices, indices
## Filter trajectories
# converged_mask = x_path_over_iters[-1,:,0].isnan()
# trajs = x_path_over_iters #x_path_over_iters[:,~converged_mask][:,::1]
# fig += k3d.lines(*trajectories_to_verts_inds(trajs), indices_type='segment', shader="simple", width=0.01, color=0x0000ff)
# fig += k3d.points(p_in.data.detach().cpu().numpy(), point_size=0.003, color=0x00ff00, name='pts')
# fig += k3d.points(x_path_over_iters[0].cpu().numpy(), point_size=0.003, color=0xff0000, name='pts_init')


## Bounds
# from k3d import platonic
# x_min, x_max, y_min, y_max, z_min, z_max = boundary_helper.bounds.flatten().cpu().numpy()
# cube = platonic.Cube()
# cube.vertices = np.array([
#     [x_max, y_max, z_max],
#     [x_max, y_max, z_min],
#     [x_max, y_min, z_max],
#     [x_max, y_min, z_min],
#     [x_min, y_max, z_max],
#     [x_min, y_max, z_min],
#     [x_min, y_min, z_max],
#     [x_min, y_min, z_min],
#     ])
# fig += k3d.mesh(cube.vertices, cube.indices, side="double", opacity=.5, color=0xff0000)

fig.display()



Output()