In [4]:
from pathlib import Path
import sys
import os
sys.path.append(str(Path(os.path.abspath('')).parent))

In [35]:
import torch
from models import safemath
import math
import plotly.express as px
from icecream import ic

In [108]:


def normalize(x):
    return x / (torch.linalg.norm(x, dim=-1, keepdim=True)+1e-8)

class GGXSampler(torch.nn.Module):
    def __init__(self, num_samples, min_roughness) -> None:
        super().__init__()
        self.sampler = torch.quasirandom.SobolEngine(dimension=2, scramble=True)
        self.num_samples = num_samples
        angs = self.sampler.draw(num_samples)
        self.register_buffer('angs', angs)
        self.min_roughness = min_roughness
        # plt.scatter(self.angs[:, 0], self.angs[:, 1])
        # plt.show()

    def draw(self, B, num_samples):
        # self.angs = self.sampler.draw(self.num_samples)
        angs = self.angs.reshape(1, self.num_samples, 2)[:, :num_samples, :].expand(B, num_samples, 2)
        # self.sampler = torch.quasirandom.SobolEngine(dimension=2, scramble=True)
        # add random offset
        offset = torch.rand(B, 1, 2, device=angs.device)*0.25
        angs = (angs + offset) % 1.0
        return angs

    def update(self, *args, **kwargs):
        pass

    def sample(self, num_samples, refdirs, viewdir, normal, r1, r2, ray_mask):
        # viewdir: (B, 3)
        # normal: (B, 3)
        # r1, r2: B roughness values for anisotropic roughness
        device = normal.device
        B = normal.shape[0]

        # establish basis for BRDF
        z_up = torch.tensor([0.0, 0.0, 1.0], device=device).reshape(1, 3).expand(B, 3)
        x_up = torch.tensor([1.0, 0.0, 0.0], device=device).reshape(1, 3).expand(B, 3)
        up = torch.where(normal[:, 2:3] < 0.9, z_up, x_up)
        tangent = normalize(torch.linalg.cross(up, normal))
        bitangent = normalize(torch.linalg.cross(normal, tangent))
        # B, 3, 3
        row_world_basis = torch.stack([tangent, bitangent, normal], dim=1).reshape(B, 3, 3)


        # GGXVNDF
        # V_l = torch.matmul(torch.inverse(row_world_basis.permute(0, 2, 1)), viewdir.unsqueeze(-1)).squeeze(-1)
        # ic((normal*viewdir).sum(dim=-1).min(), (normal*viewdir).sum(dim=-1).max())
        # ic(1, V_l.min(dim=0), V_l.max(dim=0))
        V_l = torch.matmul(row_world_basis, viewdir.unsqueeze(-1)).squeeze(-1)
        # ic(2, V_l.min(dim=0), V_l.max(dim=0))
        r1_c = r1.clip(min=self.min_roughness).squeeze(-1)
        r2_c = r2.clip(min=self.min_roughness).squeeze(-1)
        V_stretch = normalize(torch.stack([r1_c*V_l[..., 0], r2_c*V_l[..., 1], V_l[..., 2]], dim=-1)).unsqueeze(1)
        T1 = torch.where(V_stretch[..., 2:3] < 0.999, normalize(torch.linalg.cross(V_stretch, z_up.unsqueeze(1), dim=-1)), x_up.unsqueeze(1))
        T2 = normalize(torch.linalg.cross(T1, V_stretch, dim=-1))
        z = V_stretch[..., 2].reshape(-1, 1)
        a = (1 / (1+z.detach()).clip(min=1e-5)).clip(max=1e4)
        angs = self.draw(B, num_samples).to(device)

        # here is where things get really large
        u1 = angs[..., 0]
        u2 = angs[..., 1]

        # stretch and mask stuff to reduce memory
        a_mask = a.expand(u1.shape)[ray_mask]

        r_mask_u1 = r1_c.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r_mask1 = r_mask_u1.clip(min=self.min_roughness)
        r_mask_u2 = r2_c.reshape(-1, 1).expand(u1.shape)[ray_mask]
        r_mask2 = r_mask_u2.clip(min=self.min_roughness)

        z_mask = z.expand(u1.shape)[ray_mask]
        u1_mask = u1[ray_mask]
        u2_mask = u2[ray_mask]
        T1_mask = T1.expand(-1, num_samples, 3)[ray_mask]
        T2_mask = T2.expand(-1, num_samples, 3)[ray_mask]
        V_stretch_mask = V_stretch.expand(-1, num_samples, 3)[ray_mask]
        row_world_basis_mask = row_world_basis.permute(0, 2, 1).reshape(B, 1, 3, 3).expand(B, num_samples, 3, 3)[ray_mask]

        r = torch.sqrt(u1_mask)
        phi = torch.where(u2_mask < a_mask, u2_mask/a_mask*math.pi, (u2_mask-a_mask)/(1-a_mask)*math.pi + math.pi)
        P1 = (r*safemath.safe_cos(phi)).unsqueeze(-1)
        P2 = (r*safemath.safe_sin(phi)*torch.where(u2_mask < a_mask, torch.tensor(1.0, device=device), z_mask)).unsqueeze(-1)
        # ic((1-a).min(), a.min(), a.max(), phi.min(), phi.max(), (1-a).max())
        N_stretch = P1*T1_mask + P2*T2_mask + (1 - P1*P1 - P2*P2).clip(min=0).sqrt() * V_stretch_mask
        H_l = normalize(torch.stack([r_mask1*N_stretch[..., 0], r_mask2*N_stretch[..., 1], N_stretch[..., 2].clip(min=0)], dim=-1))
        first = torch.zeros_like(ray_mask)
        first[:, 0] = True
        H_l[first[ray_mask], 0] = 0
        H_l[first[ray_mask], 1] = 0
        H_l[first[ray_mask], 2] = 1
        H = torch.matmul(row_world_basis_mask, H_l.unsqueeze(-1)).squeeze(-1)
        # H = torch.einsum('bni,bij->bnj', H_l, row_world_basis)

        V = viewdir.unsqueeze(1).expand(-1, num_samples, 3)[ray_mask]
        # N = normal.reshape(-1, 1, 3).expand(-1, num_samples, 3)[ray_mask]
        L = (2.0 * (V * H).sum(dim=-1, keepdim=True) * H - V)

        # calculate mipval, which will be used to calculate the mip level
        # half is considered to be the microfacet normal
        # viewdir = incident direction

        # H = normalize(L + viewdir.reshape(-1, 1, 3))


        return L, row_world_basis_mask

