In [None]:
import torch
import falcor
import time
import numpy as np
import pyexr as exr
import sys
import os
import dataclasses
import datetime
import glob
import argparse

sys.path.append(os.path.join(os.path.abspath(''), ".."))
import common
import material_utils
from loss import compute_render_loss_L1, compute_render_loss_L2


import common
import os
from falcor import Camera, float3, uint2
import copy
import numpy as np
from typing import List, Tuple

import torch
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

import json
try:
	import tinycudann as tcnn
except ImportError:
	print("This sample requires the tiny-cuda-nn extension for PyTorch.")
	print("You can install it by running:")
	print("============================================================")
	print("tiny-cuda-nn$ cd bindings/torch")
	print("tiny-cuda-nn/bindings/torch$ python setup.py install")
	print("============================================================")
	sys.exit()

In [None]:
import yaml

class ExperimentParams:
    def __init__(self):
        self.MODELS_PATH = "models/"

params = ExperimentParams()
print(params.MODELS_PATH)  

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches

def float3tonumpy(v):
    return np.array([v.x, v.y, v.z])

def numpytofloat3(n):
    return float3(*np.copy(n))

def numpytouint2(n):
    return uint2(*np.copy(n))


class ExperimentParams:
    #render_width = 256 # random
    #render_height = 256
    render_width = 256 # random
    render_height = 256

    enable_MIS = False 
    enable_sky_learning = True # for example we could disable it and compare NIRC with NRC

    relative_error = False

    USE_GBUFFER = True
    relative_error_eps = 0.001
    

    RECOMPUTE_LOSS = {"NIRC": False, "NIRC_EM": False, "SH": False, "VMF": False, "NCV": False}

    GRADIENT_CLIP = 0.001
    SKIP_TENSOR_CHECK = True # disable if you have problems with NaNs, Inf or just 0 loss
    ADAM_BETAS = [0.98, 0.9999]
    ADAM_LR = 0.01
    num_training_frames = 150 # we gotta use the same number of training frames for all models to make conduct fair experiments!. USE a BIIG number for the final experiments (4k? 5k?)

    # MODELS_PATH = "E:/Models/" # we can just put it in the same folder as the script?
    MODELS_PATH = params.MODELS_PATH

    is_training = False



gparams = ExperimentParams()



class ValidationPoint:
    xn: float
    yn: float

    x: int
    y: int

    id: int
    
    position: np.ndarray
    target: np.ndarray
    normal: np.ndarray

    is_init: bool

    def __init__(self, xn: float, yn: float):
        self.xn = xn
        self.yn = yn

        self.x = int(gparams.render_width*xn)
        self.y = int(gparams.render_height*yn)

        self.id = self.y*gparams.render_width+self.x

        self.is_init = False


    def init_surface_data(self, p: np.ndarray, t: np.ndarray, n: np.ndarray):
        self.is_init = True

        self.position = p
        self.target = t
        self.n = n




class SceneConfig:
    model_path: str
    camera_position: np.ndarray
    camera_target: np.ndarray
    tonemapper_exposure: float
    focal_length: float
    
    validation_points: List[ValidationPoint]
    emissive_factor: float = 1.0
    

    def __init__(self, model_path: str, camera_position: Tuple[float, float, float], camera_target: Tuple[float, float, float], tonemapper_exposure: float = 0, focal_length: float = None, selected_points: List[Tuple[float, float]] = [], 
                 up: Tuple[float, float, float] = [0.0, 1.0, 0.0], lr_factor: float = 1.0, epochs: int = 4000, roughness: float = 1.0, var_est_steps: int = 1000):
        self.model_path = gparams.MODELS_PATH + model_path
        self.model_name = os.path.dirname(model_path)
        self.camera_position = np.array(camera_position)
        self.camera_target = np.array(camera_target)
        self.epochs = epochs
        self.var_est_steps = var_est_steps
        self.tonemapper_exposure = tonemapper_exposure
        self.roughness = roughness
        self.lr_factor = lr_factor
        self.checkpoint_dir = os.path.join('checkpoints', self.model_name)
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        
        self.focal_length = focal_length
        self.validation_points = []
        self.up = up

        for p in selected_points:
            self.validation_points.append(ValidationPoint(xn=p[0], yn=p[1]))

    def camera_position_f3(self) -> float3:
        return numpytofloat3(self.camera_position)

    def camera_target_f3(self) -> float3:
        return numpytofloat3(self.camera_target)

    def __repr__(self) -> str:
        return f"SceneConfig(model_path={self.model_path}, camera_position={self.camera_position.tolist()}, camera_target={self.camera_target.tolist()}, tonemapper_exposure={self.tonemapper_exposure}, focal_length={self.focal_length})"

    def add_validation_point(self, p: Tuple[float, float]):
        self.validation_points.append(ValidationPoint(xn=p[0], yn=p[1]))

    def setup_validation_points(self, mlData):
        positions = mlData["worldpos"] 
        normals = mlData["normal"]

        for i in range(len(self.validation_points)):
            p = self.validation_points[i]
            pos = positions[p.id].cpu().numpy() # we should get it from the gpu torch in some way 
            normal = normals[p.id].cpu().numpy()

            p.init_surface_data(p = pos+normal*0.00001, t = pos+normal+normal*0.00001, n = normal)
            
    def prepare_scene(self, scene, debugPointID=None):
        scene.roughnessMultiplier = self.roughness

        if debugPointID == None:
            scene.camera.position = numpytofloat3(self.camera_position)
            scene.camera.target = numpytofloat3(self.camera_target)
            if self.focal_length != None:
                scene.camera.focalLength = self.focal_length
            scene.camera.useHemisphericalCamera = False
            scene.camera.up = numpytofloat3(self.up)
        else:
            vp = self.validation_points[debugPointID]
            scene.camera.useHemisphericalCamera = True
            scene.camera.position = numpytofloat3(vp.position)
            scene.camera.target = numpytofloat3(vp.target)
            scene.camera.up = numpytofloat3(self.up)
            assert(vp.is_init)


scenes = {
    "CornellBox": SceneConfig("CornellBox/cornell_box.pyscene", camera_position=[0, 0.28, 0.6], camera_target=[0, 0.28, 0], selected_points=[[0.35, 0.5], [0.65, 0.4]], epochs=2000),
    #"CornellBoxEnv": SceneConfig("CornellBoxEnv/cornell_box_env.pyscene", camera_position=[0, 0.28, 0.6], camera_target=[0, 0.28, 0], selected_points=[[0.35, 0.5], [25.0/256.0, 150.0/256.0]]),
    #"EnvDebug": SceneConfig("EnvDebug/env_debug.pyscene", camera_position=[0, 0.28, 0.6], camera_target=[0, 0.28, 0], selected_points=[[0.35, 0.5], [0.65, 0.4]]),
    #"CornellBox": SceneConfig("CornellBox/cornell_box.pyscene", camera_position=[0, 0.28, 1.2], camera_target=[0, 0.28, 0], selected_points=[[0.35, 0.5], [0.65, 0.4]]),
    #"Bistro" : SceneConfig("Bistro/BistroExterior.pyscene", camera_position=[-29.195, 5.145, -8.768], camera_target=[-28.212, 5.137, -8.586], selected_points=[[400.0/1280, 400.0/720], [410/1200.0, 200.0/720], [800/1280.0, 600/720.0]]),
    #"CountryKitchen": SceneConfig("CountryKitchen/Country-Kitchen.gltf", camera_position=[1.456, 1.509, 1.602], camera_target=[0.678, 1.471, 0.974], selected_points=[[0.35, 0.5], [0.65, 0.4]], lr_factor=0.5, epochs=2000, var_est_steps=3000),
    #"SanMiguel": SceneConfig("SanMiguel/san-miguel.pyscene", camera_position=[22.2190, 3.3300, 6.7120], camera_target=[21.5799, 3.3054, 5.9432], selected_points=[[0.35, 0.5], [0.65, 0.4]], focal_length=17.750),
    #"TheWhiteRoomCycles": SceneConfig("TheWhiteRoomCycles/the-white-room_0001.gltf", camera_position=[2.781, 1.247, 5.251], camera_target=[2.151, 1.195, 4.477], selected_points=[[0.35, 0.5], [0.65, 0.4]], lr_factor=0.25),
    #"Sponza": SceneConfig("Sponza/Sponza.pyscene", camera_position=[8.084, 1.708, 0.827], camera_target=[7.091, 1.684, 0.715], selected_points=[[0.5, 0.1], [0.65, 0.2]], focal_length=16.750),
    #"SponzaSpecular": SceneConfig("SponzaSpecular/SponzaSpecular.pyscene", camera_position=[-9.0311, 1.1590, -1.2149], camera_target=[-8.0785, 1.0369, -0.9365], selected_points=[[0.5, 0.8], [0.15, 0.5]], epochs=2000, roughness=0.25) 
}


## Scene Select

In [None]:
scene_cfg = scenes["CornellBox"]

In [None]:
CUR_DIR = os.path.abspath('')
sys.path.append(os.path.join(CUR_DIR, ".."))


output_dir = CUR_DIR + "/results/"


device_id = 0
testbed = common.create_testbed([gparams.render_width, gparams.render_height])
device = testbed.device


# Load the reference scene.
ref_scene = common.load_scene(
    testbed,
    scene_cfg.model_path,
    gparams.render_width / gparams.render_height,
)


# I don't know why but useAnalyticLights and useEnvLight have False value if we start the renderer from the python side. haven't managed to found the root of this bizarre behaviour. it must be a bug -> TO DO: MAKE SURE and REPORT ABOUT IT!!
ref_scene.renderSettings = falcor.SceneRenderSettings(useEnvLight=True, useAnalyticLights=True, useEmissiveLights=True, useGridVolumes=True)


# init structure buffers that we need for ML side
# color = albedo+specular reflectance
field_types = {"radiance": "float3", "dir": "float3", "thp": "float3", "worldpos": "float3", "normal": "float3", "color": "float3", "dradiance": "float3", "view": "float3", "roughness": "float", "pdf": "float"}
ml_data = device.create_structured_buffer(
    struct_size = 12*8+4*2,
    element_count=gparams.render_width*gparams.render_height,
    bind_flags=falcor.ResourceBindFlags.ShaderResource  
    | falcor.ResourceBindFlags.UnorderedAccess
    | falcor.ResourceBindFlags.Shared
)


rays_fields = {"worldpos": "float3", "dir": "float3"}
ml_rays_data = device.create_structured_buffer(
    struct_size = 12*2,
    element_count=gparams.render_width*gparams.render_height,
    bind_flags=falcor.ResourceBindFlags.ShaderResource  
    | falcor.ResourceBindFlags.UnorderedAccess
    | falcor.ResourceBindFlags.Shared
)


device.render_context.wait_for_falcor()

In [None]:
render_graph = testbed.create_render_graph("StandardPathTracer")

# Create the PathTracer pass.
path_tracer_pass = render_graph.create_pass(
    "PathTracer",
    "PathTracer",
    {
        "samplesPerPixel": 1, "useSER": False, "useMIS": gparams.enable_MIS, "disableCaustics": True
    }
)


primary_render_pass_name = "GBufferRT" if gparams.USE_GBUFFER else "VBufferRT"
if  gparams.USE_GBUFFER:
    primary_render_pass = render_graph.create_pass(
        primary_render_pass_name,
        primary_render_pass_name,
        {
            "samplePattern": "Center",
            "sampleCount": 1,
            "useAlphaTest": True
        }
    )
    render_graph.mark_output(primary_render_pass_name+".vbuffercache")
    render_graph.mark_output(primary_render_pass_name+".brdf")
else:
    # Create the VBufferRT pass.
    primary_render_pass = render_graph.create_pass(
        primary_render_pass_name,
        primary_render_pass_name,
        {
            "samplePattern": "Center",
            "sampleCount": 1,
            "useAlphaTest": True
        }
    )

AccumulatePass = render_graph.createPass("AccumulatePass", "AccumulatePass", {'enabled': True, 'precisionMode': 'Single'})
ToneMapper = render_graph.createPass("ToneMapper", "ToneMapper", {'autoExposure': False, 'exposureCompensation': 0.0, 'outputFormat': 'RGBA32Float' }) 



# Add edges to connect the passes.

render_graph.add_edge(primary_render_pass_name+".vbuffer", "PathTracer.vbuffer")


render_graph.add_edge(primary_render_pass_name+".viewW", "PathTracer.viewW")
render_graph.add_edge(primary_render_pass_name+".mvec", "PathTracer.mvec")

render_graph.addEdge("PathTracer.color", "AccumulatePass.input")
render_graph.addEdge("AccumulatePass.output", "ToneMapper.src")


# Mark the output of the PathTracer pass.
render_graph.markOutput("ToneMapper.dst")
render_graph.mark_output("AccumulatePass.output")
render_graph.mark_output("PathTracer.color")

# Assign the configured render graph to the testbed.
testbed.render_graph = render_graph

path_tracer_pass.mlData = ml_data
primary_render_pass.mlRaysData = ml_rays_data


In [None]:
scene_cfg.prepare_scene(ref_scene)

In [None]:
import matplotlib.pyplot as plt 


def adjust_gamma(image, gamma=2.2):
	return (image )**(1 / gamma)


def frameRender(num_samples=1024, vis = True, tonemapped=False,  directEmissive=True, directSky=True):
    # may take some time to recompile the shaders becase of changed defines
    AccumulatePass.reset()
    AccumulatePass.enabled = True
    primary_render_pass.sampleCount = 16
    
    path_tracer_pass.useMIS = True
    path_tracer_pass.mlTraining = False
    path_tracer_pass.directEmissive = directEmissive
    path_tracer_pass.directSky = directSky

    for _ in range(num_samples):
        testbed.frame()

    if tonemapped:
        img = testbed.render_graph.get_output("ToneMapper.dst").to_numpy()[:, :, :3]
        img = adjust_gamma(img)
    else:
        img = testbed.render_graph.get_output("AccumulatePass.output").to_numpy()[:, :, :3]
    if not vis:
        return img


    plt.imshow(img)

    path_tracer_pass.directEmissive = True
    path_tracer_pass.directSky = True
    path_tracer_pass.useMIS = gparams.enable_MIS 
    primary_render_pass.sampleCount = 1
    AccumulatePass.enabled = False
    return img

In [None]:
def trainFrame():
    path_tracer_pass.useMIS = False
    path_tracer_pass.indirectSky = gparams.enable_sky_learning
    path_tracer_pass.mlTraining = True
    testbed.frame()
    path_tracer_pass.useMIS = True
    path_tracer_pass.indirectSky = True
    path_tracer_pass.mlTraining = False
        
def setupBRDFCache():
    path_tracer_pass.useMIS = False
    path_tracer_pass.indirectSky = gparams.enable_sky_learning
    path_tracer_pass.mlTraining = True
    primary_render_pass.cacheVisibility = True
    primary_render_pass.brdfRender = False
    testbed.frame()
    primary_render_pass.brdfRender = False
    primary_render_pass.cacheVisibility = False
    path_tracer_pass.useMIS = True
    path_tracer_pass.indirectSky = True
    path_tracer_pass.mlTraining = False


def getBRDF(pixel):
    path_tracer_pass.useMIS = False
    path_tracer_pass.indirectSky = gparams.enable_sky_learning
    path_tracer_pass.mlTraining = True
    primary_render_pass.brdfRender = True
    primary_render_pass.targetPixel = numpytouint2(pixel)
    testbed.frame()
    primary_render_pass.brdfRender = False
    path_tracer_pass.useMIS = True
    path_tracer_pass.indirectSky = True
    path_tracer_pass.mlTraining = False
    brdf = testbed.render_graph.get_output(primary_render_pass_name+".brdf").to_numpy()[:, :, :3]
    return brdf



In [None]:
#testbed.end_frame_forced()

In [None]:
# feel free to remove it
scene_cfg.prepare_scene(ref_scene)

im = frameRender(num_samples = 1, vis=True, tonemapped=True)
# this should produce a converged render from the correct camera viewpoint (+- as in the paper teaser)

In [None]:
def falcor_to_torch(buffer: falcor.Buffer, dtype=torch.float32):
    params = torch.tensor([0]*(buffer.element_count*12), dtype=dtype)
    buffer.copy_to_torch(params)
    device.render_context.wait_for_cuda()
    return params

In [None]:
def extract_field(data, start, struct_size, field_size):
    return torch.cat([data[i:i + field_size] for i in range(start, data.numel(), struct_size)])

def falcor_to_torch_split_interleaved(buffer, field_types, dtype=torch.float32):
    # Size mapping for different types (add more types if needed)
    size_mapping = {
        "float3": 3,  # 3 floats in a float3
        "float": 1
    }

    # Calculate the size of one complete set of fields
    struct_size = sum(size_mapping[field_type] for field_type in field_types.values())
    # Calculate the total size of the tensor
    total_size = struct_size * buffer.element_count    
    all_data = buffer.to_torch([total_size])
    device.render_context.wait_for_cuda()

    # Splitting the tensor into separate tensors for each field considering interleaved structure
    tensors = {}
    offset = 0
    for i, (field_name, field_type) in enumerate(field_types.items()):
        field_size  = size_mapping[field_type]

        # Reshaping data
        reshaped_data = all_data.view(-1, struct_size)
        # Extracting field without using a loop
        tensors[field_name] = reshaped_data[:, offset: offset + field_size]
        offset += field_size

    return tensors


In [None]:
mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)
mlDataRaysOutput = falcor_to_torch_split_interleaved(ml_rays_data, rays_fields)

In [None]:
mlDataOutput_ref = {k: v.clone() for k, v in mlDataOutput.items()}
mlDataRaysOutput_ref = {k: v.clone() for k, v in mlDataRaysOutput.items()}

In [None]:
# we gotta inform the validation points about the corresponding surface parameters (where we're gonna capute the hemispherical incident light visualization) 
#scene_cfg.add_validation_point([900/1280.0, 100/720.0])
scene_cfg.setup_validation_points(mlDataOutput_ref)
scene_cfg.prepare_scene(ref_scene, debugPointID=0)
im = frameRender(num_samples = 1, vis=True, tonemapped=True)

