In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
import tqdm

import raymarching2

rays_o = (torch.zeros((2000, 3)) + 0.1).to("cuda")
rays_d = torch.randn((2000, 3)).to("cuda")
rays_d = F.normalize(rays_d, dim=-1)

density_bitfield = (torch.ones(
    (5, 128 ** 3 // 8), dtype=torch.uint8
) * 255).to("cuda")

aabb = torch.tensor([0., 0., 0., 1., 1., 1.]).to("cuda")
torch.cuda.synchronize()

sigmas = torch.rand((2000, 1024), device=rays_o.device, requires_grad=True)
rgbs = torch.rand((2000, 1024, 3), device=rays_o.device, requires_grad=True)
bkgd_rgb = torch.rand(3).to("cuda")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torch.profiler import profile, record_function, ProfilerActivity

In [10]:
indices, positions, dirs, deltas, ts = raymarching2.generate_training_samples(
    rays_o, rays_d, aabb, density_bitfield, 1024 * rays_o.shape[0]
)

sigmas_collector = []
rgbs_collector = []
for ray_id, sample_id, sample_cnt in indices:
    sigmas_collector.append(
        sigmas[ray_id, 0: sample_cnt]
    )
    rgbs_collector.append(
        rgbs[ray_id, 0: sample_cnt]
    )
sigmas_collector = torch.cat(sigmas_collector)
rgbs_collector = torch.cat(rgbs_collector)

sigmas_collector = sigmas_collector.detach().clone().requires_grad_(True)
rgbs_collector = rgbs_collector.detach().clone().requires_grad_(True)

for _ in tqdm.tqdm(range(1000)):
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("volumetric_rendering"):
            indices, positions, dirs, deltas, ts = raymarching2.generate_training_samples(
                rays_o, rays_d, aabb, density_bitfield, 1024 * rays_o.shape[0]
            )
            (
                accumulated_weight, 
                accumulated_depth, 
                accumulated_color, 
                accumulated_position
            ) = raymarching2.volumetric_rendering(
                indices, positions, deltas, ts,
                sigmas_collector, rgbs_collector,
                bkgd_rgb
            )
            accumulated_color.sum().backward()
    torch.cuda.synchronize()

print (prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
print ("accumulated_color", accumulated_color.sum())
print ("grad sigmas", sigmas_collector.grad.sum())
print ("grad rgbs", rgbs_collector.grad.sum())

_ = sigmas_collector.grad.zero_()
_ = rgbs_collector.grad.zero_()



100%|██████████| 1000/1000 [00:08<00:00, 122.06it/s]

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void volumetric_rendering_kernel<float>(unsigned int...         0.00%       0.000us         0.00%       0.000us       0.000us       2.251ms        35.94%       2.251ms       2.251ms             1  
void kernel_generate_training_samples<float>(unsigne...         0.00%       0.000us         0.00%       0.000us       0.000us       2.081ms        33.22%       2.081ms       2.081ms             1  
void volu




In [11]:
def generate_training_samples(
    rays_o: torch.Tensor, 
    rays_d: torch.Tensor, 
    aabb: torch.Tensor, 
    max_samples: int = 10_000,
):
    device = rays_o.device
    NERF_STEPS = 1024
    STEPSIZE = 1.7320508075688772 / NERF_STEPS

    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)

    t_vals = torch.arange(0.0, NERF_STEPS, device=device) * STEPSIZE
    t_vals = t_vals[None, :].expand((rays_o.shape[0], -1))
    points = rays_o[:, None, :] + rays_d[:, None, :] * t_vals[:, :, None]

    selector = (
        (points[..., 0] >= aabb[0]) &
        (points[..., 1] >= aabb[1]) &
        (points[..., 2] >= aabb[2]) &
        (points[..., 0] <= aabb[3]) &
        (points[..., 1] <= aabb[4]) &
        (points[..., 2] <= aabb[5])
    )
    points[selector]
    return points, selector, t_vals


def volumetric_rendering(rgb, density, t_vals, dirs, color_bkgd):
    """Volumetric Rendering Function.
    Args:
        rgb: torch.ndarray(float32), color, [batch_size, num_samples, 3]
        density: torch.ndarray(float32), density, [batch_size, num_samples, 1].
        t_vals: torch.ndarray(float32), [batch_size, num_samples].
        dirs: torch.ndarray(float32), [batch_size, 3].
        color_bkgd: torch.ndarray(float32), [3].
    Returns:
        comp_rgb: torch.ndarray(float32), [batch_size, 3].
        disp: torch.ndarray(float32), [batch_size].
        acc: torch.ndarray(float32), [batch_size].
        weights: torch.ndarray(float32), [batch_size, num_samples]
    """
    t_dists = torch.cat(
        [
            t_vals[Ellipsis, 1:] - t_vals[Ellipsis, :-1],
            # torch.tensor(
            #     [1e10], dtype=t_vals.dtype, device=t_vals.device
            # ).expand(t_vals[Ellipsis, :1].shape),
            t_vals[Ellipsis, 1:2] - t_vals[Ellipsis, 0:1]
        ],
        -1,
    )
    delta = t_dists * torch.linalg.norm(dirs[Ellipsis, None, :], dim=-1)
    
    # Note that we're quietly turning density from [..., 0] to [...].
    density_delta = density[..., 0] * delta

    alpha = 1 - torch.exp(-density_delta)
    trans = torch.exp(
        -torch.cat(
            [
                torch.zeros_like(density_delta[..., :1]),
                torch.cumsum(density_delta[..., :-1], dim=-1),
            ],
            dim=-1,
        )
    )
    weights = alpha * trans

    comp_rgb = (weights[..., None] * rgb).sum(dim=-2)
    # print ("weights", weights)
    acc = weights.sum(dim=-1)
    # distance = (weights * t_mids).sum(dim=-1) / acc
    # distance = torch.clip(
    #     torch.nan_to_num(distance, torch.finfo().max), t_vals[:, 0], t_vals[:, -1]
    # )
    depth = (weights * t_vals).sum(dim=-1)
    eps = 1e-10
    inv_eps = 1 / eps
    # torch.where accepts <scaler, double tensor>
    disp = (acc / depth).double()
    disp = torch.where(
        (disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps
    )
    disp = disp.to(acc.dtype)

    comp_rgb = comp_rgb + color_bkgd * (1.0 - acc[..., None])
    return comp_rgb, depth, acc, weights

torch.cuda.synchronize()
for _ in tqdm.tqdm(range(1000)):
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("volumetric_rendering"):
            points, selector, t_vals = generate_training_samples(rays_o, rays_d, aabb)
            comp_rgb, depth, acc, weights = volumetric_rendering(
                rgbs, (sigmas * selector).unsqueeze(-1), t_vals, 
                rays_d, bkgd_rgb
            )
            comp_rgb.sum().backward()

    torch.cuda.synchronize()

print ("comp_rgb", comp_rgb.sum())
print ("grad sigmas", sigmas.grad.sum())
print ("grad rgbs", rgbs.grad.sum())
# print (rgbs.grad)
print (prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

_ = sigmas.grad.zero_()
_ = rgbs.grad.zero_()


100%|██████████| 1000/1000 [00:08<00:00, 116.70it/s]

comp_rgb tensor(3263.0479, device='cuda:0', grad_fn=<SumBackward0>)
grad sigmas tensor(-75984.8594, device='cuda:0')
grad rgbs tensor(882290.8750, device='cuda:0')
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     386.000us        15.16%     386.000us      48.250us             8  
void at::native::elementwise_kernel<128, 4,




In [None]:
import torch
import torch.nn.functional as F

import raymarching2


rays_o = (torch.zeros((10, 100, 3)) + 0.1).to("cuda")
rays_d = torch.randn((10, 100, 3)).to("cuda")
rays_d = F.normalize(rays_d, dim=-1)

density_bitfield = (torch.ones(
    (5, 128 ** 3 // 8), dtype=torch.uint8
) * 255).to("cuda")

aabb = torch.tensor([0., 0., 0., 1., 1., 1.]).to("cuda")

positions, dirs, deltas, nears, fars = raymarching2.generate_training_samples(
    rays_o, rays_d, aabb, density_bitfield
)
torch.cuda.synchronize()

In [None]:
from tava.utils.plotly import Trimesh, plot_scene, PointCloud

def aabb_to_mesh(aabb):
    vertices = torch.stack([
        aabb[[0, 1, 2]], 
        aabb[[3, 1, 2]], 
        aabb[[0, 4, 2]], 
        aabb[[0, 1, 5]], 
        aabb[[3, 4, 2]],
        aabb[[0, 4, 5]],
        aabb[[3, 1, 5]],
        aabb[[3, 4, 5]],
    ])
    faces = torch.tensor([
        [0, 1, 4], [0, 4, 2],
        [0, 3, 6], [0, 6, 1],
        [1, 6, 4], [4, 6, 7],
        [2, 4, 7], [2, 7, 5],
        [2, 5, 0], [0, 5, 3],
        [2, 4, 7], [2, 7, 5],
    ], dtype=torch.int32, device=aabb.device)
    return vertices, faces

vertices, faces = aabb_to_mesh(aabb)
plot_scene(
    {
        "bbox": {
            "struct": Trimesh(
                vertices.cpu().numpy(),
                faces.cpu().numpy(),
            ),
            "mesh_opacity": 0.7
        },
        "samples": {
            "struct": PointCloud(
                positions[positions.sum(dim=-1) > 0].cpu().numpy()
            )
        }
    }
)

In [None]:
(positions.sum(dim=-1) > 0).sum()

In [2]:
from typing import Callable

import torch


@torch.no_grad()
def root_finding(
    x: torch.Tensor, 
    func: Callable, 
    tol: float = 1e-5, 
    dvg_thresh: float = 1.0,
    max_iters: int = 50,
    dtype = torch.float64,  # double is necessary
    eps: float = 1e-6,
    verbose: bool = False,
):
    """Root finding for a multidimentional function: f(x) = 0.
    
    Here we adopt the Broyden's (Bad) Method.
    
    Args:
        x: the initial root of the function. [..., D_in]
        func: the callable function that takes x as input. It should
            also optionally take a bool as the second argument which
            contols where to return the jacobian. The return(s) of 
            this function is the f(x) or {f(x), J_f(x)}.
            The dimentions of the f(x) should be: 
                {[..., D_in], Optional[bool]} -> 
                {[..., D_out], Optional[[..., D_in, D_out]]}
    Returns:
        iters: the number of iterationss being take to finish.
        root: the root x being found. torch.FloatTensor [..., D_in]
        J: the estimated jacobian. torch.FloatTensor [..., D_in, D_out]
        mask: the validity mask of the root. torch.BoolTensor [...,]
    """
    origin_dtype = x.dtype

    with torch.enable_grad():
        f, J = func(x, mask=None, return_jac=True)

    x = x.clone().to(dtype)
    f = f.to(dtype)
    J = J.to(dtype)

    J_inv = J.inverse()
    err = torch.linalg.norm(f, dim=-1)
    accept_mask = err < tol
    reject_mask = err > dvg_thresh
    process_mask = ~ (accept_mask | reject_mask)

    iters_taken = 0
    total_processed = 0
    while process_mask.any() and iters_taken < max_iters:
        iters_taken += 1
        total_processed += process_mask.sum()

        # select a slice of data need to be processed.
        _x = x[process_mask]
        _f = f[process_mask]
        _J_inv = J_inv[process_mask]
        
        # update x
        _dx = torch.einsum("nij,nj->ni", _J_inv, _f)  # [N, D_in]
        _x_new = _x - _dx  # [N, D_in]
        x[process_mask] = _x_new  # [N, D_in]

        # update f
        _f_new = func(
            x.to(origin_dtype), mask=process_mask, return_jac=False
        ).to(dtype)[process_mask]  # [N, D_out]
        _df = _f_new - _f  # [N, D_out]
        f[process_mask] = _f_new  # [N, D_out]

        # update J_inv for the next iteration.
        _u = torch.einsum("nij,nj->ni", _J_inv, _df)  # [N, D_in]
        _d = - 1 * _dx   # [N, D_in]
        _a = _d - _u  # [N, D_in]
        _b = torch.einsum("ni,ni->n", _d, _u)
        _b += (_b > 0) * eps - (_b < 0) * eps
        _vT = torch.einsum("nij,ni->nj", _J_inv, _d)
        J_inv[process_mask] += torch.einsum(
            "nj,ni->nij", _vT, _a / _b[:, None]
        )  # [N, D_in, D_out] 
        
        err = torch.linalg.norm(f, dim=-1)
        accept_mask = err < tol
        reject_mask = err > dvg_thresh
        process_mask = ~ (accept_mask | reject_mask)

        # _err = torch.linalg.norm(_f_new, dim=-1)
        # _accept_mask = _err < tol
        # _reject_mask = _err > dvg_thresh
        # _process_mask = ~ (_accept_mask | _reject_mask)

        # J_inv[process_mask] += _dJ_inv * _process_mask[:, None, None]
        # err[process_mask] = _err * _process_mask
        
        # accept_mask = err < tol
        # reject_mask = err > dvg_thresh
        # process_mask = ~ (accept_mask | reject_mask)

        if verbose:
            print(
                "iter: %d | " % iters_taken +
                "processing: %.2f%% | " % (process_mask.float().mean() * 100) +
                "accept: %.2f%% | " % (accept_mask.float().mean() * 100) +
                "reject: %.2f%% | " % (reject_mask.float().mean() * 100) +
                "err(min): %.5f | " % err.min() +
                "err(max): %.5f | " % err.max() +
                "err(mean): %.5f | " % err.mean()
            )

    x = x.to(origin_dtype)
    J = J_inv.inverse().to(origin_dtype)
    return total_processed, iters_taken, x, J, err, accept_mask  


  from .autonotebook import tqdm as notebook_tqdm


In [10]:
import tqdm
import torch

torch.manual_seed(1234)
x_init = torch.rand((100000, 3)).float().to("cuda") * 0.1
x_jac = (2 * x_init - 1)[:, None, :] * torch.eye(3).to(x_init)

def func(x, mask=None, return_jac=False):
    out = x ** 2 - x
    if return_jac:
        # jac = torch.eye(
        #     3, device=x.device
        # )[None].expand(x.shape[0], -1, -1)
        jac = x_jac
        return out, jac
    else:
        return out


for _ in tqdm.tqdm(range(200)):
    (
        total_processed, iters_taken, x, J, err, accept_mask  
    ) = root_finding(x_init, func, verbose=False)
    torch.cuda.synchronize()
print (x.mean(), x.shape, accept_mask.float().mean())

100%|██████████| 200/200 [00:16<00:00, 12.15it/s]

tensor(2.2364e-07, device='cuda:0') torch.Size([100000, 3]) tensor(1., device='cuda:0')





In [13]:
@torch.no_grad()
def _broyden(
    g, x_init, J_inv_init, max_steps=50, cvg_thresh=1e-5, dvg_thresh=1, eps=1e-6
):
    """Find roots of the given function g(x) = 0.
    This function is impleneted based on https://github.com/locuslab/deq.
    Tensor shape abbreviation:
        N: number of points
        D: space dimension
    Args:
        g (function): the function of which the roots are to be determined. shape: [N, D, 1]->[N, D, 1]
        x_init (tensor): initial value of the parameters. shape: [N, D, 1]
        J_inv_init (tensor): initial value of the inverse Jacobians. shape: [N, D, D]
        max_steps (int, optional): max number of iterations. Defaults to 50.
        cvg_thresh (float, optional): covergence threshold. Defaults to 1e-5.
        dvg_thresh (float, optional): divergence threshold. Defaults to 1.
        eps (float, optional): a small number added to the denominator to prevent numerical error. Defaults to 1e-6.
    Returns:
        result (tensor): root of the given function. shape: [N, D, 1]
        diff (tensor): corresponding loss. [N]
        valid_ids (tensor): identifiers of converged points. [N]
    """

    # initialization
    x = x_init.clone().detach()
    J_inv = J_inv_init.clone().detach()

    ids_val = torch.ones(x.shape[0], device=x.device).bool()

    gx = g(x, mask=ids_val)
    update = -J_inv.bmm(gx)

    x_opt = x.clone()
    gx_norm_opt = torch.linalg.norm(gx.squeeze(-1), dim=-1)

    delta_gx = torch.zeros_like(gx)
    delta_x = torch.zeros_like(x)

    ids_val = torch.ones_like(gx_norm_opt).bool()

    total_processed = 0
    for i_step in range(max_steps):
        total_processed += ids_val.sum()

        # update paramter values
        delta_x[ids_val] = update
        x[ids_val] += delta_x[ids_val]
        delta_gx[ids_val] = g(x, mask=ids_val) - gx[ids_val]
        gx[ids_val] += delta_gx[ids_val]

        # store values with minial loss
        gx_norm = torch.linalg.norm(gx.squeeze(-1), dim=-1)
        ids_opt = gx_norm < gx_norm_opt
        gx_norm_opt[ids_opt] = gx_norm.clone().detach()[ids_opt]
        x_opt[ids_opt] = x.clone().detach()[ids_opt]

        print (i_step, gx_norm)

        # exclude converged and diverged points from furture iterations
        ids_val = (gx_norm_opt > cvg_thresh) & (gx_norm < dvg_thresh)
        if ids_val.sum() <= 0:
            break

        # compute paramter update for next iter
        vT = (delta_x[ids_val]).transpose(-1, -2).bmm(J_inv[ids_val])
        a = delta_x[ids_val] - J_inv[ids_val].bmm(delta_gx[ids_val])
        b = vT.bmm(delta_gx[ids_val])
        b[b >= 0] += eps
        b[b < 0] -= eps
        u = a / b
        J_inv[ids_val] += u.bmm(vT)
        update = -J_inv[ids_val].bmm(gx[ids_val])


        # print (
        #     # "_x_new", x[ids_val].mean(),
        #     # "_f_new", gx[ids_val].mean(),
        #     # "err", gx_norm_opt.mean(),
        #     # "process_mask", ids_val.float().mean(),
        #     "J_inv", J_inv.mean(),
        #     # "_b", b.mean(),
        #     # "_a", a.mean(),
        #     # "vT", vT.mean(),
        #     "u", u.shape, u.mean(),
        #     "vT", vT.shape, vT.mean(),
        # )
        # break

    iters_taken = i_step + 1
    x = x_opt
    J = J_inv.inverse()
    err = gx_norm_opt
    accept_mask = gx_norm_opt <= cvg_thresh

    return total_processed, iters_taken, x, J, err, accept_mask


@torch.no_grad()
def root_finding2(
    x: torch.Tensor, 
    func: Callable, 
    tol: float = 1e-5, 
    dvg_thresh: float = 1.0,
    max_iters: int = 50,
    dtype = torch.float64,  # double is necessary
    eps: float = 1e-10,
    verbose: bool = False,
):
    origin_dtype = x.dtype

    with torch.enable_grad():
        f, J = func(x, mask=None, return_jac=True)

    x = x.clone().to(dtype)
    f = f.to(dtype)
    J = J.to(dtype)

    J_inv = J.inverse()

    def _func_g(x, mask = None):
        x = x.squeeze(-1)
        f = func(
            x.to(origin_dtype), 
            mask=mask.to(origin_dtype) if mask is not None else None, 
            return_jac=False
        ).to(dtype)
        f = f[mask].unsqueeze(-1)
        return f

    x = x.unsqueeze(-1)
    total_processed, iters_taken, x, J, err, accept_mask = _broyden(
        _func_g, 
        x, 
        J_inv, 
        max_steps=max_iters, 
        cvg_thresh=tol, 
        dvg_thresh=dvg_thresh, 
        eps=eps
    )
    x = x.squeeze(-1)

    x = x.to(origin_dtype)
    J = J.to(origin_dtype)
    return total_processed, iters_taken, x, J, err, accept_mask

In [16]:
torch.manual_seed(1234)
x_init = torch.rand((1, 3)).float().to("cuda") * 0.1

def func(x, mask=None, return_jac=False):
    out = x ** 2 - x
    if return_jac:
        # jac = torch.eye(
        #     3, device=x.device
        # )[None].expand(x.shape[0], -1, -1)
        jac = (2 * x - 1)[:, None, :] * torch.eye(3).to(x)
        return out, jac
    else:
        return out

(
    total_processed, iters_taken, x, J, err, accept_mask  
) = root_finding2(x_init, func, verbose=True)


0 tensor([0.0019], device='cuda:0', dtype=torch.float64)
1 tensor([8.3125e-05], device='cuda:0', dtype=torch.float64)
2 tensor([1.2042e-06], device='cuda:0', dtype=torch.float64)
