In [5]:
from pathlib import Path
import sys
import os
sys.path.append(str(Path(os.path.abspath('')).parent))
print(Path(os.path.abspath('')).parent)
print(sys.path)

/home/dronelab/normalrf
['/home/dronelab/normalrf/scripts', '/home/dronelab/miniconda3/envs/31/lib/python310.zip', '/home/dronelab/miniconda3/envs/31/lib/python3.10', '/home/dronelab/miniconda3/envs/31/lib/python3.10/lib-dynload', '', '/home/dronelab/miniconda3/envs/31/lib/python3.10/site-packages', '/home/dronelab/miniconda3/envs/31/lib/python3.10/site-packages/raymarching-0.0.0-py3.10-linux-x86_64.egg', '/home/dronelab/miniconda3/envs/31/lib/python3.10/site-packages/tinycudann-1.6-py3.10-linux-x86_64.egg', '/home/dronelab/normalrf', '/home/dronelab/normalrf', '/home/dronelab/normalrf']


In [6]:
import torch
from modules import safemath
import math
import plotly.express as px
from icecream import ic

In [13]:

def normalize(x):
    return x / (torch.linalg.norm(x, dim=-1, keepdim=True)+1e-8)

class PseudoRandomSampler(torch.nn.Module):
    def __init__(self, max_samples) -> None:
        super().__init__()
        self.sampler = torch.quasirandom.SobolEngine(dimension=3, scramble=True)
        self.max_samples = max_samples
        angs = self.sampler.draw(max_samples)
        self.register_buffer('angs', angs)

    def draw(self, B, num_samples):
        if num_samples > self.max_samples:
            self.max_samples = num_samples
            self.angs = self.sampler.draw(self.max_samples)
        angs = self.angs.reshape(1, self.max_samples, 3)[:, :num_samples, :].expand(B, num_samples, 3)
        # self.sampler = torch.quasirandom.SobolEngine(dimension=2, scramble=True)
        # add random offset
        offset = torch.rand(B, 1, 3, device=angs.device)*0.25
        angs = (angs + offset) % 1.0
        return angs

    def update(self, *args, **kwargs):
        pass

class SGGXSampler(PseudoRandomSampler):

    def sample(self, refdirs, viewdir, normal, r1, r2, ray_mask):
        num_samples = ray_mask.shape[1]
        # viewdir: (B, 3)
        # normal: (B, 3)
        # r1, r2: B roughness values for anisotropic roughness
        device = normal.device
        B = normal.shape[0]
        eps=torch.finfo(normal.dtype).eps

        # establish basis for BRDF
        z_up = torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, 3).expand(B, 3)
        x_up = torch.tensor([-1.0, 0.0, 0.0], device=device).reshape(1, 3).expand(B, 3)
        up = torch.where(normal[:, 2:3] < 0.999, z_up, x_up)
        tangent = normalize(torch.linalg.cross(up, normal))
        bitangent = normalize(torch.linalg.cross(normal, tangent))
        # B, 3, 3
        row_world_basis = torch.stack([tangent, bitangent, normal], dim=1).reshape(B, 3, 3)

        # B, 3, 3
        S_diag = torch.diag_embed(torch.stack([r1, r2, torch.ones_like(r1)], dim=0))
        S = torch.matmul(torch.matmul(row_world_basis, S_diag), row_world_basis.permute(0, 2, 1))
        M = torch.zeros((B, 3, 3), device=device)
        tmp = (S[:, 1, 1]*S[:, 2, 2] - S[:, 1, 2]**2).clip(min=eps).sqrt()
        M[:, 0, 0] = torch.linalg.det(S).abs().sqrt() / tmp
        # checked
        inv_sqrt_Sii = 1/S[:, 2, 2].clip(min=eps).sqrt().clip(min=eps)
        # checked
        M[:, 1, 0] = -inv_sqrt_Sii*(S[:, 0, 2]*S[:, 1, 2] - S[:, 0, 1]*S[:, 2, 2])/tmp
        M[:, 1, 1] = inv_sqrt_Sii*tmp

        # checked
        M[:, 2, 0] = inv_sqrt_Sii * S[:, 0, 2]
        M[:, 2, 1] = inv_sqrt_Sii * S[:, 1, 2]
        M[:, 2, 2] = inv_sqrt_Sii * S[:, 2, 2]

        angs = self.draw(B, num_samples).to(device)

        M_mask = M.reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        # here is where things get really large
        u1 = angs[..., 0]
        u2 = angs[..., 1]

        # stretch and mask stuff to reduce memory
        r1_mask = r1.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r2_mask = r2.reshape(-1, 1).expand(u1.shape)[ray_mask]
        row_world_basis_mask = row_world_basis.permute(0, 2, 1).reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        u1_mask = u1[ray_mask]
        u2_mask = u2[ray_mask]

        u1sqrt = u1_mask.clip(min=eps).sqrt()
        u = (2*math.pi*u2_mask).cos() * u1sqrt
        v = (2*math.pi*u2_mask).sin() * u1sqrt
        w = (1-u**2-v**2).clip(min=eps).sqrt()

        ic(u.shape, M_mask.shape)
        H_l = normalize(u[:, None] * M_mask[:, 0] + v[:, None] * M_mask[:, 1] + w[:, None] * M_mask[:, 2])

        first = torch.zeros_like(ray_mask)
        first[:, 0] = True
        H_l[first[ray_mask], 0] = 0
        H_l[first[ray_mask], 1] = 0
        H_l[first[ray_mask], 2] = 1

        H = torch.matmul(row_world_basis_mask, H_l.unsqueeze(-1)).squeeze(-1)
        # H = torch.einsum('bni,bij->bnj', H_l, row_world_basis)

        V = viewdir.unsqueeze(1).expand(-1, num_samples, 3)[ray_mask]
        # N = normal.reshape(-1, 1, 3).expand(-1, num_samples, 3)[ray_mask]
        L = (2.0 * (V * H).sum(dim=-1, keepdim=True) * H - V)

        return L, row_world_basis_mask

    def calculate_mipval(self, H, V, N, ray_mask, roughness, row_world_basis, eps=torch.finfo(torch.float32).eps):
        num_samples = ray_mask.shape[1]
        NdotH = ((H * N).sum(dim=-1)).abs().clip(min=eps, max=1)
        HdotV = (H * V).sum(dim=-1).abs().clip(min=eps, max=1)
        NdotV = (N * V).sum(dim=-1).abs().clip(min=eps, max=1)
        logD = 2*torch.log(roughness.clip(min=eps)) - 2*torch.log((NdotH**2*(roughness**2-1)+1).clip(min=eps))
        # ic(NdotH.shape, NdotH, D, D.mean())
        # px.scatter(x=NdotH[0].detach().cpu().flatten(), y=D[0].detach().cpu().flatten()).show()
        # assert(False)
        # ic(NdotH.mean())
        lpdf = logD + torch.log(HdotV) - torch.log(NdotV)# - torch.log(roughness.clip(min=1e-5))
        # pdf = D * HdotV / NdotV / roughness.reshape(-1, 1)
        # pdf = NdotH / 4 / HdotV
        # pdf = D# / NdotH
        indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(-1, num_samples)[ray_mask]
        mipval = -torch.log(indiv_num_samples.clip(min=1)) - lpdf
        return mipval

In [18]:
device = torch.device('cpu')
N = 50
ray_mask = torch.ones((1, N), device=device, dtype=bool)
roughness = torch.tensor(0.20)
normal = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
viewdir = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
# sampler = CosineLobeSampler(512)
sampler = SGGXSampler(512)
L, basis = sampler.sample(None, viewdir, normal, roughness, roughness, ray_mask)
H = normalize(L + normal)
mipval = sampler.calculate_mipval(H, viewdir, normal, ray_mask, roughness, torch.eye(3).reshape(1, 3, 3))
ic(L.min())

h = w = 512
# h = w = 2048
distortion = 1
saTexel = distortion / h / w
num_pixels = h * w * 6
miplevel = ((mipval - math.log(saTexel)) / math.log(2))/2
miplevel = miplevel.clip(0)
res = h / 2**miplevel
ic(mipval, miplevel, math.log(saTexel*N))
# px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=miplevel)
px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2]).show()


ic| u.shape: torch.Size([50]), M_mask.shape: torch.Size([50, 3, 3])
ic| L.min(): tensor(-0.9608)
ic| mipval: tensor([-7.1309, -1.7621, -2.7569, -3.9091, -4.4523, -3.3410, -0.8835, -5.2810,
                    -5.6833,  0.1311, -2.9704, -4.8148, -4.2591, -2.3407, -1.1942, -6.3227,
                    -6.7113, -1.3580, -2.2410, -4.0005, -4.9261, -3.2699, -6.9665, -5.6009,
                    -5.2040, -0.2617, -3.6293, -4.5604, -3.6490, -2.6658, -1.8956, -6.2183,
                    -6.0931, -2.0042, -2.5540, -3.7298, -4.6360, -3.5259, -0.4791, -5.0941,
                    -5.4778, -7.0810, -3.1572, -4.9984, -4.0860, -2.1248, -1.4971, -6.5770,
                    -6.4557, -1.0461])
    miplevel: tensor([3.8561, 7.7289, 7.0113, 6.1802, 5.7884, 6.5900, 8.3627, 5.1905, 4.9004,
                      9.0946, 6.8573, 5.5268, 5.9277, 7.3115, 8.1386, 4.4392, 4.1588, 8.0204,
                      7.3835, 6.1142, 5.4466, 6.6413, 3.9747, 4.9598, 5.2461, 8.8112, 6.3820,
                      5.7103, 