# Fourier shape representation

In [1]:
# %matplotlib widget
# from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from livelossplot import PlotLosses

# import open3d as o3d
# from open3d import JVisualizer

import numpy as np
from tqdm.notebook import tqdm as tqdm
import torch
import torch.nn as nn
import trimesh
from trimesh.sample import sample_surface
import torch.utils.data as data

## build model & dataset

In [2]:
class GeometryDataset(data.Dataset):
    def __init__(self, mesh_path, samples, *args):
        super(GeometryDataset, self).__init__(*args)
        self.mesh = trimesh.load(mesh_path)
#         self.mesh.vertices[:,1]+=-0.1
        # self.mesh.vertices /= 100
        self.sample = sample_surface(self.mesh, samples)
        self.pnts = torch.from_numpy(self.sample[0])
        self.normals = torch.from_numpy(self.mesh.face_normals[self.sample[1]])

    def __getitem__(self, index: int):
        points = self.pnts[index]
        normals = self.normals[index]
        return points, normals

    def __len__(self) -> int:
        return len(self.pnts)

class NormalPerPoint(object):
    def __init__(self, global_sigma, local_sigma=0.01):
        self.global_sigma = global_sigma
        self.local_sigma = local_sigma

    def get_points(self, pc_input, local_sigma=None):
        batch_size, dim = pc_input.shape

        if local_sigma is not None:
            sample_local = pc_input + (
                torch.randn_like(pc_input) * local_sigma.unsqueeze(-1)
            )
        else:
            sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)

        sample_global = (
            torch.rand(batch_size // 8, dim, device=pc_input.device)
            * (self.global_sigma * 2)
        ) - self.global_sigma

        sample = torch.cat([sample_local, sample_global], dim=0)

        return sample
    
class ImplicitNetwork(nn.Module):
    def __init__(
        self,
        d_in,
        d_out,
        dims,
        geometric_init=True,
        bias=1.0,
        skip_in=(),
        weight_norm=True,
        multires=0,
    ):
        super().__init__()
        dims = [d_in] + dims + [d_out]
        self.embed_fn = None
        if multires > 0:
            embed_fn, input_ch = get_embedder(multires)
            self.embed_fn = embed_fn
            dims[0] = input_ch
        self.num_layers = len(dims)
        self.skip_in = skip_in
        for l in range(0, self.num_layers - 1):
            if l + 1 in self.skip_in:
                out_dim = dims[l + 1] - dims[0]
            else:
                out_dim = dims[l + 1]
            lin = nn.Linear(dims[l], out_dim)
            if geometric_init:
                if l == self.num_layers - 2:
                    torch.nn.init.normal_(
                        lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001
                    )
                    torch.nn.init.constant_(lin.bias, -bias)
                elif multires > 0 and l == 0:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
                    torch.nn.init.normal_(
                        lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)
                    )
                elif multires > 0 and l in self.skip_in:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(
                        lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)
                    )
                    torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
                else:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(
                        lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)
                    )

            if weight_norm:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)

        self.softplus = nn.Softplus(beta=100)

    def forward(self, input, compute_grad=False):
        if self.embed_fn is not None:
            input = self.embed_fn(input)
        x = input
        for l in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(l))
            if l in self.skip_in:
                x = torch.cat([x, input], 1) / np.sqrt(2)
            x = lin(x)
            if l < self.num_layers - 2:
                x = self.softplus(x)
        return x

    def gradient(self, x):
        x.requires_grad_(True)
        y = self.forward(x)[:, :1]
        d_output = torch.ones_like(y, requires_grad=False, device=y.device)
        gradients = torch.autograd.grad(
            outputs=y,
            inputs=x,
            grad_outputs=d_output,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        return gradients.unsqueeze(1)

from torch.autograd import grad
def gradient(inputs, outputs):
    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    points_grad = grad(
        outputs=outputs,
        inputs=inputs,
        grad_outputs=d_points,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0][:, -3:]
    return points_grad


## visualize data

In [3]:
# mesh_path = "./bunny/reconstruction/bun_zipper.ply"
# dataset =GeometryDataset(mesh_path,samples=100000)
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset,batch_size=1000, num_workers=4,pin_memory=True)

# mesh=dataset.mesh
# vertices = dataset.pnts
# fig = plt.figure()
# ax = Axes3D(fig)
# # ax.plot_trisurf(mesh.vertices[:, 0], mesh.vertices[:,1], triangles=mesh.faces, Z=mesh.vertices[:,2]) 
# plot_geeks = ax.scatter(vertices[:, 0], vertices[:,1], vertices[:,2],color='green')
# ax.set_title("3D plot")
# ax.set_xlabel('x-axis')
# ax.set_ylabel('y-axis')
# ax.set_zlabel('z-axis')
# plt.show()

## init models

In [4]:
torch.set_default_dtype(torch.float32)
model = ImplicitNetwork(d_in=3,d_out=1,dims=[512, 512, 512, 512, 512, 512, 512, 512],bias=0.2,skip_in=[4]).cuda()
sampler = NormalPerPoint(global_sigma=0.1, local_sigma=0.01)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.0005,
    weight_decay=0,
)
model.load_state_dict(torch.load("./checkpoints/model_0072000.pth")["model"])