In [None]:
scene_cfg.prepare_scene(ref_scene)

im = frameRender(num_samples = 1, vis=True, tonemapped=True)

In [None]:
scene_cfg.prepare_scene(ref_scene)
_ = frameRender(num_samples=1, tonemapped=True)
setupBRDFCache()

In [None]:

def check_tensor(tensor, tensor_name="Tensor", force_analyze=False, min=None, max=None):
    if gparams.SKIP_TENSOR_CHECK and not force_analyze:
        return 1
    """
    Check if the given tensor contains any NaN or Inf values and print a specific message.

    Args:
    tensor (torch.Tensor): The tensor to check.
    tensor_name (str): The name of the tensor to display in messages.
    """
    has_nan = torch.isnan(tensor).any().item()  # Check for any NaNs and convert to Python bool
    has_inf = torch.isinf(tensor).any().item()  # Check for any Infs and convert to Python bool


    good = True
    # Print relevant messages based on the tensor's content
    if has_nan and has_inf:
        print(f"{tensor_name} contains both NaN and Inf values.")
        good = False
    elif has_nan:
        print(f"{tensor_name} contains NaN values.")
        good = False
    elif has_inf:
        print(f"{tensor_name} contains Inf values.")
        good = False


    
    if min is not None and torch.all(tensor >= min).item() is False:
        good = False
        print(f"{tensor_name} is smaller than {min}")
        
    if max is not None and torch.all(tensor <= max).item() is False:
        good = False
        print(f"{tensor_name} is bigger than {max}")


    if good and not force_analyze:
        return True

    # Checking for NaNs and Infs
    has_nan = torch.isnan(tensor).any().item()
    has_inf = torch.isinf(tensor).any().item()
    print("Contains NaN values:", has_nan)
    print("Contains Inf values:", has_inf)
    
    # Basic statistics
    tensor_min = torch.min(tensor).item()
    tensor_max = torch.max(tensor).item()
    tensor_mean = torch.mean(tensor.float()).item()  # Ensure tensor is float for mean calculation
    print("Minimum value:", tensor_min)
    print("Maximum value:", tensor_max)
    print("Average value:", tensor_mean)
    
    # Histogram of the tensor values
    tensor_np = tensor.detach().cpu().numpy()  # Convert tensor to NumPy array for histogram
    plt.figure(figsize=(10, 4))
    plt.hist(tensor_np.ravel(), bins=100, color='blue', alpha=0.7)  # Flatten the tensor and plot
    plt.title(f"Histogram of {tensor_name} with shape {tensor.shape}")
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()
    
    return good

In [None]:
default_mask = (mlDataOutput_ref["thp"] > 0.0).all(dim=1)

def get_mask_for_training(rad, default_mask_use=True):
    active_pixel_mask = (rad >= 0).all(dim=1)
    nan_mask = torch.isnan(rad).any(dim=1)  # Creates a mask of rows with any NaNs
    inf_mask = torch.isinf(rad).any(dim=1)  # Creates a mask of rows with any NaNs
    non_nan_mask = ~nan_mask
    non_inf_mask = ~inf_mask
    non_nan_mask = (non_nan_mask & active_pixel_mask) & non_inf_mask
    if default_mask_use:
        non_nan_mask = default_mask & non_nan_mask
    return non_nan_mask

In [None]:
device_t = torch.device("cuda")

In [None]:
def packModelInput(dposition, ddir, dcolor, droughness, dnormal):

    if ddir is not None:
        norm = (torch.square(torch.sum(ddir**2, dim=1))+1.0e-10)
        norm = torch.unsqueeze(norm, dim=1)
        ddir = ddir/norm
        dirs = ddir*0.5+0.5
        
    normals = cart_to_sph(dnormal)

    # torch.Size([3])
    scene_min = torch.tensor(float3tonumpy(ref_scene.bounds.min_point), device="cuda:0")
    scene_max = torch.tensor(float3tonumpy(ref_scene.bounds.max_point), device="cuda:0")

    #torch.Size([147600, 3])
    positions = (dposition-scene_min)/(scene_max-scene_min)

    if ddir is not None:
        assert(check_tensor(dirs, "dirs"))

    assert(check_tensor(normals, "normals"))
    assert(check_tensor(positions, "positions"))
    assert(check_tensor(dcolor, "dcolor"))
    assert(check_tensor(droughness, "droughness"))

    if ddir is not None:
        return torch.cat((normals, positions, droughness, dcolor, dirs), dim=1)
    else:
        return torch.cat((normals, positions, droughness, dcolor), dim=1) # positions two times, idk why it doesnt work with out it!!!!!! TO DO
        


def aces_tonemap(image, exposure=1.0):
    """
    ACES tonemapping function.
    :param image: Input HDR image.
    :param exposure: Exposure factor.
    :return: Tonemapped image.
    """
    a = 2.51
    b = 0.03
    c = 2.43
    d = 0.59
    e = 0.14

    mapped = exposure * image
    return ((mapped * (a * mapped + b)) / (mapped * (c * mapped + d) + e))


def apply_gamma_correction(image):
    """
    Applies gamma correction to the image.
    :param image: Input image to correct.
    :return: Gamma corrected image.
    """
    gamma = 2.2
    return np.power(image, 1.0 / gamma)


def cart_to_sph(tensor):
    theta = torch.acos(tensor[:,2])  # arccos(z/r)
    phi = torch.atan2(tensor[:,1], tensor[:,0])  # arctan(y/x)

    theta = theta / torch.tensor(np.pi)
    phi = (phi/torch.tensor(np.pi))*0.5 + 0.5
    return torch.stack([theta, phi], dim=1)


def vectorized_uniform_stratified_sampling(width, height, tile_width, tile_height):
    # Calculate the number of tiles in each dimension
    num_tiles_x = width // tile_width
    num_tiles_y = height // tile_height

    # Generate grid indices for tiles
    tiles_x = torch.arange(num_tiles_x, device=device_t).repeat(num_tiles_y)
    tiles_y = torch.arange(num_tiles_y, device=device_t).repeat_interleave(num_tiles_x)

    # Generate random points within each tile
    points_within_tile_x = torch.randint(0, tile_width, (num_tiles_x * num_tiles_y,), device=device_t)
    points_within_tile_y = torch.randint(0, tile_height, (num_tiles_x * num_tiles_y,), device=device_t)

    # Calculate global indices
    global_indices = ((tiles_y * tile_height + points_within_tile_y) * width + (tiles_x * tile_width + points_within_tile_x))

    return global_indices

In [None]:
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
import math

import os
import torch

import torch.nn as nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
def l2_regularization(model, lambda_l2):
    l2_loss = torch.tensor(0.0).to(device_t)
    for param in model.parameters():
        l2_loss += torch.norm(param)**2
    return lambda_l2 * l2_loss

# assume one pdf for both prediction and ref, that these samples have been already divided by pdf!! pdf must be 1 for delta surfaces!!
def relativeL2(prediction, ref, pdf=None, eps=None, div=None):
    eps = gparams.relative_error_eps if eps is None else eps
    div = prediction.detach() if div is None else div
    denominator = torch.mean(div, dim=1).view(-1,1)**2 +eps
    

    rL2 = (pdf*((prediction - ref))**2) / denominator
    rL2 = rL2[~(torch.isinf(rL2) | torch.isnan(rL2))]
    rL2 = rL2.mean()
    return rL2 


def relativeL2Special(prediction, ref, pdf=None,  div=None):
    rL2 = (pdf*((prediction - ref))**2) /  div
    rL2 = rL2[~(torch.isinf(rL2) | torch.isnan(rL2))]
    rL2 = rL2.mean()
    return rL2 

def relativeL2PDFCounted(prediction, ref, pdf=None, eps=None, div=None):
    eps = gparams.relative_error_eps if eps is None else eps
    div = prediction.detach() if div is None else div
    denominator = torch.mean(prediction.detach(), dim=1).view(-1,1)**2 +eps
    

    rL2 = (((prediction - ref))**2) / (denominator*pdf)
    rL2 = rL2[~(torch.isinf(rL2) | torch.isnan(rL2))]
    rL2 = rL2.mean()
    return rL2 

# assume one pdf for both prediction and ref, that these samples have been already divided by pdf!! pdf must be 1 for delta surfaces!!
def relativeL22(prediction, ref, pdf=None, eps=None, div=None):
    eps = gparams.relative_error_eps if eps is None else eps
    div = prediction.detach() if div is None else div
    denominator = torch.mean(prediction.detach(), dim=1).view(-1,1)**2 +eps
    

    rL2 = (((prediction - ref))**2) / (denominator*pdf)
    rL2 = rL2[~(torch.isinf(rL2) | torch.isnan(rL2))]
    rL2 = rL2.mean()
    return rL2 


# we transform delta pdfs from 0 to 1.0
def safe_pdf(pdf):
    pdf = torch.where(pdf <= 0.000000000, torch.tensor(1.0), pdf)
    return pdf

# assume pne pdf for both prediction and ref, but only ref has been divided by pdf 
def relativeL2_PDF_NotApplied(prediction, ref, pdf=None, eps=None, div=None):
    eps = gparams.relative_error_eps if eps is None else eps
    div = prediction.detach() if div is None else div
    denominator = torch.mean(prediction.detach(), dim=1).view(-1,1)**2 +eps
    sqrt_pdf = torch.sqrt(pdf)

    rL2 = (((prediction - ref*pdf))**2) / (denominator*pdf)
    rL2 = rL2[~(torch.isinf(rL2) | torch.isnan(rL2))]
    rL2 = rL2.mean()
    return rL2 


def L2(prediction, ref, pdf=1):
    rL2 = pdf*(prediction - ref)**2
    rL2 = rL2.mean()
    return rL2 

def aces_tonemap(image, exposure=1.0):
    """
    ACES tonemapping function.
    :param image: Input HDR image.
    :param exposure: Exposure factor.
    :return: Tonemapped image.
    """
    a = 2.51
    b = 0.03
    c = 2.43
    d = 0.59
    e = 0.14

    mapped = exposure * image
    return ((mapped * (a * mapped + b)) / (mapped * (c * mapped + d) + e))

def apply_gamma_correction(image):
    """
    Applies gamma correction to the image.
    :param image: Input image to correct.
    :return: Gamma corrected image.
    """
    gamma = 2.2
    image = np.clip(image, 0.0, 1.0)
    image = np.power(image, 1.0 / gamma)
    image = (image*255).astype(np.uint8)
    image = np.clip(image, 0, 255)
    return image

# model(X).cpu()
def visualize_model_output(Y, width, height, epoch, tonemap_gamma=True):
    with torch.no_grad():
        output = Y.numpy().astype(np.float32)
        output_image = output.reshape(height, width, 3)

        if tonemap_gamma:
            output_image = apply_gamma_correction(aces_tonemap(output_image))
        
        plt.imshow(output_image)
        plt.title(f'Model Output at Epoch {epoch + 1}')
        plt.show()

# Function to determine if visualization should occur at the current epoch
def should_visualize(epoch, conditions):
    for max_epoch, interval in conditions:
        if epoch < max_epoch:
            return (epoch + 1) % interval == 0
    return (epoch + 1) % 200 == 0

class LambdaLayer(torch.nn.Module):
    def __init__(self, func):
        super(LambdaLayer, self).__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)

# Define the MLP class
import torch.nn as nn
import torch.nn.init as init

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64, num_layers=6):
        super(MLP, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)
        # Initialize weights using Xavier initialization
        self.model.apply(self.init_weights)
    
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            init.xavier_uniform_(m.weight)
            if m.bias is not None:
                init.zeros_(m.bias)

    def forward(self, x):
        return self.model(x)


In [None]:
import torch.nn as nn

class PixelLatentLayer(nn.Module):
    def __init__(self, width, height, params_per_pixel, device, init_method='xavier'):
        super(PixelLatentLayer, self).__init__()
        self.width = width
        self.height = height
        self.params_per_pixel = params_per_pixel
        self.latents_per_pixel = nn.Parameter(torch.empty(width, height, params_per_pixel, device=device))

        if init_method == 'xavier':
            bound = math.sqrt(2 / (64 + params_per_pixel))
        elif init_method == 'he':
            bound = math.sqrt(2 / (64))
        elif init_method == 'uniform_large':
            bound = 1e-4
        else:
            bound = 1e-4  # default

        nn.init.uniform_(self.latents_per_pixel, -bound, bound)

    def forward(self, pixel_indices):
        latents = self.latents_per_pixel[pixel_indices[:, 0], pixel_indices[:, 1]]
        return latents 

In [None]:


# Initialize the ref_estimations tensor with zeros of required shape

# Checkpoint path
checkpoint_path = os.path.join(scene_cfg.checkpoint_dir, f'ref_estimations_{gparams.render_width}x{gparams.render_height}.pth')

# Try to load cached ref_estimations if it exists
found = False
if os.path.exists(checkpoint_path):
    found = True
    ref_estimations = torch.load(checkpoint_path, map_location=device_t)
    print("Loaded cached ref_estimations from checkpoint.")
else:
    print("No cached ref_estimations found. Initializing new tensor.")

#found = False
scene_cfg.prepare_scene(ref_scene)
if not found:
    for epoch in range(4000):
        start_render = time.time()
        # Render a new frame each epoch
        trainFrame()

        # Get the newly rendered data
        mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)

        y = mlDataOutput["radiance"]+0.0

        mask = get_mask_for_training(y, False)        
        y[~mask] = 0.0

        if epoch == 0:
            ref_estimations = y
        else:
            a = 1.0 / epoch
            ref_estimations = y * a + ref_estimations * (1.0 - a)

        # Save the ref_estimations tensor every 100 epochs
        if (epoch + 1) % 100 == 0:
            torch.save(ref_estimations, checkpoint_path)
            print(f"Saved ref_estimations at epoch {epoch + 1}.")

        if (epoch + 1) % 1000 == 0:
            visualize_model_output(ref_estimations.cpu(), gparams.render_width, gparams.render_height, epoch)
else:
    visualize_model_output(ref_estimations.cpu(), gparams.render_width, gparams.render_height, 0)

In [None]:
class VarianceHelper:
    def __init__(self, width, height, device, alpha=0.95):
        self.width = width
        self.height = height
        self.device = device
        self.m1 = torch.zeros((width * height, 3), device=device)
        self.m2 = torch.zeros((width * height, 3), device=device)
        self.epochs = torch.zeros((width * height), dtype=torch.int32, device=device)
        self.epochs_ema = torch.zeros((width * height), dtype=torch.int32, device=device)
        self.alpha = alpha  # Smoothing factor for EMA
        self.m1_ema = torch.zeros((width * height, 3), device=device)
        self.m2_ema = torch.zeros((width * height, 3), device=device)
        self.variance_per_epoch = []  # Store (epoch, variance) tuples

    def state_dict(self):
        return {
            'm1': self.m1,
            'm2': self.m2,
            'epochs': self.epochs,
            'epochs_ema': self.epochs_ema,
            'm1_ema': self.m1_ema,
            'm2_ema': self.m2_ema,
            'alpha': self.alpha,
            'variance_per_epoch': self.variance_per_epoch
        }

    def load_state_dict(self, state_dict):
        self.m1 = state_dict['m1'].to(self.device)
        self.m2 = state_dict['m2'].to(self.device)
        self.epochs = state_dict['epochs'].to(self.device)
        self.epochs_ema = state_dict['epochs_ema'].to(self.device)
        self.m1_ema = state_dict['m1_ema'].to(self.device)
        self.m2_ema = state_dict['m2_ema'].to(self.device)
        self.alpha = state_dict['alpha']
        self.variance_per_epoch = state_dict['variance_per_epoch']

    def update(self, indices, bias):
        """
        Update the incremental variance estimators with new bias data.
        Skips updates for entries where bias contains NaN or Infinity.
        """
        # Ensure bias has the correct shape
        if bias.dim() != 2 or bias.size(0) != indices.size(0):
            raise ValueError("Bias tensor must be 2-dimensional with shape (N, 3) matching the number of indices.")

        # Create a mask for valid (finite) bias entries
        valid_mask = torch.isfinite(bias).all(dim=1)
        if not valid_mask.any():
            # If no valid data, exit early
            return

        # Select only valid indices and corresponding bias
        valid_indices = indices[valid_mask]
        valid_bias = bias[valid_mask]

        # Get the current epoch for valid masked pixels
        current_epochs = self.epochs[valid_indices]

        # Compute weight for valid masked pixels
        weights = 1.0 / (current_epochs.float() + 1)
        weights = weights.unsqueeze(1)  # Make it (N, 1) to match bias shape

        # Update m1 and m2 only for valid masked pixels (incremental estimator)
        self.m1[valid_indices] = valid_bias * weights + (1 - weights) * self.m1[valid_indices]
        self.m2[valid_indices] = (valid_bias ** 2) * weights + (1 - weights) * self.m2[valid_indices]

        # Update EMA m1 and m2 for valid masked pixels (EMA estimator)
        self.m1_ema[valid_indices] = self.alpha * self.m1_ema[valid_indices] + (1 - self.alpha) * valid_bias
        self.m2_ema[valid_indices] = self.alpha * self.m2_ema[valid_indices] + (1 - self.alpha) * (valid_bias ** 2)

        # Increment the epoch counter for valid masked pixels
        self.epochs[valid_indices] += 1

    def update_ema(self, bias, indices=None):
        """
        Update the EMA-based variance estimators with new bias data.
        Skips updates for entries where bias contains NaN or Infinity.
        """
        if indices is None:
            indices = torch.arange(self.m1_ema.shape[0], device=self.m1_ema.device)

        # Ensure bias has the correct shape
        if bias.dim() != 2 or bias.size(0) != indices.size(0):
            raise ValueError("Bias tensor must be 2-dimensional with shape (N, 3) matching the number of indices.")

        # Create a mask for valid (finite) bias entries
        valid_mask = torch.isfinite(bias).all(dim=1)
        if not valid_mask.any():
            # If no valid data, exit early
            return

        # Select only valid indices and corresponding bias
        valid_indices = indices[valid_mask]
        valid_bias = bias[valid_mask]

        # Create a tensor version of alpha
        alpha_tensor = torch.ones_like(self.m1_ema[valid_indices]) * self.alpha
        alpha_tensor[self.epochs_ema[valid_indices] == 0] = 0

        # Update EMA m1 and m2 for valid masked pixels (EMA estimator)
        self.m1_ema[valid_indices] = alpha_tensor * self.m1_ema[valid_indices] + (1 - alpha_tensor) * valid_bias
        self.m2_ema[valid_indices] = alpha_tensor * self.m2_ema[valid_indices] + (1 - alpha_tensor) * (valid_bias ** 2)

        # Increment the epoch counter for valid masked pixels
        self.epochs_ema[valid_indices] += 1



    def get_variance(self, estimations=ref_estimations, eps=gparams.relative_error_eps):
        # Compute variance using the incremental estimator
        variance = (self.m2 - self.m1 ** 2) / (estimations ** 2 + eps)
        # Return mean variance where epochs >= 1
        return torch.mean(variance[self.epochs >= 1])

    def get_variance_ema(self, estimations=ref_estimations, eps=gparams.relative_error_eps, epoch=None):
        # Compute variance using the EMA estimator
        variance_ema = (self.m2_ema - self.m1_ema ** 2) / (estimations ** 2 + eps)
        # Calculate mean variance where epochs_ema >= 1
        mean_variance_ema = torch.mean(variance_ema[self.epochs_ema >= 1])
        
        # Cache the epoch and mean variance for this epoch
        if epoch is not None:
            self.variance_per_epoch.append((epoch, mean_variance_ema.item()))
        
        return mean_variance_ema.item()
    
    def get_cached_variance_per_epoch(self):
        return self.variance_per_epoch


