In [1]:
import os
import sys
import time
import numpy as np
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.measure
import plyfile
from plyfile import PlyData
from sklearn.neighbors import KDTree
import trimesh
import torch_geometric
from torch_geometric.nn import (NNConv, GMMConv, GraphConv, Set2Set)
from torch_geometric.nn import (SplineConv, graclus, max_pool, max_pool_x, global_mean_pool)
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore")

In [2]:
def load_pressure_predictor(load_directory):
    model = SplineCNN8Residuals(3)
    model.load_state_dict(torch.load(load_directory + "/cfdModel.nn"))
    model = model.to("cuda:0").eval()
    return model

def load_latent_vectors(load_directory, checkpoint):
    filename = os.path.join(
        load_directory, checkpoint + ".pth"
    )
    if not os.path.isfile(filename):
        raise Exception(
            "The experiment directory ({}) does not include a latent code file"
            + " for checkpoint '{}'".format(load_directory, checkpoint)
        )
    data = torch.load(filename)
    return data["latent_codes"].cuda()

def load_decoder(load_directory, checkpoint):
    specs_filename = os.path.join(load_directory, "specs.json")
    if not os.path.isfile(specs_filename):
        raise Exception(
            'The experiment directory does not include specifications file "specs.json"'
        )
    specs = json.load(open(specs_filename))
    latent_size = specs["CodeLength"]
    decoder = Decoder(latent_size, **specs["NetworkSpecs"])
    decoder = torch.nn.DataParallel(decoder)
    saved_model_state = torch.load(os.path.join(load_directory, checkpoint +".pth"))
    decoder.load_state_dict(saved_model_state["model_state_dict"])
    decoder = decoder.module.cuda()
    decoder.eval()
    return decoder

In [3]:
from abc import ABC, abstractmethod

class objective_func(ABC):
    @abstractmethod
    def func(self, x):
        pass
    def dfunc(self, x):
        out = self.func(x)
        out.backward()
        return x.grad
    def get_optimal(self):
        return self.optimal
    def get_optimum(self):
        return self.optimum
    

class Decoder(nn.Module):
    def __init__(
        self,
        latent_size,
        dims,
        dropout=None,
        dropout_prob=0.0,
        norm_layers=(),
        latent_in=(),
        weight_norm=False,
        xyz_in_all=None,
        use_tanh=False,
        latent_dropout=False,
    ):
        super(Decoder, self).__init__()

        def make_sequence():
            return []

        dims = [latent_size + 3] + dims + [1]

        self.num_layers = len(dims)
        self.norm_layers = norm_layers
        self.latent_in = latent_in
        self.latent_dropout = latent_dropout
        if self.latent_dropout:
            self.lat_dp = nn.Dropout(0.2)

        self.xyz_in_all = xyz_in_all
        self.weight_norm = weight_norm

        for layer in range(0, self.num_layers - 1):
            if layer + 1 in latent_in:
                out_dim = dims[layer + 1] - dims[0]
            else:
                out_dim = dims[layer + 1]
                if self.xyz_in_all and layer != self.num_layers - 2:
                    out_dim -= 3

            if weight_norm and layer in self.norm_layers:
                setattr(
                    self,
                    "lin" + str(layer),
                    nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
                )
            else:
                setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))

            if (
                (not weight_norm)
                and self.norm_layers is not None
                and layer in self.norm_layers
            ):
                setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))

        self.use_tanh = use_tanh
        if use_tanh:
            self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

        self.dropout_prob = dropout_prob
        self.dropout = dropout
        self.th = nn.Tanh()

    # input: N x (L+3)
    def forward(self, input):
        xyz = input[:, -3:]

        if input.shape[1] > 3 and self.latent_dropout:
            latent_vecs = input[:, :-3]
            latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
            x = torch.cat([latent_vecs, xyz], 1)
        else:
            x = input

        for layer in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(layer))
            if layer in self.latent_in:
                x = torch.cat([x, input], 1)
            elif layer != 0 and self.xyz_in_all:
                x = torch.cat([x, xyz], 1)
            x = lin(x)
            # last layer Tanh
            if layer == self.num_layers - 2 and self.use_tanh:
                x = self.tanh(x)
            if layer < self.num_layers - 2:
                if (
                    self.norm_layers is not None
                    and layer in self.norm_layers
                    and not self.weight_norm
                ):
                    bn = getattr(self, "bn" + str(layer))
                    x = bn(x)
                x = self.relu(x)
                if self.dropout is not None and layer in self.dropout:
                    x = F.dropout(x, p=self.dropout_prob, training=self.training)

        if hasattr(self, "th"):
            x = self.th(x)

        return x
class SplineBlock(nn.Module):
    def __init__(self, num_in_features, num_outp_features, mid_features, kernel=3, dim=3, batchnorm1=True):
        super(SplineBlock, self).__init__()
        self.batchnorm1 = batchnorm1
        self.conv1 = SplineConv(num_in_features, mid_features, dim, kernel, is_open_spline=False)
        if self.batchnorm1:
            self.batchnorm1 = torch.nn.BatchNorm1d(mid_features)
        self.conv2 = SplineConv(mid_features, 2 * mid_features, dim, kernel, is_open_spline=False)
        self.batchnorm2 = torch.nn.BatchNorm1d(2 * mid_features)
        self.conv3 = SplineConv(2 * mid_features + 3, num_outp_features, dim, kernel, is_open_spline=False)
  
    def forward(self, res, data):
        if self.batchnorm1:
            res = F.elu(self.batchnorm1(self.conv1(res, data['edge_index'], data['edge_attr'])))
        else:
            res = F.elu(self.conv1(res, data['edge_index'], data['edge_attr']))
        res = F.elu(self.batchnorm2(self.conv2(res, data['edge_index'], data['edge_attr'])))
