Skip to content

Commit

Permalink
fix missing imports & add more warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
dendenxu committed Apr 12, 2024
1 parent c42f15e commit 93d7d35
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 3 deletions.
182 changes: 182 additions & 0 deletions fast_gauss/base_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations
from copy import copy
from typing import Mapping, TypeVar, Union, Iterable, Callable, Dict, List
# these are generic type vars to tell mapping to accept any type vars when creating a type
KT = TypeVar("KT") # key type
VT = TypeVar("VT") # value type

# TODO: move this to engine implementation
# TODO: this is a special type just like Config
# ? However, dotdict is a general purpose data passing object, instead of just designed for config
# The only reason we defined those special variables are for type annotations
# If removed, all will still work flawlessly, just no editor annotation for output, type and meta


def return_dotdict(func: Callable):
def inner(*args, **kwargs):
return dotdict(func(*args, **kwargs))
return inner


class DoNothing:
def __getattr__(self, name):
def method(*args, **kwargs):
pass
return method


class dotdict(dict, Dict[KT, VT]):
"""
This is the default data passing object used throughout the codebase
Main function: dot access for dict values & dict like merging and updates
a dictionary that supports dot notation
as well as dictionary access notation
usage: d = make_dotdict() or d = make_dotdict{'val1':'first'})
set attributes: d.val2 = 'second' or d['val2'] = 'second'
get attributes: d.val2 or d['val2']
"""

def update(self, dct: Dict = None, **kwargs):
dct = copy(dct) # avoid modifying the original dict, use super's copy to avoid recursion

# Handle different arguments
if dct is None:
dct = kwargs
elif isinstance(dct, Mapping):
dct.update(kwargs)
else:
super().update(dct, **kwargs)
return

# Recursive updates
for k, v in dct.items():
if k in self:

# Handle type conversions
target_type = type(self[k])
if not isinstance(v, target_type):
# NOTE: bool('False') will be True
if target_type == bool and isinstance(v, str):
dct[k] = v == 'True'
else:
dct[k] = target_type(v)

if isinstance(v, dict):
self[k].update(v) # recursion from here
else:
self[k] = v
else:
if isinstance(v, dict):
self[k] = dotdict(v) # recursion?
elif isinstance(v, list):
self[k] = [dotdict(x) if isinstance(x, dict) else x for x in v]
else:
self[k] = v
return self

def __init__(self, *args, **kwargs):
self.update(*args, **kwargs)

copy = return_dotdict(dict.copy)
fromkeys = return_dotdict(dict.fromkeys)

# def __hash__(self):
# # return hash(''.join([str(self.values().__hash__())]))
# return super(dotdict, self).__hash__()

# def __init__(self, *args, **kwargs):
# super(dotdict, self).__init__(*args, **kwargs)

"""
Uncomment following lines and
comment out __getattr__ = dict.__getitem__ to get feature:
returns empty numpy array for undefined keys, so that you can easily copy things around
TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32)
"""

def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError as e:
raise AttributeError(e)
# MARK: Might encounter exception in newer version of pytorch
# Traceback (most recent call last):
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 245, in _feed
# obj = _ForkingPickler.dumps(obj)
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
# cls(buf, protocol).dump(obj)
# KeyError: '__getstate__'
# MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception.
# FIXME: not working typing hinting code
__getattr__: Callable[..., 'torch.Tensor'] = __getitem__ # type: ignore # overidden dict.__getitem__
__getattribute__: Callable[..., 'torch.Tensor'] # type: ignore
# __getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

# TODO: better ways to programmically define these special variables?

@property
def meta(self) -> dotdict:
# Special variable used for storing cpu tensor in batch
if 'meta' not in self:
self.meta = dotdict()
return self.__getitem__('meta')

@meta.setter
def meta(self, meta):
self.__setitem__('meta', meta)

@property
def output(self) -> dotdict: # late annotation needed for this
# Special entry for storing output tensor in batch
if 'output' not in self:
self.output = dotdict()
return self.__getitem__('output')

@output.setter
def output(self, output):
self.__setitem__('output', output)

@property
def persistent(self) -> dotdict: # late annotation needed for this
# Special entry for storing persistent tensor in batch
if 'persistent' not in self:
self.persistent = dotdict()
return self.__getitem__('persistent')

@persistent.setter
def persistent(self, persistent):
self.__setitem__('persistent', persistent)

@property
def type(self) -> str: # late annotation needed for this
# Special entry for type based construction system
return self.__getitem__('type')

@type.setter
def type(self, type):
self.__setitem__('type', type)

def to_dict(self):
out = dict()
for k, v in self.items():
if isinstance(v, dotdict):
v = v.to_dict() # recursion point
out[k] = v
return out


class default_dotdict(dotdict):
def __init__(self, default_type=object, *arg, **kwargs):
super().__init__(*arg, **kwargs)
dict.__setattr__(self, 'default_type', default_type)

