In [None]:
import os
import sys
import time
import numpy as np
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.measure
import plyfile
from plyfile import PlyData
from sklearn.neighbors import KDTree
import trimesh
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import torch_geometric
from torch_geometric.nn import (NNConv, GMMConv, GraphConv, Set2Set)
from torch_geometric.nn import (SplineConv, graclus, max_pool, max_pool_x, global_mean_pool)

In [None]:
print(torch.__version__)
torch.cuda.is_available()

In [None]:
! nvcc --version

In [None]:
! nvidia-smi

In [None]:
! mkdir networks
! mv ../deep_sdf_decoder.py networks/

In [None]:
model_params_subdir = "ModelParameters"
optimizer_params_subdir = "OptimizerParameters"
latent_codes_subdir = "LatentCodes"
logs_filename = "Logs.pth"
reconstructions_subdir = "Reconstructions"
reconstruction_meshes_subdir = "Meshes"
reconstruction_codes_subdir = "Codes"
specifications_filename = "specs.json"
data_source_map_filename = ".datasources.json"
evaluation_subdir = "Evaluation"
sdf_samples_subdir = "SdfSamples"
surface_samples_subdir = "SurfaceSamples"
normalization_param_subdir = "NormalizationParameters"
training_meshes_subdir = "TrainingMeshes"

In [None]:
def load_latent_vectors(experiment_directory, checkpoint):

    filename = os.path.join(
        experiment_directory, checkpoint + ".pth"
    )
    if not os.path.isfile(filename):
        raise Exception(
            "The experiment directory ({}) does not include a latent code file"
            + " for checkpoint '{}'".format(experiment_directory, checkpoint)
        )
    data = torch.load(filename)
    return data["latent_codes"].cuda()

def load_model(experiment_directory, checkpoint):
    specs_filename = os.path.join(experiment_directory, "specs.json")

    if not os.path.isfile(specs_filename):
        raise Exception(
            'The experiment directory does not include specifications file "specs.json"'
        )

    specs = json.load(open(specs_filename))

    arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

    latent_size = specs["CodeLength"]

    decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"])

    decoder = torch.nn.DataParallel(decoder)

    saved_model_state = torch.load(
        os.path.join(experiment_directory, checkpoint + ".pth")
    )

    decoder.load_state_dict(saved_model_state["model_state_dict"])

    decoder = decoder.module.cuda()

    decoder.eval()
    
    return decoder

def create_mesh(decoder, latent_vec, filename='', N=256, max_batch=32 ** 3, offset=None, scale=None):
    start = time.time()
    ply_filename = filename

    decoder.eval()

    # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
    voxel_origin = [-1, -1, -1]
    voxel_size = 2.0 / (N - 1)

    overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
    samples = torch.zeros(N ** 3, 4)

    # transform first 3 columns
    # to be the x, y, z index
    samples[:, 2] = overall_index % N
    samples[:, 1] = (overall_index.long() / N) % N
    samples[:, 0] = ((overall_index.long() / N) / N) % N

    # transform first 3 columns
    # to be the x, y, z coordinate
    samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
    samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
    samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]

    num_samples = N ** 3

    samples.requires_grad = False

    head = 0

    while head < num_samples:
        sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()

        samples[head : min(head + max_batch, num_samples), 3] = \
                decode_sdf(decoder, latent_vec, sample_subset).squeeze(1).detach().cpu()
        head += max_batch

    sdf_values = samples[:, 3]
    sdf_values = sdf_values.reshape(N, N, N)

    end = time.time()
    #print("sampling takes: %f" % (end - start))

    return convert_sdf_samples_to_ply(
        sdf_values.data.cpu(),
        voxel_origin,
        voxel_size,
        ply_filename + ".ply",
        offset,
        scale,
    )