#         res = F.elu(self.conv2(res, data.edge_index, data.edge_attr))
        res = torch.cat([res, data['x']], dim=1)
        res = self.conv3(res, data['edge_index'], data['edge_attr'])
        return res

class SplineCNN8Residuals(nn.Module):
    def __init__(self, num_features, kernel=3, dim=3):
        super(SplineCNN8Residuals, self).__init__()
        self.block1 = SplineBlock(num_features, 16, 8, kernel, dim)
        self.block2 = SplineBlock(16, 64, 32, kernel, dim)
        self.block3 = SplineBlock(64, 64, 128, kernel, dim)
        self.block4 = SplineBlock(64, 8, 16, kernel, dim)
        self.block5 = SplineBlock(11, 32, 16, kernel, dim)
        self.block6 = SplineBlock(32, 64, 32, kernel, dim)
        self.block7 = SplineBlock(64, 64, 128, kernel, dim)
        self.block8 = SplineBlock(75, 4, 16, kernel, dim)

    def forward(self, data):
        res = data['x']
        res = self.block1(res, data)
        res = self.block2(res, data)
        res = self.block3(res, data)
        res4 = self.block4(res, data)
        res = torch.cat([res4, data['x']], dim=1)
        res = self.block5(res, data)
        res = self.block6(res, data)
        res = self.block7(res, data)
        res = torch.cat([res, res4, data['x']], dim=1)
        res = self.block8(res, data)
        return res

def create_mesh(
    decoder, latent_vec, filename='', N=256, max_batch=32 ** 3, offset=None, scale=None
):
    ply_filename = filename

    decoder.eval()

    # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
    voxel_origin = [-1, -1, -1]
    voxel_size = 2.0 / (N - 1)

    overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
    samples = torch.zeros(N ** 3, 4)

    # transform first 3 columns
    # to be the x, y, z index
    samples[:, 2] = overall_index % N
    samples[:, 1] = (overall_index.long() // N) % N
    samples[:, 0] = ((overall_index.long() // N) // N) % N

    # transform first 3 columns
    # to be the x, y, z coordinate
    samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
    samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
    samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]

    num_samples = N ** 3

    samples.requires_grad = False
    head = 0

    while head < num_samples:
        sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()
        num_subsample = min(max_batch, num_samples-head)
        latent_repeat = latent_vec.expand(num_subsample, -1)
        inputs = torch.cat([latent_repeat, sample_subset], 1)
        samples[head : min(head + max_batch, num_samples), 3] = \
                decoder(inputs).squeeze(1).detach().cpu()
        head += max_batch
        
    sdf_values = samples[:, 3].reshape(N, N, N).data.cpu()

    return convert_sdf_samples_to_ply(
        sdf_values,
        voxel_origin,
        voxel_size,
        ply_filename + ".ply",
        offset,
        scale,
    )

def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    voxel_grid_origin,
    voxel_size,
    ply_filename_out,
    offset=None,
    scale=None,
):
    """
    Convert sdf samples to .ply

    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
    :voxel_size: float, the size of the voxels
    :ply_filename_out: string, path of the filename to save to

    This function adapted from: https://github.com/RobotLocomotion/spartan
    """

    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()

    verts, faces, normals, values = skimage.measure.marching_cubes_lewiner(
        numpy_3d_sdf_tensor, level=0.0, spacing=[voxel_size] * 3
    )

    # transform from voxel coordinates to camera coordinates
    # note x and y are flipped in the output of marching_cubes
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
    mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
    mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]

    # apply additional offset and scale
    if scale is not None:
        mesh_points = mesh_points / scale
    if offset is not None:
        mesh_points = mesh_points - offset

    # try writing to the ply file

    num_verts = verts.shape[0]
    num_faces = faces.shape[0]

    verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
    norms_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

    for i in range(0, num_verts):
        verts_tuple[i] = tuple(mesh_points[i, :])
        norms_tuple[i] = tuple(normals[i, :])

    faces_building = []
    for i in range(0, num_faces):
        faces_building.append(((faces[i, :].tolist(),)))
    faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

    el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
    el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
    el_norms = plyfile.PlyElement.describe(norms_tuple, "normals")

    ply_data = plyfile.PlyData([el_verts, el_faces, el_norms])
    return ply_data

def compute_lift_faces_diff(mesh, preds):
    pressures = torch.mean(preds[mesh['face'], 0], axis=0)

    # TODO: cahnge to x if needed
    pos = mesh['x']
    cross_prod = (pos[mesh['face'][1]] - pos[mesh['face'][0]]).cross(
                  pos[mesh['face'][2]] - pos[mesh['face'][0]])
    area = -cross_prod[:, 0] / 2
    lift = torch.mul(pressures, area)
    return torch.sum(lift[~torch.isnan(lift)])