def __getitem__(self, key):
try:
return super().__getitem__(key)
except (AttributeError, KeyError) as e:
super().__setitem__(key, dict.__getattribute__(self, 'default_type')())
return super().__getitem__(key)
10 changes: 9 additions & 1 deletion fast_gauss/console_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Colors:
from rich.progress import BarColumn, TimeElapsedColumn, TimeRemainingColumn, filesize, ProgressColumn
from tqdm.std import tqdm as std_tqdm
from tqdm.rich import tqdm_rich, FractionColumn, RateColumn
from easyvolcap.utils.base_utils import default_dotdict, dotdict, DoNothing
from .base_utils import default_dotdict, dotdict, DoNothing

pdbr_theme = 'ansi_dark'
pdbr.utils.set_traceback(pdbr_theme)
Expand Down Expand Up @@ -787,3 +787,11 @@ def build_parser(d: dict, parser: argparse.ArgumentParser = None, **kwargs):
parser.add_argument(f'--{k}', type=type(v), default=v, help=markup_to_ansi(help_pattern.format(v)))

return parser


def warn_once(message: str):
if not hasattr(warn_once, 'warned'):
warn_once.warned = set()
if message not in warn_once.warned:
log(yellow(message))
warn_once.warned.add(message)
20 changes: 19 additions & 1 deletion fast_gauss/gsplat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def __init__(self,
self.resize_buffers(init_buffer_size)
self.resize_textures(*init_texture_size)

log(green_slim(f'GSplatContextManager initialized with attribute dtype: {self.dtype}, texture dtype: {self.tex_dtype}, offline rendering: {self.offline_rendering}, buffer size: {init_buffer_size}, texture size: {init_texture_size}'))

if not self.offline_rendering:
log(green_slim('Using online rendering mode, in this mode, calling the rendering function of fast_gauss will write directly to the currently bound framebuffer'))
log(green_slim('In this mode, the output of all rasterization calls will be None (same output count). Please do not perform further processing on them.'))
log(green_slim('Please make sure to set up the correct GUI environment before calling the rasterization function, see more in readme.md'))

def opengl_options(self):
# Performs face culling
gl.glDisable(gl.GL_CULL_FACE)
Expand Down Expand Up @@ -220,6 +227,13 @@ def resize_buffers(self, v: int = 0):

@torch.no_grad()
def render(self, xyz3: torch.Tensor, cov6: torch.Tensor, rgb3: torch.Tensor, occ1: torch.Tensor, raster_settings: 'GaussianRasterizationSettings'):
if xyz3.dtype != self.dtype:
warn_once(yellow(f'Input tensors has dtype {xyz3.dtype}, expected {self.dtype}, will cast to {self.dtype}'))
xyz3, cov6, rgb3, occ1 = xyz3.to(self.dtype), cov6.to(self.dtype), rgb3.to(self.dtype), occ1.to(self.dtype)
for key in raster_settings:
if isinstance(raster_settings[key], torch.Tensor):
raster_settings[key] = raster_settings[key].to(self.dtype)

# Prepare OpenGL texture size
H, W = raster_settings.image_height, raster_settings.image_width
self.resize_textures(H, W)
Expand All @@ -237,7 +251,7 @@ def render(self, xyz3: torch.Tensor, cov6: torch.Tensor, rgb3: torch.Tensor, occ

# Upload sorted data to OpenGL for rendering
from cuda import cudart
from easyvolcap.utils.cuda_utils import CHECK_CUDART_ERROR, FORMAT_CUDART_ERROR
from .cuda_utils import CHECK_CUDART_ERROR, FORMAT_CUDART_ERROR
kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice

CHECK_CUDART_ERROR(cudart.cudaGraphicsMapResources(1, self.cu_vbo, torch.cuda.current_stream().cuda_stream))
Expand Down Expand Up @@ -293,6 +307,10 @@ def render(self, xyz3: torch.Tensor, cov6: torch.Tensor, rgb3: torch.Tensor, occ
torch.cuda.current_stream().cuda_stream)) # stream
CHECK_CUDART_ERROR(cudart.cudaGraphicsUnmapResources(1, cu_tex, torch.cuda.current_stream().cuda_stream))

if rgba_map.dtype != xyz3.dtype:
warn_once(yellow(f'Using texture dtype {rgba_map.dtype}, expected {xyz3.dtype} for the output, will cast to {xyz3.dtype}'))
rgba_map = rgba_map.to(xyz3.dtype)

return rgba_map # H, W, 4
else:
return None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fast_gauss"
version = "0.0.5"
version = "0.0.6"
description = "A geometry-shader-based, global CUDA sorted high-performance 3D Gaussian Splatting rasterizer. Can achieve a 5-10x speedup in rendering compared to the vanialla diff-gaussian-rasterization."
readme = "readme.md"
license = { file = "license" }
Expand Down

0 comments on commit 93d7d35

Please sign in to comment.