def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    voxel_grid_origin,
    voxel_size,
    ply_filename_out,
    offset=None,
    scale=None,
):
    """
    Convert sdf samples to .ply

    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
    :voxel_size: float, the size of the voxels
    :ply_filename_out: string, path of the filename to save to

    This function adapted from: https://github.com/RobotLocomotion/spartan
    """
    start_time = time.time()

    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()

    verts, faces, normals, values = skimage.measure.marching_cubes_lewiner(
        numpy_3d_sdf_tensor, level=0.0, spacing=[voxel_size] * 3
    )

    # transform from voxel coordinates to camera coordinates
    # note x and y are flipped in the output of marching_cubes
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
    mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
    mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]

    # apply additional offset and scale
    if scale is not None:
        mesh_points = mesh_points / scale
    if offset is not None:
        mesh_points = mesh_points - offset

    # try writing to the ply file

    num_verts = verts.shape[0]
    num_faces = faces.shape[0]

    verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
    norms_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])

    for i in range(0, num_verts):
        verts_tuple[i] = tuple(mesh_points[i, :])
        norms_tuple[i] = tuple(normals[i, :])

    faces_building = []
    for i in range(0, num_faces):
        faces_building.append(((faces[i, :].tolist(),)))
    faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

    el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
    el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
    el_norms = plyfile.PlyElement.describe(norms_tuple, "normals")

    ply_data = plyfile.PlyData([el_verts, el_faces, el_norms])
    return ply_data


def make_mesh_from_points(points, ply_mesh):
    transformed_points = transformPoints(points)
    
    edges = trimesh.geometry.faces_to_edges(ply_mesh['face']['vertex_indices'])
    np_points = transformed_points.cpu().detach().numpy()
    edge_attr = [np_points[a] - np_points[b] for a, b in edges]
    mesh = {'x': transformed_points, 
        'face':torch.tensor(ply_mesh['face']['vertex_indices'], dtype=torch.long).to('cuda:0').t(),
        'edge_attr':torch.tensor(edge_attr, dtype=torch.float).to('cuda:0'),
        'edge_index':torch.tensor(edges, dtype=torch.long).t().contiguous().to('cuda:0')
        }
    return mesh

def transformPoints(points):
    matrix = torch.cuda.FloatTensor(AvgTransform)
    column = torch.zeros((len(points), 1), device="cuda:0") + 1
    stacked = torch.cat([points, column], dim=1)
    transformed = torch.matmul(matrix, stacked.t()).t()[:, :3]
    return transformed

def get_trimesh_from_torch_geo_with_colors(mesh, preds, vmin=-8, vmax=8):
    norm = mpl.colors.Normalize(vmin= vmin, vmax=vmax)
    cmap = cm.hot
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    
    verticies = mesh['x'].cpu().detach()
    faces = mesh['face'].t().cpu().detach()
    return trimesh.Trimesh(vertices=verticies, faces=faces, 
                           vertex_colors=list(map(lambda c: m.to_rgba(c),  preds[:, 0].cpu().detach())))


def decode_sdf(decoder, latent_vector, queries):
    num_samples = queries.shape[0]

    if latent_vector is None:
        inputs = queries
    else:
        latent_repeat = latent_vector.expand(num_samples, -1)
        inputs = torch.cat([latent_repeat, queries], 1)

    sdf = decoder(inputs)

    return sdf
def compute_lift_faces_diff(data_instance, answers, axis=0):
    pressures = torch.mean(answers[data_instance['face'], 0], axis=0)

    # TODO: cahnge to x if needed
    pos = data_instance['x']
    cross_prod = (pos[data_instance['face'][1]] - pos[data_instance['face'][0]]).cross(
                  pos[data_instance['face'][2]] - pos[data_instance['face'][0]])
    mult = -cross_prod[:, axis] / 2
    lift = torch.mul(pressures, mult)
    return torch.sum(lift[~torch.isnan(lift)])

def boundsLoss(points, box=[(-1, 1, 0)]):
    loss = 0
    for l, r, i in box:
        loss +=  torch.mean(F.relu(-points[:, i] + l))  \
               + torch.mean(F.relu( points[:, i] - r))
    return loss

def innerBoundsLoss(points, r=1, center=(0, 0, 0)):
    radiuses = torch.sum( (points - torch.Tensor(center).to('cuda:0')) ** 2 , dim=1)
    return torch.mean(F.relu(r - radiuses))