def boundsLoss(points, box=[(-1, 1, 0)]):
    loss = 0
    for l, r, i in box:
        loss +=  torch.mean(F.relu(-points[:, i] + l))  \
               + torch.mean(F.relu( points[:, i] - r))
    return loss

def innerBoundsLoss(points, r=1, center=(0, 0, 0)):
    radiuses = torch.sum( (points - torch.Tensor(center).to('cuda:0')) ** 2 , dim=1)
    return torch.mean(F.relu(r - radiuses))

def calculate_loss(mesh, local_preds, constraint_rad=0.1):
    loss = compute_lift_faces_diff(mesh, local_preds)
    first = loss.clone().detach().cpu().numpy()
    loss += boundsLoss(mesh['x'], box=[(-0.6, 0.6, 0)])
    second = loss.clone().detach().cpu().numpy()
    loss += innerBoundsLoss(mesh['x'], r=constraint_rad**2, center=(-0.05, 0.05, 0))  \
          + innerBoundsLoss(mesh['x'], r=(constraint_rad / 2)**2, center=(0.3, 0, 0))
    print("three parts (321) of loss: %.3f, %.3f, %.3f"%(loss.detach().cpu().numpy() - second, second-first, first))
    return loss

def transformPoints(points, AvgTransform):
    matrix = torch.cuda.FloatTensor(AvgTransform)
    column = torch.zeros((len(points), 1), device="cuda:0") + 1
    stacked = torch.cat([points, column], dim=1)
    transformed = torch.matmul(matrix, stacked.t()).t()[:, :3]
    return transformed

def transform_mesh(points, ply_mesh, AvgTransform):
    transformed_points = transformPoints(points, AvgTransform)
    
    edges = trimesh.geometry.faces_to_edges(ply_mesh['face']['vertex_indices'])
    np_points = transformed_points.cpu().detach().numpy()
    edge_attr = [np_points[a] - np_points[b] for a, b in edges]
    mesh = {'x': transformed_points, 
        'face':torch.tensor(ply_mesh['face']['vertex_indices'], dtype=torch.long).to('cuda:0').t(),
        'edge_attr':torch.tensor(edge_attr, dtype=torch.float).to('cuda:0'),
        'edge_index':torch.tensor(edges, dtype=torch.long).t().contiguous().to('cuda:0')
        }
    return mesh


def decode_sdf(decoder, latent_vector, queries):
    num_samples = queries.shape[0]

    if latent_vector is None:
        inputs = queries
    else:
        latent_repeat = latent_vector.expand(num_samples, -1)
        inputs = torch.cat([latent_repeat, queries], 1)

    sdf = decoder(inputs)

    return sdf
class single_experiment:
    def __init__(self, tol=0.1):
        self.tol = tol
    def set_objective(self, objective_func):
        self.objective_func = objective_func

    def set_optimizer(self, optimizer):
        self.optimizer = optimizer

    def do(self):
        optimal, optimum, statistics = self.optimizer.optimise(self.objective_func)
        dist_arg = np.linalg.norm(optimal.detach().cpu().numpy() - self.objective_func.get_optimal())
        dist_val = np.linalg.norm(optimum.detach().cpu().numpy() - self.objective_func.get_optimum())
        if  dist_arg < self.tol \
        or  dist_val < self.tol:
            statistics['status'] = 'global minimum'
        elif statistics['status'] != 'diverge':
            statistics['status'] = 'local minimum'
        print("distance domain, codomain: ", dist_arg, dist_val)
        if self.optimizer.verbose:
            print("Result: ", statistics['status'])
            print("found minimum: {}, minimum position: {}, evals: {}".format(optimum, torch.norm(optimal).item(), statistics['evals']))
        if self.optimizer.record == False:
            return statistics['status'], optimum, optimal, statistics['evals']
        else:
            statistics['optimal'] = self.objective_func.get_optimal()
            statistics['optimum'] = self.objective_func.get_optimum()
            statistics['found_optimal'] = optimal
            statistics['found_optimum'] = optimum
            return statistics
            

def get_trimesh_from_torch_geo_with_colors(mesh, preds, vmin=-8, vmax=8):
    norm = mpl.colors.Normalize(vmin= vmin, vmax=vmax)
    cmap = cm.hot
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    
    verticies = mesh['x'].cpu().detach()
    faces = mesh['face'].t().cpu().detach()
    return trimesh.Trimesh(vertices=verticies, faces=faces, 
                           vertex_colors=list(map(lambda c: m.to_rgba(c),  preds[:, 0].cpu().detach())))



In [4]:
DIR_to_load_data = 'starting_data'
experiment_directory = "data_for_this_experiments"

predictor = load_pressure_predictor(DIR_to_load_data)

decoder = load_decoder(DIR_to_load_data, "decoderModel")

latent_vectors = load_latent_vectors(DIR_to_load_data, "latentCodes").detach()

AvgTransform = np.load(DIR_to_load_data + "/avg_trans_matrix.npy") #computeAvgTransform()

LATENT_TO_OPTIMIZE = latent_vectors[32]
LATENT_KD_TREE = KDTree(np.array([lv.cpu().detach().numpy()[0] for lv in latent_vectors]))
# /cvlabdata2/home/artem/Data/cars_remeshed_dsdf/transforms/"

