In [1]:
from pathlib import Path
import sys
import os
sys.path.append(str(Path(os.path.abspath('')).parent))
print(Path(os.path.abspath('')).parent)
print(sys.path)

/home/amai/normalrf
['/home/amai/normalrf/scripts', '/home/amai/.conda/envs/31/lib/python310.zip', '/home/amai/.conda/envs/31/lib/python3.10', '/home/amai/.conda/envs/31/lib/python3.10/lib-dynload', '', '/home/amai/.local/lib/python3.10/site-packages', '/home/amai/.conda/envs/31/lib/python3.10/site-packages', '/home/amai/.conda/envs/31/lib/python3.10/site-packages/GDAL-3.5.0-py3.10-linux-x86_64.egg', '/home/amai/.conda/envs/31/lib/python3.10/site-packages/tinycudann-1.6-py3.10-linux-x86_64.egg', '/home/amai/.conda/envs/31/lib/python3.10/site-packages/raymarching_full-0.0.0-py3.10-linux-x86_64.egg', '/home/amai/triton/python', '/home/amai/normalrf']


In [2]:
import torch
from modules import safemath
import math
import plotly.express as px
from icecream import ic

In [40]:

def normalize(x):
    return x / (torch.linalg.norm(x, dim=-1, keepdim=True)+1e-8)

class PseudoRandomSampler(torch.nn.Module):
    def __init__(self, max_samples) -> None:
        super().__init__()
        self.sampler = torch.quasirandom.SobolEngine(dimension=3, scramble=True)
        self.max_samples = max_samples
        angs = self.sampler.draw(max_samples)
        self.register_buffer('angs', angs)

    def draw(self, B, num_samples):
        if num_samples > self.max_samples:
            self.max_samples = num_samples
            self.angs = self.sampler.draw(self.max_samples)
        angs = self.angs.reshape(1, self.max_samples, 3)[:, :num_samples, :].expand(B, num_samples, 3)
        # self.sampler = torch.quasirandom.SobolEngine(dimension=2, scramble=True)
        # add random offset
        offset = torch.rand(B, 1, 3, device=angs.device)*0.25
        angs = (angs + offset) % 1.0
        return angs

    def update(self, *args, **kwargs):
        pass