def calculate_loss(mesh, local_preds, axis=0, constraint_rad=0.1):
    loss =  (1 - axis) * compute_lift_faces_diff(mesh, local_preds, axis=0) + \
                  axis * compute_lift_faces_diff(mesh, local_preds, axis=1)
    print("first part of loss: ", loss.detach().cpu().numpy())
    first = loss.clone().detach().cpu().numpy()
    loss += boundsLoss(mesh['x'], box=[(-0.6, 0.6, 0)])
    print("second part of loss: ", loss.detach().cpu().numpy() - first)
    second = loss.clone().detach().cpu().numpy()
    loss += innerBoundsLoss(mesh['x'], r=constraint_rad**2, center=(-0.05, 0.05, 0))  \
          + innerBoundsLoss(mesh['x'], r=(constraint_rad / 2)**2, center=(0.3, 0, 0))
    print("third part of loss: ", loss.detach().cpu().numpy() - second)
    return loss

def soft_constraints(latent, latent_vectors, num_neignours_constr, alpha_penalty):
    # Soft-constraints
    distances, indeces = LATENT_KD_TREE.query(latent.cpu().detach(), k=num_neignours_constr)
    torch.sum((initial_la - latent_vectors[indeces.squeeze()]) ** 2, dim=2).mean()
    return penalty * alpha_penalty

def computeAvgTransform():
    objects = list()
    for (dirpath, dirnames, filenames) in os.walk("/cvlabdata2/home/artem/Data/cars_remeshed_dsdf/transforms/"):
        objects += [os.path.join(dirpath, file) for file in filenames if file[-4:] == '.npy']
    
    matricies = []
    for obj in objects:
        matricies.append(np.load(obj))
    
    return np.mean(np.array(matricies), axis=0)


class SplineBlock(nn.Module):
    def __init__(self, num_in_features, num_outp_features, mid_features, kernel=3, dim=3, batchnorm1=True):
        super(SplineBlock, self).__init__()
        self.batchnorm1 = batchnorm1
        self.conv1 = SplineConv(num_in_features, mid_features, dim, kernel, is_open_spline=False)
        if self.batchnorm1:
            self.batchnorm1 = torch.nn.BatchNorm1d(mid_features)
        self.conv2 = SplineConv(mid_features, 2 * mid_features, dim, kernel, is_open_spline=False)
        self.batchnorm2 = torch.nn.BatchNorm1d(2 * mid_features)
        self.conv3 = SplineConv(2 * mid_features + 3, num_outp_features, dim, kernel, is_open_spline=False)
  
    def forward(self, res, data):
        if self.batchnorm1:
            res = F.elu(self.batchnorm1(self.conv1(res, data['edge_index'], data['edge_attr'])))
        else:
            res = F.elu(self.conv1(res, data['edge_index'], data['edge_attr']))
        res = F.elu(self.batchnorm2(self.conv2(res, data['edge_index'], data['edge_attr'])))
#         res = F.elu(self.conv2(res, data.edge_index, data.edge_attr))
        res = torch.cat([res, data['x']], dim=1)
        res = self.conv3(res, data['edge_index'], data['edge_attr'])
        return res

class SplineCNN8Residuals(nn.Module):
    def __init__(self, num_features, kernel=3, dim=3):
        super(SplineCNN8Residuals, self).__init__()
        self.block1 = SplineBlock(num_features, 16, 8, kernel, dim)
        self.block2 = SplineBlock(16, 64, 32, kernel, dim)
        self.block3 = SplineBlock(64, 64, 128, kernel, dim)
        self.block4 = SplineBlock(64, 8, 16, kernel, dim)
        self.block5 = SplineBlock(11, 32, 16, kernel, dim)
        self.block6 = SplineBlock(32, 64, 32, kernel, dim)
        self.block7 = SplineBlock(64, 64, 128, kernel, dim)
        self.block8 = SplineBlock(75, 4, 16, kernel, dim)

    def forward(self, data):
        res = data['x']
        res = self.block1(res, data)
        res = self.block2(res, data)
        res = self.block3(res, data)
        res4 = self.block4(res, data)
        res = torch.cat([res4, data['x']], dim=1)
        res = self.block5(res, data)
        res = self.block6(res, data)
        res = self.block7(res, data)
        res = torch.cat([res, res4, data['x']], dim=1)
        res = self.block8(res, data)
        return res


In [None]:
DIR_for_dump_data = './starting_data'
experiment_directory = DIR_for_dump_data



decoder = load_model(experiment_directory, "decoderModel")
latent_vectors = load_latent_vectors(experiment_directory, "latentCodes")
latent_vectors = latent_vectors.detach()