In [5]:
def visual_Mesh(ilatent, N):
    ply_mesh = create_mesh(decoder,
                        ilatent,
                        N=N,
                        max_batch=int(2 ** 8))
    points = torch.cuda.FloatTensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                            ply_mesh['vertex']['y'][:, None], 
                                            ply_mesh['vertex']['z'][:, None])))
    scaled_mesh = transform_mesh(points, ply_mesh, AvgTransform)
    pressure_field = predictor(scaled_mesh)
    loss = compute_lift_faces_diff(scaled_mesh, pressure_field)  
    print("latent loss. %f "%(loss))
    return get_trimesh_from_torch_geo_with_colors(scaled_mesh, pressure_field)

In [6]:

class cma_es():
    def __init__(self, dim=2):
        self.dim = dim
        paras = {'x0': torch.zeros((dim,)),
                 'std': torch.ones((dim,)) * 3, 
                 'tol': 1e-5, 
                 'adjust_func': None, 
                 'record': False, 
                 'verbose': False}
        self.set_parameters(paras)
    def set_parameters(self, paras):
        self.paras = paras
        self.x0 = paras['x0'] 
        self.std = paras['std']
        self.tol = paras['tol']
        self.adjust_func = paras['adjust_func']
        self.max_iter = 400 if 'max_iter' not in paras.keys() else paras['max_iter']
        # set none to use default value 
        self.cluster_size = None if 'cluster_size' not in paras.keys() else paras['cluster_size']
        self.survival_size = None if 'survival_size' not in paras.keys() else paras['survival_size']
        self.record = True if 'record' not in paras.keys() else paras['record']
        self.verbose = True if 'verbose' not in paras.keys() else paras['verbose']
    def optimise(self, obj):
        '''
        @param obj: objective function class instance
        return arg: found minimum arguments
               val: found minimum value
               stats: collection of recorded statistics for post-analysis
        '''                  
        def update_mean(x):
            return (weights @ x).reshape(dim, 1)
        def update_ps(ps, sigma, C, mean, mean_old):
            return (1 - cs) * ps + torch.sqrt(cs * (2 - cs) * mueff) * invsqrtC @ (mean - mean_old) / sigma 
        def update_pc(pc, sigma, ps, mean, mean_old):
            hsig = (torch.norm(ps) / torch.sqrt(1 - (1 - cs)**(2 * iter_/lambda_)) / chiN < 1.4 + 2/(dim + 1)).int()
            return (1 - cc) * pc + hsig * torch.sqrt(cc * (2 - cc) * mueff) * (mean - mean_old) / sigma
        def update_C(C, pc, x, mean_old, sigma):
            hsig = (torch.norm(ps) / torch.sqrt(1 - (1 - cs)**(2 * iter_/lambda_)) / chiN < (1.4 + 2/(dim + 1))).int()
            artmp = (1 / sigma) * (x - mean_old.reshape(1, dim))
            return (1 - c1 - cmu) * C + c1 * (pc * pc.transpose(1,0) + (1 - hsig) * cc * (2 - cc) * C) + cmu * artmp.transpose(1,0) @ torch.diag(weights) @ artmp
        def update_sigma(sigma, ps):
            return sigma * torch.exp((cs / damps) * (torch.norm(ps)/ chiN - 1))
        def is_not_moving(arg, val, pre_arg, pre_val, tol):
            dis_arg = torch.norm(arg - pre_arg, dim=1).mean()
            dis_val = torch.abs(val - pre_val).mean()
            return (dis_arg < tol and dis_val < tol) 

        if self.verbose:
            print("\n\n*******starting optimisation from intitial mean: ", torch.norm(self.x0).detach().cpu().numpy())
        # User defined input parameters 
        dim = self.dim
        sigma = 0.3
        D = self.std / sigma
        mean = self.x0.reshape(dim, 1).detach()
        # the size of solutions group
        lambda_ = 4 + int(3 * np.log(dim)) if self.cluster_size == None else self.cluster_size  
        # only best "mu" solutions are used to generate iterations
        mu = int(lambda_ / 2) if self.survival_size == None else self.survival_size
        # used to combine best "mu" solutions                                               
        weights = torch.log(mu + 1/2) - torch.log(torch.arange(mu, dtype=torch.float) + 1) 
        weights = (weights / torch.sum(weights)).cuda()   
        mueff = 1 / torch.sum(weights**2) 

        # Strategy parameter setting: Adaptation
        # time constant for cumulation for C
        cc = (4 + mueff / dim) / (dim + 4 + 2 * mueff / dim)  
        # t-const for cumulation for sigma control
        cs = (mueff + 2) / (dim + mueff + 5)  
        # learning rate for rank-one update of C
        c1 = 2 / ((dim + 1.3)**2 + mueff)    
        # and for rank-mu update
        cmu = min(1 - c1, 2 * (mueff - 2 + 1 / mueff) / ((dim + 2)**2 + mueff))  
        # damping for sigma, usually close to 1  
        damps = 1 + 2 * max(0, torch.sqrt((mueff - 1)/( dim + 1)) - 1) + cs     


        # Initialize dynamic (internal) strategy parameters and constants
        # evolution paths for C and sigma
        pc = torch.zeros((dim, 1), device=torch.device('cuda:0'))     
        ps = torch.zeros((dim, 1), device=torch.device('cuda:0')) 
        # B defines the coordinate system
        B = torch.eye(int(dim), device=torch.device('cuda:0'))       
        # covariance matrix C
        C = B * torch.diag(D**2) * B.transpose(1, 0)
        # C^-1/2 
        invsqrtC = B * torch.diag(D**-1) * B.transpose(1, 0)
        # expectation of ||N(0,I)|| == norm(randn(N,1)) 
        chiN = dim**0.5 * (1 - 1/(4 * dim) + 1 / (21 * dim**2))  

        # --------------------  Initialization --------------------------------  
        x, x_old, fs = torch.zeros((lambda_, dim), device=torch.device('cuda:0')),  \
                        torch.zeros((lambda_, dim), device=torch.device('cuda:0')), \
                        torch.zeros((lambda_,), device=torch.device('cuda:0'))
        stats = {}
        inner_stats = {}
        stats['inner'] = []
        stats['val'], stats['arg'] = [], []
        stats['x_adjust'] = []
        iter_eval, stats['evals'] = torch.zeros((lambda_,)), []
        inner_stats = [{}] * lambda_
        stats['mean'], stats['std'] = [], []
        stats['status'] = None
        iter_, eval_ = 0, 0
        # initial data in record
        cand = mean.squeeze()
        #fs[0] = obj.func(cand)
        for i in range(lambda_):
            fs[i] = 100
            x[i] = cand
            x_old[i] = cand
        idx = 0
        x_ascending = x[idx]
        arg = None
        val = fs[idx]
        pre_arg = x_ascending
        pre_val = fs[idx]
        best_val = fs[0] + 1e2
        best_arg = x[0,:]
        sum_eval = 0
        # optimise by iterations
        while iter_ < self.max_iter:
            iter_ += 1
            # generate candidate solutions with some stochastic elements
            for i in range(lambda_):
                candidate_old = (mean + sigma * B @ torch.diag(D) @ torch.randn(dim, 1).cuda()).reshape(1,-1)
                print("candidate: ", candidate_old.shape)
                print(candidate_old)
                candidate_new, val, inner_stats[i] = obj.func(candidate_old.requires_grad_(True))
                del candidate_old 
                x[i] = candidate_new.detach()
                fs[i] = val.detach()

                eval_ += inner_stats[i]['evals']
                iter_eval[i] = inner_stats[i]['evals']
           # sort the value and positions of solutions 
            idx = torch.argsort(fs)
            x_ascending = x[idx]

            # update the parameter for next iteration
            mean_old = mean
            mean = update_mean(x_ascending[:mu])
            # print("mean old and new: ", mean_old, mean)
            ps =   update_ps(ps, sigma, C, mean, mean_old)
            pc =   update_pc(pc, sigma, ps, mean, mean_old)
            sigma = update_sigma(sigma, ps)

            C =    update_C(C, pc, x_ascending[:mu], mean_old, sigma)
            C = (torch.triu(C) + torch.triu(C, 1).transpose(1,0))
            D, B = torch.eig(C, eigenvectors=True)
            D = torch.sqrt(D[:,0])
            invsqrtC = B @ torch.diag(D**-1) @ B.transpose(1,0)
            arg = x_ascending
            val = fs[idx]
            if self.verbose:
                print("**************** cma iter: ", iter_, "**********************")
                print("loss: %.5f"%val[0].item())
                print("evals: ", iter_eval.sum())
                #print("latent: ", x_ascending[0].cpu().numpy())
                #print("mean: ", mean)
                #print("sigma: ", sigma)
                #print("std: ", D)
                print("\n")
            # record data during process for post analysis
            if self.record:
                #stats['inner'].append(inner_stats)
                stats['arg'].append(x_ascending[0].cpu().numpy())
                stats['val'].append(fs[idx].detach().cpu().numpy())
                #stats['mean'].append(mean.cpu().numpy())
                #stats['std'].append((sigma * B @ torch.diag(D)).cpu().numpy())
                sum_eval += iter_eval.sum() 
                stats['evals'].append(sum_eval)
                #stats['x_adjust'].append(np.vstack((x.transpose(1,0).cpu().numpy(), x_old.transpose(1,0).cpu().numpy())))
            # stopping condition  
            if best_val > val[0]:
                best_val = val[0]
                best_arg = arg[0]              
            # check the stop condition
            if torch.max(D) > (torch.min(D) * 1e4):
                stats['status'] = 'diverge'
                print('diverge, concentrate in low dimension manifold')
                break
            if is_not_moving(arg, val, pre_arg, pre_val, self.tol) :
                break
            pre_arg = arg
            pre_val = val
        if self.verbose:
            #print('eigenvalue of variance = {}'.format(D))
            print('total iterations = {}, total evaluatios = {}'.format(iter_, eval_))
            print('found minimum position = {}, found minimum = {}'.format(best_arg.detach().cpu().numpy()[:10], best_val.detach().cpu().numpy()))

        # carry statistics info before quit
        if self.record:
            stats['arg'] = np.array(stats['arg'])
            stats['val'] = np.array(stats['val'])
            #stats['mean'] = np.array(stats['mean'])
            #stats['std'] = np.array(stats['std'])
            stats['evals'] = np.array(stats['evals'])
            #stats['x_adjust'] = np.array(stats['x_adjust'])
        stats['evals'] = eval_
        return best_arg, best_val, stats
    
    def update_mean(self, xs):
        print(self.WEIGHTS.shape, xs.shape)
        return (self.WEIGHTS @ xs).reshape(self.DIM, 1)
    def update_ps(self):
        return (1 - self.CS) * self.ps + torch.sqrt(self.CS * (2 - self.CS) * self.MUEFF) * self.invsqrtC @ (self.mean - self.mean_old) / self.sigma 
    def update_pc(self, iter_):
        hsig = (torch.norm(self.ps) / torch.sqrt(1 - (1 - self.CS)**(2 * iter_/self.LAMBDA_)) / self.chiN < 1.4 + 2/(self.DIM + 1)).int()
        return (1 - self.CC) * self.pc + hsig * torch.sqrt(self.CC * (2 - self.CC) * self.MUEFF) * (self.mean - self.mean_old) / self.sigma
    def update_C(self, iter_, xs):
        hsig = (torch.norm(self.ps) / torch.sqrt(1 - (1 - self.CS)**(2 * iter_/self.LAMBDA_)) / self.chiN < (1.4 + 2/(self.DIM + 1))).int()
        artmp = (1 / self.sigma) * (xs - self.mean_old.reshape(1, self.DIM))
        return (1 - self.C1 - self.CMU) * self.c + self.C1 *  \
                (self.pc * self.pc.transpose(1,0) + (1 - hsig) * self.CC * \
                (2 - self.CC) * self.c) + self.CMU * artmp.transpose(1,0) @  \
                torch.diag(self.WEIGHTS) @ artmp
    def update_sigma(self):
        return self.sigma * torch.exp((self.CS / self.DAMPS) * (torch.norm(self.ps)/ self.chiN - 1))

    def init(self, latent):

        # User defined input parameters 
        self.DIM = latent.shape[1]
        # the size of solutions group
        self.LAMBDA_ = 4 + int(3 * torch.log(torch.tensor(self.DIM))) if self.cluster_size == None else self.cluster_size  
        # only best "mu" solutions are used to generate iterations
        self.MU = int(self.LAMBDA_ / 2) if self.survival_size == None else self.survival_size
        # used to combine best "mu" solutions                                               
        self.WEIGHTS = torch.log(torch.tensor(self.MU + 1/2)) - torch.log(torch.arange(self.MU, dtype=torch.float) + 1) 
        self.WEIGHTS = (self.WEIGHTS / torch.sum(self.WEIGHTS)).cuda()   
        self.MUEFF = 1 / torch.sum(self.WEIGHTS**2) 

        # Strategy parameter setting: Adaptation
        # time constant for cumulation for C
        self.CC = (4 +self. MUEFF / self.DIM) / (self.DIM + 4 + 2 * self.MUEFF / self.DIM)  
        # t-const for cumulation for sigma control
        self.CS = (self.MUEFF + 2) / (self.DIM + self.MUEFF + 5)  
        # learning rate for rank-one update of C
        self.C1 = 2 / ((self.DIM + 1.3)**2 + self.MUEFF)    
        # and for rank-mu update
        self.CMU = min(1 - self.C1, 2 * (self.MUEFF - 2 + 1 / self.MUEFF) / ((self.DIM + 2)**2 + self.MUEFF))  
        # damping for sigma, usually close to 1  
        self.DAMPS = 1 + 2 * max(0, torch.sqrt((self.MUEFF - 1)/( self.DIM + 1)) - 1) + self.CS     
        self.chiN = self.DIM**0.5 * (1 - 1/(4 * self.DIM) + 1 / (21 * self.DIM**2))  

        
        self.sigma = 0.3
        self.d = self.std / self.sigma
        self.mean = latent.detach().reshape(self.DIM,1)
        # Initialize dynamic (internal) strategy parameters and constants
        # evolution paths for C and sigma
        self.pc = torch.zeros((self.DIM, 1), device=torch.device('cuda:0'))     
        self.ps = torch.zeros((self.DIM, 1), device=torch.device('cuda:0')) 
        # B defines the coordinate system
        self.b = torch.eye(int(self.DIM), device=torch.device('cuda:0'))       
        # covariance matrix C
        self.c = self.b * torch.diag(self.d**2) * self.b.transpose(1, 0)
        # C^-1/2 
        self.invsqrtC = self.b * torch.diag(self.d**-1) * self.b.transpose(1, 0)
        # expectation of ||N(0,I)|| == norm(randn(N,1)) 
        self.xs = []
        self.fs = []
        self.j = 0
    def add_inner(self, opt):
        self.opt = opt
    def step(self, i, x, val, grad):
        if i % self.LAMBDA_ != 0 or i == 0 :
            self.xs.append(x.detach().squeeze())
            self.fs.append(val)
            #old = x.clone()
            #opt = torch.optim.SGD([x], lr=0.2)
            #opt.step()
            #return x
        else:
            idx = torch.argsort(torch.tensor(self.fs))
            x_ascending = torch.stack(self.xs)[idx]

            # update the parameter for next iteration
            self.mean_old = self.mean
            self.mean = self.update_mean(x_ascending[:self.MU])
            self.ps = self.update_ps()
            self.pc = self.update_pc(i)
            self.sigma = self.update_sigma()
            self.c = self.update_C(i, x_ascending[:self.MU])
            self.c = (torch.triu(self.c) + torch.triu(self.c, 1).transpose(1,0))
            self.d, self.b = torch.eig(self.c, eigenvectors=True)
            self.d = torch.sqrt(self.d[:,0])
            self.invsqrtC = self.b @ torch.diag(self.d**-1) @ self.b.transpose(1,0)
            self.xs = []
            self.fs = []
            self.j += 1
            print("********** %d th iter of CMA completed************"%(self.j))
        x_new = (self.mean + self.sigma * self.b @ torch.diag(self.d)  \
                 @ torch.randn(self.DIM, 1).cuda()).reshape(1,-1)
        return x_new.reshape(1, self.DIM)

