From a7bd8a61b3777bfbb4beba76aab29c7dd9de366a Mon Sep 17 00:00:00 2001 From: heiwang1997 <462149548@qq.com> Date: Tue, 31 May 2022 19:18:08 +0800 Subject: [PATCH] merge to merged branch --- .gitignore | 3 + jittor/LICENSE | 24 + jittor/README.md | 81 ++++ jittor/convert.py | 33 ++ jittor/criterion.py | 30 ++ jittor/data-shapenet.yaml | 16 + jittor/data_generator.py | 265 +++++++++++ jittor/lif_dataset.py | 85 ++++ jittor/lr_schedule.py | 72 +++ jittor/network.py | 80 ++++ jittor/sampler_cuda/CMakeLists.txt | 17 + jittor/sampler_cuda/PreprocessMesh.cu | 631 ++++++++++++++++++++++++++ jittor/sampler_cuda/ShaderProgram.cpp | 134 ++++++ jittor/sampler_cuda/Utils.cu | 249 ++++++++++ jittor/sampler_cuda/Utils.h | 54 +++ jittor/train.py | 141 ++++++ jittor/train.yaml | 28 ++ jittor/utils/exp_util.py | 228 ++++++++++ jittor/utils/motion_util.py | 142 ++++++ 19 files changed, 2313 insertions(+) create mode 100644 jittor/LICENSE create mode 100644 jittor/README.md create mode 100644 jittor/convert.py create mode 100644 jittor/criterion.py create mode 100644 jittor/data-shapenet.yaml create mode 100644 jittor/data_generator.py create mode 100644 jittor/lif_dataset.py create mode 100644 jittor/lr_schedule.py create mode 100644 jittor/network.py create mode 100644 jittor/sampler_cuda/CMakeLists.txt create mode 100644 jittor/sampler_cuda/PreprocessMesh.cu create mode 100644 jittor/sampler_cuda/ShaderProgram.cpp create mode 100644 jittor/sampler_cuda/Utils.cu create mode 100644 jittor/sampler_cuda/Utils.h create mode 100644 jittor/train.py create mode 100644 jittor/train.yaml create mode 100644 jittor/utils/exp_util.py create mode 100644 jittor/utils/motion_util.py diff --git a/.gitignore b/.gitignore index 18fc2ca..cf70687 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,6 @@ debug/ pytorch/sampler_cuda/bin pytorch/sampler_cuda/build + +jittor/sampler_cuda/bin +jittor/sampler_cuda/build diff --git a/jittor/LICENSE b/jittor/LICENSE new file mode 100644 index 0000000..fdddb29 --- /dev/null +++ b/jittor/LICENSE @@ -0,0 +1,24 @@ +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to diff --git a/jittor/README.md b/jittor/README.md new file mode 100644 index 0000000..a2d659e --- /dev/null +++ b/jittor/README.md @@ -0,0 +1,81 @@ +# DI-Fusion: 基于深度先验的在线隐式三维重建 +本仓库是[此文章](https://github.com/huangjh-pub/di-fusion)提出的用于隐式三维重建的网络部分的[计图](https://cg.cs.tsinghua.edu.cn/jittor/)实现,作者是:[黄家晖](https://cg.cs.tsinghua.edu.cn/people/~huangjh/zh)、[黄石生](https://cg.cs.tsinghua.edu.cn/people/~shisheng/)、宋浩轩和[胡事民](https://cg.cs.tsinghua.edu.cn/shimin.htm) + +- 计图是由清华大学计算机系 [图形学实验室](https://cg.cs.tsinghua.edu.cn/) 推出的一个完全基于动态编译、内部使用创新的元算子和统一计算图的深度学习框架。 + +- DI-Fusion是一个基于RGBD输入的在线三维重建系统。它的相机定位追踪模块以及地图表示完全基于由深度神经网络建模的局部隐式表示。请进一步参考我们的[ArXiv报告](http://arxiv.org/abs/2012.05551)和[视频](https://youtu.be/yxkIQFXQ6rw)。 +- PyTorch version **available [here](https://github.com/huangjh-pub/di-fusion).** + +## 网络训练 + +### 训练数据生成 + +首先,请编译我们的CUDA点云采样器: + +```bash +cd sampler_cuda +mkdir build; cd build +cmake .. +make -j +``` + +编译成功之后,会在`sampler_cuda/bin/`文件夹下,生成名为`PreprocessMeshCUDA`的可执行文件,接着运行: + +```bash +python data_generator.py data-shapenet.yaml --nproc 4 +``` + +即可生成用于训练的数据。 + +### 网络训练 + +当完成了训练数据的生成之后,请运行: + +```bash +python train.py train.yaml +``` + +即可开始训练,如果您在上一步更改了数据存放的位置,可能需要修改`train.yaml`中对应的路径,让程序正确的找到数据集。 + +### 速度对比 + +计图框架采用了先进的元算子融合以及统一计算图技术,这使得执行效率大大提高,下表对比了使用计图的版本和使用PyTorch的版本的训练速度,训练同一个epoch**计图所用的时间仅有PyTorch的三分之一**,原需要训练1-2天的模型,可以在半天内取得较好的收敛效果。 + +| | PyTorch | PyTorch JIT | 计图 | +| ------------------- | ------- | ----------- | ---- | +| 每秒训练步数 (it/s) | 13 | 14 | 39 | + +## 运行 (beta) + +如果需要运行完整的DI-Fusion系统,首先需要进行简单的权值格式转换: + +```bash +python convert.py +``` + +假设本仓库的路径是``,那么上述权值转换程序将输出`/model_300.pth.tar`和`/encoder_300.pth.tar`两个文件。 + +接着,执行如下命令,拷贝官方实现的代码: + +```bash +git clone https://github.com/huangjh-pub/di-fusion.git +cd di-fusion +cp /model_300.pth.tar ./ckpt/default/ +cp /encoder_300.pth.tar ./ckpt/default/ +``` + +执行完成之后,请依照[这里](https://github.com/huangjh-pub/di-fusion#running)的做法继续之后的操作步骤。 + +## 引用 + +欢迎您引用我们的工作,蟹蟹: + +``` +@inproceedings{huang2021difusion, + title={DI-Fusion: Online Implicit 3D Reconstruction with Deep Priors}, + author={Huang, Jiahui and Huang, Shi-Sheng and Song, Haoxuan and Hu, Shi-Min}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2021} +} +``` + diff --git a/jittor/convert.py b/jittor/convert.py new file mode 100644 index 0000000..ad587e0 --- /dev/null +++ b/jittor/convert.py @@ -0,0 +1,33 @@ +import torch +import pickle +from pathlib import Path + +ENC_MAPPING = { + "layer0.conv": 0, + "layer0.normlayer.bn": 1, + "layer1.conv": 3, + "layer1.normlayer.bn": 4, + "layer2.conv": 6, + "layer2.normlayer.bn": 7, + "layer3.conv": 9 +} + +enc_jittor_path = Path("../di-checkpoints/default/encoder_300.jt.tar") +dec_jittor_path = Path("../di-checkpoints/default/model_300.jt.tar") + +with enc_jittor_path.open("rb") as f: + enc_jt_weight = pickle.load(f) +pth_dict = {} +for wkey in list(enc_jt_weight.keys()): + wnew_key = None + for mkey in ENC_MAPPING.keys(): + if str(ENC_MAPPING[mkey]) in wkey: + wnew_key = wkey.replace(str(ENC_MAPPING[mkey]), mkey) + pth_dict[wnew_key] = torch.from_numpy(enc_jt_weight[wkey]).cuda() +torch.save({"epoch": 300, "model_state": pth_dict}, "./encoder_300.pth.tar") + +with dec_jittor_path.open("rb") as f: + dec_jt_weight = pickle.load(f) +for wkey in list(dec_jt_weight.keys()): + dec_jt_weight[wkey] = torch.from_numpy(dec_jt_weight[wkey]).cuda() +torch.save({"epoch": 300, "model_state": dec_jt_weight}, "./model_300.pth.tar") diff --git a/jittor/criterion.py b/jittor/criterion.py new file mode 100644 index 0000000..7e7fe6f --- /dev/null +++ b/jittor/criterion.py @@ -0,0 +1,30 @@ +import math +import jittor as jt + + +def normal_log_prob(loc, scale, value): + var = (scale ** 2) + log_scale = scale.log() + return -((value - loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + + +def neg_log_likelihood(args, info: dict, pd_sdf, pd_sdf_std, gt_sdf, **kwargs): + if args.enforce_minmax: + gt_sdf = jt.clamp(gt_sdf, -args.clamping_distance, args.clamping_distance) + pd_sdf = jt.clamp(pd_sdf, -args.clamping_distance, args.clamping_distance) + + pd_dist_loc = pd_sdf.squeeze(1) + pd_dist_std = pd_sdf_std.squeeze(1) + + sdf_loss = -normal_log_prob(pd_dist_loc, pd_dist_std, gt_sdf.squeeze(1)).sum() / info["num_sdf_samples"] + return { + 'll': sdf_loss + } + + +def reg_loss(args, info: dict, latent_vecs, **kwargs): + l2_size_loss = jt.sum(jt.norm(latent_vecs, k=2, dim=1)) + reg_loss = min(1, info["epoch"] / 100) * l2_size_loss / info["num_sdf_samples"] + return { + 'reg': reg_loss * args.code_reg_lambda + } diff --git a/jittor/data-shapenet.yaml b/jittor/data-shapenet.yaml new file mode 100644 index 0000000..a0140a4 --- /dev/null +++ b/jittor/data-shapenet.yaml @@ -0,0 +1,16 @@ +provider: shapenet_model + +provider_kwargs: + shapenet_path: "/dataset/ShapeNetCore.v2" + categories: ["03001627", "02871439", "03211117", "04256520", "03636649", "04379243"] + shapes_per_category: [100, 100, 100, 100, 100, 100] + scale: [1.0, 1.7, 0.6, 1.8, 0.5, 1.5] +output: "../di-datasets/shapenet_plivoxs" + +sample_method: 1 +sampler_var: 0.00015 +sampler_count: 800000 +voxel_size: 0.08 +nn_size: 2.0 + +nproc: 4 diff --git a/jittor/data_generator.py b/jittor/data_generator.py new file mode 100644 index 0000000..b538d8a --- /dev/null +++ b/jittor/data_generator.py @@ -0,0 +1,265 @@ +import numpy as np +import functools +from multiprocessing import Pool, Value, Manager +from sklearn.neighbors import NearestNeighbors +from pathlib import Path +import subprocess +import json +import shutil +import os, math +import random +import argparse +import logging +from utils import motion_util, exp_util + +CUDA_SAMPLER_PATH = Path(__file__).resolve().parent / "sampler_cuda" / "bin" / "PreprocessMeshCUDA" +_counter = Value('i', 0) +_bad_counter = Value('i', 0) + + +class ShapeNetGenerator: + """ + Use ShapeNet core to generate data. + """ + + def __init__(self, shapenet_path, categories, shapes_per_category, scale): + self.categories = categories + self.shapes_per_category = shapes_per_category + self.scale = scale + + # Sample objects + self.data_sources = [] + self.data_scales = [] + for category_name, category_shape_count, category_scale in zip(self.categories, self.shapes_per_category, + self.scale): + category_path = Path(shapenet_path) / category_name + sampled_objects = os.listdir(category_path) + if category_shape_count != -1: + sampled_objects = random.sample(sampled_objects, category_shape_count) + self.data_sources += [category_path / s for s in sampled_objects] + self.data_scales += [category_scale for _ in sampled_objects] + + def __len__(self): + return len(self.data_sources) + + @staticmethod + def _equidist_point_on_sphere(samples): + points = [] + phi = math.pi * (3. - math.sqrt(5.)) + + for i in range(samples): + y = 1 - (i / float(samples - 1)) * 2 + radius = math.sqrt(1 - y * y) + theta = phi * i + + x = math.cos(theta) * radius + z = math.sin(theta) * radius + points.append((x, y, z)) + + return np.asarray(points) + + def get_source(self, data_id): + return str(self.data_sources[data_id]) + + def __getitem__(self, idx): + data_source = self.data_sources[idx] + data_scale = self.data_scales[idx] + obj_path = data_source / "models" / "model_normalized.obj" + + vp_camera = self._equidist_point_on_sphere(300) + camera_ext = [] + for camera_i in range(vp_camera.shape[0]): + iso = motion_util.Isometry.look_at(vp_camera[camera_i], np.zeros(3, )) + camera_ext.append(iso) + camera_int = [0.8, 0.0, 2.5] # (window-size-half, z-min, z-max) under ortho-proj. + + return str(obj_path), [camera_int, camera_ext], None, data_scale + + def clean(self, data_id): + pass + + +def generate_samples(idx: int, args: argparse.ArgumentParser, provider, output_base, source_list): + mesh_path, vcam, ref_bin_path, sampler_mult = provider[idx] + + # Tmp file for sampling. + output_tmp_path = output_base / ("%06d.raw" % idx) + surface_tmp_path = output_base / ("%06d.surf" % idx) + vcam_file_path = output_base / ("%06d.cam" % idx) + + # Save the camera + with vcam_file_path.open('wb') as f: + np.asarray(vcam[0]).flatten().astype(np.float32).tofile(f) + np.asarray([cam.to_gl_camera().inv().matrix.T for cam in vcam[1]]).reshape(-1, 16).astype(np.float32).tofile(f) + + # Call CUDA sampler + arg_list_common = [str(CUDA_SAMPLER_PATH), + '-m', mesh_path, + '-s', str(int(args.sampler_count * sampler_mult * sampler_mult)), + '-o', str(output_tmp_path), + '-c', str(vcam_file_path), + '-r', str(args.sample_method), + '--surface', str(surface_tmp_path)] + arg_list_data = arg_list_common + ['-p', '0.8', '--var', str(args.sampler_var), '-e', str(args.voxel_size * 2.5)] + if ref_bin_path is not None: + arg_list_data += ['--ref', ref_bin_path, '--max_ref_dist', str(args.max_ref_dist)] + + is_bad = False + sampler_pass = args.__dict__.get("sampler_pass", 1) + + data_arr = [] + surface_arr = [] + for sid in range(sampler_pass): + subproc = subprocess.Popen(arg_list_data, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subproc.wait() + # Read raw file and convert it to numpy file. + try: + cur_data_arr = np.fromfile(str(output_tmp_path), dtype=np.float32).reshape(-1, 4) + cur_surface_arr = np.fromfile(str(surface_tmp_path), dtype=np.float32).reshape(-1, 6) + data_arr.append(cur_data_arr) + surface_arr.append(cur_surface_arr) + os.unlink(output_tmp_path) + os.unlink(surface_tmp_path) + except FileNotFoundError: + print(' '.join(arg_list_data)) + is_bad = True + break + + # Do cleaning of sampler. + os.unlink(vcam_file_path) + + if is_bad: + provider.clean(idx) + with _bad_counter.get_lock(): + _bad_counter.value += 1 + return + + data_arr = np.concatenate(data_arr, axis=0) * sampler_mult + surface_arr = np.concatenate(surface_arr, axis=0) + surface_arr[:, :3] *= sampler_mult + + # Badly, some surface arr may have NaN normal, we prune them + surface_arr_nan_row = np.any(np.isnan(surface_arr), axis=1) + surface_arr = surface_arr[~surface_arr_nan_row] + + # Do LIF splitting. + voxel_size = args.voxel_size + data_xyz = data_arr[:, :3] + data_sdf = data_arr[:, 3] + surface_xyz = surface_arr[:, :3] + + voxel_centers = np.ceil(data_xyz / voxel_size) - 1 + voxel_centers = np.unique(voxel_centers, axis=0) + voxel_centers = (voxel_centers + 0.5) * voxel_size + nbrs = NearestNeighbors(radius=voxel_size * (args.nn_size / 2.0), metric='chebyshev').fit(data_xyz) + lif_indices = nbrs.radius_neighbors(voxel_centers, return_distance=False) + nbrs_local = NearestNeighbors(radius=voxel_size * 0.5, metric='chebyshev').fit(data_xyz) + local_indices = nbrs_local.radius_neighbors(voxel_centers, return_distance=False) + nbrs_surface = NearestNeighbors(radius=voxel_size * (args.nn_size / 2.0), metric='chebyshev').fit(surface_xyz) + surface_indices = nbrs_surface.radius_neighbors(voxel_centers, return_distance=False) + + lif_data = [] + lif_data_count = [] + surface_data_count = [] + used_points = 0 + + for vox_center, lif_index, local_index, surface_index in zip(voxel_centers, lif_indices, local_indices, + surface_indices): + if local_index.shape[0] < 50 or surface_index.shape[0] < 50: + continue + + inner_sdf = data_sdf[lif_index] + num_pos = np.count_nonzero(inner_sdf > 0) + pos_ratio = num_pos / lif_index.shape[0] + if pos_ratio < 0.1 or pos_ratio > 0.9: + continue + + vox_min = vox_center - 0.5 * voxel_size + vox_max = vox_center + 0.5 * voxel_size + + lif_data_count.append(lif_index.shape[0]) + surface_data_count.append(surface_index.shape[0]) + used_points += local_index.shape[0] + + # Gather and normalize data (so that the center of data is 0) + output_data_xyzs = data_arr[lif_index] + output_surface_xyzn = surface_arr[surface_index] + output_data_xyzs[:, :3] = (output_data_xyzs[:, :3] - vox_center) / (vox_max - vox_min) + output_surface_xyzn[:, :3] = (output_surface_xyzn[:, :3] - vox_center) / (vox_max - vox_min) + output_data_xyzs[:, 3] /= voxel_size + + if np.max(output_data_xyzs[:, 3]) > 10.0: + print("Error", np.max(output_data_xyzs[:, 3])) + + lif_data.append({"min": vox_min, + "max": vox_max, + "data": output_data_xyzs, + "surface": output_surface_xyzn}) + + output_lif_base = output_base / "payload" + output_lif_inds = [] + with _counter.get_lock(): + mesh_idx = _counter.value + _counter.value += 1 + for lif_id in range(len(lif_data)): + output_lif_inds.append(len(source_list)) + source_list.append([provider.get_source(idx), mesh_idx, len(output_lif_inds) - 1]) + if len(lif_data_count) > 0: + print( + f"{_counter.value}: + {len(output_lif_inds)} = {len(source_list)}, {used_points} / {data_arr.shape[0]}, " + f"mean lif #: {int(np.mean(lif_data_count))}, mean surface #: {int(np.mean(surface_data_count))}") + + # Output mesh + output_obj_path = output_base / "mesh" / ("%06d.obj" % mesh_idx) + shutil.copy(mesh_path, output_obj_path) + provider.clean(idx) + + # Write data. + for new_lif_id, new_lif_data in zip(output_lif_inds, lif_data): + np.savez(output_lif_base / ("%08d.npz" % new_lif_id), **new_lif_data) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + exp_util.init_seed(4) + mesh_providers = { + 'shapenet_model': ShapeNetGenerator, + } + + parser = exp_util.ArgumentParserX(add_hyper_arg=True, description='ImplicitSLAM LIF Data Generator.') + args = parser.parse_args() + + print(args) + + dataset = mesh_providers[args.provider](**args.provider_kwargs) + output_path = Path(args.output) + + if output_path.exists(): + print("Removing old dataset...") + shutil.rmtree(output_path) + + output_path.mkdir(parents=True, exist_ok=True) + + (output_path / "mesh").mkdir(exist_ok=True, parents=True) + (output_path / "payload").mkdir(exist_ok=True, parents=True) + + with (output_path / "config.json").open("w") as f: + json.dump(vars(args), f, indent=2) + + # Shared structures: + manager = Manager() + source_list = manager.list() + + if args.nproc > 0: + p = Pool(processes=args.nproc) + p.map(functools.partial(generate_samples, args=args, output_base=output_path, + provider=dataset, source_list=source_list), range(len(dataset))) + else: + for idx in range(len(dataset)): + generate_samples(idx, args, dataset, output_path, source_list) + + with (output_path / "source.json").open("w") as f: + json.dump(list(source_list), f, indent=2) + + print(f"Done with {_bad_counter.value} bad shapes") diff --git a/jittor/lif_dataset.py b/jittor/lif_dataset.py new file mode 100644 index 0000000..c67324e --- /dev/null +++ b/jittor/lif_dataset.py @@ -0,0 +1,85 @@ +from pathlib import Path +from jittor.dataset import Dataset +import json +import numpy as np +import random +from utils.motion_util import Quaternion + + +def perturb_normal(normals, theta_range): + normal_x_1 = np.stack([-normals[:, 1], normals[:, 0], np.zeros_like(normals[:, 0])], axis=1) + normal_x_2 = np.stack([-normals[:, 2], np.zeros_like(normals[:, 0]), normals[:, 0]], axis=1) + normal_x_mask = np.abs(np.abs(normals[:, 2]) - 1.0) > 0.1 + normal_x = np.zeros_like(normals) + normal_x[normal_x_mask] = normal_x_1[normal_x_mask] + normal_x[~normal_x_mask] = normal_x_2[~normal_x_mask] + normal_x /= np.linalg.norm(normal_x, axis=1, keepdims=True) + normal_y = np.cross(normals, normal_x) + + phi = np.random.rand(normal_x.shape[0], 1) * 2.0 * np.pi + phi_dir = np.cos(phi) * normal_x + np.sin(phi) * normal_y + theta = np.random.rand(normal_x.shape[0], 1) * theta_range + perturbed_normal = np.cos(theta) * normals + np.sin(theta) * phi_dir + return perturbed_normal + + +class LifDataset(Dataset): + def __init__(self, data_path, num_sample, num_surface_sample: int = 0, augment_rotation=None, + augment_noise=(0.0, 0.0), batch_size=1): + super().__init__() + self.batch_size = batch_size + self.shuffle = True + self.num_workers = 8 + self.drop_last = True + + self.data_path = Path(data_path) + with (self.data_path / "source.json").open() as f: + self.data_sources = json.load(f) + self.num_sample = num_sample + self.num_surface_sample = num_surface_sample + self.surface_format = None + self.augment_rotation = augment_rotation + self.augment_noise = augment_noise + self.total_len = len(self.data_sources) + + def get_raw_data(self, idx): + sdf_path = self.data_path / "payload" / ("%08d.npz" % idx) + return np.load(sdf_path) + + def __getitem__(self, idx): + if idx < 0: + assert -idx <= len(self) + idx = len(self) + idx + + # Load data + lif_raw = self.get_raw_data(idx) + lif_data = lif_raw["data"] + lif_surface = None + if self.num_surface_sample > 0: + lif_surface = lif_raw["surface"] + + pos_mask = lif_data[:, 3] > 0 + pos_tensor = lif_data[pos_mask] + neg_tensor = lif_data[~pos_mask] + half = int(self.num_sample / 2) + random_pos = (np.random.rand(half) * pos_tensor.shape[0]).astype(int) + random_neg = (np.random.rand(half) * neg_tensor.shape[0]).astype(int) + sample_pos = pos_tensor[random_pos] + sample_neg = neg_tensor[random_neg] + samples = np.concatenate([sample_pos, sample_neg], 0) + + lif_surface = lif_surface[np.random.choice(lif_surface.shape[0], size=self.num_surface_sample, replace=True), :] + + # Data augmentation + if self.augment_rotation is not None: + base_rot = random.choice([0.0, 90.0, 180.0, 270.0]) + rand_rot = Quaternion(axis=[0.0, 1.0, 0.0], degrees=base_rot + 30.0 * random.random()) + samples[:, 0:3] = samples[:, 0:3] @ rand_rot.rotation_matrix.T.astype(np.float32) + lif_surface[:, :3] = lif_surface[:, :3] @ rand_rot.rotation_matrix.T.astype(np.float32) + lif_surface[:, 3:6] = lif_surface[:, 3:6] @ rand_rot.rotation_matrix.T.astype(np.float32) + + if self.augment_noise[0] > 0.0: + lif_surface[:, :3] += np.random.randn(lif_surface.shape[0], 3) * self.augment_noise[0] + lif_surface[:, 3:6] = perturb_normal(lif_surface[:, 3:6], np.deg2rad(self.augment_noise[1])) + + return samples.astype(np.float32), lif_surface.astype(np.float32), idx diff --git a/jittor/lr_schedule.py b/jittor/lr_schedule.py new file mode 100644 index 0000000..cd1b061 --- /dev/null +++ b/jittor/lr_schedule.py @@ -0,0 +1,72 @@ + +class LearningRateSchedule: + def get_learning_rate(self, epoch): + pass + + +class ConstantLearningRateSchedule(LearningRateSchedule): + def __init__(self, value): + self.value = value + + def get_learning_rate(self, epoch): + return self.value + + +class StepLearningRateSchedule(LearningRateSchedule): + def __init__(self, initial, interval, factor): + self.initial = initial + self.interval = interval + self.factor = factor + + def get_learning_rate(self, epoch): + return self.initial * (self.factor ** (epoch // self.interval)) + + +class WarmupLearningRateSchedule(LearningRateSchedule): + def __init__(self, initial, warmed_up, length): + self.initial = initial + self.warmed_up = warmed_up + self.length = length + + def get_learning_rate(self, epoch): + if epoch > self.length: + return self.warmed_up + return self.initial + (self.warmed_up - self.initial) * epoch / self.length + + +def get_learning_rate_schedules(args): + schedule_specs = args.lr_schedule + schedules = [] + for schedule_specs in schedule_specs: + if schedule_specs["Type"] == "Step": + schedules.append( + StepLearningRateSchedule( + schedule_specs["Initial"], + schedule_specs["Interval"], + schedule_specs["Factor"], + ) + ) + elif schedule_specs["Type"] == "Warmup": + schedules.append( + WarmupLearningRateSchedule( + schedule_specs["Initial"], + schedule_specs["Final"], + schedule_specs["Length"], + ) + ) + elif schedule_specs["Type"] == "Constant": + schedules.append(ConstantLearningRateSchedule(schedule_specs["Value"])) + + else: + raise Exception( + 'no known learning rate schedule of type "{}"'.format( + schedule_specs["Type"] + ) + ) + + return schedules + + +def adjust_learning_rate(lr_schedules, optimizer, epoch): + for i, param_group in enumerate(optimizer.param_groups): + param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) diff --git a/jittor/network.py b/jittor/network.py new file mode 100644 index 0000000..4cdf19f --- /dev/null +++ b/jittor/network.py @@ -0,0 +1,80 @@ +import math +import jittor as jt +import jittor.nn as nn + + +SHAPE_MULT = 1 + + +class WNLinear(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super(WNLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_v = nn.init.invariant_uniform((out_features, in_features), "float32") + self.weight_g = jt.norm(self.weight_v, k=2, dim=1, keepdim=True) + bound = 1.0 / math.sqrt(in_features) + self.bias = nn.init.uniform((out_features,), "float32", -bound, bound) if bias else None + + def execute(self, x): + weight = self.weight_g * (self.weight_v / jt.norm(self.weight_v, k=2, dim=1, keepdim=True)) + x = nn.matmul_transpose(x, weight) + if self.bias is not None: + return x + self.bias + return x + + +class DIDecoder(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = WNLinear(32, 128 * SHAPE_MULT) + self.lin1 = WNLinear(128 * SHAPE_MULT, 128 * SHAPE_MULT) + self.lin2 = WNLinear(128 * SHAPE_MULT, 128 * SHAPE_MULT - 32) + self.lin3 = WNLinear(128 * SHAPE_MULT, 128 * SHAPE_MULT) + self.lin4 = WNLinear(128 * SHAPE_MULT, 1) + self.uncertainty_layer = nn.Linear(128 * SHAPE_MULT, 1) + self.relu = nn.ReLU() + self.dropout = [0, 1, 2, 3, 4, 5] + self.th = nn.Tanh() + + def execute(self, ipt): + x = self.lin0(ipt) + x = self.relu(x) + x = nn.dropout(x, p=0.2, is_train=True) + + x = self.lin1(x) + x = self.relu(x) + x = nn.dropout(x, p=0.2, is_train=True) + + x = self.lin2(x) + x = self.relu(x) + x = nn.dropout(x, p=0.2, is_train=True) + + x = jt.contrib.concat([x, ipt], 1) + x = self.lin3(x) + x = self.relu(x) + x = nn.dropout(x, p=0.2, is_train=True) + + std = self.uncertainty_layer(x) + std = 0.05 + 0.5 * nn.softplus(std) + x = self.lin4(x) + x = self.th(x) + + return x, std + + +class DIEncoder(nn.Module): + def __init__(self): + super().__init__() + self.mlp = nn.Sequential( + nn.Conv1d(6, 32 * SHAPE_MULT, kernel_size=1, bias=False), nn.BatchNorm1d(32 * SHAPE_MULT), nn.ReLU(), + nn.Conv1d(32 * SHAPE_MULT, 64 * SHAPE_MULT, kernel_size=1, bias=False), nn.BatchNorm1d(64 * SHAPE_MULT), nn.ReLU(), + nn.Conv1d(64 * SHAPE_MULT, 256 * SHAPE_MULT, kernel_size=1, bias=False), nn.BatchNorm1d(256 * SHAPE_MULT), nn.ReLU(), + nn.Conv1d(256 * SHAPE_MULT, 29, kernel_size=1, bias=True) + ) + + def execute(self, x): + x = x.transpose([0, 2, 1]) + x = self.mlp(x) # (B, L, N) + r = jt.mean(x, dim=-1) + return r diff --git a/jittor/sampler_cuda/CMakeLists.txt b/jittor/sampler_cuda/CMakeLists.txt new file mode 100644 index 0000000..eebb407 --- /dev/null +++ b/jittor/sampler_cuda/CMakeLists.txt @@ -0,0 +1,17 @@ +project(Sampler LANGUAGES CXX CUDA) +cmake_minimum_required(VERSION 3.8) + +find_package(CLI11 CONFIG REQUIRED) +find_package(Eigen3 3.3 REQUIRED) +find_package(Pangolin REQUIRED) +find_package(flann REQUIRED) + +add_executable(PreprocessMeshCUDA PreprocessMesh.cu ShaderProgram.cpp Utils.cu) +target_link_libraries(PreprocessMeshCUDA PRIVATE Eigen3::Eigen CLI11::CLI11 pangolin flann_cuda -lcurand) +target_compile_features(PreprocessMeshCUDA PRIVATE cxx_std_14) +set_target_properties(PreprocessMeshCUDA PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bin") +target_compile_options(PreprocessMeshCUDA PRIVATE + $<$: --use_fast_math> + $<$: -fPIC -O3 -march=native > +) +set_target_properties(PreprocessMeshCUDA PROPERTIES CUDA_SEPARABLE_COMPILATION ON) diff --git a/jittor/sampler_cuda/PreprocessMesh.cu b/jittor/sampler_cuda/PreprocessMesh.cu new file mode 100644 index 0000000..fd1a377 --- /dev/null +++ b/jittor/sampler_cuda/PreprocessMesh.cu @@ -0,0 +1,631 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +// The GPU version is drafted by heiwang1997@github.com + +#define FLANN_USE_CUDA +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include "Utils.h" +#include +#include +#include +#include + +extern pangolin::GlSlProgram GetShaderProgram(); +static const int METHOD2_SAMPLES_MULT = 10; + +__global__ static void SampleUniformKernel(size_t num_uniform_sample, float2 ub_x, float2 ub_y, float2 ub_z, + curandState *rng_state, float4* output) { + unsigned int sample_id = blockIdx.x * blockDim.x + threadIdx.x; + if (sample_id >= num_uniform_sample) { + return; + } + curandState local_rng = rng_state[sample_id]; + + float4 pos; + pos.x = curand_uniform(&local_rng) * (ub_x.y - ub_x.x) + ub_x.x; + pos.y = curand_uniform(&local_rng) * (ub_y.y - ub_y.x) + ub_y.x; + pos.z = curand_uniform(&local_rng) * (ub_z.y - ub_z.x) + ub_z.x; + pos.w = 0.0f; + output[sample_id] = pos; +} + +__device__ __forceinline__ static float4 toFloat4(const Eigen::Vector3f& vec) { + return make_float4(vec(0), vec(1), vec(2), 0.0f); +} + +__global__ static void SampleSurfacePointKernel(size_t num_sample, float4* __restrict__ ref_xyz, + float4* __restrict__ ref_normal, float total_area, curandState *rng_state, + const float* __restrict__ tri_cdf, unsigned char* __restrict__ triangles, size_t triangles_pitch, + size_t num_tris, unsigned char* __restrict__ vertices, size_t vertices_pitch, int n_samples) { + unsigned int sample_id = blockIdx.x * blockDim.x + threadIdx.x; + if (sample_id >= num_sample) { + return; + } + curandState local_rng = rng_state[sample_id]; + + // Generate a uniform number, binary search that in cdf to get triangle id. + const float u = curand_uniform(&local_rng) * total_area; + int tri_id = lower_bound(tri_cdf, u, num_tris); + auto* inds = (uint32_t*) (triangles + triangles_pitch * tri_id); + + // Randomly sample in that triangle. + auto* ap = (float*) (vertices + vertices_pitch * inds[0]); + auto* bp = (float*) (vertices + vertices_pitch * inds[1]); + auto* cp = (float*) (vertices + vertices_pitch * inds[2]); + + Eigen::Vector3f va(ap[0], ap[1], ap[2]); + Eigen::Vector3f vb(bp[0], bp[1], bp[2]); + Eigen::Vector3f vc(cp[0], cp[1], cp[2]); + float4 normal = toFloat4((vb - va).cross(vc - va).normalized()); + +//#pragma unroll + for (int k = 0; k < n_samples; ++k) { + size_t buf_id = sample_id * n_samples + k; + float r1 = curand_uniform(&local_rng); + float r2 = curand_uniform(&local_rng); + float wa = 1 - sqrt(r1); + float wb = (1 - wa) * (1 - r2); + float wc = r2 * (1 - wa); + ref_xyz[buf_id] = toFloat4(wa * va + wb * vb + wc * vc); + ref_normal[buf_id] = normal; + } + + rng_state[sample_id] = local_rng; +} + +__global__ static void SamplePointKernel(size_t num_half_sample, float4* output, + float total_area, curandState *rng_state, const float* __restrict__ tri_cdf, + unsigned char* __restrict__ triangles, size_t triangles_pitch, size_t num_tris, + unsigned char* __restrict__ vertices, size_t vertices_pitch, float small_std, float large_std) { + unsigned int sample_id = blockIdx.x * blockDim.x + threadIdx.x; + if (sample_id >= num_half_sample) { + return; + } + curandState local_rng = rng_state[sample_id]; + + // Generate a uniform number, binary search that in cdf to get triangle id. + const float u = curand_uniform(&local_rng) * total_area; + int tri_id = lower_bound(tri_cdf, u, num_tris); + auto* inds = (uint32_t*) (triangles + triangles_pitch * tri_id); + + // Randomly sample in that triangle. + const float r1 = curand_uniform(&local_rng); + const float r2 = curand_uniform(&local_rng); + + auto* ap = (float*) (vertices + vertices_pitch * inds[0]); + auto* bp = (float*) (vertices + vertices_pitch * inds[1]); + auto* cp = (float*) (vertices + vertices_pitch * inds[2]); + + const float wa = 1 - sqrt(r1); + const float wb = (1 - wa) * (1 - r2); + const float wc = r2 * (1 - wa); + + float4 pos1, pos2; + pos1.x = wa * ap[0] + wb * bp[0] + wc * cp[0] + curand_normal(&local_rng) * small_std; + pos1.y = wa * ap[1] + wb * bp[1] + wc * cp[1] + curand_normal(&local_rng) * small_std; + pos1.z = wa * ap[2] + wb * bp[2] + wc * cp[2] + curand_normal(&local_rng) * small_std; + pos1.w = 0.0f; + + pos2.x = wa * ap[0] + wb * bp[0] + wc * cp[0] + curand_normal(&local_rng) * large_std; + pos2.y = wa * ap[1] + wb * bp[1] + wc * cp[1] + curand_normal(&local_rng) * large_std; + pos2.z = wa * ap[2] + wb * bp[2] + wc * cp[2] + curand_normal(&local_rng) * large_std; + pos2.w = 0.0f; + + output[sample_id] = pos1; + output[sample_id + num_half_sample] = pos2; + + rng_state[sample_id] = local_rng; +} + +__global__ static void ComputeSDFKernel(size_t num_samples, int num_votes, const float4* __restrict__ ref_xyz, + const float4* __restrict__ ref_normals, const int* __restrict__ knn_index, + float4* __restrict__ sample_xyz, float stdv, float max_ref_dist) { + unsigned int sample_id = blockIdx.x * blockDim.x + threadIdx.x; + if (sample_id >= num_samples) { + return; + } + + float4 cur_xyz = sample_xyz[sample_id]; + Eigen::Vector3f sample_pos(cur_xyz.x, cur_xyz.y, cur_xyz.z); + + float sdf; + int num_pos = 0; + for (int vote_i = 0; vote_i < num_votes; ++vote_i) { + int cur_ind = knn_index[sample_id * num_votes + vote_i]; + float4 nb_xyz = ref_xyz[cur_ind]; + float4 nb_normal = ref_normals[cur_ind]; + + Eigen::Vector3f nb_pos(nb_xyz.x, nb_xyz.y, nb_xyz.z); + Eigen::Vector3f nb_norm(nb_normal.x, nb_normal.y, nb_normal.z); + Eigen::Vector3f ray_vec = sample_pos - nb_pos; + float ray_vec_len = ray_vec.norm(); + + // SDF value will take the first vote. (nearest neighbour) + if (vote_i == 0) { + if (ray_vec_len > max_ref_dist) { + // Just invalidate this point. + num_pos = 1; + break; + } + if (ray_vec_len < stdv) { + sdf = abs(nb_norm.dot(ray_vec)); + } else { + sdf = ray_vec_len; + } + } + float d = nb_norm.dot(ray_vec / ray_vec_len); + if (d > 0) { + num_pos += 1; + } + } + + if (num_pos == 0) { + sample_xyz[sample_id].w = -sdf; + } else if (num_pos == num_votes) { + sample_xyz[sample_id].w = sdf; + } else { + sample_xyz[sample_id].w = NAN; + } +} + +struct Keep3Functor { + __host__ __device__ float4 operator()(const float4 &x) const + {return make_float4(x.x, x.y, x.z, 0.0);} +}; + +struct ValidWFunctor { + __host__ __device__ bool operator()(const float4 &x) const + { + return !isnan(x.w); + } +}; + +void GenerateSDFSamples( + int sample_method, + pangolin::GlGeometry &geom, + int num_half_surface_sample, + int num_uniform_sample, + thrust::device_vector& ref_xyz, + thrust::device_vector& ref_normal, + float var_small, float var_large, float max_ref_dist, + int num_votes, + thrust::host_vector& valid_data, + float2 ub_x, float2 ub_y, float2 ub_z) { + + //// Generate sampled points, by first sample from triangle and then perturb. + const int num_surface_sample = num_half_surface_sample * 2; + const int num_total_sample = num_surface_sample + num_uniform_sample; + + // Map OpenGL geometry into cuda. + cudaGraphicsResource_t vbo_handle; + cudaGraphicsResource_t ibo_handle; + unsigned char* vbo_data; size_t vbo_nbytes; + size_t vbo_stride = geom.buffers["geometry"].attributes["vertex"].stride_bytes; + unsigned char* ibo_data; size_t ibo_nbytes; size_t ibo_stride = 0; + size_t num_tris = 0; + + cudaSafeCall(cudaGraphicsGLRegisterBuffer(&vbo_handle, geom.buffers["geometry"].bo, cudaGraphicsMapFlagsReadOnly)); + for (const auto& object : geom.objects) { + // assert: object.first == "mesh" + auto it_vert_indices = object.second.attributes.find("vertex_indices"); + if (it_vert_indices != object.second.attributes.end()) { + cudaSafeCall(cudaGraphicsGLRegisterBuffer(&ibo_handle, object.second.bo, cudaGraphicsMapFlagsReadOnly)); + ibo_stride = it_vert_indices->second.stride_bytes; + num_tris = it_vert_indices->second.num_elements; + break; + } + } + cudaSafeCall(cudaGraphicsMapResources(1, &vbo_handle)); + cudaSafeCall(cudaGraphicsMapResources(1, &ibo_handle)); + cudaSafeCall(cudaGraphicsResourceGetMappedPointer((void**)&vbo_data, &vbo_nbytes, vbo_handle)); + cudaSafeCall(cudaGraphicsResourceGetMappedPointer((void**)&ibo_data, &ibo_nbytes, ibo_handle)); + + // Compute triangle areas. + thrust::device_vector tri_area(num_tris); + { + dim3 dimBlock = dim3(128); + dim3 dimGrid = dim3((num_tris + dimBlock.x - 1) / dimBlock.x); + TriangleAreaKernel<<>>(vbo_data, vbo_stride, ibo_data, ibo_stride, num_tris, tri_area.data().get()); + cudaKernelCheck + } + + // Convert to CDF. + thrust::inclusive_scan(tri_area.begin(), tri_area.end(), tri_area.begin()); + float total_area = tri_area.back(); + + // Allocate space for RNG. + curandState* p_rng_state; + int num_rng = std::max(num_half_surface_sample, num_uniform_sample); + cudaSafeCall(cudaMalloc((void**)&p_rng_state, num_rng * sizeof(curandState))); + { + dim3 dimBlock = dim3(256); + dim3 dimGrid = dim3((num_rng + dimBlock.x - 1) / dimBlock.x); + RNGSetupKernel<<>>(p_rng_state, num_rng); + cudaKernelCheck + } + + // If use method2, just make samples and write to ref_xyz and ref_normal + if (sample_method == 2) { + ref_xyz.resize(num_rng * METHOD2_SAMPLES_MULT); + ref_normal.resize(num_rng * METHOD2_SAMPLES_MULT); + dim3 dimBlock(256); + dim3 dimGrid((num_rng + dimBlock.x - 1) / dimBlock.x); + SampleSurfacePointKernel<<>>(num_rng, ref_xyz.data().get(), ref_normal.data().get(), + total_area, p_rng_state, tri_area.data().get(), ibo_data, ibo_stride, num_tris, vbo_data, vbo_stride, METHOD2_SAMPLES_MULT); + cudaKernelCheck + } + + // Sample and perturb surface samples according to CDF. + thrust::device_vector sampled_points(num_total_sample); + { + dim3 dimBlock = dim3(256); + dim3 dimGrid = dim3((num_half_surface_sample + dimBlock.x - 1) / dimBlock.x); + SamplePointKernel<<>>(num_half_surface_sample, sampled_points.data().get(), + total_area, p_rng_state, tri_area.data().get(), ibo_data, ibo_stride, num_tris, vbo_data, vbo_stride, + std::sqrt(var_small), std::sqrt(var_large)); + cudaKernelCheck + } + + cudaSafeCall(cudaGraphicsUnmapResources(1, &vbo_handle)); + cudaSafeCall(cudaGraphicsUnmapResources(1, &ibo_handle)); + cudaSafeCall(cudaGraphicsUnregisterResource(vbo_handle)); + cudaSafeCall(cudaGraphicsUnregisterResource(ibo_handle)); + + // Also, add uniform samples. + if (num_uniform_sample > 0) { + dim3 dimBlock = dim3(256); + dim3 dimGrid = dim3((num_uniform_sample + dimBlock.x - 1) / dimBlock.x); + SampleUniformKernel<<>>(num_uniform_sample, ub_x, ub_y, ub_z, p_rng_state, + sampled_points.data().get() + num_surface_sample); + cudaKernelCheck + } + cudaSafeCall(cudaFree(p_rng_state)); + + //// Query all the generated samples to the view-sampled geometry, to get the sdf value. + thrust::transform(ref_xyz.begin(), ref_xyz.end(), ref_xyz.begin(), Keep3Functor()); + // thrust::transform(sampled_points.begin(), sampled_points.end(), sampled_points.begin(), Keep3Functor()); + + // Do kNN to retrieve nearest idx for all queries. + thrust::device_vector dist(num_total_sample * num_votes); + thrust::device_vector indices(num_total_sample * num_votes); + + std::cout << ref_xyz.size() << ", " << ref_normal.size() << std::endl; + flann::Matrix knn_ref((float*)ref_xyz.data().get(), ref_xyz.size(), 3, 4 * sizeof(float)); + flann::KDTreeCuda3dIndexParams knn_params; + knn_params["input_is_gpu_float4"] = true; + flann::KDTreeCuda3dIndex > knn_index(knn_ref, knn_params); + knn_index.buildIndex(); + flann::Matrix knn_dist((float*)dist.data().get(), num_total_sample, num_votes); + flann::Matrix knn_indices((int*)indices.data().get(), num_total_sample, num_votes); + flann::Matrix knn_query((float*)sampled_points.data().get(), num_total_sample, 3, 4 * sizeof(float)); + flann::SearchParams params; + params.matrices_in_gpu_ram = true; + params.sorted = true; + knn_index.knnSearch(knn_query, knn_indices, knn_dist, num_votes, params); + + // Compute SDF for the samples. Invalid samples' sdf will be marked NaN. + { + dim3 dimBlock = dim3(256); + dim3 dimGrid = dim3((num_total_sample + dimBlock.x - 1) / dimBlock.x); + ComputeSDFKernel<<>>(num_total_sample, num_votes, ref_xyz.data().get(), ref_normal.data().get(), + indices.data().get(), sampled_points.data().get(), + std::sqrt(var_small), max_ref_dist); + cudaKernelCheck + } + + // Copy and delete all invalid sdfs. + thrust::device_vector sampled_points_valid(num_total_sample); + auto result_end = thrust::copy_if(sampled_points.begin(), sampled_points.end(), + sampled_points_valid.begin(), ValidWFunctor()); + valid_data.resize(thrust::distance(sampled_points_valid.begin(), result_end)); + thrust::copy(sampled_points_valid.begin(), result_end, valid_data.begin()); +} + +struct TriTestFunctor { + __host__ __device__ int operator()(const int& x, const int& y) const { + return (abs(x) == y) ? y : -1; + } +}; + +struct PlusPositiveFunctor { + __host__ __device__ int operator()(const int &lhs, const int &rhs) const { + if (lhs < 0) return rhs; + if (rhs < 0) return lhs; + return lhs + rhs; + } +}; + +int main(int argc, char **argv) { + std::string meshFileName; + bool vis = false; + + std::string outputFileName; + std::string cameraFileName; + std::string surfaceFileName; // For output of sampled surface. + std::string referenceFileName; + float variance = 0.005; + int num_sample = 500000; + float num_samp_near_surf_ratio = 47.0f / 50.0f; + float uniform_sample_bbox_expand = 1.2f; + int reference_method = 1; + float max_ref_dist = 1.0e8f; + + CLI::App app{"PreprocessMesh"}; + app.add_option("-m", meshFileName, "Mesh File Name for Reading")->required(); + app.add_option("-o", outputFileName, "Output file path")->required(); + app.add_option("--surface", surfaceFileName, "Output surface file path")->required(); + app.add_option("-s", num_sample, "Number of attempted samples"); + app.add_option("-p", num_samp_near_surf_ratio, "Portion of near surface"); + app.add_option("-e", uniform_sample_bbox_expand, "Expansion of the bounding box for uniform sampling"); + app.add_option("--var", variance, "Set Variance"); + app.add_option("-r", reference_method, "Method 1 is camera. Method 2 is mesh normal. Method 3 is reference points.")->required(); + app.add_option("-c", cameraFileName, "Name of the camera definition (required for ref-method 1)."); + app.add_option("--ref", referenceFileName, "Name of the reference file."); + app.add_option("--max_ref_dist", max_ref_dist, "Maximum reference dist to prune some invalid data."); + app.add_flag("-v", vis, "enable visualization"); + + CLI11_PARSE(app, argc, argv); + + if (reference_method == 1 && cameraFileName.empty()) { + std::cout << "Camera not provided!" << std::endl; + return -1; + } + + float second_variance = variance / 5; + std::cout << "variance: " << variance << " second: " << second_variance << std::endl; + + pangolin::Geometry geom = pangolin::LoadGeometry(meshFileName); + std::cout << geom.objects.size() << " objects" << std::endl; + + // linearize the object indices + LinearizeObject(geom); + + // remove textures + geom.textures.clear(); + + // Get bounding boxes. + float2 ub_x, ub_y, ub_z; + ComputeNormalizationParameters(geom, uniform_sample_bbox_expand, ub_x, ub_y, ub_z); + + // Get surface samples and normals + thrust::device_vector point_normals; + thrust::device_vector point_verts; + + if (vis) + pangolin::CreateWindowAndBind("Main", 640, 480); + else + pangolin::CreateWindowAndBind("Main", 1, 1); + + glPixelStorei(GL_UNPACK_ALIGNMENT, 1); + glPixelStorei(GL_UNPACK_ROW_LENGTH, 0); + glPixelStorei(GL_UNPACK_SKIP_PIXELS, 0); + glPixelStorei(GL_UNPACK_SKIP_ROWS, 0); + glEnable(GL_DEPTH_TEST); + glDisable(GL_DITHER); + glDisable(GL_POINT_SMOOTH); + glDisable(GL_LINE_SMOOTH); + glDisable(GL_POLYGON_SMOOTH); + glHint(GL_POINT_SMOOTH, GL_DONT_CARE); + glHint(GL_LINE_SMOOTH, GL_DONT_CARE); + glHint(GL_POLYGON_SMOOTH_HINT, GL_DONT_CARE); + glDisable(GL_MULTISAMPLE_ARB); + glShadeModel(GL_FLAT); + + // Check if OpenGL direct resource mapping is usable. (interop with CUDA) + // This may fail for software rendering (e.g. when using remote-glx). + { + unsigned int nGLDevCount; + int cudaDevices[4]; + cudaSafeCall(cudaGLGetDevices(&nGLDevCount, cudaDevices, 4, cudaGLDeviceListAll)); + if (nGLDevCount == 0) { + std::cerr << "No OpenGL hardware found." << std::endl; + return -1; + } + } + + // Map geometry to gpu. + pangolin::GlGeometry gl_geom = pangolin::ToGlGeometry(geom); + + if (reference_method == 1) { + // Method1: Use camera to render the mesh. Take inverse camera ray as normal + pangolin::Image modelFaces = pangolin::get>( + geom.objects.begin()->second.attributes["vertex_indices"]); + size_t num_tri = modelFaces.h; + + // Load in camera definition + float max_dist; + float z_extent[2]; + std::vector view_matrices; + { + std::ifstream fin(cameraFileName, std::ios::in | std::ios::binary); + if (!fin) { + std::cerr << "File " << cameraFileName << " not found!" << std::endl; + return -1; + } + fin.read((char*)&max_dist, 4); + fin.read((char*)z_extent, sizeof(float) * 2); + float buffer[16]; + while (true) { + fin.read((char*)buffer, sizeof(float) * 16); + if (!fin) break; + view_matrices.emplace_back(); + auto& m = view_matrices.back(); + for (int i = 0; i < 16; ++i) { + m.m[i] = buffer[i]; + } + } + std::cout << "Available cameras = " << view_matrices.size() << std::endl; + } + + // Define Projection and initial ModelView matrix + pangolin::OpenGlRenderState s_cam2( + pangolin::ProjectionMatrixOrthographic(-max_dist, max_dist, max_dist, -max_dist, z_extent[0], z_extent[1]), + pangolin::ModelViewLookAt(0, 0, -1, 0, 0, 0, pangolin::AxisY)); + + // Create Interactive View in window + pangolin::GlSlProgram prog = GetShaderProgram(); + + if (vis) { + pangolin::OpenGlRenderState s_cam( + pangolin::ProjectionMatrixOrthographic(-max_dist, max_dist, -max_dist, max_dist, z_extent[0], z_extent[1]), + pangolin::ModelViewLookAt(0, 0, -1, 0, 0, 0, pangolin::AxisY)); + s_cam.SetModelViewMatrix(view_matrices[0]); + pangolin::Handler3D handler(s_cam); + pangolin::View &d_cam = pangolin::CreateDisplay() + .SetBounds(0.0, 1.0, 0.0, 1.0, -640.0f / 480.0f) + .SetHandler(&handler); + + while (!pangolin::ShouldQuit()) { + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + d_cam.Activate(s_cam); + prog.Bind(); + prog.SetUniform("MVP", s_cam.GetProjectionModelViewMatrix()); + prog.SetUniform("V", s_cam.GetModelViewMatrix()); + pangolin::GlDraw(prog, gl_geom, nullptr); + prog.Unbind(); + + // Swap frames and Process Events + pangolin::FinishFrame(); + } + } + + // Create Framebuffer with attached textures + size_t w = 400; + size_t h = 400; + size_t wh = w * h; + pangolin::GlRenderBuffer zbuffer(w, h, GL_DEPTH_COMPONENT32); + pangolin::GlTexture normals(w, h, GL_RGBA32F); + pangolin::GlTexture vertices(w, h, GL_RGBA32F); + pangolin::GlFramebuffer framebuffer(vertices, normals, zbuffer); + + // Register CUDA buffer. + cudaGraphicsResource_t cudaResourceNormals; + cudaGraphicsResource_t cudaResourceVertices; + cudaSafeCall(cudaGraphicsGLRegisterImage(&cudaResourceNormals, normals.tid, GL_TEXTURE_2D, cudaGraphicsRegisterFlagsReadOnly)); + cudaSafeCall(cudaGraphicsGLRegisterImage(&cudaResourceVertices, vertices.tid, GL_TEXTURE_2D, cudaGraphicsRegisterFlagsReadOnly)); + + // Thrust container for normals and vertices; + unsigned int valid_point_num = 0; + point_normals.resize(wh * 2); + point_verts.resize(wh * 2); + + thrust::device_vector tri_pos(num_tri, 0); + thrust::device_vector tri_all(num_tri, 0); + + for (unsigned int v = 0; v < view_matrices.size(); v++) { + // change camera location + s_cam2.SetModelViewMatrix(view_matrices[v]); + // Draw the scene to the framebuffer + framebuffer.Bind(); + glViewport(0, 0, w, h); + glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT); + + prog.Bind(); + prog.SetUniform("MVP", s_cam2.GetProjectionModelViewMatrix()); + prog.SetUniform("V", s_cam2.GetModelViewMatrix()); + pangolin::GlDraw(prog, gl_geom, nullptr); + prog.Unbind(); + + framebuffer.Unbind(); + + // Map Resource + cudaArray_t normals_tex_arr, verts_tex_arr; + cudaSafeCall(cudaGraphicsMapResources(1, &cudaResourceNormals)); + cudaSafeCall(cudaGraphicsMapResources(1, &cudaResourceVertices)); + cudaSafeCall(cudaGraphicsSubResourceGetMappedArray(&normals_tex_arr, cudaResourceNormals, 0, 0)); + cudaSafeCall(cudaGraphicsSubResourceGetMappedArray(&verts_tex_arr, cudaResourceVertices, 0, 0)); + + if (point_normals.size() - valid_point_num < wh) { + point_normals.resize(valid_point_num + wh); + point_verts.resize(valid_point_num + wh); + } + auto view_points = ValidPointsNormalCUDA(normals_tex_arr, verts_tex_arr, normals.width, normals.height, tri_pos, tri_all, + point_normals.data().get() + valid_point_num, point_verts.data().get() + valid_point_num); + valid_point_num += view_points; + + // Unmap Resource + cudaSafeCall(cudaGraphicsUnmapResources(1, &cudaResourceNormals)); + cudaSafeCall(cudaGraphicsUnmapResources(1, &cudaResourceVertices)); + } + + // Unregister normal and vertex buffer. + cudaSafeCall(cudaGraphicsUnregisterResource(cudaResourceNormals)); + cudaSafeCall(cudaGraphicsUnregisterResource(cudaResourceVertices)); + + point_normals.resize(valid_point_num); + point_verts.resize(valid_point_num); + } else if (reference_method == 2) { + // Method2: Believe in mesh normal. + // Do nothing here. + } else { + // Method3: Load reference points captured outside. + thrust::host_vector cpu_point_verts; + thrust::host_vector cpu_point_normals; + { + std::ifstream fin(referenceFileName, std::ios::in | std::ios::binary); + int point_count; + fin.read((char*)&point_count, sizeof(int)); +// std::cout << "Point Count: " << point_count << std::endl; + cpu_point_verts.resize(point_count); + cpu_point_normals.resize(point_count); + fin.read((char*)cpu_point_verts.data(), sizeof(float4) * point_count); + fin.read((char*)cpu_point_normals.data(), sizeof(float4) * point_count); + } + point_verts = cpu_point_verts; + point_normals = cpu_point_normals; + } + + int num_samp_near_surf = num_sample * num_samp_near_surf_ratio; + std::cout << "num_samp_near_surf: " << num_samp_near_surf << std::endl; + + thrust::host_vector sampled_data; + + auto start = std::chrono::high_resolution_clock::now(); + GenerateSDFSamples(reference_method, gl_geom, num_samp_near_surf / 2, num_sample - num_samp_near_surf, + point_verts, point_normals, variance, second_variance, max_ref_dist, 11, sampled_data, ub_x, ub_y, ub_z); + auto finish = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(finish - start).count(); + std::cout << elapsed << std::endl; + + std::cout << "num points sampled: " << sampled_data.size() << std::endl; + + // Write raw file. + { + std::ofstream fout(outputFileName, std::ios::out | std::ios::binary); + fout.write((char*) sampled_data.data(), sizeof(float4) * sampled_data.size()); + fout.close(); + } + + // Write surface file. + if (!surfaceFileName.empty()) { + std::ofstream fout(surfaceFileName, std::ios::out | std::ios::binary); + thrust::host_vector cpu_point_verts = point_verts; + thrust::host_vector cpu_point_normals = point_normals; + // NOTE: Here sampling method1 is also downsampled. + int increment = reference_method > 2 ? 1 : METHOD2_SAMPLES_MULT; + + for (int i = 0; i < cpu_point_normals.size(); i += increment) { + fout.write((char*) (cpu_point_verts.data() + i), sizeof(float3)); + fout.write((char*) (cpu_point_normals.data() + i), sizeof(float3)); + } + fout.close(); + } + + return 0; +} diff --git a/jittor/sampler_cuda/ShaderProgram.cpp b/jittor/sampler_cuda/ShaderProgram.cpp new file mode 100644 index 0000000..614ef8c --- /dev/null +++ b/jittor/sampler_cuda/ShaderProgram.cpp @@ -0,0 +1,134 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include + +constexpr const char* shaderText = R"Shader( +@start vertex +#version 330 core + +layout(location = 0) in vec3 vertex; + +out vec4 position_world; +out vec4 position_camera; +out vec3 viewDirection_camera; + +uniform mat4 MVP; +uniform mat4 V; + +void main(){ + + // Projected image coordinate + gl_Position = MVP * vec4(vertex,1); + + // world coordinate location of the vertex + position_world = vec4(vertex,1); + position_camera = V * vec4(vertex, 1); + + viewDirection_camera = normalize(vec3(0,0,0) - position_camera.xyz); + +} + +@start geometry +#version 330 + +layout ( triangles ) in; +layout ( triangle_strip, max_vertices = 3 ) out; + +in vec4 position_world[]; +in vec3 viewDirection_camera[]; + +out vec3 normal_camera; +out vec3 normal_world; +out vec4 xyz_world; +out vec3 viewDirection_cam; +out vec4 xyz_camera; + +uniform mat4 V; + +void main() +{ + vec3 A = position_world[1].xyz - position_world[0].xyz; + vec3 B = position_world[2].xyz - position_world[0].xyz; + vec3 normal = normalize(cross(A,B)); + vec3 normal_cam = (V * vec4(normal,0)).xyz; + + gl_Position = gl_in[0].gl_Position; + normal_camera = normal_cam; + normal_world = normal; + xyz_world = position_world[0]; + xyz_camera = V * xyz_world; + viewDirection_cam = viewDirection_camera[0]; + gl_PrimitiveID = gl_PrimitiveIDIn; + EmitVertex(); + + gl_Position = gl_in[1].gl_Position; + normal_camera = normal_cam; + normal_world = normal; + xyz_world = position_world[1]; + xyz_camera = V * xyz_world; + viewDirection_cam = viewDirection_camera[1]; + gl_PrimitiveID = gl_PrimitiveIDIn; + + EmitVertex(); + + gl_Position = gl_in[2].gl_Position; + normal_camera = normal_cam; + normal_world = normal; + xyz_world = position_world[2]; + xyz_camera = V * xyz_world; + viewDirection_cam = viewDirection_camera[2]; + gl_PrimitiveID = gl_PrimitiveIDIn; + + EmitVertex(); + EndPrimitive(); +} + +@start fragment +#version 330 core + +in vec3 viewDirection_cam; +in vec3 normal_world; +in vec3 normal_camera; +in vec4 xyz_world; +in vec4 xyz_camera; +in int gl_PrimitiveID ; + +layout(location = 0) out vec4 out_xyz; +layout(location = 1) out vec4 out_normal; + +void main(){ + vec3 view_vector = vec3(0,0,1); + + // Check if we need to flip the normal. + vec3 normal_world_cor; + float d = dot(normalize(normal_camera), normalize(view_vector)); + + if (abs(d) < 0.001) { + out_xyz = vec4(0,0,0,0); + out_normal = vec4(0,0,0,0); + return; + } + else{ + if (d < 0) { + normal_world_cor = -normal_world; + } else { + normal_world_cor= normal_world; + } + + out_xyz = xyz_world; + out_xyz.w = gl_PrimitiveID + 1.0f; + + out_normal = vec4(normalize(normal_world_cor),1); + out_normal.w = gl_PrimitiveID + 1.0f; + } +} +)Shader"; + +pangolin::GlSlProgram GetShaderProgram() { + pangolin::GlSlProgram program; + + program.AddShader(pangolin::GlSlAnnotatedShader, shaderText); + program.Link(); + + return program; +} diff --git a/jittor/sampler_cuda/Utils.cu b/jittor/sampler_cuda/Utils.cu new file mode 100644 index 0000000..db832c3 --- /dev/null +++ b/jittor/sampler_cuda/Utils.cu @@ -0,0 +1,249 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include "Utils.h" + +#include + +std::vector EquiDistPointsOnSphere(const uint numSamples, const float radius) { + std::vector points(numSamples); + const float offset = 2.f / numSamples; + + const float increment = static_cast(M_PI) * (3.f - std::sqrt(5.f)); + + for (uint i = 0; i < numSamples; i++) { + const float y = ((i * offset) - 1) + (offset / 2); + const float r = std::sqrt(1 - std::pow(y, 2.f)); + + const float phi = (i + 1.f) * increment; + + const float x = cos(phi) * r; + const float z = sin(phi) * r; + + points[i] = radius * Eigen::Vector3f(x, y, z); + } + + return points; +} + +__global__ static void ValidPointsNormalKernel(cudaTextureObject_t normals, cudaTextureObject_t verts, + unsigned int img_width, unsigned int img_height, int* __restrict__ tri_pos, int* __restrict__ tri_total, + float4* __restrict__ output_normals, float4* __restrict__ output_xyz, int* __restrict__ output_count) { + unsigned int x = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int y = blockIdx.y * blockDim.y + threadIdx.y; + if (x >= img_width || y >= img_height) return; + + float4 normal = tex2D(normals, (float) x, (float) y); + float4 xyz = tex2D(verts, (float) x, (float) y); + + if (normal.w == 0.0f || xyz.w == 0.0f) { + return; + } + unsigned int triInd = (unsigned int)(normal.w + 0.01f) - 1; + + // Compute a proxy for normal direction. + // Nearby triangles tend to share similar normals so this branching is warp-efficient. + int normal_dir = normal.x > 0 ? 1 : -1; + if (fabs(normal.x) < 1e-6) { + normal_dir = normal.y > 0 ? 1 : -1; + if (fabs(normal.y) < 1e-6) { + normal_dir = normal.z > 0 ? 1 : -1; + } + } + atomicAdd(tri_total + triInd, 1); + atomicAdd(tri_pos + triInd, normal_dir); + + // Gather all data. This step can be largely accelerated using a compaction algorithm. + int idx = atomicAdd(output_count, 1); + output_normals[idx] = normal; + output_xyz[idx] = xyz; +} + +unsigned int ValidPointsNormalCUDA(cudaArray_t normals, cudaArray_t verts, unsigned int img_width, unsigned int img_height, + thrust::device_vector& tri_pos, thrust::device_vector& tri_total, float4* output_normals, float4* output_xyz) { + + cudaResourceDesc normal_res_desc{}; // Will init all fields to 0. + cudaResourceDesc vert_res_desc{}; + normal_res_desc.resType = vert_res_desc.resType = cudaResourceTypeArray; + normal_res_desc.res.array.array = normals; + vert_res_desc.res.array.array = verts; + + // See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#texture-object-api for default values. + cudaTextureDesc normal_tex_desc{}; + cudaTextureDesc vert_tex_desc{}; + + cudaTextureObject_t normal_tex_obj, vert_tex_obj; + cudaCreateTextureObject(&normal_tex_obj, &normal_res_desc, &normal_tex_desc, nullptr); + cudaCreateTextureObject(&vert_tex_obj, &vert_res_desc, &vert_tex_desc, nullptr); + + dim3 dimBlock = dim3(16, 16); + unsigned int xBlocks = (img_width + dimBlock.x - 1) / dimBlock.x; + unsigned int yBlocks = (img_height + dimBlock.y - 1) / dimBlock.y; + dim3 dimGrid = dim3(xBlocks, yBlocks); + + thrust::device_vector nOutput(1, 0); + ValidPointsNormalKernel<<>>(normal_tex_obj, vert_tex_obj, img_width, img_height, tri_pos.data().get(), + tri_total.data().get(), output_normals, output_xyz, nOutput.data().get()); + cudaKernelCheck + + return nOutput[0]; +} + +void ComputeNormalizationParameters( + pangolin::Geometry &geom, + const float buffer, float2& ub_x, float2& ub_y, float2& ub_z) { + + float xMin = 1000000, xMax = -1000000, yMin = 1000000, yMax = -1000000, zMin = 1000000, + zMax = -1000000; + + pangolin::Image vertices = + pangolin::get>(geom.buffers["geometry"].attributes["vertex"]); + + const std::size_t numVertices = vertices.h; + + ///////// Only consider vertices that were used in some face + std::vector verticesUsed(numVertices, 0); + // turn to true if the vertex is used + for (const auto &object : geom.objects) { + auto itVertIndices = object.second.attributes.find("vertex_indices"); + if (itVertIndices != object.second.attributes.end()) { + pangolin::Image ibo = + pangolin::get>(itVertIndices->second); + + for (uint i = 0; i < ibo.h; ++i) { + for (uint j = 0; j < 3; ++j) { + verticesUsed[ibo(j, i)] = 1; + } + } + } + } + ///////// + + // compute min max in each dimension + for (size_t i = 0; i < numVertices; i++) { + // pass when it's not used. + if (verticesUsed[i] == 0) + continue; + xMin = fmin(xMin, vertices(0, i)); + yMin = fmin(yMin, vertices(1, i)); + zMin = fmin(zMin, vertices(2, i)); + xMax = fmax(xMax, vertices(0, i)); + yMax = fmax(yMax, vertices(1, i)); + zMax = fmax(zMax, vertices(2, i)); + } + + float center_x = (xMax + xMin) / 2.0f; + float center_y = (yMax + yMin) / 2.0f; + float center_z = (zMax + zMin) / 2.0f; + + float size_x = (xMax - xMin) + buffer; + float size_y = (yMax - yMin) + buffer; + float size_z = (zMax - zMin) + buffer; + + ub_x = make_float2(center_x - size_x / 2.0f, center_x + size_x / 2.0f); + ub_y = make_float2(center_y - size_y / 2.0f, center_y + size_y / 2.0f); + ub_z = make_float2(center_z - size_z / 2.0f, center_z + size_z / 2.0f); +} + +__device__ int lower_bound(const float* __restrict__ A, float val, int n) { + int l = 0; + int h = n; + while (l < h) { + int mid = l + (h - l) / 2; + if (val <= A[mid]) { + h = mid; + } else { + l = mid + 1; + } + } + return l; +} + +void LinearizeObject(pangolin::Geometry& geom) { + int total_num_faces = 0; + for (const auto &object : geom.objects) { + auto it_vert_indices = object.second.attributes.find("vertex_indices"); + if (it_vert_indices != object.second.attributes.end()) { + pangolin::Image ibo = + pangolin::get>(it_vert_indices->second); + + total_num_faces += ibo.h; + } + } + + // const int total_num_indices = total_num_faces * 3; + pangolin::ManagedImage new_buffer(3 * sizeof(uint32_t), total_num_faces); + + pangolin::Image new_ibo = + new_buffer.UnsafeReinterpret().SubImage(0, 0, 3, total_num_faces); + + int index = 0; + + for (const auto &object : geom.objects) { + auto it_vert_indices = object.second.attributes.find("vertex_indices"); + if (it_vert_indices != object.second.attributes.end()) { + pangolin::Image ibo = + pangolin::get>(it_vert_indices->second); + + for (int i = 0; i < ibo.h; ++i) { + new_ibo.Row(index).CopyFrom(ibo.Row(i)); + ++index; + } + } + } + + geom.objects.clear(); + auto faces = geom.objects.emplace(std::string("mesh"), pangolin::Geometry::Element()); + + faces->second.Reinitialise(3 * sizeof(uint32_t), total_num_faces); + + faces->second.CopyFrom(new_buffer); + + new_ibo = faces->second.UnsafeReinterpret().SubImage(0, 0, 3, total_num_faces); + faces->second.attributes["vertex_indices"] = new_ibo; +} + +__global__ void TriangleAreaKernel(unsigned char* __restrict__ vertices, size_t vertices_pitch, + unsigned char* __restrict__ triangles, size_t triangles_pitch, size_t num_tris, float* __restrict__ areas) { + unsigned int tri_id = blockIdx.x * blockDim.x + threadIdx.x; + if (tri_id >= num_tris) { + return; + } + auto* inds = (uint32_t*) (triangles + triangles_pitch * tri_id); + + auto* ap = (float*) (vertices + vertices_pitch * inds[0]); + auto* bp = (float*) (vertices + vertices_pitch * inds[1]); + auto* cp = (float*) (vertices + vertices_pitch * inds[2]); + + Eigen::Vector3f a(ap[0], ap[1], ap[2]); + Eigen::Vector3f b(bp[0], bp[1], bp[2]); + Eigen::Vector3f c(cp[0], cp[1], cp[2]); + + const Eigen::Vector3f ab = b - a; + const Eigen::Vector3f ac = c - a; + float abnorm = ab.norm(), acnorm = ac.norm(); + float costheta = ab.dot(ac) / (abnorm * acnorm); + + if (costheta < -1) // meaning theta is pi + costheta = cos(static_cast(M_PI) * 359.f / 360); + else if (costheta > 1) // meaning theta is zero + costheta = cos(static_cast(M_PI) * 1.f / 360); + + const float sinTheta = sqrt(1 - costheta * costheta); + + float area = 0.5f * abnorm * acnorm * sinTheta; + if (isnan(area)) { + area = 0.0f; + } + + areas[tri_id] = area; +} + +__global__ void RNGSetupKernel(curandState *state, size_t num_kernel) { + unsigned int id = threadIdx.x + blockIdx.x * blockDim.x; + if (id >= num_kernel) { + return; + } + /* Each thread gets same seed, a different sequence + number, no offset */ + curand_init(1000, id, 0, &state[id]); +} diff --git a/jittor/sampler_cuda/Utils.h b/jittor/sampler_cuda/Utils.h new file mode 100644 index 0000000..ede9fec --- /dev/null +++ b/jittor/sampler_cuda/Utils.h @@ -0,0 +1,54 @@ +// Copyright 2004-present Facebook. All Rights Reserved. +// The GPU version is drafted by heiwang1997@github.com + +#include + +#include +#include +#include +#include +#include +#include + +std::vector EquiDistPointsOnSphere(const uint numSamples, const float radius); + +unsigned int ValidPointsNormalCUDA(cudaArray_t normals, cudaArray_t verts, unsigned int img_width, unsigned int img_height, + thrust::device_vector& tri_pos, thrust::device_vector& tri_total, + float4* output_normals, float4* output_xyz); + +__device__ int lower_bound(const float* __restrict__ A, float val, int n); + +__global__ void TriangleAreaKernel(unsigned char* __restrict__ vertices, size_t vertices_pitch, + unsigned char* __restrict__ triangles, size_t triangles_pitch, size_t num_tris, float* __restrict__ areas); + +__global__ void RNGSetupKernel(curandState *state, size_t num_kernel); + +void ComputeNormalizationParameters( + pangolin::Geometry& geom, + const float buffer, float2& ub_x, float2& ub_y, float2& ub_z); + +void LinearizeObject(pangolin::Geometry& geom); + +#ifndef cudaSafeCall +#define cudaSafeCall(err) __cudaSafeCall(err, __FILE__, __LINE__) + +inline void __cudaSafeCall( cudaError err, const char *file, const int line ) +{ + if( cudaSuccess != err) { + printf("%s(%i) : cudaSafeCall() Runtime API error : %s.\n", + file, line, cudaGetErrorString(err) ); + exit(-1); + } +} + +#endif + +#ifndef cudaKernelCheck + +// For normal operation +#define cudaKernelCheck + +// For debugging purposes +//#define cudaKernelCheck { cudaDeviceSynchronize(); __cudaSafeCall(cudaPeekAtLastError(), __FILE__, __LINE__); } + +#endif \ No newline at end of file diff --git a/jittor/train.py b/jittor/train.py new file mode 100644 index 0000000..95da43a --- /dev/null +++ b/jittor/train.py @@ -0,0 +1,141 @@ +import json +import jittor as jt +import logging +import shutil +from pathlib import Path + +import lr_schedule +import tqdm +import yaml +from tensorboardX import SummaryWriter + +import lif_dataset as ldata +import criterion +from utils import exp_util +from network import DIDecoder, DIEncoder + + +if jt.has_cuda: + jt.flags.use_cuda = 1 + + +class TensorboardViz(object): + + def __init__(self, logdir): + self.logdir = logdir + self.writter = SummaryWriter(self.logdir) + + def text(self, _text): + # Enhance line break and convert to code blocks + _text = _text.replace('\n', ' \n\t') + self.writter.add_text('Info', _text) + + def update(self, mode, it, eval_dict): + self.writter.add_scalars(mode, eval_dict, global_step=it) + + def flush(self): + self.writter.flush() + + +parser = exp_util.ArgumentParserX(add_hyper_arg=True) + + +def main(): + logging.basicConfig(level=logging.INFO) + args = parser.parse_args() + logging.info(args) + + checkpoints = list(range(args.snapshot_frequency, args.num_epochs + 1, args.snapshot_frequency)) + for checkpoint in args.additional_snapshots: + checkpoints.append(checkpoint) + checkpoints.sort() + + lr_schedules = lr_schedule.get_learning_rate_schedules(args) + + model = DIDecoder() + encoder = DIEncoder() + + lif_loader = ldata.LifDataset(**args.train_set[0], num_sample=args.samples_per_lif, batch_size=args.batch_size) + + loss_func_args = exp_util.dict_to_args(args.training_loss) + loss_funcs = [ + getattr(criterion, t) for t in loss_func_args.types + ] + + optimizer_all = jt.nn.Adam(model.parameters() + encoder.parameters(), lr=lr_schedules[0].get_learning_rate(0)) + + save_base_dir = Path("../di-checkpoints/%s" % args.run_name) + shutil.rmtree(save_base_dir, ignore_errors=True) + save_base_dir.mkdir(parents=True, exist_ok=True) + + viz = TensorboardViz(logdir=str(save_base_dir / 'tensorboard')) + viz.text(yaml.dump(vars(args))) + with (save_base_dir / "hyper.json").open("w") as f: + json.dump(vars(args), f, indent=2) + + start_epoch = 1 + epoch_bar = tqdm.trange(start_epoch, args.num_epochs + 1, desc='epochs') + + it = 0 + for epoch in epoch_bar: + model.train() + encoder.train() + lr_schedule.adjust_learning_rate(lr_schedules, optimizer_all, epoch) + train_meter = exp_util.AverageMeter() + train_running_meter = exp_util.RunningAverageMeter(alpha=0.3) + batch_bar = tqdm.tqdm(total=len(lif_loader), leave=False, desc='train') + + for sdf_data, surface_data, idx in lif_loader: + # Process the input data + sdf_data = sdf_data.reshape(-1, sdf_data.size(-1)) + surface_data = surface_data # (B, N, 6) + + num_sdf_samples = sdf_data.shape[0] + xyz = sdf_data[:, 0:3] + sdf_gt = sdf_data[:, 3:] + + optimizer_all.zero_grad() + + lat_vecs = encoder(surface_data) # (B, L) + lat_vecs = lat_vecs.unsqueeze(1).repeat(1, args.samples_per_lif, 1).view(-1, lat_vecs.size(-1)) # (BxS, L) + + net_input = jt.concat([lat_vecs, xyz], dim=1) + pred_sdf, pred_sdf_std = model(net_input) + + loss_dict = {} + for loss_func in loss_funcs: + loss_dict.update(loss_func( + args=loss_func_args, pd_sdf=pred_sdf, pd_sdf_std=pred_sdf_std, gt_sdf=sdf_gt, + latent_vecs=lat_vecs, coords=xyz, + info={"num_sdf_samples": num_sdf_samples, "epoch": epoch} + )) + loss_sum = sum(loss_dict.values()) + loss_res = {"value": loss_sum.item()} + optimizer_all.step(loss_sum) + + batch_bar.update() + train_running_meter.append_loss(loss_res) + batch_bar.set_postfix(train_running_meter.get_loss_dict()) + epoch_bar.refresh() + it += 1 + + if it % 10 == 0: + for loss_name, loss_val in loss_res.items(): + viz.update('train/' + loss_name, it, {'scalar': loss_val}) + train_meter.append_loss(loss_res) + + batch_bar.close() + + train_avg = train_meter.get_mean_loss_dict() + for meter_key, meter_val in train_avg.items(): + viz.update("epoch_sum/" + meter_key, epoch, {'train': meter_val}) + for sid, schedule in enumerate(lr_schedules): + viz.update(f"train_stat/lr_{sid}", epoch, {'scalar': schedule.get_learning_rate(epoch)}) + + if epoch in checkpoints: + model.save(str(save_base_dir / f"model_{epoch}.jt.tar")) + encoder.save(str(save_base_dir / f"encoder_{epoch}.jt.tar")) + + +if __name__ == '__main__': + main() diff --git a/jittor/train.yaml b/jittor/train.yaml new file mode 100644 index 0000000..e6b8787 --- /dev/null +++ b/jittor/train.yaml @@ -0,0 +1,28 @@ +run_name: "default" + +num_epochs: 300 +batch_size: 64 +samples_per_lif: 4096 +min_context_points: 16 + +lr_schedule: + - { "Type" : "Step", "Initial" : 0.001, "Interval" : 80, "Factor" : 0.4 } + +# Dataset. +train_set: + - { "data_path": "../di-datasets/shapenet_plivoxs", "augment_rotation": 'Y', "num_surface_sample": 128, "augment_noise": [0.025, 40.0] } + +# Code specification +code_bound: null +code_length: 29 + +# Snapshots saving parameters +snapshot_frequency: 100 +additional_snapshots: [] + +# SDF samples +training_loss: + types: [ "neg_log_likelihood", "reg_loss" ] + enforce_minmax: true + clamping_distance: 0.2 + code_reg_lambda: 1.0e-2 diff --git a/jittor/utils/exp_util.py b/jittor/utils/exp_util.py new file mode 100644 index 0000000..b16ae11 --- /dev/null +++ b/jittor/utils/exp_util.py @@ -0,0 +1,228 @@ +import argparse +from pathlib import Path +import numpy as np +import sys +import json +import yaml +import random +import pickle +from collections import defaultdict, OrderedDict + + +def parse_config_json(json_path: Path, args: argparse.Namespace = None): + """ + Parse a json file and add key:value to args namespace. + Json file format [ {attr}, {attr}, ... ] + {attr} = { "_": COMMENT, VAR_NAME: VAR_VALUE } + """ + if args is None: + args = argparse.Namespace() + + with json_path.open() as f: + json_text = f.read() + + try: + raw_configs = json.loads(json_text) + except: + # Do some fixing of the json text + json_text = json_text.replace("\'", "\"") + json_text = json_text.replace("None", "null") + json_text = json_text.replace("False", "false") + json_text = json_text.replace("True", "true") + raw_configs = json.loads(json_text) + + if isinstance(raw_configs, dict): + raw_configs = [raw_configs] + configs = {} + for raw_config in raw_configs: + for rkey, rvalue in raw_config.items(): + if rkey != "_": + configs[rkey] = rvalue + + if configs is not None: + for ckey, cvalue in configs.items(): + args.__dict__[ckey] = cvalue + return args + + +def parse_config_yaml(yaml_path: Path, args: argparse.Namespace = None, override: bool = True): + """ + Parse a yaml file and add key:value to args namespace. + """ + if args is None: + args = argparse.Namespace() + with yaml_path.open() as f: + configs = yaml.load(f, Loader=yaml.FullLoader) + if configs is not None: + if "include_configs" in configs.keys(): + base_config = configs["include_configs"] + del configs["include_configs"] + base_config_path = yaml_path.parent / Path(base_config) + with base_config_path.open() as f: + base_config = yaml.load(f, Loader=yaml.FullLoader) + base_config.update(configs) + configs = base_config + for ckey, cvalue in configs.items(): + if override or ckey not in args.__dict__.keys(): + args.__dict__[ckey] = cvalue + return args + + +def dict_to_args(data: dict): + args = argparse.Namespace() + for ckey, cvalue in data.items(): + args.__dict__[ckey] = cvalue + return args + + +class ArgumentParserX(argparse.ArgumentParser): + def __init__(self, base_config_path=None, add_hyper_arg=True, **kwargs): + super().__init__(**kwargs) + self.add_hyper_arg = add_hyper_arg + self.base_config_path = base_config_path + if self.add_hyper_arg: + self.add_argument('hyper', type=str, help='Path to the yaml parameter') + self.add_argument('--exec', type=str, help='Extract code to modify the args') + + def parse_args(self, args=None, namespace=None): + # Parse arg for the first time to extract args defined in program. + _args = self.parse_known_args(args, namespace)[0] + # Add the types needed. + file_args = argparse.Namespace() + if self.base_config_path is not None: + file_args = parse_config_yaml(Path(self.base_config_path), file_args) + if self.add_hyper_arg: + if _args.hyper.endswith("json"): + file_args = parse_config_json(Path(_args.hyper), file_args) + else: + file_args = parse_config_yaml(Path(_args.hyper), file_args) + for ckey, cvalue in file_args.__dict__.items(): + try: + self.add_argument('--' + ckey, type=type(cvalue), default=cvalue, required=False) + except argparse.ArgumentError: + continue + # Parse args fully to extract all useful information + _args = super().parse_args(args, namespace) + # After that, execute exec part. + exec_code = _args.exec + if exec_code is not None: + for exec_cmd in exec_code.split(";"): + exec_cmd = "_args." + exec_cmd.strip() + exec(exec_cmd) + return _args + + +class AverageMeter: + def __init__(self): + self.loss_dict = OrderedDict() + + def export(self, f): + if isinstance(f, str): + f = open(f, 'wb') + pickle.dump(self.loss_dict, f) + + def load(self, f): + if isinstance(f, str): + f = open(f, 'rb') + self.loss_dict = pickle.load(f) + return self + + def append_loss(self, losses): + for loss_name, loss_val in losses.items(): + if loss_val is None: + continue + loss_val = float(loss_val) + if np.isnan(loss_val): + continue + if loss_name not in self.loss_dict.keys(): + self.loss_dict.update({loss_name: [loss_val]}) + else: + self.loss_dict[loss_name].append(loss_val) + + def get_mean_loss_dict(self): + loss_dict = {} + for loss_name, loss_arr in self.loss_dict.items(): + loss_dict[loss_name] = np.mean(loss_arr) + return loss_dict + + def get_mean_loss(self): + mean_loss_dict = self.get_mean_loss_dict() + if len(mean_loss_dict) == 0: + return 0.0 + else: + return sum(mean_loss_dict.values()) / len(mean_loss_dict) + + def get_printable_mean(self): + text = "" + all_loss_sum = 0.0 + for loss_name, loss_mean in self.get_mean_loss_dict().items(): + all_loss_sum += loss_mean + text += "(%s:%.4f) " % (loss_name, loss_mean) + text += " sum = %.4f" % all_loss_sum + return text + + def get_newest_loss_dict(self, return_count=False): + loss_dict = {} + loss_count_dict = {} + for loss_name, loss_arr in self.loss_dict.items(): + if len(loss_arr) > 0: + loss_dict[loss_name] = loss_arr[-1] + loss_count_dict[loss_name] = len(loss_arr) + if return_count: + return loss_dict, loss_count_dict + else: + return loss_dict + + def get_printable_newest(self): + nloss_val, nloss_count = self.get_newest_loss_dict(return_count=True) + return ", ".join([f"{loss_name}[{nloss_count[loss_name] - 1}]: {nloss_val[loss_name]}" + for loss_name in nloss_val.keys()]) + + def print_format_loss(self, color=None): + if hasattr(sys.stdout, "terminal"): + color_device = sys.stdout.terminal + else: + color_device = sys.stdout + if color == "y": + color_device.write('\033[93m') + elif color == "g": + color_device.write('\033[92m') + elif color == "b": + color_device.write('\033[94m') + print(self.get_printable_mean(), flush=True) + if color is not None: + color_device.write('\033[0m') + + +class RunningAverageMeter: + def __init__(self, alpha=1.0): + self.alpha = alpha + self.loss_dict = OrderedDict() + + def append_loss(self, losses): + for loss_name, loss_val in losses.items(): + if loss_val is None: + continue + loss_val = float(loss_val) + if np.isnan(loss_val): + continue + if loss_name not in self.loss_dict.keys(): + self.loss_dict.update({loss_name: loss_val}) + else: + old_mean = self.loss_dict[loss_name] + self.loss_dict[loss_name] = self.alpha * old_mean + (1 - self.alpha) * loss_val + + def get_loss_dict(self): + return {k: v for k, v in self.loss_dict.items()} + + +def init_seed(seed=0): + random.seed(seed) + np.random.seed(seed) + # According to https://pytorch.org/docs/stable/notes/randomness.html, + # As pytorch run-to-run reproducibility is not guaranteed, and perhaps will lead to performance degradation, + # We do not use manual seed for training. + # This would influence stochastic network layers but will not influence data generation and processing w/o pytorch. + # torch.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False diff --git a/jittor/utils/motion_util.py b/jittor/utils/motion_util.py new file mode 100644 index 0000000..2cef7f9 --- /dev/null +++ b/jittor/utils/motion_util.py @@ -0,0 +1,142 @@ +import numpy as np +from pyquaternion import Quaternion + + +def project_orthogonal(rot): + u, s, vh = np.linalg.svd(rot, full_matrices=True, compute_uv=True) + rot = u @ vh + if np.linalg.det(rot) < 0: + u[:, 2] = -u[:, 2] + rot = u @ vh + return rot + + +class Isometry: + GL_POST_MULT = Quaternion(degrees=180.0, axis=[1.0, 0.0, 0.0]) + + def __init__(self, q=None, t=None): + if q is None: + q = Quaternion() + if t is None: + t = np.zeros(3) + if not isinstance(t, np.ndarray): + t = np.asarray(t) + assert t.shape[0] == 3 and t.ndim == 1 + self.q = q + self.t = t + + def __repr__(self): + return f"Isometry: t = {self.t}, q = {self.q}" + + @property + def rotation(self): + return Isometry(q=self.q) + + @property + def matrix(self): + mat = self.q.transformation_matrix + mat[0:3, 3] = self.t + return mat + + @staticmethod + def from_matrix(mat, t_component=None, ortho=False): + assert isinstance(mat, np.ndarray) + if t_component is None: + assert mat.shape == (4, 4) + if ortho: + mat[:3, :3] = project_orthogonal(mat[:3, :3]) + return Isometry(q=Quaternion(matrix=mat), t=mat[:3, 3]) + else: + assert mat.shape == (3, 3) + assert t_component.shape == (3,) + if ortho: + mat = project_orthogonal(mat) + return Isometry(q=Quaternion(matrix=mat), t=t_component) + + @property + def continuous_repr(self): + rot = self.q.rotation_matrix[:, 0:2].T.flatten() # (6,) + return np.concatenate([rot, self.t]) # (9,) + + @staticmethod + def from_continuous_repr(rep, gs=True): + if isinstance(rep, list): + rep = np.asarray(rep) + assert isinstance(rep, np.ndarray) + assert rep.shape == (9,) + # For rotation, use Gram-Schmidt orthogonalization + col1 = rep[0:3] + col2 = rep[3:6] + if gs: + col1 /= np.linalg.norm(col1) + col2 = col2 - np.dot(col1, col2) * col1 + col2 /= np.linalg.norm(col2) + col3 = np.cross(col1, col2) + return Isometry(q=Quaternion(matrix=np.column_stack([col1, col2, col3])), t=rep[6:9]) + + @property + def full_repr(self): + rot = self.q.rotation_matrix.T.flatten() + return np.concatenate([rot, self.t]) + + @staticmethod + def from_full_repr(rep, ortho=False): + assert isinstance(rep, np.ndarray) + assert rep.shape == (12,) + rot = rep[0:9].reshape(3, 3).T + if ortho: + rot = project_orthogonal(rot) + return Isometry(q=Quaternion(matrix=rot), t=rep[9:12]) + + @staticmethod + def random(): + return Isometry(q=Quaternion.random(), t=np.random.random((3,))) + + def inv(self): + qinv = self.q.inverse + return Isometry(q=qinv, t=-(qinv.rotate(self.t))) + + def dot(self, right): + return Isometry(q=(self.q * right.q), t=(self.q.rotate(right.t) + self.t)) + + def to_gl_camera(self): + return Isometry(q=(self.q * self.GL_POST_MULT), t=self.t) + + @staticmethod + def look_at(source: np.ndarray, target: np.ndarray, up: np.ndarray = None): + z_dir = target - source + z_dir /= np.linalg.norm(z_dir) + if up is None: + up = np.asarray([0.0, 1.0, 0.0]) + if np.linalg.norm(np.cross(z_dir, up)) < 1e-6: + up = np.asarray([1.0, 0.0, 0.0]) + else: + up /= np.linalg.norm(up) + x_dir = np.cross(z_dir, up) + x_dir /= np.linalg.norm(x_dir) + y_dir = np.cross(z_dir, x_dir) + R = np.column_stack([x_dir, y_dir, z_dir]) + return Isometry(q=Quaternion(matrix=R), t=source) + + def tangent(self, prev_iso, next_iso): + t = 0.5 * (next_iso.t - prev_iso.t) + l1 = Quaternion.log((self.q.inverse * prev_iso.q).normalised) + l2 = Quaternion.log((self.q.inverse * next_iso.q).normalised) + e = Quaternion() + e.q = -0.25 * (l1.q + l2.q) + e = self.q * Quaternion.exp(e) + return Isometry(t=t, q=e) + + def __matmul__(self, other): + if isinstance(other, Isometry): + return self.dot(other) + if type(other) != np.ndarray or other.ndim == 1: + return self.q.rotate(other) + self.t + else: + return other @ self.q.rotation_matrix.T + self.t[np.newaxis, :] + + @staticmethod + def interpolate(source, target, alpha): + iquat = Quaternion.slerp(source.q, target.q, alpha) + it = source.t * (1 - alpha) + target.t * alpha + return Isometry(q=iquat, t=it)