In [None]:
from fit_gtTF import *

result = fit_all_merl_materials("../merl100/brdfs/")
with open("../results/materials.json", "w") as f:
    f.write(result)

In [None]:
import cv2
import matplotlib.pyplot as plt
import os
import merl
import numpy as np
from bsdf import *
from glTF import *
from fit_gtTF import *


def linear_to_srgb(c):
    return np.where(
        c <= 0.0031308,
        12.92 * c,
        (1.0 + 0.055) * np.power(c, 1.0 / 2.4) - 0.055
    )


dir = "../merl100/brdfs/"
materials = get_merl_material_list(dir)
param_dicts = read_glTF_materials("../results/materials.json")
model = glTF_brdf(True)
brdf_np = model.get_np()


theta_h, _, theta_d, _ = merl.generate_dense_half_diffs(0)
theta_h, theta_d = np.meshgrid(theta_h, theta_d, indexing='ij')
theta_o, phi_o, theta_i, phi_i = merl.half_diff_to_std_coords(
    theta_h, 0, theta_d, np.pi / 2
)
n = np.array([0, 0, 1])
v = np.stack((np.sin(theta_o) * np.cos(phi_o), np.sin(theta_o)
             * np.sin(phi_o), np.cos(theta_o)), axis=-1)
l = np.stack((np.sin(theta_i) * np.cos(phi_i), np.sin(theta_i)
             * np.sin(phi_i), np.cos(theta_i)), axis=-1)


width, height = 10, 10
slice_width, slice_height = 90, 90

img0 = np.ones((900, 900, 3))
img1 = np.ones((900, 900, 3))
img2 = np.ones((900, 900, 3))

for i, material in enumerate(materials):
    merl_data = merl.read_merl_brdf(os.path.join(dir, f"{material}.binary"))
    merl_data = np.moveaxis(merl_data[:, :, :, 90], 0, -1)
    merl_data = merl_data.swapaxes(0, 1)

    params = param_dicts[material]
    base_color, alpha, metallic, ior = params
    model_output = brdf_np(v, n, l, base_color, alpha, metallic, ior)
    model_output = np.swapaxes(model_output, 0, -1).reshape(merl_data.shape)

    x, y = i % width, i // width
    xa, xb = slice_width*x, slice_width * (x+1)
    ya, yb = slice_height*y, slice_height*(y+1)
    img0[ya:yb, xa:xb] = merl_data
    img1[ya:yb, xa:xb] = model_output
    # / (merl_data**2 + 0.01)
    img2[ya:yb, xa:xb] = (model_output - merl_data)**2

print(f"rMSE: {np.average(img2)}")

cv2.imwrite("../results/merl_slices.png", (255 *
            np.clip(linear_to_srgb(img0[::-1, ::, ::-1]), 0, 1)).astype(np.uint8))
cv2.imwrite("../results/glTF_slices.png", (255 *
            np.clip(linear_to_srgb(img1[::-1, ::, ::-1]), 0, 1)).astype(np.uint8))
cv2.imwrite("../results/error_slices.png", (255 *
            np.clip(linear_to_srgb(img2[::-1, ::, ::-1]), 0, 1)).astype(np.uint8))

fig, axs = plt.subplots(1, 3)
for ax in axs:
    ax.set_xticks([])
    ax.set_xticks([], minor=True)
    ax.set_yticks([])
    ax.set_yticks([], minor=True)
axs[0].imshow(np.clip(linear_to_srgb(img0), 0, 1), origin='lower')
axs[1].imshow(np.clip(linear_to_srgb(img1), 0, 1), origin='lower')
axs[2].imshow(np.clip(linear_to_srgb(np.sqrt(img2)), 0, 1), origin='lower')
plt.show()