class SGGXSampler(PseudoRandomSampler):

    def sample(self, viewdir, normal, r1, r2, ray_mask, eps=torch.finfo(torch.float32).eps, **kwargs):
        num_samples = ray_mask.shape[1]
        # viewdir: (B, 3)
        # normal: (B, 3)
        # r1, r2: B roughness values for anisotropic roughness
        device = normal.device
        B = normal.shape[0]
        eps=torch.finfo(normal.dtype).eps

        # establish basis for BRDF
        z_up = torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, 3).expand(B, 3)
        x_up = torch.tensor([-1.0, 0.0, 0.0], device=device).reshape(1, 3).expand(B, 3)
        up = torch.where(normal[:, 2:3] < 0.999, z_up, x_up)
        tangent = normalize(torch.linalg.cross(up, normal))
        bitangent = normalize(torch.linalg.cross(normal, tangent))
        # B, 3, 3
        row_world_basis = torch.stack([tangent, bitangent, normal], dim=1).reshape(B, 3, 3)

        # B, 3, 3
        S_diagv = torch.stack([r1, r2, torch.ones_like(r1)], dim=-1).reshape(-1, 3)
        S_diag = torch.diag_embed(S_diagv)
        S = torch.matmul(torch.matmul(row_world_basis, S_diag), row_world_basis.permute(0, 2, 1))
        M = torch.zeros((B, 3, 3), device=device)
        tmp = (S[:, 1, 1]*S[:, 2, 2] - S[:, 1, 2]**2).clip(min=eps).sqrt()
        M[:, 0, 0] = torch.linalg.det(S).abs().sqrt() / tmp
        # checked
        inv_sqrt_Sii = 1/S[:, 2, 2].clip(min=eps).sqrt().clip(min=eps)
        # checked
        M[:, 1, 0] = -inv_sqrt_Sii*(S[:, 0, 2]*S[:, 1, 2] - S[:, 0, 1]*S[:, 2, 2])/tmp
        M[:, 1, 1] = inv_sqrt_Sii*tmp

        # checked
        M[:, 2, 0] = inv_sqrt_Sii * S[:, 0, 2]
        M[:, 2, 1] = inv_sqrt_Sii * S[:, 1, 2]
        M[:, 2, 2] = inv_sqrt_Sii * S[:, 2, 2]

        angs = self.draw(B, num_samples).to(device)

        M_mask = M.reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]
        S_mask_v = S_diagv.reshape(B, 1, 3).expand(B, num_samples, 3)[ray_mask]

        # here is where things get really large
        u1 = angs[..., 0]
        u2 = angs[..., 1]

        # stretch and mask stuff to reduce memory
        # r1_mask = r1.reshape(-1, 1).expand(u1.shape)[ray_mask]
        # r2_mask = r2.reshape(-1, 1).expand(u1.shape)[ray_mask]
        row_world_basis_mask = row_world_basis.permute(0, 2, 1).reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        u1_mask = u1[ray_mask]
        u2_mask = u2[ray_mask]

        u1sqrt = u1_mask.clip(min=eps).sqrt()
        u = (2*math.pi*u2_mask).cos() * u1sqrt
        v = (2*math.pi*u2_mask).sin() * u1sqrt
        w = (1-u**2-v**2).clip(min=eps).sqrt()

        H_l = normalize(u[:, None] * M_mask[:, 0] + v[:, None] * M_mask[:, 1] + w[:, None] * M_mask[:, 2])

        first = torch.zeros_like(ray_mask)
        first[:, 0] = True
        H_l[first[ray_mask], 0] = 0
        H_l[first[ray_mask], 1] = 0
        H_l[first[ray_mask], 2] = 1

        H = torch.matmul(row_world_basis_mask, H_l.unsqueeze(-1)).squeeze(-1)
        # H = torch.einsum('bni,bij->bnj', H_l, row_world_basis)

        V = viewdir.unsqueeze(1).expand(-1, num_samples, 3)[ray_mask]
        # N = normal.reshape(-1, 1, 3).expand(-1, num_samples, 3)[ray_mask]
        L = (2.0 * (V * H).sum(dim=-1, keepdim=True) * H - V)

        temp = torch.matmul(torch.matmul(H_l.reshape(-1, 1, 3), torch.diag_embed(1/S_mask_v.clip(min=eps))), H_l.reshape(-1, 3, 1))
        prob = 1 / (math.pi * (S_mask_v[:, 0] * S_mask_v[:, 1] * S_mask_v[:, 2]).sqrt().reshape(-1) * (temp.reshape(-1))**2)

        return L, row_world_basis_mask, prob

    def compute_prob(self, halfvec, eN, r1, r2, **kwargs):
        S_diag = torch.diag_embed(torch.stack([r1, r2, torch.ones_like(r1)], dim=-1).reshape(-1, 3))
        temp = torch.matmul(torch.matmul(halfvec.reshape(-1, 1, 3), S_diag), halfvec.reshape(-1, 3, 1))
        prob = 1 / (math.pi * torch.linalg.det(S_diag).sqrt().reshape(-1) * (temp.reshape(-1))**2)
        return prob


In [48]:
device = torch.device('cpu')
eps=torch.finfo(torch.float32).eps
N = 50
ray_mask = torch.ones((1, N), device=device, dtype=bool)
roughness = torch.tensor(0.1)
normal = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
viewdir = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
# sampler = CosineLobeSampler(512)
sampler = SGGXSampler(512)
L, basis, prob = sampler.sample(viewdir, normal, roughness, roughness, ray_mask)

indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(ray_mask.shape)[ray_mask]
mipval = -(prob * indiv_num_samples).log()
H = normalize(L + viewdir)
# mipval = sampler.calculate_mipval(H, viewdir, normal, ray_mask, roughness, torch.eye(3).reshape(1, 3, 3))
ic(L.min())

w = 1024
h = 512
dw = (1-L[:, 2]**2).clip(min=eps).sqrt() * 2 * math.pi / w
dh = torch.ones_like(dw) * math.pi / h
saSample = mipval.reshape(-1)

saTexel_w = dw.clip(min=eps).log() + dh.clip(min=eps).log()
saTexel_h = dw.clip(min=eps).log() + dh.clip(min=eps).log()
# saTexel is the ratio to the solid angle subtended by one pixel of the 0th mipmap level
# num_pixels = self.bg_mat.numel() // 3
# saTexel = distortion / num_pixels
miplevel_w = ((saSample - torch.log(saTexel_w.clip(min=eps))) / math.log(2)) / 2
miplevel_h = ((saSample - torch.log(saTexel_h.clip(min=eps))) / math.log(2)) / 2

px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=2**miplevel_h).show()

ic| L.min(): tensor(-0.9219)