LATENT_TO_OPTIMIZE = latent_vectors[32]
LATENT_KD_TREE = KDTree(np.array([lv.cpu().detach().numpy()[0] for lv in latent_vectors]))
AvgTransform = np.load(DIR_for_dump_data + "/avg_trans_matrix.npy") #computeAvgTransform()

model = SplineCNN8Residuals(3)
model.load_state_dict(torch.load(experiment_directory + "/cfdModel.nn"))
model = model.to("cuda:0")
model = model.eval()

In [None]:
initial_la = latent_vectors[32]
ply_mesh = create_mesh( decoder,
                        initial_la,
                        N=256,
                        max_batch=int(2 ** 18),
                        offset=None,
                        scale=None)

In [None]:
points = torch.cuda.FloatTensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                            ply_mesh['vertex']['y'][:, None], 
                                            ply_mesh['vertex']['z'][:, None])))

# from mesh to pressure field
points.requires_grad = True
mesh = make_mesh_from_points(points, ply_mesh)
#del ply_mesh, points
local_preds = model(mesh)

In [None]:
loss = calculate_loss(mesh, local_preds, axis=0, constraint_rad=0.05)

In [None]:
loss.backward()
dL_dp = points.grad.clone()

In [None]:
points.grad.data.zero_()
sdf_value = decode_sdf(decoder, initial_la, points)
sdf_value.backward(torch.ones([len(points), 1], dtype=torch.float32).cuda())

In [None]:
# assemble constant 
mults = [-p1.dot(p2) for p1, p2 in zip(dL_dp, points.grad)]       
multipliers = torch.cuda.FloatTensor(mults)

In [None]:
points = points.detach()
initial_la = initial_la.detach().requires_grad_(True)
latent_inputs = initial_la.expand(points.shape[0], -1)
inputs = torch.cat([latent_inputs, points], 1).cuda() 
sdf_value = decoder(inputs)
final_loss = torch.sum(sdf_value.squeeze() * multipliers)
final_loss.backward()

In [None]:
#initial_la.grad

In [None]:
#initial_la.grad.data.zero_()

In [None]:
apenalty = soft_constraints(initial_la, latent_vectors, num_neignours_constr=10, alpha_penalty=0.2)
apenalty.backward()

In [None]:
#initial_la.grad

In [None]:
final_loss

In [None]:
penalty

In [None]:
innerBoundsLoss(mesh['x'], r=(0.5 / 2)**2, center=(0.3, 0, 0))

In [None]:
boundsLoss(mesh['x'], box=[(-0.006, 0.006, 0)])

In [None]:
local_preds[0]

In [None]:
points.shape

In [None]:
ply_mesh.write("data_for_this_experiments/mesh32.ply")

In [None]:
class cfd_obj:
    def __init__(self, decoder, p_predictor):
        self.N_MARCHING_CUBE = 128
        self.regl2 = 1e-3
        self.iter = 0
        self.quick = True
        self.AvgTransform = computeAvgTransform()
        self.decoder = decoder
        self.pressure_pred = p_predictor
        self.optimal = latent_target.detach().cpu().numpy()
        self.constraint_rad = 0.05
    
        
    def func(self, latent):
        # from latent to xyz
        mesh = create_mesh(decoder, latent, N=self.N_MARCHING_CUBE, max_batch=int(2 ** 18))
        points = torch.cuda.FloatTensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                            ply_mesh['vertex']['y'][:, None], 
                                            ply_mesh['vertex']['z'][:, None])))
        # from mesh to pressure field
        self.xyz_upstream = points.detach().requires_grad_(True)
        scaled_mesh = make_mesh_from_points(points, mesh)
        pressure_field = model(scaled_mesh)
        loss = calculate_loss(scaled_mesh, pressure_field, axis=0, constraint_rad=self.constraint_rad)
        self.last_loss = loss
        self.last_latent = latent
        return loss
    
    def dfunc(self, latent):
        if latent.grad is not None:
            latent.grad.detach_()
            latent.grad.zero_()
        # step 1
        if self.quick and self.last_latent is not None and torch.all(latent == self.last_latent):
            loss = self.last_loss
        else:
            loss = self.func(latent)
        loss.backward()
        dL_dx_i = self.xyz_upstream.grad
        
        
        # step 2
        # calculate mesh normal
        xyz = self.xyz_upstream.clone().detach()
        xyz.requires_grad = True
        
        latent_inputs = latent.expand(xyz.shape[0], -1)
        inputs = torch.cat([latent_inputs, xyz], 1).cuda()      #Add .cuda() if you want to run on GPU
        #first compute normals
        pred_sdf = self.decoder(inputs)
        
        loss_normals = torch.sum(pred_sdf)
        loss_normals.backward(retain_graph = True)
        normals = xyz.grad/torch.norm(xyz.grad, 2, 1).unsqueeze(-1)
        
        print("normal: ", normals.shape)
        print("dl_dx_i: ", dL_dx_i.shape)
        
        # step 3
        # now assemble inflow derivative
        latent.grad.detach_()
        latent.grad.zero_()
        multipliers = -torch.matmul(dL_dx_i.unsqueeze(1), normals.unsqueeze(-1)).squeeze(-1)
        loss_backward = torch.sum(multipliers * pred_sdf)
        
        # artificial loss
        apenalty = soft_constraints(latent, latent_vectors, num_neignours_constr)
        
        loss_backward += apenalty
        
        # Backpropagate
        loss_backward.backward()
        
        return latent.grad