In [7]:
def method4_to_arbitatry_loss(points, ply_mesh, model, constraint_rad=0.1, axis=0):
    initial_dir = points.grad.clone()
    points.grad.data.zero_()
    #points.requires_grad_(False)
    mesh = transform_mesh(points, ply_mesh, AvgTransform)
    #mesh['x'] = mesh['x'].detach().requires_grad_(True)
    local_preds = model(mesh)
    loss = calculate_loss(mesh, local_preds, constraint_rad=constraint_rad)
    loss.backward()

    sign = [-p1.dot(p2) for p1, p2 in zip(initial_dir, points.grad)]
    
    return sign, loss, local_preds, mesh

def optimize_shape_deepSDF(decoder, latent, initial_points=None, num_points=None, 
                           num_iters=100, point_iters=100, num_neignours_constr=10,
                           lr=0.2, decreased_by=2, adjust_lr_every=10, alpha_penalty=0.05,
                           multiplier_func=method4_to_arbitatry_loss, verbose=None, save_to_dir=None, N=256):

    def adjust_learning_rate(
        initial_lr, optimizer, num_iterations, decreased_by, adjust_lr_every
    ):
        lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every)) \
                        * ((punch_lr_at_reindex_by) ** (num_iterations // reindex_latent_each))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
            
        return lr
    
    if not os.path.exists(os.path.join(save_to_dir, 'meshes')):
        os.makedirs(os.path.join(save_to_dir, 'meshes'))
    if not os.path.exists(os.path.join(save_to_dir, 'predictions')):
        os.makedirs(os.path.join(save_to_dir, 'predictions'))

    ref_latent = latent.clone().detach()
    decoder.eval()
    latent = latent.clone()
    latent.requires_grad = True
    optimizer = cma_es(dim=256)
    optParas ={'x0': latent,
           'std': torch.ones((256,), device=torch.device('cuda:0')) * 0.03, 
           'tol': 1e-6, 
           'adjust_func': None, 
           'record': True, 
           'max_iter': 50,
           'cluster_size': 6,
           'verbose': True}
    optimizer.set_parameters(optParas)
    optimizer.init(latent)
    #inner_opt = torch.optim.SGD([latent], lr=lr)
    #inner_opt.zero_grad()
    #optimizer.add_inner(inner_opt)
    loss_plot = []
    penalty_plot = []
    latent_plot = []
    math_loss_plot = []
    for i in range(num_iters):
        time_start = time.time()
        
        #cur_rl = adjust_learning_rate(lr, inner_opt, i, decreased_by, adjust_lr_every)

            
        start = time.time()
        with torch.no_grad():
            ply_mesh = create_mesh( decoder,
                                    latent,
                                    N=N,
                                    max_batch=int(2 ** 18),
                                    offset=None,
                                    scale=None)
        end = time.time()
        print("mesh time: %.1f "%(end-start))

        points = torch.tensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                                    ply_mesh['vertex']['y'][:, None], 
                                                    ply_mesh['vertex']['z'][:, None]))).cuda(0)
        
        points.requires_grad = True

        sdf_value = decode_sdf(decoder, latent, points)
        sdf_value.backward(torch.ones([len(points), 1], dtype=torch.float32).cuda(0))

        mults, loss_value, preds, transformed_mesh = multiplier_func(points, ply_mesh)         
        multipliers = torch.cuda.FloatTensor(mults).cuda(0)
    
        latent.grad.zero_()
        sdf_value = torch.squeeze(decode_sdf(decoder, latent, points.detach()))
    
        final_loss = torch.sum(sdf_value * multipliers)
        
        
        #final_loss.backward()
       # first_deri = torch.norm(latent.grad).item()
        first_d = latent.grad.clone().squeeze()
        

        # Soft-constraints
        distances, indeces = LATENT_KD_TREE.query(latent.cpu().detach(), k=num_neignours_constr)
        apenalty = alpha_penalty * torch.sum((latent - latent_vectors[indeces.squeeze()]) ** 2, dim=2).mean()
        #apenalty.backward()
        #sum_d = latent.grad.squeeze()
        #second_d = sum_d - first_d
        #second_deri = torch.norm(second_d).item()



        math_loss = (apenalty + final_loss).detach()

        #save_path = os.path.join(save_to_dir, 'meshes/' + str(i).zfill(5) + ".ply")
        #preds_save_path = os.path.join(save_to_dir, 'predictions/' + str(i).zfill(5) + ".npy")
        #tri_mesh = get_trimesh_from_torch_geo_with_colors(transformed_mesh, preds)
        #tri_mesh.export(save_path)
        #np.save(preds_save_path, preds[:,0].cpu().detach().numpy())
        #np.save(os.path.join(save_to_dir, "latent_plot.npy"), latent_plot)    

        
        latent = optimizer.step(i, latent, math_loss.item(), None)
        latent = latent.detach().requires_grad_(True)
        
        #inner_opt.step()
        end_end = time.time()
        
        print("backward time: %.2f"%(end_end-end))
        
        math_loss_plot.append(math_loss)
        penalty_plot.append(apenalty)
        loss_plot.append(loss_value.cpu().detach().numpy())
        latent_plot.append(latent.clone())
        np.save(os.path.join(save_to_dir, "phy_loss_plot.npy"), loss_plot)    
        np.save(os.path.join(save_to_dir, "latent_series.npy"), latent_plot)    
        np.save(os.path.join(save_to_dir, "math_loss_plot.npy"), math_loss_plot)   

        
        if verbose is not None and i % verbose == 0:
            print('Iter ', i)
            #print("gradient: first %.3f second %.3f, full %.3f, angle %.2f "%(first_deri, second_deri, \
            #    torch.norm(sum_d), 90 / np.pi * torch.acos(first_d.dot(second_d) / torch.norm(first_d)/torch.norm(second_d))))
            print('phys Loss: %.5f'%loss_value.item())
            print('apenality: %.4f'%apenalty.item())
            print("math_loss: %.4f"%((apenalty + final_loss).item()))
        print("\n")
        if i > 2 and np.abs(loss_plot[-1] - loss_plot[-2]) < 1e-4:
            print("one time of low progress!")
            #break
        
    return loss_plot, math_loss_plot, penalty_plot, latent_plot