In [None]:
def updateVarianceBaseline(steps=1000, recompute=False):
    # Define checkpoint path for saving and loading variance data
    checkpoint_path = os.path.join(scene_cfg.checkpoint_dir, 'ref_variance.pth')

    # Check if recompute is set to False and the checkpoint file exists
    if not recompute and os.path.exists(checkpoint_path):
        # Load the variance data from checkpoint
        scene_cfg.ref_variance = torch.load(checkpoint_path)
        print(f"Loaded variance from {checkpoint_path}: {scene_cfg.ref_variance}")
        return

    # Else, perform variance computation
    variance_helper = VarianceHelper(gparams.render_width, gparams.render_height, device=device_t)

    for i in range(steps):
        trainFrame()  # Call to train frame

        # Get machine learning data and split it accordingly
        mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)
        radiance = mlDataOutput["radiance"] + 0  # Assuming radiance is a tensor of shape [N, 3]

        # Find valid indices where radiance values are >= 0 and no inf or nan for all RGB channels
        indices = torch.all((radiance >= 0) & torch.isfinite(radiance), dim=1).nonzero(as_tuple=True)[0]

        # Select the valid radiance values
        radiance = radiance[indices]

        # Update the variance helper
        variance_helper.update(indices=indices, bias=radiance)
        print(f"Variance Estimation {i}: {variance_helper.get_variance()}")

    # Get the final variance after the loop
    final_variance = variance_helper.get_variance()

    # Assign the computed variance to scene_cfg.ref_variance
    scene_cfg.ref_variance = final_variance

    # Save the final variance to the checkpoint file
    torch.save(final_variance, checkpoint_path)
    print(f"Saved computed variance to {checkpoint_path}: {scene_cfg.ref_variance}")

# Example of calling the function
updateVarianceBaseline(steps=scene_cfg.var_est_steps, recompute=True)

In [None]:
from zunis.models.flows.sampling import UniformSampler
from zunis.training.weighted_dataset.stateful_trainer import StatefulTrainer


class NCVModel(torch.nn.Module):
    def __init__(self, width, height, params_per_pixel, config_path, integral_model=True, color=True, encoding_mode="world", coord_system="spherical", optional_features=True, model="default", count_transf_pdf=True):
        super(NCVModel, self).__init__()

        num_dims = 2

        self.encoding_mode = encoding_mode

        self.world_encoder = None
        self.screen_encoder = None
        self.variance_helper = VarianceHelper(width, height, device=device_t)

        num_feature_dims = 0
        if self.encoding_mode == "world":
            with open(config_path) as config_file:
                config = json.load(config_file)
            
            encoding = tcnn.Encoding(11, config["encoding"]).to(device_t)
            halftofloat = LambdaLayer(lambda x: x.float())
            self.world_encoder = torch.nn.Sequential(encoding, halftofloat)
            #num_feature_dims = 47
            num_feature_dims = encoding.n_output_dims
        elif self.encoding_mode == "screen":
            self.params_per_pixel = params_per_pixel
            self.screen_encoder = PixelLatentLayer(width, height, params_per_pixel, device_t, init_method="default")
            num_feature_dims = params_per_pixel
        else:
            assert(0)


        if integral_model == True:
            self.integral_model = MLP(num_feature_dims, 3, num_layers=1).to(device_t)
        else:
            self.integral_model = None


        self.color = color

        if coord_system == "cylindrical":
            num_dims = 3

        
        sampler = UniformSampler(d=num_dims, device=device_t)
        self.optional_features = optional_features
        self.coord_system = coord_system
        #flow_options = {'cell_params': {'d_opt_feats': num_feature_dims * self.optional_features + 0, "n_bins": 64, "model": model}, "masking_options": {"repetitions": 2}}
        #FIIIX
        flow_options = {'cell_params': {'d_opt_feats': num_feature_dims * self.optional_features + 0, "n_bins": 64, "model": model, "d_hidden": 64, "n_hidden": 6}, "masking_options": {"repetitions": 2}}

        num_channels = 3 if self.color else 1
        self.trainers = [StatefulTrainer(d=num_dims, loss="l2", flow_options=flow_options, flow="pwquad", flow_prior=sampler, device=device_t) for _ in range(num_channels)]
        
                # Collect parameters into separate groups
        flow_params = [param for trainer in self.trainers for param in trainer.flow.parameters()]

        # Collect other parameters
        other_params = []
        if self.integral_model is not None:
            other_params.extend(self.integral_model.parameters())

        if self.world_encoder is not None:
            other_params.extend(self.world_encoder.parameters())

        if self.screen_encoder is not None:
            other_params.extend(self.screen_encoder.parameters())

        # Prepare parameter groups with different learning rates
        param_groups = [
            {'params': flow_params, 'lr': 0.00025 * scene_cfg.lr_factor, 'eps':1e-15}
        ]

        if other_params:
            param_groups.append({'params': other_params, 'lr': 0.001 * scene_cfg.lr_factor, 'eps':1e-15})

        # Create the optimizer with parameter groups
        self.optimizer = torch.optim.Adam(
            param_groups
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=200, gamma=0.5)

        for trainer in self.trainers:
            print(f"NN params {count_parameters(trainer.flow)}")        
        print(f"NCV params {count_parameters(self.integral_model)}")        
        print(f"NCV params {count_parameters(self.world_encoder)}")       

        self.checkpoint_path = None
        self.loss = 0
        self.device = device_t
        self.count_transf_pdf = count_transf_pdf

        if not optional_features:
            print("Warning, the model is training without the features. use it only for 1 pixel training!")

    def set_checkpoint_path(self, checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)

        name = 'ncv_checkpoint.pth'
        if self.integral_model is not None:
            name = 'ncv_nrc_checkpoint.pth'

        self.checkpoint_path = os.path.join(checkpoint_dir, name)

    def save_checkpoint(self, epoch, loss):
        if self.checkpoint_path:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'flow_state_dicts': [trainer.flow.state_dict() for trainer in self.trainers],
                'flow_inverses': [trainer.flow.inverse for trainer in self.trainers],
                'flows_inverses': [
                    [flow.transform.inverse for flow in trainer.flow.flows]
                    for trainer in self.trainers
                ],
                'loss': loss,
                'rL2': self.rL2,
                'variance_helper_state': self.variance_helper.state_dict()
            }

            if self.integral_model is not None:
                checkpoint['integral_model_state_dict'] = self.integral_model.state_dict()
            torch.save(checkpoint, self.checkpoint_path)

    def load_checkpoint(self, skip=False):
        if not skip and self.checkpoint_path and os.path.exists(self.checkpoint_path):
            checkpoint = torch.load(self.checkpoint_path, map_location=self.device)

            self.load_state_dict(checkpoint['model_state_dict'])
            if 'integral_model_state_dict' in checkpoint:
                self.integral_model.load_state_dict(checkpoint['integral_model_state_dict'])

            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            flow_state_dicts = checkpoint['flow_state_dicts']
            flows_inverses = checkpoint['flows_inverses']
            flow_inverses = checkpoint.get('flow_inverses', [True] * len(self.trainers))  # Default to False if not present

            print(flows_inverses)

            for i in range(len(self.trainers)):
                # Load flow state dict
                self.trainers[i].flow.load_state_dict(flow_state_dicts[i])
                # Restore inversion state for each flow in the trainer
                for j, inverse_flag in enumerate(flows_inverses[i]):
                    self.trainers[i].flow.flows[j].transform.inverse = inverse_flag
                self.trainers[i].flow.inverse = flow_inverses[i]

            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            self.rL2 = checkpoint.get('rL2', 0.0)  # Use get to provide a default value if 'rL2' isn't in the checkpoint

            # Load VarianceHelper state
            if 'variance_helper_state' in checkpoint:
                self.variance_helper.load_state_dict(checkpoint['variance_helper_state'])

            print(f"Checkpoint loaded: Epoch {epoch}, Loss {loss}")

            return epoch, loss
        else:
            return 0, None


    def toCylindrical(self, dirs):
        rho = torch.sqrt(dirs[:, 0] ** 2 + dirs[:, 1] ** 2)
        phi = torch.atan2(dirs[:, 1], dirs[:, 0])
        z = dirs[:, 2]

        rho_norm = rho
        phi_norm = (phi + torch.pi) / (2 * torch.pi)
        z_norm = (z + 1) / 2

        cylindrical = torch.stack((rho_norm, phi_norm, z_norm), dim=1)
        cylindrical = torch.clamp(cylindrical, 0.0, 1.0)

        jacobian = rho / (4 * torch.pi)
        jacobian = torch.clamp(jacobian, min=1e-6)

        assert(check_tensor(cylindrical, "cylindrical", min=0.0, max=1.0))
        assert(check_tensor(jacobian, "jacobian", min=1e-6))

        return cylindrical, jacobian

    def to2D(self, x):
        return x, torch.ones((x.shape[0]), device=x.device)

    def toSpherical(self, dirs):
        x, y, z = dirs[:, 0], dirs[:, 1], dirs[:, 2]

        theta = torch.atan2(torch.sqrt(x**2 + y**2), z)
        phi = torch.atan2(y, x)

        normalized_azimuth = (phi + np.pi) / (2 * np.pi)
        normalized_theta = theta / np.pi

        jacobian = torch.abs(torch.sin(theta))/(2 * np.pi * np.pi)

        assert(check_tensor(normalized_azimuth, "normalized_azimuth", min=0.0, max=1.0))
        assert(check_tensor(normalized_theta, "normalized_theta", min=0.0, max=1.0))

        res = torch.stack([normalized_theta, normalized_azimuth], dim=1)
        res = torch.clamp(res, 0.0, 1.0)
        return res, jacobian

    def transform_coordinates(self, x):
        if self.coord_system == "cylindrical":
            return self.toCylindrical(x)
        elif self.coord_system == "spherical":
            return self.toSpherical(x)
        elif self.coord_system == "2D":
            return self.to2D(x)
        else:
            assert(0)

    def forward(self, x, jacobian, pixels, world_data, integral_values=None, surface_color=None, detach=False, pdf_b=None, only_integral=False):
        assert(check_tensor(x, "positions", min=0.00000, max=1.0))

        if not self.count_transf_pdf:
            jacobian = jacobian * 0.0 + 1.0

        check_tensor(torch.log(jacobian.unsqueeze(-1)), "torch.log(pdf.unsqueeze(-1))")

        batch_size_limit = 256 * 256  # Maximum number of entries per batch
        num_samples = x.shape[0]
        outputs = []
        estimations_list = []

        for start_idx in range(0, num_samples, batch_size_limit):
            end_idx = min(start_idx + batch_size_limit, num_samples)

            # Slice the batch
            x_batch = x[start_idx:end_idx]
            jacobian_batch = jacobian[start_idx:end_idx]
            pixels_batch = pixels[start_idx:end_idx] if pixels is not None else None
            world_data_batch = world_data[start_idx:end_idx] if world_data is not None else None

            xd_batch = torch.cat((x_batch, torch.log(jacobian_batch.unsqueeze(-1))), -1)

            # Feature prefetching within the batch
            if not self.optional_features:
                features_batch = None
            else:
                features_list = []
                if self.world_encoder is not None:
                    features_world = self.world_encoder(world_data_batch)
                    features_list.append(features_world)
                if self.screen_encoder is not None:
                    features_screen = self.screen_encoder(pixels_batch)
                    features_list.append(features_screen)
                if features_list:
                    features_batch = torch.cat(features_list, dim=-1)
                else:
                    features_batch = None

            # Integral model processing within the batch
            if self.integral_model is None:
                if integral_values is not None:
                    estimation_batch = integral_values[start_idx:end_idx]
                else:
                    estimation_batch = None  # Set to zero or an appropriate default if necessary
            else:
                estimation_batch = torch.relu(self.integral_model(features_batch))
                if detach:
                    estimation_batch = estimation_batch.detach()

            estimations_list.append(estimation_batch)

            if only_integral:
                continue  # Skip processing trainers if only the integral is needed

            batch_predictions = []

            for i, trainer in enumerate(self.trainers):
                if not trainer.flow.inverse:
                    trainer.flow.invert()

                # Flow processing within the batch
                zj = trainer.flow(xd_batch, opt_feats=features_batch)
                if detach:
                    zj = zj.detach()

                z = zj[:, :-1]
                logqx = zj[:, -1] + trainer.latent_prior.log_prob(z)
                prediction = torch.exp(logqx).unsqueeze(-1)

                # Apply the estimation from the integral model
                if estimation_batch is not None:
                    estimation_i = estimation_batch[:, i:i+1] if estimation_batch.ndim > 1 else estimation_batch.unsqueeze(-1)
                    estimation_i = estimation_i.detach()
                    prediction = prediction * estimation_i

                if pdf_b is not None:
                    prediction = prediction / pdf_b[start_idx:end_idx]

                batch_predictions.append(prediction)

            # Concatenate predictions from all trainers for this batch
            prediction_batch = torch.cat(batch_predictions, dim=1)  # Shape: [batch_size, num_trainers]
            outputs.append(prediction_batch)

        if only_integral:
            estimations = torch.cat(estimations_list, dim=0)
            return None, estimations
        else:
            # Concatenate outputs from all batches
            predictions = torch.cat(outputs, dim=0)  # Shape: [num_samples, num_trainers]
            estimations = torch.cat(estimations_list, dim=0)  # Shape: [num_samples, num_trainers]
            return predictions, estimations

    



    def prepare_input_(self, dirs, pixel_indices):
        x, jacobian = self.transform_coordinates(dirs)

        assert(check_tensor(x, "x", min=0.0, max=1.0, force_analyze=False))
        assert(check_tensor(jacobian, "jacobian", min=0.0, max=1.0, force_analyze=False))
        return x, jacobian, pixel_indices

    def world_prepare_input_(self, dposition, dcolor, droughness, dnormal, dview, alpha=1.0):
        normals = dnormal
        normals = cart_to_sph(normals)

        scene_min = torch.tensor(float3tonumpy(ref_scene.bounds.min_point), device=device_t)
        scene_max = torch.tensor(float3tonumpy(ref_scene.bounds.max_point), device=device_t)
        positions = (dposition - scene_min) / (scene_max - scene_min)


        views = cart_to_sph(dview)

        assert(check_tensor(normals, "normals",min=0.0, max=1.0))
        assert(check_tensor(positions, "positions",min=0.0, max=1.0))
        assert(check_tensor(dcolor, "dcolor",min=0.0, max=1.0))
        assert(check_tensor(droughness, "droughness",min=0.0, max=1.0))
        assert(check_tensor(views, "dview", min=0.0, max=1.0))

        return torch.cat((normals*alpha, positions*alpha, droughness*alpha, views*alpha,dcolor), dim=1)

    def world_prepare_input(self, mlDataOutput, alpha=1.0):
        return self.world_prepare_input_(mlDataOutput["worldpos"] + 0.0, mlDataOutput["color"] + 0.0, mlDataOutput["roughness"] + 0.0, mlDataOutput["normal"] + 0.0, mlDataOutput["view"] + 0.0, alpha=alpha)

    def prepare_input(self, mlDataOutput, pixel_indices, alpha=1.0):
        ddir = mlDataOutput["dir"] + 0
        pdf = mlDataOutput["pdf"] + 0.0
        pdf = safe_pdf(pdf)
        x, jacobian, pixels = self.prepare_input_(ddir, pixel_indices)
        world_input = self.world_prepare_input(mlDataOutput, alpha=alpha)
        return pdf, x, jacobian, pixels, world_input

    def prepare_output(self, mlDataOutput):
        Radiance = mlDataOutput["radiance"] + 0.0
        return Radiance