In [None]:
# from latent to mesh/point
with torch.no_grad():
    ply_mesh = create_mesh( decoder,
                            latent,
                            N=N,
                            max_batch=int(2 ** 18),
                            offset=None,
                            scale=None)
points = torch.cuda.FloatTensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                            ply_mesh['vertex']['y'][:, None], 
                                            ply_mesh['vertex']['z'][:, None])))

# from mesh to pressure field
points.requires_grad = True
mesh = make_mesh_from_points(points, ply_mesh)
local_preds = model(mesh)
loss = calculate_loss(mesh, local_preds, axis=axis, constraint_rad=constraint_rad)



loss.backward()
dL_dp = points.grad.clone()





# calculate mesh normal
points.grad.data.zero_()
sdf_value = decode_sdf(decoder, latent, points)
sdf_value.backward(torch.ones([len(points), 1], dtype=torch.float32).cuda())

# assemble constant 
mults = [-p1.dot(p2) for p1, p2 in zip(dL_dp, points.grad)]       
multipliers = torch.cuda.FloatTensor(mults)



# get gradient of sdf w.r.t. latent
#optimizer.zero_grad()
latent.grad.data.zero_()
sdf_value = torch.squeeze(.decode_sdf(decoder, latent, points))
final_loss = torch.sum(sdf_value * multipliers)
final_loss.backward()

# artificial loss
apenalty = soft_constraints(latent, latent_vectors, num_neignours_constr)
apenalty.backward()

#print("Latent grad penalized: ", torch.sum(latent.grad ** 2))

#optimizer.step()

In [None]:
def method4_to_arbitatry_loss(points, ply_mesh, model, constraint_rad=0.1, axis=0):
    initial_dir = points.grad.clone()
    points.grad.data.zero_()

    mesh = make_mesh_from_points(points, ply_mesh)
    #signs = compute_signs_for_loss(mesh, transformPoints(normals, AvgTransform))
    local_preds = model(mesh)
    loss = calculate_loss(mesh, local_preds, axis=axis, constraint_rad=constraint_rad)
    loss.backward()

    sign = [-p1.dot(p2) for p1, p2 in zip(initial_dir, points.grad)]
    
    return sign, loss, local_preds, mesh