def make_full_transformation(initial_latent, experiment_name, 
                             decoder, model, alpha_penalty=0.05, constraint_rad=0.1, axis=0, **kwargs):
    '''
    kwargs:
        num_iters=1000, 
        adjust_lr_every=10, 
        decreased_by=1.2,
        lr=0.005
        verbose=10,
    '''

    #ref_points = get_points_from_latent(decoder, ref_latent, N=128)
    save_to_dir = experiment_name
    if not os.path.exists(save_to_dir):
        os.makedirs(save_to_dir)

    #np.save(os.path.join(save_to_dir, "target_verts.npy"), ref_points)

    return optimize_shape_deepSDF(decoder, initial_latent, initial_points=None,
                                           alpha_penalty=alpha_penalty,
                                           num_points=None, point_iters=2,
                                           multiplier_func=lambda x, y: 
                                               method4_to_arbitatry_loss(x, y, model, 
                                                                         constraint_rad=constraint_rad, 
                                                                         axis=axis),
                                           save_to_dir=save_to_dir, **kwargs)
   

In [None]:
LATENT_TO_OPTIMIZE = latent_vectors[32]
DIR_for_dump_data = './cma'
punch_lr_at_reindex_by=1
reindex_latent_each = 10000

np.random.seed(101)
torch.manual_seed(0)
%time resCMA = make_full_transformation(LATENT_TO_OPTIMIZE.detach(), \
                         experiment_name=DIR_for_dump_data, decoder=decoder, model=predictor, \
                         alpha_penalty=0.2, axis=0, \
                         constraint_rad=0.05,  \
                         num_iters=360,  \
                         adjust_lr_every=20,  \
                         decreased_by=1.1,  \
                         lr=0.2,  \
                         verbose=1, \
                         N=256,  \
                         num_neignours_constr=10)