def train_ncv(ncv_model, tile_width,tile_height, B, N, start_epoch=0, integral_delay = 50, onlyIntegral=False, onlyRender=False, debug_train_id=None, var_bootstrap_steps=0):
    scene_cfg.prepare_scene(ref_scene)

    average_kld_loss = 0.0
    average_l2_loss = 0.0

    relative_error = False
    if start_epoch >= N:
        if gparams.RECOMPUTE_LOSS["NCV"]:
            start_epoch = N
            ncv_model.rL2 = 0
        else:
            onlyRender = True


    re_accum = 1


    if start_epoch == 0:
        with torch.no_grad():
            for i in range(var_bootstrap_steps):
                trainFrame()
                mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)
                # Prepare inputs
                
                Y = ncv_model.prepare_output(mlDataOutput)
                SurfaceColor = mlDataOutput["color"] + 0.0
                SurfaceColor = SurfaceColor

                THP = mlDataOutput["thp"] + 0.0

                arr = torch.arange(mlDataOutput["dir"].shape[0])
                pixel_indices = torch.stack((arr % gparams.render_width, arr // gparams.render_width), dim=1).to(device_t)
                integral_values = ref_estimations

                alpha = 1.0
                PDF, X, Jacobian, Pixels, WorldData = ncv_model.prepare_input(mlDataOutput, pixel_indices, alpha=alpha)

                res, _ = ncv_model(x=X, jacobian=Jacobian, world_data=WorldData, integral_values=integral_values, surface_color=SurfaceColor, pixels=Pixels, pdf_b=PDF, detach=True)
                res = res.detach()
                bias = res-Y
                ncv_model.variance_helper.update_ema(indices=None, bias=bias)
                print(f"Initial Variance Estimation {i}, rVar: {ncv_model.variance_helper.get_variance_ema()}")

    loss_integral = 0
    rL2Integral = 0
    for epoch in range(start_epoch, N + scene_cfg.var_est_steps):

        if ncv_model.integral_model is None:
            onlyIntegral = False
            integral_delay = 0

        if debug_train_id is not None:
            B = 1

        start_time = time.time()

        trainFrame()
        mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)

        Y = ncv_model.prepare_output(mlDataOutput)

        Y = torch.clamp(Y, -10000, 10000)
        SurfaceColor = mlDataOutput["color"] + 0.0
        SurfaceColor = SurfaceColor

        THP = mlDataOutput["thp"] + 0.0
        batch_size = 0
        arr = torch.arange(mlDataOutput["dir"].shape[0])
        pixel_indices = torch.stack((arr % gparams.render_width, arr // gparams.render_width), dim=1).to(device_t)

        alpha = 1.0
        PDF, X, Jacobian, Pixels, WorldData = ncv_model.prepare_input(mlDataOutput, pixel_indices, alpha=alpha)


        width, height = gparams.render_width, gparams.render_height
        loss_val = 0


        if (epoch) % 100 == 99 and epoch != start_epoch:
            ncv_model.save_checkpoint(epoch, average_loss)

        loss_epoch = epoch-N
        bTrainNow = loss_epoch < 0

        if not onlyRender:
            for batch_index in range(B):
                if bTrainNow:
                    ncv_model.optimizer.zero_grad()

                # Apply tiling to the current batch
                batch_indices = vectorized_uniform_stratified_sampling(width, height, tile_width, tile_height)
                y_batch_ = Y[batch_indices]
                mask = get_mask_for_training(y_batch_, False)    
                batch_indices = batch_indices[mask]


                

                y = Y[batch_indices]
                thp = THP[batch_indices]
                world_data = WorldData[batch_indices]
                pixels = Pixels[batch_indices]
                jacobian = Jacobian[batch_indices]
                pdf = PDF[batch_indices]
                x = X[batch_indices]
                surface_color = SurfaceColor[batch_indices]

                y = torch.clamp(y, -10000, 10000)
                integral_values = ref_estimations[batch_indices]
                batch_size = y.shape[0]
                if debug_train_id is not None:
                    x = torch.unsqueeze(X[debug_train_id], dim=0)
                    world = torch.unsqueeze(WorldData[debug_train_id], dim=0)
                    pixels = torch.unsqueeze(Pixels[debug_train_id], dim=0)
                    pdf = torch.unsqueeze(PDF[debug_train_id], dim=0)
                    y = torch.unsqueeze(Y[debug_train_id], dim=0)
                    surface_color = torch.unsqueeze(SurfaceColor[debug_train_id], dim=0)
                    integral_values = torch.unsqueeze(ref_estimations[debug_train_id], dim=0)

                if debug_train_id is not None and torch.any(y < 0.0):
                    continue

                
                if True:
                    mask = (pdf > 0.0)
                    pdf[~mask] = 1.0
                    pdf[mask] = pdf[mask] + 0.0001
                
                outputs, integral = ncv_model(x=x, jacobian=jacobian, world_data=world_data, integral_values=integral_values, surface_color=surface_color, pixels=pixels, pdf_b=pdf, detach=False)

                divider = torch.mean(integral.detach(), dim=-1) **2 + gparams.relative_error_eps
                if not bTrainNow:
                    outputs = outputs.detach()
                    loss_l2_v = 0


                divider = divider.unsqueeze(-1)


                if onlyIntegral:
                    outputs = outputs.detach()

                
                loss = relativeL2Special(outputs, y, pdf=pdf, div=divider)

                if onlyIntegral:
                    loss = loss*0

                

                        
                if epoch >= integral_delay:
                    bias = outputs.detach()-y
                    if not bTrainNow:
                        ncv_model.variance_helper.update(indices=batch_indices, bias=bias)
                    else:
                        ncv_model.variance_helper.update_ema(indices=batch_indices, bias=bias)

                total_loss = loss
                loss_val = loss.item()
                rL2 = loss.item()

                if ncv_model.integral_model is not None:
                    loss_integral = relativeL2(integral, y, pdf=pdf*0.0+1.0)
                    rL2Integral = loss_integral.item()
                    if epoch >= integral_delay:
                        total_loss = loss*0.8+loss_integral*0.2
                    else:
                        total_loss = loss_integral

                if bTrainNow:
                    if epoch == 0 and batch_index == 0:
                        ncv_model.rL2 = rL2
                    ncv_model.rL2 = rL2*0.90+ncv_model.rL2*0.10
                else:
                    overe = 1.0 / re_accum  
                    ncv_model.rL2 = rL2* overe + (1-overe) * ncv_model.rL2
                    re_accum += 1


                if bTrainNow:
                    total_loss.backward()
                    #torch.nn.utils.clip_grad_norm_(ncv_model.parameters(), 1000)  # Clipping gradients to avoid explosion
                    if ncv_model.integral_model is not None:
                        torch.nn.utils.clip_grad_norm_(ncv_model.integral_model.parameters(), 3000)  # Clipping gradients to avoid explosion
                    ncv_model.optimizer.step()
                if onlyRender:
                    break

            #ncv_model.scheduler.step()
            end_time = time.time()


        if epoch == start_epoch:
            average_loss = loss_val
        else:
            alpha = 0.95
            average_loss = average_loss * alpha + (1.0 - alpha) * loss_val


        if bTrainNow and epoch>=integral_delay:
            var = ncv_model.variance_helper.get_variance_ema(epoch=epoch-integral_delay)
        else:
            var = ncv_model.variance_helper.get_variance()
            
        if (epoch + 1) % 5 == 0:
            print(f'Epoch [{epoch + 1}/{N}], Batch: {batch_size} AveragedLoss: {average_loss}, Loss: {loss_val}, rL2_loss: {ncv_model.rL2} rL2_integral: {rL2Integral} variance: {var}')

        #max_epoch, interval
        conditions = [
                (10, 5),
                (100, 10),
                (500, 50),
            ]
        # Example usage in training loop
        if should_visualize(epoch, conditions) and bTrainNow or onlyRender:
            if ncv_model.integral_model is not None:
                outputs, integral = ncv_model(x=X, jacobian=Jacobian, world_data=WorldData, integral_values=None, surface_color=SurfaceColor, pixels=Pixels, pdf_b=PDF, only_integral=True, detach=True)
                visualize_model_output(integral.cpu(), gparams.render_width, gparams.render_height, epoch)
                #visualize_model_output(Y.cpu(), gparams.render_width, gparams.render_height, epoch)

        if onlyRender:
            break

        # Visualization conditions (if any)

# Instantiate the model

if True:
    ncv_model = NCVModel(width=gparams.render_width, height=gparams.render_height, config_path="data/ncv.json", params_per_pixel=74, encoding_mode="world", integral_model=True, coord_system="cylindrical", optional_features=True, model="oneblob", count_transf_pdf=False, color=True).to(device_t)
    ncv_model.set_checkpoint_path(scene_cfg.checkpoint_dir)
    start_epoch, _ = ncv_model.load_checkpoint(skip=True)
    train_ncv(ncv_model, tile_width=4, tile_height=4, B=1, N=scene_cfg.epochs,start_epoch=start_epoch, var_bootstrap_steps=2)

# ALL MODELS


## NRC

In [None]:
# class NRC_Model(torch.nn.Module):
#     def __init__(self, config_path, device):
#         super(NRC_Model, self).__init__()
#         with open(config_path) as config_file:
#             config = json.load(config_file)
        
#         encoding = tcnn.Encoding(n_input_dims, config["encoding"])
#         network = tcnn.Network(encoding.n_output_dims, n_output_dims, config["network"])
#         model = torch.nn.Sequential(encoding, network)

#         self.model = tcnn.NetworkWithInputEncoding(
#             n_input_dims=12, 
#             n_output_dims=3, 
#             encoding_config=config["encoding"], 
#             network_config=config["network"]
#         ).to(device)
#         self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, betas=[0.9, 0.99], eps=1e-16)
#         self.checkpoint_path = None
#         self.device = device
#         self.rL2 = None

#     def set_checkpoint_path(self, checkpoint_dir):
#         os.makedirs(checkpoint_dir, exist_ok=True)
#         self.checkpoint_path = os.path.join(checkpoint_dir, f'nrc_checkpoint.pth')

#     def save_checkpoint(self, epoch, loss):
#         if self.checkpoint_path:
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': self.model.state_dict(),
#                 'optimizer_state_dict': self.optimizer.state_dict(),
#                 'loss': loss,
#                 'rL2': self.rL2,
#             }, self.checkpoint_path)

#     def load_checkpoint(self, skip=False):
#         if not skip and self.checkpoint_path and os.path.exists(self.checkpoint_path):
#             checkpoint = torch.load(self.checkpoint_path)
#             self.model.load_state_dict(checkpoint['model_state_dict'])
#             self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#             epoch = checkpoint['epoch']
#             loss = checkpoint['loss']
#             self.rL2 = checkpoint['rL2']
            
#             print(f"Checkpoint loaded: Epoch {epoch}, Loss {loss}")
#             return epoch, loss
#         else:
#             return 0, None

#     def forward(self, x, color, d_radiance=0):
#         return self.model(x)*color+d_radiance

#     def prepare_input(self, mlDataOutput, camera_position_tensor):
#         pdf = mlDataOutput["pdf"]
#         view_vectors = camera_position_tensor - mlDataOutput["worldpos"]
#         norm_view_vectors = view_vectors / torch.norm(view_vectors, dim=1, keepdim=True)
#         X = packModelInput(mlDataOutput["worldpos"], norm_view_vectors, mlDataOutput["color"], mlDataOutput["roughness"], mlDataOutput["normal"])
#         return X, pdf


# # Assuming scene_cfg, gparams, ml_data, and other necessary variables are properly defined
# scene_cfg.prepare_scene(ref_scene)

# # Initialize NRC_Model
# nrc_model = NRC_Model(config_path="data/nrc.json", device=device_t)
# nrc_model.set_checkpoint_path(scene_cfg.checkpoint_dir)

# # Load checkpoint if it exists
# start_epoch, _ = nrc_model.load_checkpoint(skip=False)

# # Training parameters
# N = gparams.num_training_frames  # number of epochs
# B = 16  # number of batches
# simulate_tiles = False  # like in the original NRC work
# tile_width, tile_height = 4, 4  # Sampling one point per 2x2 tile
# eps = 0.001
# lambda_l2 = 0.00000001  # Weight for L2 regularization
# L2Reg = False

# # Enter training mode
# nrc_model.train()

# # Convert camera position to PyTorch tensor
# camera_position_tensor = torch.tensor(scene_cfg.camera_position, dtype=torch.float32).to(device_t)

# relative_error = True

# N = 0


# for epoch in range(start_epoch, N):
#     start_render = time.time()
#     # Render a new frame each epoch
#     trainFrame()
#     # Get the newly rendered data
#     mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)

#     thp = mlDataOutput["thp"]
#     render_time = time.time() - start_render

#     # Prepare inputs using the new method
#     X, pdf = nrc_model.prepare_input(mlDataOutput, camera_position_tensor)

#     color = mlDataOutput["color"]
#     y = mlDataOutput["radiance"].clone()  # Clone to manage memory

#     direct_radiance =  mlDataOutput["dradiance"]

#     # y = y*thp

#     width, height = gparams.render_width, gparams.render_height
#     total_samples = X.size()[0]
#     samples_per_batch = total_samples // B

#     for batch_index in range(B):
#         nrc_model.optimizer.zero_grad()

#         # Apply tiling to the current batch
#         batch_indices = vectorized_uniform_stratified_sampling(width, height, tile_width, tile_height)
#         batch_indices = batch_indices[:samples_per_batch]  # Ensure the batch size matches

#         batch_x = X[batch_indices]
#         batch_y = y[batch_indices]
#         batch_pdf = pdf[batch_indices]

#         mask = get_mask_for_training(batch_y, False)
#         batch_y = batch_y[mask]
#         batch_x = batch_x[mask]
#         batch_pdf = batch_pdf[mask]
#         batch_color = color[batch_indices][mask]
#         # Forward pass
#         outputs = nrc_model(batch_x, batch_color)
#         denominator = torch.norm(outputs.detach(), dim=1).view(-1,1) + gparams.relative_error_eps if gparams.relative_error else 1.0
#         loss = ((outputs - batch_y)**2 / denominator).mean()
#         # Add L2 regularization if enabled
#         total_loss = loss


#         denominator = torch.norm(ref_estimations[batch_indices][mask], dim=1).view(-1,1)+gparams.relative_error_eps
#         nrc_model.rL2 = ((outputs.detach() - batch_y)**2 / denominator).mean()


#         # Backward and optimize
#         total_loss.backward()
#         torch.nn.utils.clip_grad_norm_(nrc_model.model.parameters(), 1000)  # Clipping gradients to avoid explosion
#         nrc_model.optimizer.step()

#     # Reporting
#     if epoch == start_epoch:
#         average_loss = loss.item()
#     else:
#         alpha = 0.95
#         average_loss = average_loss * alpha + (1.0 - alpha) * loss.item()

#     if (epoch + 1) % 5 == 0:
#         print(f'Epoch [{epoch + 1}/{N}], AveragedLoss: {average_loss}, Loss: {loss.item()}, rL2_loss: {nrc_model.rL2}')

#     if (epoch + 1) % 500 == 0:
#         nrc_model.eval()
#         visualize_model_output(nrc_model(X, color).cpu(), gparams.render_width, gparams.render_height, epoch)
#         nrc_model.train()


#     if (epoch + 1) % 100 == 0:
#         nrc_model.save_checkpoint(epoch, average_loss)


## NIRC

In [None]:

class NIRC_Model(torch.nn.Module):
    def __init__(self, width, height, config_path, device, init_method='xavier'):
        super(NIRC_Model, self).__init__()
        
        with open(config_path) as config_file:
            config = json.load(config_file)

        
        encoding = tcnn.Encoding(12, config["encoding"]).to(device)
        halftofloat = LambdaLayer(lambda x: x.float())
        mlp = MLP(encoding.n_output_dims, 3).to(device)
        print(f"NN params {count_parameters(mlp)}")
        self.model = torch.nn.Sequential(encoding, halftofloat, mlp)
        self.variance_helper = VarianceHelper(width, height, device=device)


        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001*scene_cfg.lr_factor, betas=[0.9, 0.99], eps=1e-15)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=200, gamma=0.5)
        print(f"NN params {count_parameters(self.model)}")
        self.checkpoint_path = None
        self.device = device
        self.rL2 = 0


    def set_checkpoint_path(self, checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.checkpoint_path = os.path.join(checkpoint_dir, f'nirc_checkpoint.pth')

    def save_checkpoint(self, epoch, loss):
        if self.checkpoint_path and not math.isnan(loss) and not math.isinf(loss) and loss > 0:
            print(f"model saved on epoch {epoch} and rL2: {self.rL2}")
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'variance_helper_state': self.variance_helper.state_dict(),
                'loss': loss,
                'rL2': self.rL2,
            }, self.checkpoint_path)

    def load_checkpoint(self, skip=False):
        if not skip and self.checkpoint_path and os.path.exists(self.checkpoint_path):
            checkpoint = torch.load(self.checkpoint_path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            self.rL2 = checkpoint['rL2']
            
            if 'variance_helper_state' in checkpoint:
                self.variance_helper.load_state_dict(checkpoint['variance_helper_state'])
            
            print(f"Checkpoint loaded: Epoch {epoch}, Loss {self.rL2 }")
            return epoch, loss
        else:
            return 0, None

    def forward(self, x, thp):
        return torch.relu(self.model(x))*thp
    
    def prepare_input_(self, dposition, ddir, dcolor, droughness, dnormal):
        norm = (torch.sqrt(torch.sum(ddir**2, dim=1))+0.000000001)
        norm = torch.unsqueeze(norm, dim=1)
        ddir = ddir/norm
        dirs = ddir*0.5+0.5
            
        normals =  dnormal
        normals = cart_to_sph(normals)

        # torch.Size([3])
        scene_min = torch.tensor(float3tonumpy(ref_scene.bounds.min_point), device="cuda:0")
        scene_max = torch.tensor(float3tonumpy(ref_scene.bounds.max_point), device="cuda:0")

        #torch.Size([147600, 3])
        positions = (dposition-scene_min)/(scene_max-scene_min)

        assert(check_tensor(dirs, "dirs"))
        assert(check_tensor(normals, "normals"))
        assert(check_tensor(positions, "positions"))
        assert(check_tensor(dcolor, "dcolor"))
        assert(check_tensor(droughness, "droughness"))

        return torch.cat((normals, positions, droughness, dcolor, dirs), dim=1)


    def prepare_input(self, mlDataOutput):
        return self.prepare_input_(mlDataOutput["worldpos"]+0.0, mlDataOutput["dir"]+0.0, mlDataOutput["color"]+0.0, mlDataOutput["roughness"]+0.0, mlDataOutput["normal"]+0.0)
        

# Assuming scene_cfg, gparams, ml_data, and other necessary variables are properly defined
scene_cfg.prepare_scene(ref_scene)

# Initialize NIRC_Model
nirc_model = NIRC_Model(init_method='default', config_path="data/nirc.json", width=gparams.render_width, height=gparams.render_height, device=device_t)
nirc_model.set_checkpoint_path(scene_cfg.checkpoint_dir)

# Load checkpoint if it exists
start_epoch, _ = nirc_model.load_checkpoint(skip=True)

# Training parameters
N = scene_cfg.epochs  # number of epochs
simulate_tiles = False  # like in the original NRC work
tile_width, tile_height = 8, 8  # Sampling one point per 2x2 tile
B = 4  # number of batches
eps = 0.005
lambda_l2 = 0.00000001  # Weight for L2 regularization
L2Reg = False



# Enter training mode
nirc_model.train()
onlyRender = False
if start_epoch >= N:
    if gparams.RECOMPUTE_LOSS["NIRC"]:
        start_epoch = N
        nirc_model.rL2 = 0
    else:
        onlyRender = True


res_average = None
re_accum = 1


if start_epoch == 0:
    nirc_model.eval()
    for i in range(100):
        trainFrame()
        mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)
        # Prepare inputs
        y = 0.0+mlDataOutput["radiance"]  # Clone to manage memory
        y = torch.clamp(y, -10000, 10000)
        thp = mlDataOutput["thp"]+0.000
        X = nirc_model.prepare_input(mlDataOutput)
        res = nirc_model(X, thp).detach()        
        bias = res-y
        nirc_model.variance_helper.update_ema(indices=None, bias=bias)
        print(f"Initial Variance Estimation {i}, rVar: {nirc_model.variance_helper.get_variance_ema()}")
        
    nirc_model.train()