def optimize_shape_deepSDF(decoder, latent, initial_points=None, num_points=None, 
                           num_iters=100, point_iters=100, num_neignours_constr=10,
                           lr=0.2, decreased_by=2, adjust_lr_every=10, alpha_penalty=0.05,
                           multiplier_func=method4_to_arbitatry_loss, verbose=None, save_to_dir=None, N=256):

    def adjust_learning_rate(
        initial_lr, optimizer, num_iterations, decreased_by, adjust_lr_every
    ):
        lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every)) \
                        * ((punch_lr_at_reindex_by) ** (num_iterations // reindex_latent_each))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
            
        return lr
    
    if not os.path.exists(os.path.join(save_to_dir, 'meshes')):
        os.makedirs(os.path.join(save_to_dir, 'meshes'))
    if not os.path.exists(os.path.join(save_to_dir, 'predictions')):
        os.makedirs(os.path.join(save_to_dir, 'predictions'))

    ref_latent = latent.clone().detach()
    decoder.eval()
    latent = latent.clone()
    latent.requires_grad = True
    optimizer = torch.optim.SGD([latent], lr=lr)

    loss_plot = []
    latent_dist = []
    lr_plot = []
    latent_plot = []

    for i in range(num_iters):

        time_start = time.time()

        save_path = os.path.join(save_to_dir, 'meshes/' + str(i).zfill(5) + ".ply")
        preds_save_path = os.path.join(save_to_dir, 'predictions/' + str(i).zfill(5) + ".npy")


        cur_rl = adjust_learning_rate(lr, optimizer, i, decreased_by, adjust_lr_every)

        with torch.no_grad():
            ply_mesh = create_mesh( decoder,
                                    latent,
                                    N=N,
                                    max_batch=int(2 ** 18),
                                    offset=None,
                                    scale=None)

        points = torch.cuda.FloatTensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                                    ply_mesh['vertex']['y'][:, None], 
                                                    ply_mesh['vertex']['z'][:, None])))

        points.requires_grad = True

        sdf_value = decode_sdf(decoder, latent, points)
        sdf_value.backward(torch.ones([len(points), 1], dtype=torch.float32).cuda())

        mults, loss_value, preds, transformed_mesh = multiplier_func(points, ply_mesh)         
        multipliers = torch.cuda.FloatTensor(mults)

        optimizer.zero_grad()
        sdf_value = torch.squeeze(decode_sdf(decoder, latent, points))

        final_loss = torch.sum(sdf_value * multipliers)
        final_loss.backward()
        print("backward loss: ", final_loss)

        # Soft-constraints
        distances, indeces = LATENT_KD_TREE.query(latent.cpu().detach(), k=num_neignours_constr)
        penalty = torch.mean(
                    torch.stack([torch.sum( 
                                    (latent - latent_vectors[indeces[0][i]]) ** 2
                                 )
                                 for i in range(len(indeces[0]))]
                               )
                    )
        apenalty = penalty * alpha_penalty
        apenalty.backward()
        print("penality: ", apenalty)
        print("together: ", apenalty + final_loss, "\n")

        optimizer.step()
       

        tri_mesh = get_trimesh_from_torch_geo_with_colors(transformed_mesh, preds)
        tri_mesh.export(save_path)
        #np.save(preds_save_path, preds.cpu().detach().numpy())

        #if save_to_dir is not None:
        #    plot_points_from_torch

        loss_plot.append(loss_value.cpu().detach().numpy())
        latent_dist.append(torch.sum((latent - ref_latent) ** 2 ).cpu().detach().numpy() )
        latent_plot.append(latent.detach().cpu().numpy())
        lr_plot.append(penalty)

        time_end = time.time()

        if verbose is not None and i % verbose == 0:
            print('Iter ', i, 'Loss: ', loss_value.detach().cpu().numpy(), ' LD: ', lr_plot[-1])
    
        np.save(os.path.join(save_to_dir, "latent_plot.npy"), latent_plot)    
        np.save(os.path.join(save_to_dir, "loss_plot.npy"), loss_plot)
        np.save(os.path.join(save_to_dir, "latent_dist.npy"), latent_dist)
        np.save(os.path.join(save_to_dir, "lr_plot.npy"), lr_plot)



def make_full_transformation(initial_latent, experiment_name, 
                             decoder, model, alpha_penalty=0.05, constraint_rad=0.1, axis=0, **kwargs):
    '''
    kwargs:
        num_iters=1000, 
        adjust_lr_every=10, 
        decreased_by=1.2,
        lr=0.005
        verbose=10,
    '''

    #ref_points = get_points_from_latent(decoder, ref_latent, N=128)
    save_to_dir = experiment_name
    if not os.path.exists(save_to_dir):
        os.makedirs(save_to_dir)

    #np.save(os.path.join(save_to_dir, "target_verts.npy"), ref_points)

    optimize_shape_deepSDF(decoder, initial_latent, initial_points=None,
                                           alpha_penalty=alpha_penalty,
                                           num_points=None, point_iters=2,
                                           multiplier_func=lambda x, y: 
                                               method4_to_arbitatry_loss(x, y, model, 
                                                                         constraint_rad=constraint_rad, 
                                                                         axis=axis),
                                           save_to_dir=save_to_dir, **kwargs)
   

