In [None]:
import sys
sys.argv = ['eval.py', '--checkpoint_name', 'fleece', '--yarn_name', 'fleece']

In [1]:
from config import device, variant

import mitsuba as mi
import drjit as dr
mi.set_variant(variant)

import torch

import numpy as np
from mitsuba import ScalarTransform4f as sT

from utils.geometry import create_single_yarn

from config.parameters import get_fiber_parameters
from bsdf.neuralyarn import NeuralYarn
from network.model import Model_M, Model_T
from network.wrapper import MiModelWrapper

import torch

# import matplotlib.pyplot as plt
import numpy as np

import argparse
import os

parser = argparse.ArgumentParser(description="Fitting the RDM")

parser.add_argument("--checkpoint_name", help="Checkpoint name to store outputs")
parser.add_argument("--yarn_name", help="Name of the yarn defined in config/parameters.py")
parser.add_argument("--batch_size", type=int, help='Samples per pixel per batch rendering', default=32)
parser.add_argument("--num_batches", type=int, help='Number of batches for rendering', default=8)

args = parser.parse_args()

output_dir = os.path.join('checkpoints/', args.checkpoint_name)
parameters = get_fiber_parameters(args.yarn_name)

model_m = Model_M().to(device)
model_t = Model_T().to(device)

model_m.load_state_dict(torch.load(os.path.join(output_dir, 'model_m.pth'), weights_only=True))
model_t.load_state_dict(torch.load(os.path.join(output_dir, 'model_t.pth'), weights_only=True))

npz = np.load(os.path.join(output_dir, 'pdf.npz'))
kappa_R, beta_M, gamma_M, kappa_M = npz['kappa_R'], npz['beta_M'], npz['gamma_M'], npz['kappa_M']

mlp_m = MiModelWrapper(model_m, activation=dr.exp)
mlp_t = MiModelWrapper(model_t, activation=lambda x: 1.0 / (1.0 + dr.exp(-x)))

neuralyarn = NeuralYarn.create(parameters, mlp_m, mlp_t, kappa_R, beta_M, gamma_M, kappa_M)

# Disable megekernel mode
def mega_kernel(state):
    dr.set_flag(dr.JitFlag.LoopRecord, state)
    dr.set_flag(dr.JitFlag.VCallRecord, state)
    dr.set_flag(dr.JitFlag.VCallOptimize, state)
mega_kernel(False)

scene_dict = {
    'type': 'scene',
    'integrator': {
        'type': 'path',
        'max_depth': -1,
        'rr_depth': 9999999,
        'hide_emitters': True,
    },
    'light1': {
        'type': 'sphere',
        'center': [0, 10, 10],
        'radius': 2.5,
        'emitter': {
            'type': 'area',
            'radiance': {
                'type': 'rgb',
                'value': 30,
            }
        }
    },
    'light2': {
        'type': 'sphere',
        'center': [0, -10, -10],
        'radius': 2.5,
        'emitter': {
            'type': 'area',
            'radiance': {
                'type': 'rgb',
                'value': 30,
            }
        }
    },
    'sensor': {
        'type': 'perspective',
        'to_world': sT.look_at(
            origin=[0, 0, 10],
            target=[0, 0, 0],
            up=[0, 1, 0]
        ),
        'film': {
            'type': 'hdrfilm',
            'width': 256,
            'height': 256,
        },
    },
    'yarn': {
        'type':'linearcurve',
        'filename': './curves/ply.txt',
        'bsdf': neuralyarn
    }
}

os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)

scene = mi.load_dict(scene_dict)

batch_sample = args.batch_size
samples = args.num_batches

pred_image = None
with torch.no_grad():
    with dr.suspend_grad():
        for i in range(samples):
            pred_image_ = mi.render(scene, spp=batch_sample, seed=i)
            try:
                pred_image += pred_image_
            except:
                print('Rendering NN')
                pred_image = pred_image_
            print(f'Batch ({i+1}/{samples})')
        pred_image /= samples 

        mi.util.write_bitmap(os.path.join(output_dir, 'images/pred.png'), pred_image)


scene_dict_ref = scene_dict.copy()
scene_dict_ref['yarn'] = create_single_yarn(parameters)
scene_ref = mi.load_dict(scene_dict_ref)

true_image = None
with torch.no_grad():
    with dr.suspend_grad():
        for i in range(samples):
            true_image_ = mi.render(scene_ref, spp=batch_sample, seed=i)
            try:
                true_image += true_image_
            except:
                print('Rendering Ref')
                true_image = true_image_
            print(f'Batch ({i+1}/{samples})')
        true_image /= samples 

        mi.util.write_bitmap(os.path.join(output_dir, 'images/true.png'), true_image)


Rendering NN
Batch (0/8)
Batch (1/8)
Batch (2/8)
Batch (3/8)
Batch (4/8)
Batch (5/8)
Batch (6/8)
Batch (7/8)
Rendering Ref
Batch (0/8)
Batch (1/8)
Batch (2/8)
Batch (3/8)
Batch (4/8)
Batch (5/8)
Batch (6/8)
Batch (7/8)