for epoch in range(start_epoch, N + scene_cfg.var_est_steps):
    start_render = time.time()
    # Render a new frame each epoch
    trainFrame()
    # Get the newly rendered data
    mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)

    render_time = time.time() - start_render

    # Prepare inputs
    y = 0.0+mlDataOutput["radiance"]  # Clone to manage memory
    thp = mlDataOutput["thp"]+0.000
    pdf = mlDataOutput["pdf"]+0.00
    X = nirc_model.prepare_input(mlDataOutput)
    width, height = gparams.render_width, gparams.render_height
    total_samples = X.size()[0]
    samples_per_batch = total_samples // B
    loss_val = 0

    loss_epoch = epoch-N
    bTrainNow = loss_epoch < 0


    if not onlyRender:
        for batch_index in range(B):
            if bTrainNow:
                nirc_model.optimizer.zero_grad()

            # Step 1: Obtain initial batch indices
            batch_indices = vectorized_uniform_stratified_sampling(width, height, tile_width, tile_height)

            # Step 2: Fetch batch_y to compute the mask
            batch_y = y[batch_indices]

            # Step 3: Compute the mask based on batch_y
            mask = get_mask_for_training(batch_y, False)

            # Step 4: Update the indices based on the mask
            batch_indices = batch_indices[mask]

            # Step 5: Fetch all the data for the batch using updated indices
            batch_x = X[batch_indices]
            batch_y = y[batch_indices]
            batch_thp = thp[batch_indices]
            batch_pdf = pdf[batch_indices]


            # Forward pass
            outputs = nirc_model(batch_x, batch_thp)

            if not bTrainNow:
                outputs = outputs.detach()
                loss_val = 0

            # For paper correct use
            # loss = relativeL2(outputs, batch_y, pdf=batch_pdf)
            
            # For comparison with Control Variates (SH, VMF, NCV), as the demand higher stability
            loss = relativeL2(outputs, batch_y, pdf=batch_pdf, div=ref_estimations[batch_indices])

            if True:
                batch_y = torch.clamp(batch_y, -10000, 10000)
                bias = outputs.detach()-batch_y
                if not bTrainNow:
                    nirc_model.variance_helper.update(indices=batch_indices, bias=bias)
                else:
                    nirc_model.variance_helper.update_ema(bias, indices=batch_indices)


            total_loss = loss 

            loss_val = loss.item()
            rL2 = loss.item()
            if bTrainNow:
                if epoch == 0 and batch_index == 0:
                    nirc_model.rL2 = rL2
                nirc_model.rL2 = rL2*0.90+nirc_model.rL2*0.10
            else:
                overe = 1.0 / re_accum
                nirc_model.rL2 = rL2* overe + (1-overe) * nirc_model.rL2
                re_accum += 1

                    # Backward and optimize
            if bTrainNow:
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(nirc_model.parameters(), 1000)  # Clipping gradients to avoid explosion
                nirc_model.optimizer.step()
            if onlyRender:
                break

        nirc_model.scheduler.step()

    if bTrainNow:
        var = nirc_model.variance_helper.get_variance_ema(epoch=epoch)
    else:
        var = nirc_model.variance_helper.get_variance()
    # Reporting
    if epoch == start_epoch:
        average_loss = loss_val
    else:
        alpha = 0.95
        average_loss = average_loss * alpha + (1.0 - alpha) * loss_val

    if (epoch + 1) % 5 == 0:
        print(f'Epoch [{epoch + 1}/{N}], AveragedLoss: {average_loss}, Loss: {loss_val}, rL2_loss: {nirc_model.rL2} variance: {var}')

    # if (loss_val == 0 or math.isinf(loss_val) or math.isnan(loss_val)) and bTrainNow:
    #     start_epoch, _ = nirc_model.load_checkpoint(skip=False)
    #     epoch = start_epoch
    #     print("RESTART!")
    #     continue

    conditions = [
        (10, 5),
        (50, 10),
        (300, 50),
    ]


    if not bTrainNow and False:
        nirc_model.eval()
        res = nirc_model(X, thp).detach()      
        if res_average is None:
            res_average = res
        else:
            res_average = res_average*0.99+res*0.01
            
        if epoch % 10 == 0:
            visualize_model_output(res_average.cpu(), gparams.render_width, gparams.render_height, epoch)
        nirc_model.train()

    # Example usage in training loop
    if should_visualize(epoch, conditions) and bTrainNow or onlyRender:
        nirc_model.eval()
        res = nirc_model(X, thp).detach()        

        visualize_model_output(res.cpu(), gparams.render_width, gparams.render_height, epoch)
        nirc_model.train()

    if (epoch) % 100 == 99:
        nirc_model.save_checkpoint(epoch, average_loss)

    if onlyRender:
        break

## NIRC Equal Memory

In [None]:
# Define the SH Layer
class SHLayer(nn.Module):
    def __init__(self, num_bands=5):
        super(SHLayer, self).__init__()
        self.num_bands = num_bands

    def forward(self, dirs):
        x, y, z = dirs[:, 0], dirs[:, 1], dirs[:, 2]

        # Compute necessary products and powers
        xy = x * y
        xz = x * z
        yz = y * z
        x2 = x * x
        y2 = y * y
        z2 = z * z
        x4 = x2 * x2
        y4 = y2 * y2
        z4 = z2 * z2

        # Initialize the tensor to store the SH features
        batch_size = dirs.size(0)
        num_features = 25  # Assuming 5 bands
        sh_features = torch.zeros(batch_size, num_features, device=dirs.device)

        # Band 0
        sh_features[:, 0] = 0.28209479177387814

        # Band 1
        if self.num_bands > 1:
            sh_features[:, 1] = -0.48860251190291987 * y
            sh_features[:, 2] = 0.48860251190291987 * z
            sh_features[:, 3] = -0.48860251190291987 * x

        # Band 2
        if self.num_bands > 2:
            sh_features[:, 4] = 1.0925484305920792 * xy
            sh_features[:, 5] = -1.0925484305920792 * yz
            sh_features[:, 6] = 0.94617469575755997 * z2 - 0.31539156525251999
            sh_features[:, 7] = -1.0925484305920792 * xz
            sh_features[:, 8] = 0.54627421529603959 * (x2 - y2)

        # Band 3
        if self.num_bands > 3:
            sh_features[:, 9] = 0.59004358992664352 * y * (-3 * x2 + y2)
            sh_features[:, 10] = 2.8906114426405538 * xy * z
            sh_features[:, 11] = 0.45704579946446572 * y * (1 - 5 * z2)
            sh_features[:, 12] = 0.3731763325901154 * z * (5 * z2 - 3)
            sh_features[:, 13] = 0.45704579946446572 * x * (1 - 5 * z2)
            sh_features[:, 14] = 1.4453057213202769 * z * (x2 - y2)
            sh_features[:, 15] = 0.59004358992664352 * x * (-x2 + 3 * y2)

        # Band 4
        if self.num_bands > 4:
            sh_features[:, 16] = 2.5033429417967046 * xy * (x2 - y2)
            sh_features[:, 17] = 1.7701307697799304 * yz * (-3 * x2 + y2)
            sh_features[:, 18] = 0.94617469575756008 * xy * (7 * z2 - 1)
            sh_features[:, 19] = 0.66904654355728921 * yz * (3 - 7 * z2)
            sh_features[:, 20] = -3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293
            sh_features[:, 21] = 0.66904654355728921 * xz * (3 - 7 * z2)
            sh_features[:, 22] = 0.47308734787878004 * (x2 - y2) * (7 * z2 - 1)
            sh_features[:, 23] = 1.7701307697799304 * xz * (-x2 + 3 * y2)
            sh_features[:, 24] = -3.7550144126950569 * x2 * y2 + 0.62583573544917614 * x4 + 0.62583573544917614 * y4

        return sh_features

# Define the combined model
class CombinedModel(nn.Module):
    def __init__(self, width, height, params_per_pixel, num_bands, device):
        super(CombinedModel, self).__init__()
        self.params_per_pixel = params_per_pixel
        self.pixel_latent_layer = PixelLatentLayer(width, height, params_per_pixel, device)
        self.sh_layer = SHLayer(num_bands).to(device)
        input_dim = params_per_pixel + 25  # PixelLatentLayer produces params_per_pixel features, SHLayer produces 25 features
        self.mlp = MLP(input_dim, 3).to(device)
        print(f"NN params {count_parameters(self.mlp)}")

    def forward(self, directions, pixel_indices):
        sh_features = self.sh_layer(directions)
        latent_features = self.pixel_latent_layer(pixel_indices)
        combined_features = torch.cat((latent_features, sh_features), dim=1)
        output = self.mlp(combined_features)
        return output