def ggx_dist(NdotH, roughness):
    # takes the cos of the zenith angle between the micro surface and the macro surface
    # and returns the probability of that micro surface existing
    a2 = roughness**2
    # return a2 / np.pi / ((NdotH**2*(a2-1)+1)**2).clip(min=1e-8)
    return ((a2 / (NdotH.clip(min=0, max=1)**2*(a2-1)+1))**2).clip(min=0, max=1)

def calculate_mipval(H, V, N, ray_mask, roughness):
    num_samples = ray_mask.shape[1]
    NdotH = ((H * N).sum(dim=-1)).clip(min=1e-8, max=1)
    HdotV = (H * V).sum(dim=-1).abs()
    NdotV = (N * V).sum(dim=-1).abs().clip(min=1e-8, max=1)
    # D = ggx_dist(NdotH, roughness.clip(min=1e-3))
    # return ((a2 / (NdotH.clip(min=0, max=1)**2*(a2-1)+1))**2).clip(min=0, max=1)
    logD = 2*torch.log(roughness) - 2*torch.log(NdotH**2*(roughness**2-1)+1)
    # ic(NdotH.shape, NdotH, D, D.mean())
    # px.scatter(x=NdotH[0].detach().cpu().flatten(), y=D[0].detach().cpu().flatten()).show()
    # assert(False)
    # ic(NdotH.mean())
    lpdf = logD# + torch.log(HdotV.clip(min=1e-5)) - torch.log(NdotV)# - torch.log(roughness.clip(min=1e-5))
    # lpdf = torch.log(D.clip(min=1e-5)) # - torch.log(roughness.clip(min=1e-5))
    # pdf = D * HdotV / NdotV / roughness.reshape(-1, 1)
    # pdf = NdotH / 4 / HdotV
    # pdf = D# / NdotH
    indiv_num_samples = ray_mask.sum(dim=1, keepdim=True).clip(min=1).expand(-1, num_samples)[ray_mask]
    ic(logD, lpdf, -torch.log(indiv_num_samples))
    mipval = -torch.log(indiv_num_samples) - lpdf
    # mipval =  - lpdf
    return mipval

In [109]:
device = torch.device('cpu')
N = 50
ray_mask = torch.ones((1, 50), device=device, dtype=bool)
roughness = torch.tensor(0.01)
normal = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
viewdir = torch.tensor([0, 0, 1.0], device=device).reshape(1, 3)
sampler = GGXSampler(N, 0)
L, basis = sampler.sample(N, None, viewdir, normal, roughness, roughness, ray_mask)
H = normalize(L + normal)
mipval = calculate_mipval(H, viewdir, normal, ray_mask, roughness)

h = w = 512
# h = w = 2048
distortion = 1
saTexel = distortion / h / w
num_pixels = h * w * 6
miplevel = ((mipval - math.log(saTexel)) / math.log(2))/2
miplevel = miplevel.clip(0)
res = h / 2**miplevel
ic(mipval, math.log(saTexel*N))
px.scatter_3d(x=L[:, 0], y=L[:, 1], z=L[:, 2], color=2**miplevel)


ic| logD: tensor([9.2100, 2.3606, 5.9802, 8.0938, 7.5470, 6.8858, 8.8993, 8.5794, 8.7824,
                  9.0741, 6.3062, 7.1543, 7.9214, 5.4476, 4.1286, 8.5043, 8.4465, 3.5527,
                  4.7885, 7.7523, 7.3798, 6.6418, 9.1169, 8.8311, 8.7143, 9.0148, 6.9587,
                  7.5985, 8.0533, 5.8606, 9.1722, 8.2018, 8.2721, 9.2076, 5.6174, 7.9902,
                  7.6923, 7.0635, 8.9678, 8.6683, 8.8629, 9.1374, 6.5401, 7.3308, 7.8071,
                  4.9593, 3.0266, 8.4174, 8.5432, 4.3622])
    lpdf: tensor([9.2100, 2.3606, 5.9802, 8.0938, 7.5470, 6.8858, 8.8993, 8.5794, 8.7824,
                  9.0741, 6.3062, 7.1543, 7.9214, 5.4476, 4.1286, 8.5043, 8.4465, 3.5527,
                  4.7885, 7.7523, 7.3798, 6.6418, 9.1169, 8.8311, 8.7143, 9.0148, 6.9587,
                  7.5985, 8.0533, 5.8606, 9.1722, 8.2018, 8.2721, 9.2076, 5.6174, 7.9902,
                  7.6923, 7.0635, 8.9678, 8.6683, 8.8629, 9.1374, 6.5401, 7.3308, 7.8071,
                  4.9593, 3.0266, 8.4174,