In [None]:

DIR_for_dump_data = './data_for_this_experiments'
punch_lr_at_reindex_by=1
reindex_latent_each = 10000


In [None]:
np.random.seed(101)
make_full_transformation(LATENT_TO_OPTIMIZE.detach(),
                         experiment_name=DIR_for_dump_data, decoder=decoder, model=model,
                         alpha_penalty=0.2, axis=0,
                         constraint_rad=0.05,
                         num_iters=30,
                         adjust_lr_every=20,
                         decreased_by=1.1, 
                         lr=0.2,
                         verbose=1,
                         N=256,
                         num_neignours_constr=10)

In [None]:
#preds = np.load("data_for_this_experiments/OptimizationPaper/predictions/00000.npy")
mesh = trimesh.load_mesh("data_for_this_experiments/meshes/00029.ply")
loss_plot = np.load("data_for_this_experiments/loss_plot.npy", allow_pickle=True)
lr_plot = np.load("data_for_this_experiments/lr_plot.npy", allow_pickle=True)
latent_plot = np.load("data_for_this_experiments/latent_plot.npy", allow_pickle=True)

In [None]:
#plt.plot(loss_plot)

In [None]:
distances, indeces = LATENT_KD_TREE.query(torch.tensor(latent_plot[-1]), k=10)

In [None]:
indeces

In [None]:
newl = 0.1 * latent_vectors[indeces.squeeze()[0]] + 0.9 * latent_vectors[indeces.squeeze()[1]]

In [None]:
#visual_Mesh(torch.tensor(newl)).show()

In [None]:
#visual_Mesh(torch.tensor(latent_vectors[indeces.squeeze()[:2]].mean(axis=0))).show()

In [None]:
#visual_Mesh(torch.tensor(latent_vectors[indeces].mean(axis=0))).show()

In [None]:
#visual_Mesh(latent_vectors[267]).show()

In [None]:
#visual_Mesh(latent_vectors[350]).show()

In [None]:
#visual_Mesh(torch.tensor(latent_plot[-1]).cuda()).show()

In [None]:
#plt.plot(lr_plot)

In [None]:
mesh = trimesh.load_mesh("data_for_this_experiments/meshes/00004.ply")
#mesh.show()

In [None]:
mesh = trimesh.load_mesh("data_for_this_experiments/meshes/00029.ply")
#mesh.show()

In [None]:
distances, indeces = LATENT_KD_TREE.query(LATENT_TO_OPTIMIZE.cpu().detach(), k=10)

In [None]:
mesh['face'].shape

In [None]:
torch.norm(latent_vectors[indeces.squeeze()[0]] - latent_vectors[indeces.squeeze()[8]])

In [None]:
torch.norm(latent_vectors[indeces.squeeze()[9]] - latent_vectors[indeces.squeeze()[8]])

In [None]:
torch.norm(latent_vectors[indeces.squeeze()[7]] - latent_vectors[indeces.squeeze()[8]])

In [None]:

def visual_Mesh(ilatent):
    ply_mesh = create_mesh(decoder,
                        ilatent,
                        N=256,
                        max_batch=int(2 ** 18))
    points = torch.cuda.FloatTensor(np.hstack(( ply_mesh['vertex']['x'][:, None], 
                                            ply_mesh['vertex']['y'][:, None], 
                                            ply_mesh['vertex']['z'][:, None])))
    scaled_mesh = make_mesh_from_points(points, ply_mesh)
    pressure_field = model(scaled_mesh)
    loss = compute_lift_faces_diff(scaled_mesh, pressure_field, axis=0)  
    print("latent loss. %f "%(loss))
    return get_trimesh_from_torch_geo_with_colors(scaled_mesh, pressure_field)
    

In [None]:
#visual_Mesh(latent_vectors[26]).show() # 

In [None]:
mesh37 = trimesh.load_mesh("Expirements/OptimizationPaper/meshes/00037.ply")
mesh37.show()


In [None]:
mesh = {'point_pos': tansformed_points, 
        'edge_vec':torch.tensor(edge_attr, dtype=torch.float).to('cuda:0'),
        'edge_index':torch.tensor(edges, dtype=torch.long).t().contiguous().to('cuda:0')
        }