class NIRC_EQMem_Model(torch.nn.Module):
    def __init__(self, width, height, config_path, device, params_per_pixel=2, init_method='xavier'):
        super(NIRC_EQMem_Model, self).__init__()
        self.model = CombinedModel(width, height,params_per_pixel, 5, device)
        #self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001*scene_cfg.lr_factor, betas=[0.9, 0.99], eps=1e-15)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001, betas=[0.9, 0.99], eps=1e-15)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=300, gamma=0.25)
        self.params_per_pixel = params_per_pixel
        self.checkpoint_path = None
        self.device = device
        self.rL2 = 0
        self.variance_helper = VarianceHelper(width, height, device=device)


    def set_checkpoint_path(self, checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.checkpoint_path = os.path.join(checkpoint_dir, f'nirc_eqmem_{self.params_per_pixel}_checkpoint.pth')

    def save_checkpoint(self, epoch, loss):
        if self.checkpoint_path:
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'variance_helper_state': self.variance_helper.state_dict(),
                'loss': loss,
                'rL2': self.rL2,
            }, self.checkpoint_path)

    def load_checkpoint(self, skip=False):
        if not skip and self.checkpoint_path and os.path.exists(self.checkpoint_path):
            checkpoint = torch.load(self.checkpoint_path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            self.rL2 = checkpoint['rL2']
            if 'variance_helper_state' in checkpoint:
                self.variance_helper.load_state_dict(checkpoint['variance_helper_state'])
            
            print(f"Checkpoint loaded: Epoch {epoch}, Loss {self.rL2 }")
            return epoch, loss
        else:
            return 0, None

    def forward(self, x, pixel_indices, thp):
        return torch.relu(self.model(x, pixel_indices))*thp

    def prepare_input_(self, dirs, pixel_indices, mask=None):
        ddir = dirs
        norm = (torch.sqrt(torch.sum(ddir**2, dim=1))).unsqueeze(1)
        ddir = ddir / norm
        if mask is not None:
            X = X[mask]
            pixel_indices = pixel_indices[mask]

        return ddir, pixel_indices

    def prepare_input(self, mlDataOutput, pixel_indices, mask=None):
        ddir = mlDataOutput["dir"]+0
        return self.prepare_input_(ddir, pixel_indices, mask)


# Assuming scene_cfg, gparams, ml_data, and other necessary variables are properly defined
scene_cfg.prepare_scene(ref_scene)

# Initialize NIRC_Model
nirc_eqmem_model = NIRC_EQMem_Model(init_method='uniform_large', config_path="data/nirc_equal_mem_dummy.json", width=gparams.render_width, height=gparams.render_height, params_per_pixel=74, device=device_t)
nirc_eqmem_model.set_checkpoint_path(scene_cfg.checkpoint_dir)

# Load checkpoint if it exists
start_epoch, _ = nirc_eqmem_model.load_checkpoint(skip=False)

# Training parameters
N = scene_cfg.epochs  # number of epochs
B = 4  # batch size
simulate_tiles = False  # like in the original NRC work
tile_width, tile_height = 8, 8  # Sampling one point per 2x2 tile
eps = 0.005
lambda_l2 = 0.00000001  # Weight for L2 regularization
L2Reg = False

# Enter training mode
nirc_eqmem_model.train()

re_accum = 1

onlyRender = False
if start_epoch >= N:
    if gparams.RECOMPUTE_LOSS["NIRC_EM"]:
        start_epoch = N
        nirc_eqmem_model.rL2 = 0
    else:
        onlyRender = True


if start_epoch == 0:
    nirc_eqmem_model.eval()
    for i in range(50):
        trainFrame()
        mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)
        # Prepare inputs
        y = 0.0+mlDataOutput["radiance"]  # Clone to manage memory
        y = torch.clamp(y, -10000, 10000)
        thp = mlDataOutput["thp"]+0.000
        arr = torch.arange(thp.shape[0])
        pixel_indices = torch.stack((arr % gparams.render_width, arr // gparams.render_width), dim=1).to(device_t)
        X, pixel_indices= nirc_eqmem_model.prepare_input(mlDataOutput, pixel_indices)
        res = nirc_eqmem_model(X,pixel_indices,  thp).detach()        
        bias = res-y
        nirc_eqmem_model.variance_helper.update_ema(indices=None, bias=bias)
        print(f"Initial Variance Estimation {i}, rVar: {nirc_eqmem_model.variance_helper.get_variance_ema()}")
        
    nirc_eqmem_model.train()

for epoch in range(start_epoch, N + scene_cfg.var_est_steps):
    start_render = time.time()
    # Render a new frame each epoch
    trainFrame()
    # Get the newly rendered data
    mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)
    render_time = time.time() - start_render


    # Generate pixel indices
    arr = torch.arange(mlDataOutput["dir"].shape[0])
    pixel_indices = torch.stack((arr % gparams.render_width, arr // gparams.render_width), dim=1).to(device_t)

    # Prepare inputs
    y = mlDataOutput["radiance"]+0.0  # Clone to manage memory
    thp = mlDataOutput["thp"]+0.0
    pdf = safe_pdf(mlDataOutput["pdf"]+0.0)
    # y = y*thp
    X, pixel_indices = nirc_eqmem_model.prepare_input(mlDataOutput, pixel_indices)

    width, height = gparams.render_width, gparams.render_height
    total_samples = X.size()[0]
    samples_per_batch = total_samples // B

        
    loss_epoch = epoch-N
    bTrainNow = loss_epoch < 0
    for batch_index in range(B):
        if onlyRender:
            continue

        if bTrainNow:
            nirc_eqmem_model.optimizer.zero_grad()

        # Step 1: Obtain initial batch indices
        batch_indices = vectorized_uniform_stratified_sampling(width, height, tile_width, tile_height)

        # Step 2: Fetch batch_y to compute the mask
        batch_y = y[batch_indices]

        # Step 3: Compute the mask based on batch_y
        mask = get_mask_for_training(batch_y, False)

        # Step 4: Update the indices based on the mask
        batch_indices = batch_indices[mask]

        # Step 5: Fetch all the data for the batch using updated indices
        batch_x = X[batch_indices]
        batch_y = y[batch_indices]
        batch_thp = thp[batch_indices]
        batch_pdf = pdf[batch_indices]
        batch_pixel_indices = pixel_indices[batch_indices]


        # Forward pass
        outputs = nirc_eqmem_model(batch_x, batch_pixel_indices, batch_thp)

        if not bTrainNow:
            outputs = outputs.detach()
            loss_val = 0

        loss = relativeL2(outputs, batch_y, pdf=batch_pdf, div=ref_estimations[batch_indices])

        if True:
            batch_y = torch.clamp(batch_y, -10000, 10000)
            bias = outputs.detach()-batch_y
            if not bTrainNow:
                nirc_eqmem_model.variance_helper.update(indices=batch_indices, bias=bias)
            else:
                nirc_eqmem_model.variance_helper.update_ema(bias, indices=batch_indices)


        loss_val = loss.item()
        rL2 = loss.item()
        if bTrainNow:
            if epoch == 0 and batch_index == 0:
                nirc_eqmem_model.rL2 = rL2
            nirc_eqmem_model.rL2 = rL2*0.90+nirc_eqmem_model.rL2*0.10
        else:
            overe = 1.0 / re_accum
            nirc_eqmem_model.rL2 = rL2* overe + (1-overe) * nirc_eqmem_model.rL2
            re_accum += 1

        # Backward and optimize
        if bTrainNow:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(nirc_eqmem_model.model.parameters(), 3000)  # Clipping gradients to avoid explosion
            nirc_eqmem_model.optimizer.step()
    nirc_eqmem_model.scheduler.step()
    # Reporting
    if epoch == start_epoch:
        average_loss = loss_val
    else:
        alpha = 0.95
        average_loss = average_loss * alpha + (1.0 - alpha) * loss_val

    if bTrainNow:
        var = nirc_eqmem_model.variance_helper.get_variance_ema(epoch=epoch)
    else:
        var = nirc_eqmem_model.variance_helper.get_variance()

    if (epoch + 1) % 5 == 0 or onlyRender:
        print(f'Epoch [{epoch + 1}/{N}], AveragedLoss: {average_loss}, Loss: {loss_val}, rL2_loss: {nirc_eqmem_model.rL2} variance: {var}')

    if (loss_val == 0 or math.isinf(loss_val) or math.isnan(loss_val)) and bTrainNow:
        start_epoch, _ = nirc_eqmem_model.load_checkpoint(skip=False)
        epoch = start_epoch
        print("RESTART!")
        continue
    
    
        # Define the conditions and intervals as tuples (max_epoch, interval)
    conditions = [
        (10, 10),
        (25, 20),
        (500, 40),
    ]
    # Example usage in training loop
    if (should_visualize(epoch, conditions) and bTrainNow) or loss_epoch == 0 or onlyRender:
        #nirc_eqmem_model.eval()
        visualize_model_output((nirc_eqmem_model(X, pixel_indices, thp).detach()).cpu(), gparams.render_width, gparams.render_height, epoch)
        #nirc_eqmem_model.train()

    if (epoch) % 100 == 99:
        nirc_eqmem_model.save_checkpoint(epoch, average_loss)
    
    if onlyRender:
        break


## SH

In [None]:
class SHModel(torch.nn.Module):
    def __init__(self, num_bands, width, height):
        super(SHModel, self).__init__()
        self.num_bands = num_bands
        self.num_coeffs = self.num_bands**2  # Total number of SH coefficients for n bands
        self.sh_coefficients = torch.nn.Parameter(torch.zeros(3, width * height, self.num_coeffs))  # 3 sets of coefficients for RGB channels
        self.rL2 = 0

    def forward(self, dirs, debugID=None, mask=None):
        batch_size = dirs.size(0)
        dirs = dirs
        norm = (torch.sqrt(torch.sum(dirs**2, dim=1))).unsqueeze(1)
        dirs = dirs / norm

        radiance = torch.zeros(batch_size, 3, device=dirs.device)

        for i in range(3):  # Process each color channel
            if debugID is not None:
                # Use only the coefficients corresponding to the debugID
                coefficients = self.sh_coefficients[i, debugID, :].unsqueeze(0).repeat(batch_size, 1)
            else:
                coefficients = self.sh_coefficients[i]

            if mask is not None:
                coefficients = coefficients[mask]
            radiance[:, i] = self.eval_sh(dirs, coefficients)

        return radiance

    def set_checkpoint_path(self, checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.checkpoint_path = os.path.join(checkpoint_dir, f'sh_{self.num_coeffs * 3}_checkpoint.pth')

    def save_checkpoint(self, epoch, loss):
        if self.checkpoint_path:
            torch.save({
                'epoch': epoch,
                'num_bands': self.num_bands,
                'sh_coefficients': self.state_dict(),
                'loss': loss,
                'rL2': self.rL2,
            }, self.checkpoint_path)

    def load_checkpoint(self, skip=False):
        if not skip and self.checkpoint_path and os.path.exists(self.checkpoint_path):
            checkpoint = torch.load(self.checkpoint_path)

            self.num_bands = checkpoint['num_bands']
            self.num_coeffs = self.num_bands**2
            self.load_state_dict(checkpoint['sh_coefficients'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            self.rL2 = checkpoint['rL2']

            print(f"Checkpoint loaded: Epoch {epoch}, Loss {self.rL2}")
            return epoch, loss
        else:
            return 0, None

    def eval_sh(self, dirs, coefficients):
        # Compute SH basis functions up to the given degree
        x, y, z = dirs[:, 0], dirs[:, 1], dirs[:, 2]
        batch_size = dirs.size(0)
        data_out = torch.zeros(batch_size, device=dirs.device)

        # Compute necessary products and powers
        xy = x * y
        xz = x * z
        yz = y * z
        x2 = x * x
        y2 = y * y
        z2 = z * z

        # Include terms for each band, similar to the provided eval_sh function
        # Band 0
        xy = x * y
        xz = x * z
        yz = y * z
        x2 = x * x
        y2 = y * y
        z2 = z * z
        x4 = x2 * x2
        y4 = y2 * y2
        z4 = z2 * z2

        # First band: 0
        data_out += 0.28209479177387814 * coefficients[:, 0]

        # Second band: 1
        if self.num_bands > 1:
            data_out += -0.48860251190291987 * y * coefficients[:, 1]
            data_out += 0.48860251190291987 * z * coefficients[:, 2]
            data_out += -0.48860251190291987 * x * coefficients[:, 3]

        # Third band: 2
        if self.num_bands > 2:
            data_out += 1.0925484305920792 * xy * coefficients[:, 4]
            data_out += -1.0925484305920792 * yz * coefficients[:, 5]
            data_out += (0.94617469575755997 * z2 - 0.31539156525251999) * coefficients[:, 6]
            data_out += -1.0925484305920792 * xz * coefficients[:, 7]
            data_out += (0.54627421529603959 * (x2 - y2)) * coefficients[:, 8]

        # Fourth band: 3
        if self.num_bands > 3:
            data_out += 0.59004358992664352 * y * (-3 * x2 + y2) * coefficients[:, 9]
            data_out += 2.8906114426405538 * xy * z * coefficients[:, 10]
            data_out += 0.45704579946446572 * y * (1 - 5 * z2) * coefficients[:, 11]
            data_out += 0.3731763325901154 * z * (5 * z2 - 3) * coefficients[:, 12]
            data_out += 0.45704579946446572 * x * (1 - 5 * z2) * coefficients[:, 13]
            data_out += 1.4453057213202769 * z * (x2 - y2) * coefficients[:, 14]
            data_out += 0.59004358992664352 * x * (-x2 + 3 * y2) * coefficients[:, 15]

        # Fifth band: 4
        if self.num_bands > 4:
            data_out += 2.5033429417967046 * xy * (x2 - y2) * coefficients[:, 16]
            data_out += 1.7701307697799304 * yz * (-3 * x2 + y2) * coefficients[:, 17]
            data_out += 0.94617469575756008 * xy * (7 * z2 - 1) * coefficients[:, 18]
            data_out += 0.66904654355728921 * yz * (3 - 7 * z2) * coefficients[:, 19]
            data_out += (-3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293) * coefficients[:, 20]
            data_out += 0.66904654355728921 * xz * (3 - 7 * z2) * coefficients[:, 21]
            data_out += 0.47308734787878004 * (x2 - y2) * (7 * z2 - 1) * coefficients[:, 22]
            data_out += 1.7701307697799304 * xz * (-x2 + 3 * y2) * coefficients[:, 23]
            data_out += (-3.7550144126950569 * x2 * y2 + 0.62583573544917614 * x4 + 0.62583573544917614 * y4) * coefficients[:, 24]

        return data_out
    

import time

scene_cfg.prepare_scene(ref_scene)

# how many bands?
sh_model = SHModel(num_bands=5, width=gparams.render_width, height=gparams.render_height).to(device_t)
optimizer = torch.optim.Adam(sh_model.parameters(), lr=gparams.ADAM_LR*scene_cfg.lr_factor, betas=gparams.ADAM_BETAS)

sh_model.set_checkpoint_path(scene_cfg.checkpoint_dir)

# Load checkpoint if it exists
start_epoch, _ = sh_model.load_checkpoint(skip=False)

relative_error = False
N = 1000  # number of epochs
variance_helper = VarianceHelper(gparams.render_width, gparams.render_height, device=device_t)

# we use gradient backprogatation rather than just plain SH projection because it converges faster (the power of Adam)
onlyRender = False
if start_epoch >= N:
    if gparams.RECOMPUTE_LOSS["SH"]:
        start_epoch = N
        sh_model.rL2 = 0
    else:
        onlyRender = True

re_accum = 1


sum_c = 1
sumrender = None
for epoch in range(start_epoch, N + gparams.scene_cgf.var_est_steps):
    loss_epoch = epoch-N
    bTrainNow = loss_epoch < 0
    if not bTrainNow:
        sh_model.eval()
    else:
        optimizer.zero_grad()
            
    start_render = time.time()
    # Render a new frame each epoch
    trainFrame()


    # Get the newly rendered data
    mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)

    y = mlDataOutput["radiance"]+0.0
    thp = mlDataOutput["thp"]+0.0
    pdf = safe_pdf(mlDataOutput["pdf"]+0.0)
    

    thp_original = thp
    pdf_original = pdf
    pdf_original = pdf
    dirs = mlDataOutput["dir"]+0.0
    dirs_original = dirs

    # y = y*thp

    mask = get_mask_for_training(y)
    mask = mask
    mask = mask 
    # Filter out NaNs from 'y' and corresponding elements in 'X'

    num_elements = pdf.shape[0]
    y = y[mask]
    thp = thp[mask]
    dirs = dirs[mask]
    pdf = pdf[mask]

    num_used_elements = pdf.shape[0]

    # Prepare the data  
    loss_val = 0

    # Forward pass
    if not onlyRender:
        outputs = sh_model(dirs, mask=mask)

        if not bTrainNow:
            outputs = outputs.detach()
        

        mcestimator = outputs/pdf
        loss = relativeL2(mcestimator, y, pdf, div=ref_estimations[mask])
        rL2 = loss.detach()
        #loss = torch.mean((outputs/pdf-y*thp)**2)
        #loss = relativeL2PDFCounted(outputs, y*pdf, pdf=pdf, div=ref_estimations)
        #rL2 = relativeL2PDFCounted(outputs.detach(), y*pdf, pdf=pdf, div=ref_estimations)
        #loss = rL2
        
        
        loss_val = loss.item()
        rL2 = rL2.item()
        if bTrainNow:
            if epoch == 0:
                sh_model.rL2 = rL2
            sh_model.rL2 = rL2*0.90+sh_model.rL2*0.10
        else:
            bias = mcestimator.detach()-y
            #variance_helper.update(mask, bias)

            overe = 1.0 / re_accum
            sh_model.rL2 = rL2* overe + (1-overe) * sh_model.rL2
            re_accum += 1
        
        # Backward and optimize
        start_backprop = time.time()
        
        
        # Backward and optimize
        if bTrainNow:
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(sh_model.parameters(), 1000)
            optimizer.step()


    if epoch == 0:
        average_loss = loss_val
    else:
        alpha = 0.95
        average_loss = average_loss*alpha+(1.0-alpha)*loss_val
         
         
    if (epoch+1) % 2 == 0 or onlyRender:
        print(f'Epoch [{epoch+1}/{N}],  AverageLoss: {average_loss}, Loss: {loss_val}, rL2_loss: {sh_model.rL2} Num Used Elements {num_used_elements}/{num_elements} Variance {variance_helper.get_variance()}')
    
    if (epoch) % 100 == 99:
        sh_model.save_checkpoint(epoch, average_loss)

    conditions = [
        (10, 2),
        (25, 5),
        (50, 20),
        (1000, 50),
    ]

    if (should_visualize(epoch, conditions) and bTrainNow) or epoch == N or onlyRender:
        sh_model.eval()

        p = sh_model(dirs_original)/pdf_original

        if sum_c == 1:
            sumrender = p
            sum_c += 1
        else:
            w = 1.0/sum_c
            sumrender = p*w + sumrender*(1-w)
            sum_c += 1
        visualize_model_output((p).cpu(), gparams.render_width, gparams.render_height, epoch)
        sh_model.train()

    if onlyRender:
        break

## VMF

In [None]:
def log_sum_exp(values, dim=1):
    # Subtract the maximum value to prevent overflow or underflow
    max_val = torch.max(values, dim=dim, keepdim=True)[0]
    stable_logsumexp = max_val + torch.log(torch.sum(torch.exp(values - max_val), dim=dim, keepdim=True))
    return stable_logsumexp

class VMFModel(torch.nn.Module):
    def __init__(self, width, height, K, alpha = 0.8, max_kappa=200,min_kappa=None, device=None, kappa_force=None, debug_amplitudes=True, normals=None, em_delay=0, batch_size=1):
        super(VMFModel, self).__init__()
        self.K = K
        self.max_kappa = max_kappa
        self.min_kappa = min_kappa
        self.num_pixels = width * height
        self.alpha = alpha
        self.em_delay = em_delay
        self.batch_size = 10
        self.device = device if device is not None else torch.device("cuda")
        self.params_per_lobe = 8
        self.vmf_coefficients = torch.zeros(self.num_pixels, K * self.params_per_lobe, device=self.device)
        self.S_data = torch.zeros((K, self.num_pixels, 3), device=self.device)
        self.S_weight_accum = torch.zeros((K, self.num_pixels, 1), device=self.device)
        self.stats_point_weight = torch.zeros((self.num_pixels, 1), device=self.device)
        self.accumulated_lum = torch.zeros((self.num_pixels, 1), device=self.device)
        self.amplitudes = torch.nn.Parameter(torch.ones((self.num_pixels, K, 3), device=device_t))
        self.optimizer = torch.optim.Adam(self.parameters(), lr=gparams.ADAM_LR, betas=gparams.ADAM_BETAS)

        self.init_params(kappa_force, debug_amplitudes=debug_amplitudes, normals=normals)
        self.rL2 = 0

        self.params_per_pixel = K * (self.params_per_lobe - 1)
    
    def set_checkpoint_path(self, checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.checkpoint_path = os.path.join(checkpoint_dir, f'vmf_{self.params_per_pixel}_checkpoint.pth')

    def load_state_dict(self, state):
        self.K = state['K']
        self.max_kappa = state['max_kappa']
        self.min_kappa = state['min_kappa']
        self.num_pixels = state['num_pixels']
        self.alpha = state['alpha']
        self.em_delay = state['em_delay']
        self.batch_size = state['batch_size']
        self.device = state['device']
        self.params_per_lobe = state['params_per_lobe']
        self.vmf_coefficients = state['vmf_coefficients']
        self.S_data = state['S_data']
        self.amplitudes.data = state['amplitudes']  # Load amplitudes
        self.S_weight_accum = state['S_weight_accum']
        self.stats_point_weight = state['stats_point_weight']
        self.accumulated_lum = state['accumulated_lum']
        self.params_per_pixel = state['params_per_pixel']
        super().load_state_dict(state['model_state'])

    def state_dict(self):
        state = {
            'K': self.K,
            'max_kappa': self.max_kappa,
            'min_kappa': self.min_kappa,
            'num_pixels': self.num_pixels,
            'alpha': self.alpha,
            'em_delay': self.em_delay,
            'batch_size': self.batch_size,
            'device': self.device,
            'params_per_lobe': self.params_per_lobe,
            'vmf_coefficients': self.vmf_coefficients,
            'S_data': self.S_data,
            'amplitudes': self.amplitudes.data,  # Include amplitudes
            'S_weight_accum': self.S_weight_accum,
            'stats_point_weight': self.stats_point_weight,
            'accumulated_lum': self.accumulated_lum,
            'params_per_pixel': self.params_per_pixel,
            'model_state': super().state_dict()  # Save the state of the nn.Module
        }
        return state

    def save_checkpoint(self, epoch, loss):
        if self.checkpoint_path:
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.state_dict(),  # Use the updated state_dict
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': loss,
                'rL2': self.rL2,
            }, self.checkpoint_path)

    def load_checkpoint(self, skip=False):
        if not skip and self.checkpoint_path and os.path.exists(self.checkpoint_path):
            print(self.checkpoint_path)
            checkpoint = torch.load(self.checkpoint_path)
            self.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            self.rL2 = checkpoint['rL2']

            print(f"Checkpoint loaded: Epoch {epoch}, Loss {loss}")
            return epoch, loss
        else:
            return 0, None

    def init_params(self, kappa_force=None, debug_amplitudes=False, normals=None):
        # Fibonacci sphere sampling for mean direction initialization
        golden_ratio = (1 + 5 ** 0.5) / 2
        indices = torch.arange(0, self.K, dtype=torch.float32)
        theta = 2 * np.pi * indices / golden_ratio
        phi = torch.acos(1 - 2 * (indices + 0.5) / self.K)
        x = torch.cos(theta) * torch.sin(phi)
        y = torch.sin(theta) * torch.sin(phi)
        z = torch.cos(phi)

        mean_directions_init = torch.stack([x, y, z], dim=1).to(self.device)
        
        # Compute initial sharpness for kappa values based on direction pairs
        sharpness = self.calculate_kappa(kappa_force, mean_directions_init)
        if self.min_kappa is None:
            self.min_kappa = sharpness

        mean_directions_init = mean_directions_init.unsqueeze(0).repeat(self.num_pixels, 1, 1)
        print(mean_directions_init.shape)
        if normals is not None:
            # Assume normals are normalized and have shape (num_pixels, 3)
            mean_directions_init = self.reproject_to_tangent_space(normals, mean_directions_init)


        print(f"Kappa is init with value: {sharpness} for the {self.K} vMF lobes")

        # Initialize vmf_coefficients with new sharpness values
        for k in range(self.K):
            start = k * self.params_per_lobe
            self.vmf_coefficients[:, start:start+3] = mean_directions_init[:, k]
            self.vmf_coefficients[:, start+3] = torch.full((self.num_pixels,), sharpness, device=self.device)
            self.vmf_coefficients[:, start+4] = 1.0 / self.K

        # Debug amplitude initialization
        if debug_amplitudes:
            self.amplitudes.data = self.generate_colors(self.K, self.num_pixels, self.device)

    def calculate_kappa(self, kappa_force, mean_directions):
        if kappa_force is not None:
            return kappa_force
        minDP = 1.0
        for i in range(1, self.K):
            h = (mean_directions[i] + mean_directions[0])
            h = h / (h.norm(dim=-1, keepdim=True) + 0.00001)
            minDP = min(minDP, torch.dot(h, mean_directions[0]).item())
        if self.K == 1:
            return torch.tensor(1.0, device=self.device)
        else:
            return (torch.log(torch.tensor(0.65, device=self.device)) * self.K) / (minDP - 1.0001)

    def reproject_to_tangent_space(self, normals, directions):
        # Create local coordinate systems
        up = torch.tensor([0.0, 0.0, 1.0], device=self.device).expand_as(normals)
        tangent = torch.cross(normals, up)
        tangent = tangent / (tangent.norm(dim=1, keepdim=True) + 1e-10)  # Normalize
        bitangent = torch.cross(normals, tangent)
        bitangent = bitangent / (bitangent.norm(dim=1, keepdim=True) + 1e-10)

        # Expand directions to match each pixel
        expanded_directions = directions

        # Initialize the container for reprojected directions
        reprojected_directions = torch.zeros_like(expanded_directions)  # [num_pixels, K, 3]

        # Form the rotation matrix for each normal
        rotation_matrix = torch.stack([tangent, bitangent, normals], dim=-1)  # [num_pixels, 3, 3]

        # Reproject each lobe's directions
        for k in range(self.K):
            direction_k = expanded_directions[:, k, :]  # Now correctly indexing [num_pixels, 3]
            
            # Apply the rotation to the k-th direction vector
            reprojected_dir_k = torch.bmm(rotation_matrix, direction_k.unsqueeze(-1)).squeeze(-1)  # [num_pixels, 3]
            
            # Ensure the reprojected direction is in the upper hemisphere relative to the normals
            dot_products = torch.sum(reprojected_dir_k * normals, dim=1, keepdim=True)
            mask = dot_products < 0
            reprojected_dir_k[mask.squeeze(1)] *= -1

            # Store the reprojected direction back into the container
            reprojected_directions[:, k, :] = reprojected_dir_k

        return reprojected_directions


    def generate_colors(self, K, num_pixels, device):
        # Generate distinct colors by varying the hue in HSV space
        hues = torch.linspace(0, 1 - 1e-6, steps=K)  # Small epsilon to simulate endpoint=False
        colors = torch.zeros((K, 3), device=device)
        for i, hue in enumerate(hues):
            colors[i] = torch.tensor(self.hsv_to_rgb(hue, 1.0, 1.0))  # Saturation and Value are set to 1 for vibrant colors
        return colors.unsqueeze(1).expand(-1, num_pixels, -1).permute(1, 0, 2)  # Shape: (num_pixels, K, 3)

    def hsv_to_rgb(self, h, s, v):
        i = int(h * 6.)
        f = h * 6. - i
        p = v * (1. - s)
        q = v * (1. - f * s)
        t = v * (1. - (1. - f) * s)
        i = i % 6
        if i == 0:
            return [v, t, p]
        if i == 1:
            return [q, v, p]
        if i == 2:
            return [p, v, t]
        if i == 3:
            return [p, q, v]
        if i == 4:
            return [t, p, v]
        if i == 5:
            return [v, p, q]

    def forward(self, dirs, mask=None, debugID=None):
        return self.compute_radiance(dirs, mask, debugID)

    def compute_radiance(self, dirs, mask = None, debugID = None):
        num_pixels = dirs.shape[0]
        radiance = torch.zeros((num_pixels, 3), device=self.device)
        vmfs_ = self.vmf_coefficients
        amplitudes_ = self.amplitudes

        if mask != None:
            vmfs_ = vmfs_[mask]
            amplitudes_ = amplitudes_[mask]

        if debugID is not None:
            vmfs_ = self.vmf_coefficients[debugID].unsqueeze(0).repeat(num_pixels, 1)
            amplitudes_ = self.amplitudes[debugID].unsqueeze(0).repeat(num_pixels, 1, 1)


        for k in range(self.K):
            start_index = k * self.params_per_lobe
            mu = vmfs_[:, start_index:start_index+3]
            kappa = vmfs_[:, start_index+3]
            weight = vmfs_[:, start_index+4]
            mu = torch.nn.functional.normalize(mu, dim=1)
            amplitudes = amplitudes_[:, k]
            cos_theta = torch.einsum('ij,ij->i', dirs, mu)
            pdf_constant = kappa / (2 * torch.pi * (1.0 - torch.exp(-2.0 * kappa)))
            pdf_values = pdf_constant * torch.exp(kappa * (cos_theta - 1))
            amplitudes = torch.relu(amplitudes)
            assert(check_tensor(weight, "weight"))
            assert(check_tensor(pdf_values, "pdf_values"))
            assert(check_tensor(amplitudes, "amplitudes"))
            radiance += weight.unsqueeze(1) * pdf_values.unsqueeze(1) *amplitudes   # Assuming mu as amplitude
        return radiance

    def calculate_responsibilities(self, dirs, y_lum):
        num_pixels = dirs.shape[0]
        mus = torch.zeros((num_pixels, self.K, 3), device=dirs.device, dtype=dirs.dtype)
        kappas = torch.zeros((num_pixels, self.K), device=dirs.device, dtype=dirs.dtype)
        weights = torch.zeros((num_pixels, self.K), device=dirs.device, dtype=dirs.dtype)
        for k in range(self.K):
            start_idx = k * self.params_per_lobe
            mus[:, k, :] = self.vmf_coefficients[:, start_idx:start_idx+3]
            kappas[:, k] = self.vmf_coefficients[:, start_idx+3]
            weights[:, k] = self.vmf_coefficients[:, start_idx+4]


        if not gparams.SKIP_TENSOR_CHECK:
            row_sums = torch.sum(weights, dim=1)
            if not torch.allclose(row_sums, torch.ones_like(row_sums)):
                assert(check_tensor(responsibilities, "responsibilities", True))
                print("responsibilities Normalization check failed.")
                print("Row sums:", row_sums)
                assert(0)
        
        weights_sum = torch.sum(weights, dim=1)
        if not torch.allclose(weights_sum, torch.ones_like(weights_sum)):
            assert(check_tensor(weights_sum, "weights_sum", True))
            print("weights Normalization check failed.")
            print("Row sums:", weights_sum)
            assert(0)

        dot_products = torch.einsum('ijk,ik->ij', mus, dirs)

        pdf_constant = kappas / (2 * torch.pi * (1.0 - torch.exp(-2.0 * kappas)))

        pdf_values = pdf_constant * torch.exp(kappas * (dot_products - 1))


        pdf_values = pdf_values+1e-8
        responsibilities = weights * pdf_values
        responsibilities = responsibilities+1e-10
        log_likelihood = (torch.log(responsibilities)*y_lum).mean()
        #assert(check_tensor(responsibilities[0], "responsibilities_before_norm", True))

        eps = 1e-16
        responsibilities = responsibilities/(torch.sum(responsibilities, dim=1, keepdim=True)+eps)
        #responsibilities = torch.exp(responsibilities - log_sum_exp(responsibilities))

        
        assert(check_tensor(responsibilities, "responsibilities"))
        assert(check_tensor(responsibilities, "responsibilities"))


        if not gparams.SKIP_TENSOR_CHECK:
            row_sums = torch.sum(responsibilities, dim=1)

            assert(check_tensor(row_sums, "rows_sums"))
            if not torch.allclose(row_sums, torch.ones_like(row_sums)):
                assert(check_tensor(responsibilities, "responsibilities", True))
                print("responsibilities Normalization check failed.")
                print("Row sums:", row_sums)
                assert(0)

        assert(check_tensor(responsibilities, "responsibilities"))
        return responsibilities, log_likelihood

    def update_parameters(self, new_w, epoch):
        
        #norm = log_sum_exp(self.S_weight_accum, 0)
        #nw = torch.exp(self.S_weight_accum - norm)
        nw = self.S_weight_accum/torch.sum(self.S_weight_accum, dim=0, keepdim=True)
        
        for k in range(self.K):
            S_data_k = self.S_data[k]
            norm_S = torch.sqrt(torch.sum(S_data_k ** 2, dim=1, keepdim=True))

            eps = 1e-18
            new_mu = S_data_k / (norm_S + eps)
            assert(check_tensor(new_mu, "new_mu"))
            R_bar = norm_S / (self.S_weight_accum[k] + eps)
            assert(check_tensor(R_bar, "R_bar"))
            new_kappa = (R_bar * (3.0 - R_bar ** 2)) / (1.0 - R_bar ** 2 + eps)
            assert(check_tensor(new_kappa, "new_kappa"))
            new_kappa = torch.clamp(new_kappa, min=self.min_kappa, max=self.max_kappa)
            start = k * self.params_per_lobe
            old_weights = self.vmf_coefficients[:, start+4]


            assert(check_tensor(new_mu, "new_mu"))

            #new_weights = responsibilities[:, k] * weight_wo_responsibilities[:, k] + (1 - weight_wo_responsibilities[:, k]) * old_weights
            new_mu = new_mu*new_w+(1.0-new_w)*self.vmf_coefficients[:, start:start+3]
            new_mu = new_mu/(torch.sqrt(torch.sum(new_mu**2, dim=1, keepdim=True))+eps)
            
            self.vmf_coefficients[:, start:start+3] = new_mu
            #assert(check_tensor(self.vmf_coefficients[:, start+3], "old_kappa", True))
            new_kappa = new_kappa.squeeze()*new_w+self.vmf_coefficients[:, start+3]*(1.0-new_w)
            self.vmf_coefficients[:, start+3] = new_kappa
            #assert(check_tensor(self.vmf_coefficients[:, start+3], "new_kappa", True))
            #self.vmf_coefficients[:, start+4] = nw[k].squeeze()*new_w+self.vmf_coefficients[:, start+4]*(1.0-new_w)

            # Update weights using translated C++ logic

            if True:
                float_min = 1e-38  # Close to the smallest positive normal float
                mask = self.S_weight_accum[k] > float_min

                
                num_data = epoch
                
                new_weights = ((self.S_weight_accum[k] / (self.stats_point_weight + eps)) * num_data + 1e-2) / (num_data + self.K * 1e-2)

                
                mask = mask.squeeze()
                new_weights = new_weights.squeeze()
                
                self.vmf_coefficients[:, start+4][mask] = new_weights[mask]
                self.vmf_coefficients[:, start+4][~mask] = self.vmf_coefficients[:, start+4][~mask]  # Maintain previous weights where mask is False
            else:
                self.vmf_coefficients[:, start+4] = nw[k].squeeze()*new_w+self.vmf_coefficients[:, start+4]*(1.0-new_w)


            if not gparams.SKIP_TENSOR_CHECK and False:
                # Check if new_mu is normalized
                if not torch.allclose(torch.norm(new_mu, dim=1), torch.ones(new_mu.shape[0], device=new_mu.device), atol=1e-5):
                    check_tensor(torch.norm(new_mu, dim=1), "norm", True)
                    print(f"Normalization check failed for lobe {k}. Norms of new_mu:", torch.norm(new_mu, dim=1))
                    assert(0)
        
            #self.vmf_coefficients[:, start+3] = new_kappa.squeeze()
            #self.vmf_coefficients[:, start+4] = nw[k].squeeze()

            assert(check_tensor(new_mu, "new_mu"))

            
            assert(check_tensor(new_kappa, "new_kappa"))

            if epoch % 50 == 0 and False:
                assert(check_tensor(new_kappa, "new_kappa", True))




        # After the loop, re-normalize weights to ensure they sum to 1 across all K lobes for each pixel
        weight_sums = torch.sum(self.vmf_coefficients[:, 4::self.params_per_lobe], dim=1, keepdim=True)  # Sum weights across all K
        self.vmf_coefficients[:, 4::self.params_per_lobe] /= weight_sums  # Normalize weights
        assert(check_tensor(self.vmf_coefficients[:, 4::self.params_per_lobe], "new_weights"))

    def EM_step(self, epoch, dirs, y, update_params=True):
        assert(check_tensor(y, "y"))
        y_luminance = torch.norm(y,  dim=1, keepdim=True)

        
        y_luminance += 1e-8
        responsibilities, log_likehood = self.calculate_responsibilities(dirs, y_luminance)
        new_w = 1.0
        old_w = 1.0

        if self.alpha != 0.0:
            new_w = pow(epoch, -self.alpha)
            old_w = 1.0-new_w
        
        if epoch != 0:
            new_accum = y_luminance + self.accumulated_lum
            weight = y_luminance

            self.accumulated_lum = new_accum
        else:
            self.accumulated_lum = y_luminance
            weight = torch.ones((self.num_pixels, 1), device=self.device) * y_luminance

        self.stats_point_weight = old_w*self.stats_point_weight+new_w*weight
        weight = weight.expand(-1, self.K)

        if True:
            weight = weight * responsibilities
        else:
            weight = responsibilities

        assert(check_tensor(dirs, "dirs"))
        assert(check_tensor(responsibilities, "responsibilities"))
        assert(check_tensor(y_luminance, "y_luminance"))

        for k in range(self.K):
            self.S_data[k] = self.S_data[k]*old_w+dirs * weight[:, k].unsqueeze(1)*new_w
            self.S_weight_accum[k] = self.S_weight_accum[k]*old_w+weight[:, k].unsqueeze(1)*new_w

        assert(check_tensor(self.S_data, "S_data"))
        assert(check_tensor(self.S_weight_accum, "S_weight_accum"))
        assert(check_tensor(responsibilities, "responsibilities"))
        if (epoch > self.em_delay and (epoch % self.batch_size) == 0):
            self.update_parameters(new_w, epoch)
        return log_likehood

    def visualize_directions_for_pixel(self, pixel_id):
        """Visualizes the vMF directions for a given pixel ID."""
        if pixel_id >= self.num_pixels:
            print("Pixel ID is out of bounds.")
            return

        # Extract the direction vectors for the given pixel ID
        directions = self.vmf_coefficients[pixel_id, :].view(self.K, self.params_per_lobe)[:, :3]  # Assuming directions are stored in the first 3 columns

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim([-1, 1])
        ax.set_ylim([-1, 1])
        ax.set_zlim([-1, 1])

        # Plot the origin
        ax.scatter([0], [0], [0], color='red', label='Origin')

        # Plot each direction vector
        for i in range(directions.shape[0]):
            x, y, z = directions[i].cpu().numpy()
            ax.quiver(0, 0, 0, x, y, z, length=1.0, normalize=True)

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.legend()
        plt.title(f'Direction Vectors for Pixel ID {pixel_id}')
        plt.show()

def plot_radiance(radiance, title="Radiance"):
    plt.figure(figsize=(10, 4))
    radiance_np = radiance.detach().cpu().numpy()
    plt.imshow(np.clip(radiance_np.reshape(gparams.render_height, gparams.render_width, 3), 0, 1))
    plt.title(title)
    plt.axis('off')
    plt.show()


# Usage example
device_t = torch.device("cuda")


normals = mlDataOutput_ref["normal"]
# K = number of lobes, how many? I implemented online-em from the On-line Learning of Parametric Mixture Models for Light Transport Simulation. with the same regularization techniques, alpha value (convergence).
# lobes are generated using the Fibonacci sequence and only in the top hemisphere of a surface. that's why we forward the normals. batch = we need it, really. as in that paper. 
model_vmf = VMFModel(gparams.render_width, gparams.render_height, K=11,max_kappa=250, alpha=0.75, device=device_t, debug_amplitudes=False, normals=normals, em_delay=20, batch_size=40)
model_vmf.set_checkpoint_path(scene_cfg.checkpoint_dir)



EM_STEPS = gparams.num_training_frames  # number of epochs
EM_STEPS= 2500
AMPLITUDE_DELAY = 20

AMPLITUDE_ADDITIONAL_STEPS = 500

Visualize = True

N = EM_STEPS+AMPLITUDE_ADDITIONAL_STEPS
start_epoch = 0
start_epoch, _ = model_vmf.load_checkpoint(skip=False)
scene_cfg.prepare_scene(ref_scene)

isFirstWas = False

estimation_rad = model_vmf(mlDataOutput["dir"])
variance_helper = VarianceHelper(gparams.render_width, gparams.render_height, device=device_t)

onlyRender = False
if start_epoch >= N:
    if gparams.RECOMPUTE_LOSS["VMF"]:
        start_epoch = N
        model_vmf.rL2 = 0
    else:
        onlyRender = True

re_accum = 1
average_loss = 0
loss = 0
for epoch in range(start_epoch, N + gparams.scene_cgf.var_est_steps):
    start_render = time.time()
    trainFrame()

    # Get the newly rendered data
    mlDataOutput = falcor_to_torch_split_interleaved(ml_data, field_types)

    dirs = mlDataOutput["dir"] + 0.0
    # Find any vectors that are significantly off from length 1
    lengths = torch.sqrt(torch.sum(dirs ** 2, dim=1, keepdim=True))
    tolerance = 1e-5  # Set a tolerance level for floating point precision
    
    global incorrect_indices
    incorrect_indices = torch.where(torch.abs(lengths - 1) > tolerance)[0]
    
    # (#pixels, 3)
    y = mlDataOutput["radiance"] + 0.0
    pdf = safe_pdf(mlDataOutput["pdf"]+0.0)
    mask = get_mask_for_training(y)
    mask = mask 

    y = torch.relu(y)

    y[~mask] = 0
    dirs[~mask, :] = 0.0
    dirs[~mask, 0] = 1.0
    loss_val = None
    l2 = None
    used = torch.count_nonzero(mask).item()
    all = mask.shape[0]

    loss_epoch = epoch-N
    if epoch >= AMPLITUDE_DELAY and not onlyRender:
        
        
        bTrainNow = loss_epoch < 0
        if bTrainNow:
            model_vmf.optimizer.zero_grad()   
        # Forward pass

        
        estimation_rad = model_vmf(dirs)
        
        # Simulate a target for loss calculation
        target = torch.randn(model_vmf.num_pixels, 3, device=device_t)
        #inf_mask = torch.isinf(estimation_rad)


        #estimation_rad[~mask] = 0.0
        
        y_masked = y[mask]
        estimation_rad_masked = estimation_rad[mask]
        batch_pdf = pdf[mask]

        batch_y = y_masked
        if not bTrainNow:
            estimation_rad_masked = estimation_rad_masked.detach()
    

        if bTrainNow:
            #loss = relativeL2_PDF_NotApplied(estimation_rad_masked, y_masked, pdf=batch_pdf, div=ref_estimations[mask])
            #loss = torch.mean((estimation_rad_masked/batch_pdf - y_masked)**2/batch_pdf)
            mc = estimation_rad_masked/batch_pdf
            loss = relativeL2(mc, y_masked, pdf=batch_pdf, div=ref_estimations[mask])

        rL2 = relativeL2(ref_estimations[mask], y_masked, pdf=batch_pdf, div=ref_estimations[mask]).item()

        if bTrainNow:
            total_loss = loss 
            loss_val = loss.item()
        else:
            loss_val = 0




        if bTrainNow:
            if epoch == 0:
                model_vmf.rL2 = rL2
            model_vmf.rL2 = rL2*0.90+model_vmf.rL2*0.10
        else:
            overe = 1.0 / re_accum
            model_vmf.rL2 = rL2* overe + (1-overe) * model_vmf.rL2
            bias = estimation_rad_masked/batch_pdf -y_masked
            #variance_helper.update(mask, bias)
            re_accum += 1
    
        # Backward and optimize
        if bTrainNow:
            # backpropate to optimize the Amplitude
            loss.backward()

        if epoch == AMPLITUDE_DELAY:
            average_loss = loss
        else:
            alpha = 0.95
            average_loss = average_loss*alpha+(1.0-alpha)*loss
        
        # Backward and optimize
        if bTrainNow:
            torch.nn.utils.clip_grad_norm_(model_vmf.parameters(),  5000)
            model_vmf.optimizer.step()

        

        if (loss_val == 0 or math.isinf(loss_val) or math.isnan(loss_val)) and bTrainNow and False:
            #print(loss_val)
            start_epoch, _ = model_vmf.load_checkpoint(skip=False)
            epoch = start_epoch
            print("RESTART!")
            continue
        
    log_likehood = None
    if epoch < EM_STEPS and not onlyRender:
        log_likehood = model_vmf.EM_step(epoch+1, dirs, y).item()
        if (epoch+1) % 100 == 0 and False:
            model_vmf.visualize_directions_for_pixel(0)

    conditions = [
            (100, 10),
            (25, 5),
            (50, 50),
    ]
    
    if (epoch) % 100 == 99:
        model_vmf.save_checkpoint(epoch, average_loss)
    if should_visualize(epoch, conditions) and (bTrainNow or not isFirstWas) and Visualize or loss_epoch == 0 or onlyRender:
        plot_radiance(estimation_rad/pdf)
    if (epoch+1) % 10 == 0 or onlyRender:
        print(f'Epoch [{epoch+1}/{N}],  Log_Likehood: {log_likehood} L2: {loss_val} rL2: {model_vmf.rL2} used: {used}/{all} variance: {variance_helper.get_variance()}')
    isFirstWas = True

    if onlyRender:
        break

In [None]:
global gt_render_cache
gt_render_cache = None


from matplotlib import colors

import matplotlib.pyplot as plt
import pandas as pd
import os
from matplotlib import font_manager as fm
from matplotlib.ticker import ScalarFormatter
import numpy as np
def scientific_formatter(x, pos):
    if x == 0:
        return '0'
    elif x < 0:
        return f'{x:.1f}'
    else:
        exponent = int(np.floor(np.log10(abs(x))))
        mantissa = x / (10**exponent)
        return f'{mantissa:.1f}e{exponent}'
    

class ExperimentPreview:
    num_samples: int
    scene_cfg: SceneConfig

    showNIRC: bool
    showSH: bool
    showVMF: bool

    def __init__(self, scene_cfg: SceneConfig, num_samples: int = 64, showNIRC: bool = True, showSH: bool = True, showVMF: bool = True, ncv = None, showNIRCEQMem: bool = True, exposure: float = 1.0):
        self.num_samples = num_samples
        self.scene_cfg = scene_cfg
        self.showNIRC = showNIRC
        self.showNIRCEQMem = showNIRCEQMem
        self.showSH = showSH
        self.exposure = exposure
        self.showVMF = showVMF
        self.ncv = ncv
        

    def do(self, gt_render = None, exposures=[5, 5, 5, 5]):
        global gt_render_cache

        # Set the path to the font file
        fpath = "LinLibertine_R.ttf"
        prop = fm.FontProperties(fname=fpath)
        from matplotlib.ticker import FuncFormatter, MaxNLocator


        figsize = (12,6)
        plt.figure(figsize=figsize)  # Adjusted for three subplots
        ax = plt.gca()
        ax.set_facecolor('#e6f2ff')
        # Plot Bias vs rVar


        sci_formatter = FuncFormatter(scientific_formatter)
        sci_formatter2 = ScalarFormatter(useMathText=True)
        sci_formatter2.set_scientific(True)
        sci_formatter2.set_powerlimits((0, 0))
        font_scalar = 200
        linewidth = 3
        markersize = 3

        if self.ncv and self.showNIRC:
            plt.rcParams['font.size'] = 28  # Change this to a reasonable size
            # Unpack epochs and variances
            #epochs_ncv_in, variance_values_ncv_in = zip(*ncv_integral_model.variance_helper.get_cached_variance_per_epoch())
            epochs_ncv, variance_values_ncv = zip(*self.ncv.variance_helper.get_cached_variance_per_epoch())
            epochs_nirc, variance_values_nirc = zip(*nirc_model.variance_helper.get_cached_variance_per_epoch())
            #plt.plot(epochs_ncv_in, variance_values_ncv_in, label='NCVin', color="#2b5c6f", marker="o", markersize=markersize, linewidth=linewidth)
            plt.plot(epochs_ncv, variance_values_ncv, label='NCV', color="#9b5c6f", marker="o", markersize=markersize, linewidth=linewidth)
            plt.plot(epochs_nirc, variance_values_nirc, label='NIRC', color="#537b8d", marker="o", markersize=markersize, linewidth=linewidth)
            #plt.gca().xaxis.set_major_formatter(sci_formatter)  # Apply scientific formatting to x-axis
            plt.xticks(fontsize=12, fontproperties=prop)
            plt.yticks(fontsize=12, fontproperties=prop)
                # Plot NCV and NIRC variance data

            #plt.gca().xaxis.set_major_formatter(sci_formatter)
            plt.gca().yaxis.set_major_formatter(sci_formatter)  # Apply scientific formatting to x-axis
            #plt.gca().xaxis.set_major_locator(MaxNLocator(nbins=3))  # Limit the number of ticks on X axis to 4
            plt.gca().yaxis.set_major_locator(MaxNLocator(nbins=3))  # Limit the number of ticks on Y axis to 3
            plt.grid(True, linestyle='-', alpha=1.0, color="white", linewidth=2)
            plt.gca().set_facecolor('#e6f2ff')
            plt.legend(loc='upper right', fontsize=22, prop=prop)

            h_pad = 0.0
            w_pad = 0.0
            pad = 0.0
            plt.tight_layout(h_pad=h_pad, w_pad=w_pad,pad=pad)
            plt.savefig(os.path.join(scene_cfg.checkpoint_dir, f'variance.svg'), format='svg')
            #plt.show()
            plt.close()

        if gt_render == None:
            if gt_render_cache is None:
                self.scene_cfg.prepare_scene(ref_scene)
                gt_render_cache = frameRender(num_samples=1, tonemapped=True, vis=False)
            
            gt_render = gt_render_cache

        
        fig, ax = plt.subplots(1)
        pi_colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'black', 'white']

        ax.imshow(gt_render)

        for i, p in enumerate(scene_cfg.validation_points):
            color = pi_colors[i % len(pi_colors)]
            circle = patches.Circle((p.x, p.y), gparams.render_width // 100, facecolor=color, edgecolor='black', linewidth=(gparams.render_width // 500))
            ax.add_patch(circle)
        
        plt.show()
        plt.savefig(os.path.join(scene_cfg.checkpoint_dir, f'experiment_preview.png'))

        n_cameras = len(scene_cfg.validation_points)
        figure, axes = plt.subplots(n_cameras, 6, figsize=(25, 6 * n_cameras))  # Adding another column for SH render
        figure.suptitle('Ground Truth, Predicted Renders, and SH Renders', fontsize=16)

        if n_cameras == 1:
            axes = np.array([axes])
        
        for i, c in enumerate(scene_cfg.validation_points):
            scene_cfg.prepare_scene(ref_scene, debugPointID=i)
            validation_gt_render = frameRender(num_samples=self.num_samples, vis=False, tonemapped=False, directEmissive=False, directSky=True)
            brdf = getBRDF([scene_cfg.validation_points[i].x, scene_cfg.validation_points[i].y])
            brdf_gpu = torch.from_numpy(brdf).cuda()
            brdf_gpu = brdf_gpu.view(-1, 3)
            
            
            validation_gt_render = (validation_gt_render*brdf)
            #validation_gt_render = validation_gt_render[:, :, 0]

            # Extract pixel-specific data
            pi = c.id
            # Generate rays and other inputs for the model
            mlDataRaysOutput = falcor_to_torch_split_interleaved(ml_rays_data, rays_fields)
            positions = mlDataRaysOutput["worldpos"]+0.0
            directions = mlDataRaysOutput["dir"]+0.0

            color = mlDataOutput_ref["color"][pi]+0.0
            r = mlDataOutput_ref["roughness"][pi]+0.0
            n = mlDataOutput_ref["normal"][pi]+0.0
            v = mlDataOutput_ref["view"][pi]+0.0

            exposure = exposures[i]



            if True:
                tonemapped_gt = apply_gamma_correction(aces_tonemap(validation_gt_render, exposure=exposure))
                # Visualization
                axes[i, 0].imshow(tonemapped_gt)
                axes[i, 0].set_title('Ground Truth', fontsize=10)
                plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_gt{i}.png'), tonemapped_gt)
                if False:
                    padding = gparams.render_width // 100
                    
                    padded_image = np.pad(tonemapped_gt, ((padding, padding), (padding, padding), (0, 0)), mode='constant', constant_values=0)
                    
                    padded_image[:padding, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                    padded_image[-padding:, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                    padded_image[:, :padding, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                    padded_image[:, -padding:, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                    plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_gt{i}.png'), tonemapped_gt)
                        

            if self.showNIRC:
                roughness = r[None, :] * torch.ones((positions.shape[0], 1), device=positions.device)
                colors_arr = color[None, :] * torch.ones((positions.shape[0], 1), device=positions.device)
                normals = n[None, :] * torch.ones((positions.shape[0], 1), device=positions.device)


                X = nirc_model.prepare_input_(positions, directions, colors_arr, roughness, normals)
                assert X.shape[1] == 12, 'Input shape mismatch'
                
                Y = nirc_model(X, thp=brdf_gpu).detach()
                Y_image = Y.view(gparams.render_height, gparams.render_width, 3).detach().cpu().numpy()
                Y_image = Y_image
                Y_image = Y_image.astype(np.float32)
                
                l2_error_model = np.linalg.norm(validation_gt_render - Y_image) / np.sqrt(validation_gt_render.size)
                tonemapped_nirc = apply_gamma_correction(aces_tonemap(Y_image, exposure=exposure))
                
                # Adding colored padding
                padding = gparams.render_width // 100
                padded_image = np.pad(tonemapped_nirc, ((padding, padding), (padding, padding), (0, 0)), mode='constant', constant_values=0)
                
                padded_image[:padding, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[-padding:, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, :padding, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, -padding:, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                
                plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_nirc{i}.png'), tonemapped_nirc)

                axes[i, 1].imshow(padded_image)
                axes[i, 1].set_title(f'NIRC Render rL2 = {nirc_model.rL2}:.4f rVar = {nirc_model.variance_helper.get_variance():.4f}', fontsize=10)
                
            if self.showNIRCEQMem:
                
                pixel_indices = torch.ones((positions.shape[0]), device=positions.device) * pi  
                pixel_indices = torch.stack((pixel_indices % gparams.render_width, pixel_indices // gparams.render_width), dim=1).to(torch.int) 
            
                dirs, pixel_indices = nirc_eqmem_model.prepare_input_(directions, pixel_indices)
                Y = nirc_eqmem_model(dirs, pixel_indices=pixel_indices, thp=brdf_gpu).detach()
                Y_image = Y.view(gparams.render_height, gparams.render_width, 3).detach().cpu().numpy()
                Y_image = Y_image.astype(np.float32)
                l2_error_model = np.linalg.norm(validation_gt_render - Y_image) / np.sqrt(validation_gt_render.size)
                tonemapped_nirc_eqmem = apply_gamma_correction(aces_tonemap(Y_image, exposure=exposure))
                
                # Adding colored padding
                padding = gparams.render_width // 100
                padded_image = np.pad(tonemapped_nirc_eqmem, ((padding, padding), (padding, padding), (0, 0)), mode='constant', constant_values=0)
                padded_image[:padding, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[-padding:, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, :padding, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, -padding:, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                
                plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_nirc_eqmem{i}.png'), tonemapped_nirc_eqmem)
                
                axes[i, 2].imshow(padded_image)
                axes[i, 2].set_title(f'NIRC PerPixel Render L2 error = {nirc_eqmem_model.rL2:.4f}, var = {nirc_eqmem_model.variance_helper.get_variance():.4f}', fontsize=10)
            directions = mlDataRaysOutput["dir"]+0.0    
            if self.showSH:
                # SH Rendering
                shs_render = sh_model(directions, debugID=pi)

                shs_render_image = shs_render.view(gparams.render_height, gparams.render_width, 3).detach().cpu().numpy()
                shs_render_image = shs_render_image.astype(np.float32)
                l2_error_sh = np.linalg.norm(validation_gt_render - shs_render_image)/ np.sqrt(validation_gt_render.size)
                tonemapped_shs = apply_gamma_correction(aces_tonemap(shs_render_image, exposure=exposure))
                
                # Adding colored padding
                padding = gparams.render_width // 100
                padded_image = np.pad(tonemapped_shs, ((padding, padding), (padding, padding), (0, 0)), mode='constant', constant_values=0)
                padded_image[:padding, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[-padding:, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, :padding, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, -padding:, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])

                plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_sh{i}.png'), tonemapped_shs)

                axes[i, 4].imshow(padded_image)
                axes[i, 4].set_title(f'SH Render L2 error = {l2_error_sh:.4f}, rL2 = {sh_model.rL2:.4f}', fontsize=10)

            if self.showVMF:
                vmf_render = model_vmf(directions, debugID=pi)  # Adjust based on your vmf_coefficients structure
                vmf_render_image = vmf_render.view(gparams.render_height, gparams.render_width, 3).detach().cpu().numpy()
                vmf_render_image = vmf_render_image.astype(np.float32)
                l2_error_vmf = np.linalg.norm(validation_gt_render - vmf_render_image)/ np.sqrt(validation_gt_render.size)
                tonemapped_vmf = apply_gamma_correction(aces_tonemap(vmf_render_image, exposure=exposure))
                
                # Adding colored padding
                padding = gparams.render_width // 100
                padded_image = np.pad(tonemapped_vmf, ((padding, padding), (padding, padding), (0, 0)), mode='constant', constant_values=0)
                padded_image[:padding, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[-padding:, :, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, :padding, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])
                padded_image[:, -padding:, :] = colors.to_rgb(pi_colors[i % len(pi_colors)])

                plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_vmf{i}.png'), tonemapped_vmf)

                axes[i, 5].imshow(padded_image)
                axes[i, 5].set_title(f'vMF Render L2 error = {l2_error_vmf:.4f}, rL2 = {model_vmf.rL2:.4f}', fontsize=10)

            if self.ncv is not None:
                roughness = r[None, :] * torch.ones((positions.shape[0], 1), device=positions.device)
                colors_arr = color[None, :] * torch.ones((positions.shape[0], 1), device=positions.device)
                normals = n[None, :] * torch.ones((positions.shape[0], 1), device=positions.device)
                views = v[None, :]* torch.ones((positions.shape[0], 1), device=positions.device)

                pixel_indices = torch.ones((positions.shape[0]), device=positions.device) * pi  
                pixel_indices = torch.stack((pixel_indices % gparams.render_width, pixel_indices // gparams.render_width), dim=1).to(torch.int) 

                c_multiplier = ref_estimations[pi]* torch.ones((positions.shape[0], 1), device=positions.device)
                x, jacobian, pixels = self.ncv.prepare_input_(directions, pixel_indices)
                world = self.ncv.world_prepare_input_(positions, colors_arr, roughness, normals, views)

                Y, integral = self.ncv(x=x, jacobian=jacobian, pixels=pixels, world_data = world, integral_values=c_multiplier, surface_color=colors_arr, detach=True)
                

                Y_image = Y.view(gparams.render_height, gparams.render_width, 3).detach().cpu().numpy()
                Y_image = Y_image.astype(np.float32)
                l2_error_model = np.linalg.norm(validation_gt_render - Y_image) / np.sqrt(validation_gt_render.size)
                tonemapped_nirc = apply_gamma_correction(aces_tonemap(Y_image, exposure=exposure))
                plt.imsave(os.path.join(scene_cfg.checkpoint_dir, f'experiment_ncv{i}.png'), tonemapped_nirc)

                axes[i, 3].imshow(tonemapped_nirc)
                axes[i,3].set_title(f'NCV Render rL2 = {self.ncv.rL2:.4f} rVar = {self.ncv.variance_helper.get_variance():.4f}', fontsize=10)
            
    
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(os.path.join(scene_cfg.checkpoint_dir, f'experiment_full_result.png'))
        plt.show()


ex = ExperimentPreview(scene_cfg=scene_cfg, num_samples=128, ncv=ncv_model, showNIRC=True, showNIRCEQMem=True, showSH=False, showVMF=True)
ex.do(exposures=[2, 2, 2, 2, 0.1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
nirc_model.variance_helper.get_variance()

In [None]:

ex = ExperimentPreview(scene_cfg=scene_cfg, showNCV=True ,showNIRC=False, showNIRCEQMem=False, showSH=False, showVMF=False)
ex.do(exposures=[3, 3, 1, 4, 0.1, 1, 1, 1, 1, 1, 1, 1])