mesh time: 11.4 
three parts (321) of loss: 0.000, 0.000, 0.097
backward time: 35.28
Iter  0
phys Loss: 0.09732
apenality: 0.0797
math_loss: 0.0796


mesh time: 11.3 
three parts (321) of loss: 0.000, 0.000, 0.070
backward time: 34.80
Iter  1
phys Loss: 0.07042
apenality: 0.1166
math_loss: 0.1165


mesh time: 11.2 
three parts (321) of loss: 0.000, 0.000, 0.087
backward time: 32.49
Iter  2
phys Loss: 0.08696
apenality: 0.1405
math_loss: 0.1404


mesh time: 11.1 
three parts (321) of loss: 0.000, 0.000, 0.093
backward time: 35.62
Iter  3
phys Loss: 0.09329
apenality: 0.1251
math_loss: 0.1250


mesh time: 11.3 
three parts (321) of loss: 0.000, 0.000, 0.093
backward time: 36.37
Iter  4
phys Loss: 0.09337
apenality: 0.1249
math_loss: 0.1248


one time of low progress!
mesh time: 11.1 
three parts (321) of loss: 0.000, 0.000, 0.085
backward time: 33.73
Iter  5
phys Loss: 0.08528
apenality: 0.1239
math_loss: 0.1239


mesh time: 11.2 
three parts (321) of loss: 0.000, 0.000, 0.106
torch.Size