In [1]:
from pathlib import Path
import sys
import os
sys.path.append(str(Path(os.path.abspath('')).parent))
print(Path(os.path.abspath('')).parent)

/home/amai/normalrf


In [5]:
import torch
from modules import safemath
import math
import plotly.express as px
from icecream import ic

In [46]:

def normalize(x):
    return x / (torch.linalg.norm(x, dim=-1, keepdim=True)+1e-8)

class PseudoRandomSampler(torch.nn.Module):
    def __init__(self, max_samples) -> None:
        super().__init__()
        self.sampler = torch.quasirandom.SobolEngine(dimension=3, scramble=True)
        self.max_samples = max_samples
        angs = self.sampler.draw(max_samples)
        self.register_buffer('angs', angs)

    def draw(self, B, num_samples):
        if num_samples > self.max_samples:
            self.max_samples = num_samples
            self.angs = self.sampler.draw(self.max_samples)
        angs = self.angs.reshape(1, self.max_samples, 3)[:, :num_samples, :].expand(B, num_samples, 3)
        # self.sampler = torch.quasirandom.SobolEngine(dimension=2, scramble=True)
        # add random offset
        offset = torch.rand(B, 1, 3, device=angs.device)*0.25
        angs = (angs + offset) % 1.0
        return angs

    def update(self, *args, **kwargs):
        pass

class CosineLobeSampler(PseudoRandomSampler):
    def sample(self, refdirs, viewdir, normal, r1, r2, ray_mask, eps=torch.finfo(torch.float32).eps):
        num_samples = ray_mask.shape[1]
        # viewdir: (B, 3)
        # normal: (B, 3)
        # r1, r2: B roughness values for anisotropic roughness
        device = normal.device
        B = normal.shape[0]

        # establish basis for BRDF
        z_up = torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, 3).expand(B, 3)
        x_up = torch.tensor([-1.0, 0.0, 0.0], device=device).reshape(1, 3).expand(B, 3)
        up = torch.where(normal[:, 2:3] < 0.9, z_up, x_up)
        tangent = normalize(torch.linalg.cross(up, normal))
        bitangent = normalize(torch.linalg.cross(normal, tangent))
        # B, 3, 3
        row_world_basis = torch.stack([tangent, bitangent, normal], dim=1).reshape(B, 3, 3)

        # GGXVNDF
        # V_l = torch.matmul(torch.inverse(row_world_basis.permute(0, 2, 1)), viewdir.unsqueeze(-1)).squeeze(-1)
        # ic((normal*viewdir).sum(dim=-1).min(), (normal*viewdir).sum(dim=-1).max())
        # ic(1, V_l.min(dim=0), V_l.max(dim=0))
        V_l = torch.matmul(row_world_basis, viewdir.unsqueeze(-1)).squeeze(-1)
        # ic(2, V_l.min(dim=0), V_l.max(dim=0))
        r1_c = r1.squeeze(-1)
        r2_c = r2.squeeze(-1)

        angs = self.draw(B, num_samples).to(device)

        # here is where things get really large
        u1 = angs[..., 0]
        u2 = angs[..., 1]
        u3 = angs[..., 2]

        # stretch and mask stuff to reduce memory
        r_mask1 = r1_c.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r_mask2 = r2_c.reshape(-1, 1).expand(u1.shape)[ray_mask]

        u_mask = angs[ray_mask]
        u1_mask = u1[ray_mask]
        u2_mask = u2[ray_mask]
        u3_mask = u3[ray_mask]
        row_world_basis_mask = row_world_basis.permute(0, 2, 1).reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        
        sphere_noise = normalize(u_mask*2-1)
        # so this function is the inverse of the CDF
        H_l = normalize(r_mask1.reshape(-1, 1) * sphere_noise + torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, -1))

        first = torch.zeros_like(ray_mask)
        first[:, 0] = True
        H_l[first[ray_mask], 0] = 0
        H_l[first[ray_mask], 1] = 0
        H_l[first[ray_mask], 2] = 1

        H = torch.matmul(row_world_basis_mask, H_l.unsqueeze(-1)).squeeze(-1)
        # H = torch.einsum('bni,bij->bnj', H_l, row_world_basis)

        V = viewdir.unsqueeze(1).expand(-1, num_samples, 3)[ray_mask]
        # N = normal.reshape(-1, 1, 3).expand(-1, num_samples, 3)[ray_mask]
        L = (2.0 * (V * H).sum(dim=-1, keepdim=True) * H - V)

        return L, row_world_basis_mask

    def calculate_mipval(self, H, V, N, ray_mask, roughness, row_world_basis, eps=torch.finfo(torch.float32).eps):
        num_samples = ray_mask.shape[1]
        device = ray_mask.device
        # H_l = torch.matmul(row_world_basis.permute(0, 2, 1), H.unsqueeze(-1)).squeeze(-1)
        # sphere_noise = (H_l - torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, -1)) / roughness

        costheta = (H * N).sum(dim=-1).clip(min=eps, max=1)

        HdotV = (H * V).sum(dim=-1).abs().clip(min=eps, max=1)
        NdotV = (N * V).sum(dim=-1).abs().clip(min=eps, max=1)
        lpdf1 = -(2*math.pi*roughness.clip(min=eps)).log() + (2*costheta**2 + roughness**2-1).clip(min=eps).log() - 0.5*(roughness**2 + costheta**2 - 1).clip(min=eps).log()
        lpdf = torch.where(costheta > (1-roughness*2).clip(min=eps).sqrt(), lpdf1, torch.zeros_like(lpdf1))
        indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(-1, num_samples)[ray_mask]

        # logD = -math.log(2*math.pi**2) - dTdtheta
        # ic(NdotH.shape, NdotH, D, D.mean())
        # px.scatter(x=NdotH[0].detach().cpu().flatten(), y=D[0].detach().cpu().flatten()).show()
        # assert(False)
        # pdf = D * HdotV / NdotV / roughness.reshape(-1, 1)
        # pdf = NdotH / 4 / HdotV
        # pdf = D# / NdotH
        # mipval = -torch.log(indiv_num_samples.clip(min=1)) - lpdf
        ic(lpdf.exp())
        mipval = -torch.log(indiv_num_samples.clip(min=1)) - lpdf
        return mipval