<All keys matched successfully>

## training

In [5]:
# # plt_groups = {'mnfld_loss_train':[],'grad_loss':[]}
# # plotlosses_model = PlotLosses(groups=plt_groups)
# plotlosses_model = PlotLosses()
# steps=0
# for i in range(1000000):
#     for mnfld_pnts, normals in tqdm(dataloader, leave=False):
#         mnfld_pnts=mnfld_pnts.cuda()
#         normals=normals.cuda()
#         nonmnfld_pnts = sampler.get_points(mnfld_pnts)
#         # forward
#         mnfld_pnts.requires_grad_()
#         nonmnfld_pnts.requires_grad_()
#         mnfld_pred = model(mnfld_pnts)
#         nonmnfld_pred = model(nonmnfld_pnts)
#         mnfld_grad = gradient(mnfld_pnts, mnfld_pred)
#         nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred)
#         # maniflod_loss
#         mnfld_loss = (mnfld_pred.abs()).mean()
#         # eikonal loss
#         grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
#         # regularize: prevents off-surface locations to create zero-isosurface
#         sdf_global_without_surface = nonmnfld_pred[mnfld_pnts.shape[0] :]
#         reg_loss = torch.exp(
#             -100 * (sdf_global_without_surface.abs())
#         ).mean()

#         loss = (
#             mnfld_loss
#             + 0.1 * grad_loss
#             + 0.1 * reg_loss
#         )
#         # normal loss
#         if False:
#             normals = normals.view(-1, 3)
#             normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
#             loss = loss + cfg.SOLVER.LOSS.NORMAL_LAMBDA * normals_loss
#         # back propagation
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         steps+=1
#         if steps%100==0:
#             torch.save(model.state_dict(), "./checkpoints/last.pth")
#             plotlosses_model.update({'loss':loss.item(),'mnfld_loss_train':mnfld_loss.item(),'grad_loss':grad_loss.item()}, current_step=steps)
#             plotlosses_model.send()


## get volume and mesh

In [6]:
from plot import plot_surface
RESULOTION=100
mesh,volume = plot_surface(model,path="./results",iteration=72000, shapename="result",resolution=RESULOTION,mc_value=0.,is_uniform_grid=True,verbose=False,
                        save_html=False,
                        save_ply=True,
                        overwrite=True,cube_length=1.5)
print("number of inner points :",(volume<0).sum())
print("number of outer points :",(volume>0).sum())

face_normals incorrect shape, ignoring!


number of inner points : 192972
number of outer points : 807028


## Fourier transform for geometry
### visualize volume

In [None]:
import trimesh
from skimage import measure
def get_mesh(volume):
    verts, faces, normals, values = measure.marching_cubes_lewiner(volume=volume,level=0.)
    meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values)   
    return meshexport

In [None]:
# test_mesh=get_mesh(volume)
# _=test_mesh.export("gt.ply", "ply")

### FFT for shape

In [None]:
F=torch.fft.fftn(torch.from_numpy(volume).cuda(),s=[100,100,100])
F_shift=torch.fft.fftshift(F)
F_low=torch.zeros_like(F_shift)
F_low[40:60,40:60,40:60]=F_shift[40:60,40:60,40:60]
F_high=F_shift.clone()
F_high[40:60,40:60,40:60]=0

### IFFT for shape

In [None]:
f=torch.fft.ifftn(torch.fft.ifftshift(F_shift))
f=torch.real(f).cpu().numpy()

f_low=torch.fft.ifftn(torch.fft.ifftshift(F_low))
f_low=torch.real(f_low).cpu().numpy()

f_high=torch.fft.ifftn(torch.fft.ifftshift(F_high))
f_high=torch.real(f_high).cpu().numpy()

In [None]:
res_mesh=get_mesh(f_low)
_=res_mesh.export("low.ply", "ply")
res_mesh=get_mesh(f_high)
_=res_mesh.export("high.ply", "ply")
res_mesh=get_mesh(f)
_=res_mesh.export("full.ply", "ply")

## Learned Fourier transform for geometry

In [7]:
def fft_3(fscale):
    fx = torch.fft.fftfreq(int(100* fscale)) * fscale
    x_train = torch.stack(torch.meshgrid(fx.reshape(-1), fx.reshape(-1),fx.reshape(-1)), -1)
    return x_train

In [8]:
# torch.fft.fftfreq(int(100))
freq=fft_3(1.0)
freq.shape


torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2228.)



torch.Size([100, 100, 100, 3])

In [None]:
F[0,0,0]

In [None]:
F[99,99,99]

In [9]:
model_kspace=ImplicitNetwork(d_in=3,d_out=2,dims=[512, 512, 512, 512],geometric_init=False).cuda()

In [10]:
model_kspace(freq.cuda()).shape

RuntimeError: CUDA out of memory. Tried to allocate 1.91 GiB (GPU 0; 11.93 GiB total capacity; 9.57 GiB already allocated; 1.74 GiB free; 9.57 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF