In [1]:
from pathlib import Path
import sys
import os
sys.path.append(str(Path(os.path.abspath('')).parent))
print(Path(os.path.abspath('')).parent)
print(sys.path)

/home/amai/normalrf
['/home/amai/normalrf/scripts', '/home/amai/.conda/envs/31/lib/python310.zip', '/home/amai/.conda/envs/31/lib/python3.10', '/home/amai/.conda/envs/31/lib/python3.10/lib-dynload', '', '/home/amai/.local/lib/python3.10/site-packages', '/home/amai/.conda/envs/31/lib/python3.10/site-packages', '/home/amai/.conda/envs/31/lib/python3.10/site-packages/GDAL-3.5.0-py3.10-linux-x86_64.egg', '/home/amai/.conda/envs/31/lib/python3.10/site-packages/tinycudann-1.6-py3.10-linux-x86_64.egg', '/home/amai/.conda/envs/31/lib/python3.10/site-packages/raymarching_full-0.0.0-py3.10-linux-x86_64.egg', '/home/amai/triton/python', '/home/amai/normalrf']


In [2]:
import torch
from modules import safemath
import math
import plotly.express as px
from icecream import ic

In [15]:

def normalize(x):
    return x / (torch.linalg.norm(x, dim=-1, keepdim=True)+1e-8)

class PseudoRandomSampler(torch.nn.Module):
    def __init__(self, max_samples) -> None:
        super().__init__()
        self.sampler = torch.quasirandom.SobolEngine(dimension=3, scramble=True)
        self.max_samples = max_samples
        angs = self.sampler.draw(max_samples)
        self.register_buffer('angs', angs)

    def draw(self, B, num_samples):
        if num_samples > self.max_samples:
            self.max_samples = num_samples
            self.angs = self.sampler.draw(self.max_samples)
        angs = self.angs.reshape(1, self.max_samples, 3)[:, :num_samples, :].expand(B, num_samples, 3)
        # self.sampler = torch.quasirandom.SobolEngine(dimension=2, scramble=True)
        # add random offset
        offset = torch.rand(B, 1, 3, device=angs.device)*0.25
        angs = (angs + offset) % 1.0
        return angs

    def update(self, *args, **kwargs):
        pass

class SGGXSampler(PseudoRandomSampler):

    def sample(self, refdirs, viewdir, normal, r1, r2, ray_mask):
        num_samples = ray_mask.shape[1]
        # viewdir: (B, 3)
        # normal: (B, 3)
        # r1, r2: B roughness values for anisotropic roughness
        device = normal.device
        B = normal.shape[0]
        eps=torch.finfo(normal.dtype).eps

        # establish basis for BRDF
        z_up = torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, 3).expand(B, 3)
        x_up = torch.tensor([-1.0, 0.0, 0.0], device=device).reshape(1, 3).expand(B, 3)
        up = torch.where(normal[:, 2:3] < 0.999, z_up, x_up)
        tangent = normalize(torch.linalg.cross(up, normal))
        bitangent = normalize(torch.linalg.cross(normal, tangent))
        # B, 3, 3
        row_world_basis = torch.stack([tangent, bitangent, normal], dim=1).reshape(B, 3, 3)

        # B, 3, 3
        S_diag = torch.diag_embed(torch.stack([r1, r2, torch.ones_like(r1)], dim=0))
        S = torch.matmul(torch.matmul(row_world_basis, S_diag), row_world_basis.permute(0, 2, 1))
        M = torch.zeros((B, 3, 3), device=device)
        tmp = (S[:, 1, 1]*S[:, 2, 2] - S[:, 1, 2]**2).clip(min=eps).sqrt()
        M[:, 0, 0] = torch.linalg.det(S).abs().sqrt() / tmp
        # checked
        inv_sqrt_Sii = 1/S[:, 2, 2].clip(min=eps).sqrt().clip(min=eps)
        # checked
        M[:, 1, 0] = -inv_sqrt_Sii*(S[:, 0, 2]*S[:, 1, 2] - S[:, 0, 1]*S[:, 2, 2])/tmp
        M[:, 1, 1] = inv_sqrt_Sii*tmp

        # checked
        M[:, 2, 0] = inv_sqrt_Sii * S[:, 0, 2]
        M[:, 2, 1] = inv_sqrt_Sii * S[:, 1, 2]
        M[:, 2, 2] = inv_sqrt_Sii * S[:, 2, 2]

        angs = self.draw(B, num_samples).to(device)

        M_mask = M.reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]
        S_mask = S.reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        # here is where things get really large
        u1 = angs[..., 0]
        u2 = angs[..., 1]

        # stretch and mask stuff to reduce memory
        r1_mask = r1.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r2_mask = r2.reshape(-1, 1).expand(u1.shape)[ray_mask]
        row_world_basis_mask = row_world_basis.permute(0, 2, 1).reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        u1_mask = u1[ray_mask]
        u2_mask = u2[ray_mask]

        u1sqrt = u1_mask.clip(min=eps).sqrt()
        u = (2*math.pi*u2_mask).cos() * u1sqrt
        v = (2*math.pi*u2_mask).sin() * u1sqrt
        w = (1-u**2-v**2).clip(min=eps).sqrt()

        ic(u.shape, M_mask.shape)
        H_l = normalize(u[:, None] * M_mask[:, 0] + v[:, None] * M_mask[:, 1] + w[:, None] * M_mask[:, 2])

        first = torch.zeros_like(ray_mask)
        first[:, 0] = True
        H_l[first[ray_mask], 0] = 0
        H_l[first[ray_mask], 1] = 0
        H_l[first[ray_mask], 2] = 1

        H = torch.matmul(row_world_basis_mask, H_l.unsqueeze(-1)).squeeze(-1)
        # H = torch.einsum('bni,bij->bnj', H_l, row_world_basis)

        V = viewdir.unsqueeze(1).expand(-1, num_samples, 3)[ray_mask]
        # N = normal.reshape(-1, 1, 3).expand(-1, num_samples, 3)[ray_mask]
        L = (2.0 * (V * H).sum(dim=-1, keepdim=True) * H - V)

        temp = torch.matmul(torch.matmul(H_l.reshape(-1, 1, 3), S_mask), H_l.reshape(-1, 3, 1))
        ic(temp.shape, S_mask.shape, H_l.shape)
        prob = 1 / (math.pi * torch.linalg.det(S).sqrt().reshape(-1) * (temp.reshape(-1))**2)

        return L, row_world_basis_mask, prob

    def calculate_mipval(self, H, V, N, ray_mask, roughness, row_world_basis, eps=torch.finfo(torch.float32).eps):
        num_samples = ray_mask.shape[1]
        NdotH = ((H * N).sum(dim=-1)).abs().clip(min=eps, max=1)
        HdotV = (H * V).sum(dim=-1).abs().clip(min=eps, max=1)
        NdotV = (N * V).sum(dim=-1).abs().clip(min=eps, max=1)
        logD = 2*torch.log(roughness.clip(min=eps)) - 2*torch.log((NdotH**2*(roughness**2-1)+1).clip(min=eps))
        # ic(NdotH.shape, NdotH, D, D.mean())
        # px.scatter(x=NdotH[0].detach().cpu().flatten(), y=D[0].detach().cpu().flatten()).show()
        # assert(False)
        # ic(NdotH.mean())
        lpdf = logD + torch.log(HdotV) - torch.log(NdotV)# - torch.log(roughness.clip(min=1e-5))
        # pdf = D * HdotV / NdotV / roughness.reshape(-1, 1)
        # pdf = NdotH / 4 / HdotV
        # pdf = D# / NdotH
        indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(-1, num_samples)[ray_mask]
        mipval = -torch.log(indiv_num_samples.clip(min=1)) - lpdf
        return mipval

In [16]:
device = torch.device('cpu')
N = 50
ray_mask = torch.ones((1, N), device=device, dtype=bool)
roughness = torch.tensor(0.1)
normal = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
viewdir = torch.tensor([0, 0, -1.0], device=device).reshape(1, 3)
# sampler = CosineLobeSampler(512)
sampler = SGGXSampler(512)
L, basis, prob = sampler.sample(None, viewdir, normal, roughness, roughness, ray_mask)

        num_samples = ray_mask.shape[1]
indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(-1, num_samples)[ray_mask]
mipval = prob * indiv_num_samples
H = normalize(L + viewdir)
mipval = sampler.calculate_mipval(H, viewdir, normal, ray_mask, roughness, torch.eye(3).reshape(1, 3, 3))
ic(L.min())

h = w = 512
# h = w = 2048
distortion = 1
saTexel = distortion / h / w
num_pixels = h * w * 6
miplevel = ((mipval - math.log(saTexel)) / math.log(2))/2
miplevel = miplevel.clip(0)
res = h / 2**miplevel
ic(mipval, miplevel, math.log(saTexel*N))
# px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=miplevel)
px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2]).show()


ic| u.shape: torch.Size([50]), M_mask.shape: torch.Size([50, 3, 3])
ic| temp.shape: torch.Size([50, 1, 1])
    S_mask.shape: torch.Size([50, 3, 3])
    H_l.shape: torch.Size([50, 3])
ic| L.min(): tensor(-1.)
ic| mipval: tensor([-8.5172, -6.4890, -2.1749, -3.3524, -4.1602, -1.6194, -8.4313, -4.4508,
                    -4.9154,  0.1623, -0.9173, -3.7164, -2.9369, -1.6676, -7.1201, -6.1790,
                    -5.8431, -7.4031, -2.1418, -2.8813, -3.7666, -0.1884,  0.6439, -5.1949,
                    -4.7240, -8.0156, -1.0563, -4.2170, -3.2926, -2.6081, -6.7290, -5.3754,
                    -5.4613, -6.8803, -2.5379, -3.2075, -4.2992, -1.1594, -7.7944, -4.6462,
                    -5.0630,  1.3805, -0.4294, -3.9082, -2.7277, -2.0059, -7.7188, -5.9955,
                    -6.0704, -7.0129])
    miplevel: tensor([2.8561, 4.3192, 7.4311, 6.5818, 5.9991, 7.8318, 2.9181, 5.7894, 5.4543,
                      9.1171, 8.3383, 6.3192, 6.8815, 7.7971, 3.8639, 4.5428, 4.7851, 3.6598,
             