class GGXSampler(PseudoRandomSampler):

    def sample(self, refdirs, viewdir, normal, r1, r2, ray_mask, eps=torch.finfo(torch.float32).eps):
        num_samples = ray_mask.shape[1]
        # viewdir: (B, 3)
        # normal: (B, 3)
        # r1, r2: B roughness values for anisotropic roughness
        device = normal.device
        B = normal.shape[0]

        # establish basis for BRDF
        z_up = torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, 3).expand(B, 3)
        x_up = torch.tensor([-1.0, 0.0, 0.0], device=device).reshape(1, 3).expand(B, 3)
        up = torch.where(normal[:, 2:3] < 0.999, z_up, x_up)
        tangent = normalize(torch.linalg.cross(up, normal))
        bitangent = normalize(torch.linalg.cross(normal, tangent))
        # B, 3, 3
        row_world_basis = torch.stack([tangent, bitangent, normal], dim=1).reshape(B, 3, 3)


        # GGXVNDF
        # V_l = torch.matmul(torch.inverse(row_world_basis.permute(0, 2, 1)), viewdir.unsqueeze(-1)).squeeze(-1)
        # ic((normal*viewdir).sum(dim=-1).min(), (normal*viewdir).sum(dim=-1).max())
        # ic(1, V_l.min(dim=0), V_l.max(dim=0))
        V_l = torch.matmul(row_world_basis, viewdir.unsqueeze(-1)).squeeze(-1)
        # ic(2, V_l.min(dim=0), V_l.max(dim=0))
        r1_c = r1.squeeze(-1)
        r2_c = r2.squeeze(-1)
        V_stretch = normalize(torch.stack([r1_c*V_l[..., 0], r2_c*V_l[..., 1], V_l[..., 2]], dim=-1)).unsqueeze(1)
        T1 = torch.where(V_stretch[..., 2:3] < 0.999, normalize(torch.linalg.cross(V_stretch, z_up.unsqueeze(1), dim=-1)), x_up.unsqueeze(1))
        T2 = normalize(torch.linalg.cross(T1, V_stretch, dim=-1))
        z = V_stretch[..., 2].reshape(-1, 1)
        a = (1 / (1+z.detach()).clip(min=1e-5)).clip(max=1e4)
        angs = self.draw(B, num_samples).to(device)

        # here is where things get really large
        u1 = angs[..., 0]
        u2 = angs[..., 1]

        # stretch and mask stuff to reduce memory
        a_mask = a.expand(u1.shape)[ray_mask]

        r_mask_u1 = r1_c.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r_mask1 = r_mask_u1
        r_mask_u2 = r2_c.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r_mask2 = r_mask_u2

        z_mask = z.expand(u1.shape)[ray_mask]
        u1_mask = u1[ray_mask]
        u2_mask = u2[ray_mask]
        T1_mask = T1.expand(-1, num_samples, 3)[ray_mask]
        T2_mask = T2.expand(-1, num_samples, 3)[ray_mask]
        V_stretch_mask = V_stretch.expand(-1, num_samples, 3)[ray_mask]
        row_world_basis_mask = row_world_basis.permute(0, 2, 1).reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        r = torch.sqrt(u1_mask)
        phi = torch.where(u2_mask < a_mask, u2_mask/a_mask*math.pi, (u2_mask-a_mask)/(1-a_mask)*math.pi + math.pi)
        P1 = (r*safemath.safe_cos(phi)).unsqueeze(-1)
        P2 = (r*safemath.safe_sin(phi)*torch.where(u2_mask < a_mask, torch.tensor(1.0, device=device), z_mask)).unsqueeze(-1)
        # ic((1-a).min(), a.min(), a.max(), phi.min(), phi.max(), (1-a).max())
        N_stretch = P1*T1_mask + P2*T2_mask + (1 - P1*P1 - P2*P2).clip(min=0).sqrt() * V_stretch_mask
        # H_l = normalize(torch.stack([r_mask1*N_stretch[..., 0], r_mask2*N_stretch[..., 1], N_stretch[..., 2].clip(min=0)], dim=-1))
        H_l = normalize(torch.stack([r_mask1*N_stretch[..., 0], r_mask2*N_stretch[..., 1], N_stretch[..., 2]], dim=-1))

        first = torch.zeros_like(ray_mask)
        first[:, 0] = True
        H_l[first[ray_mask], 0] = 0
        H_l[first[ray_mask], 1] = 0
        H_l[first[ray_mask], 2] = 1

        H = torch.matmul(row_world_basis_mask, H_l.unsqueeze(-1)).squeeze(-1)
        # H = torch.einsum('bni,bij->bnj', H_l, row_world_basis)

        V = viewdir.unsqueeze(1).expand(-1, num_samples, 3)[ray_mask]
        # N = normal.reshape(-1, 1, 3).expand(-1, num_samples, 3)[ray_mask]
        L = (2.0 * (V * H).sum(dim=-1, keepdim=True) * H - V)

        prob = -(math.pi * r_mask1 * r_mask2 * (
            H_l[:, 0]**2 / (r_mask1**2).clip(min=eps) + 
            H_l[:, 1]**2 / (r_mask2**2).clip(min=eps) + 
            H_l[:, 2]**2
            )**2).clip(min=eps).log()

        # prob = 2*torch.log(r_mask1.clip(min=eps)) - 2*torch.log(math.pi*(H_l[:, 2]**2*(r_mask1**2-1)+1).clip(min=eps))
        return L, row_world_basis_mask, prob

    def calculate_mipval(self, H, V, N, ray_mask, roughness, row_world_basis, eps=torch.finfo(torch.float32).eps):
        num_samples = ray_mask.shape[1]
        NdotH = ((H * N).sum(dim=-1)).abs().clip(min=eps, max=1)
        HdotV = (H * V).sum(dim=-1).abs().clip(min=eps, max=1)
        NdotV = (N * V).sum(dim=-1).abs().clip(min=eps, max=1)
        logD = 2*torch.log(roughness.clip(min=eps)) - 2*torch.log(math.pi*(NdotH**2*(roughness**2-1)+1).clip(min=eps))
        # ic(NdotH.shape, NdotH, D, D.mean())
        # px.scatter(x=NdotH[0].detach().cpu().flatten(), y=D[0].detach().cpu().flatten()).show()
        # assert(False)
        # ic(NdotH.mean())
        lpdf = logD + torch.log(HdotV) - torch.log(NdotV)# - torch.log(roughness.clip(min=1e-5))
        # pdf = D * HdotV / NdotV / roughness.reshape(-1, 1)
        # pdf = NdotH / 4 / HdotV
        # pdf = D# / NdotH
        indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(-1, num_samples)[ray_mask]
        mipval = -torch.log(indiv_num_samples.clip(min=1)) - lpdf
        return logD
        return mipval

In [56]:
device = torch.device('cpu')
eps=torch.finfo(torch.float32).eps
N = 50
ray_mask = torch.ones((1, N), device=device, dtype=bool)
roughness = torch.tensor(0.10)
normal = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
viewdir = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
# sampler = CosineLobeSampler(512)
sampler = GGXSampler(512)
L, basis, prob = sampler.sample(None, viewdir, normal, roughness, roughness, ray_mask)
H = normalize(L + normal)
indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).expand(ray_mask.shape)[ray_mask]
ic(prob, indiv_num_samples)
mipval = -prob.clip(min=eps).log() - (indiv_num_samples).clip(min=eps).log()
# mipval = sampler.calculate_mipval(H, viewdir, normal, ray_mask, roughness, torch.eye(3).reshape(1, 3, 3))
ic(mipval)

w = 1024
h = 512
dw = (1-L[:, 2]**2).clip(min=eps).sqrt() * 2 * math.pi / w
dh = torch.ones_like(dw) * math.pi / h
saSample = mipval.reshape(-1)

# saTexel is the ratio to the solid angle subtended by one pixel of the 0th mipmap level
# num_pixels = self.bg_mat.numel() // 3
# saTexel = distortion / num_pixels
miplevel_w = ((saSample - torch.log(dw.clip(min=eps))) / math.log(2))
miplevel_h = ((saSample - torch.log(dh.clip(min=eps))) / math.log(2))

# distortion_w = 1/(1-u[:, 2]**2).clip(min=eps)
dw2 = (1-L[:, 2]**2).clip(min=eps).sqrt()
dh2 = torch.ones_like(dw2)

saTexel_w = dw2 / h / w
saTexel_h = dh2 / h / w
# saTexel is the ratio to the solid angle subtended by one pixel of the 0th mipmap level
# num_pixels = self.bg_mat.numel() // 3
# saTexel = distortion / num_pixels
miplevel_w2 = ((saSample - torch.log(saTexel_w.clip(min=eps))) / math.log(2))/2
miplevel_h2 = ((saSample - torch.log(saTexel_h.clip(min=eps))) / math.log(2))/2

# px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=miplevel)
px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=miplevel_h).show()
px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=miplevel_h2).show()
px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=mipval).show()


ic| prob: tensor([ 3.4604, -0.7333, -1.1486,  1.8972,  2.1551,  3.2609,  1.1138,  2.6503,
                   2.4583,  0.7068,  3.3713,  2.3286,  1.5428, -3.6098,  0.1535,  2.9956,
                   3.0623,  0.4187,  3.4604,  1.4391,  2.2601,  3.3303,  0.9094,  2.5459,
                   2.7515,  1.3274,  3.1651,  1.9843,  1.7028, -2.2611, -0.2368,  2.9183,
                   2.9334, -0.5268, -2.4754,  1.8028,  2.0714,  3.1516,  1.1999,  2.7680,
                   2.6238,  0.8388,  3.2759,  2.2950,  1.4915,  3.4095,  0.3280,  3.1228,
                   3.0198, -0.1280])
    indiv_num_samples: tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
                               50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
                               50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])
ic| mipval: tensor([-5.1534, 12.0304, 12.0304, -4.5524, -4.6799, -5.0940, -4.0198, -4.8867,
                    -4.8115, -3.5651,