Notes:  
- Number of levels for grid encoder seems to be main reason for instability of training. More levels is solving the issue

# Imports

In [1]:
%matplotlib inline
import os
from typing import Optional, Tuple, List, Union, Callable

import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import axes3d
from tqdm.notebook import tqdm, trange
import random
from PIL import Image

# For repeatability
seed = 2024
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
###
### Select training parameters
###
# Encoders
from enum import Enum
Encoders = Enum('Encoder', ["NONE","FREQ", "HASH"]) 
encoder_used = Encoders.HASH

## Freq encoder
d_input = 3           # Number of input dimensions
n_freqs = 10          # Number of encoding functions for samples
log_space = True      # If set, frequencies scale in log space
use_viewdirs = True   # If set, use view direction as input
n_freqs_views = 4     # Number of encoding functions for views
n_freqs_time = 4     # Number of encoding functions for time

## Hash encoder


# Stratified sampling
n_samples = 32         # Number of spatial samples per ray
perturb = True         # If set, applies noise to sample positions
inverse_depth = False  # If set, samples points linearly in inverse depth

# Model
d_filter = 128          # Dimensions of linear layer filters
n_layers = 3#3            # Number of layers in network bottleneck
skip = []#[4]              # Layers at which to apply input residual
use_fine_model = True   # If set, creates a fine model
d_filter_fine = 128     # Dimensions of linear layer filters of fine network
n_layers_fine = 4#6       # Number of layers in fine network bottleneck

# Hierarchical sampling
n_samples_hierarchical = 64   # Number of samples per ray
perturb_hierarchical = False  # If set, applies noise to sample positions

near, far = 2., 6.

# Optimizer
lr = 1e-4  # Learning rate
lr_hash_enc = 5e-4
scheduler_start_end_factors = [1.0, 0.9] # Linear decay of learning rate

# Training
n_epochs = 3000
batch_size = 2**14          # Number of rays per gradient step (power of 2)
one_image_per_step = True   # One image per gradient step (disables batching)
chunksize = 2**14           # Modify as needed to fit in GPU memory
center_crop = True          # Crop the center of image (one_image_per_)
center_crop_iters = 0      # Stop cropping center after this many epochs
display_rate = 50          # Display test output every X epochs
shuffle_data = True		    # Shuffle rays on every iteration. Mostly for debug. 

# Early Stopping
warmup_iters = 100          # Number of iterations during warmup phase
warmup_min_fitness = 10.0   # Min val PSNR to continue training at warmup_iters
n_restarts = 10             # Number of times to restart if training stalls

# We bundle the kwargs for various functions to pass all at once.
kwargs_sample_stratified = {
	'n_samples': n_samples,
	'perturb': perturb,
	'inverse_depth': inverse_depth
}
kwargs_sample_hierarchical = {
	'perturb': perturb
}

test_save_dir = f'./results/{encoder_used.name}/'
os.makedirs(test_save_dir, exist_ok=True)

# Data

In [3]:
# if not os.path.exists('./data/tiny_dnerf_data.npz'):
	# !wget https://github.com/ivanvoid/Tiny_Kilo-D-NeRF-NGP/blob/main/tiny_dnerf_data.npz
    # !wget https://github.com/ivanvoid/Tiny_Kilo-D-NeRF-NGP/blob/8091c93e53f9c77ccc84cd275aad28171d21a928/tiny_dnerf_data.npz

In [5]:
# data = np.load('tiny_dnerf_data.npz')
# data = np.load('tiny_dnerf_data.npz', allow_pickle=True)
import zipfile

# Check if the file is a valid zip file
is_zipfile = zipfile.is_zipfile('../data/tiny_dnerf_data.npz')
print(f"Is the file a valid zipfile? {is_zipfile}")

# If it is a valid zip file, inspect the contents
if is_zipfile:
    with zipfile.ZipFile('../data/tiny_dnerf_data.npz', 'r') as zip_ref:
        zip_ref.printdir()


Is the file a valid zipfile? True
File Name                                             Modified             Size
images.npy                                     1980-01-01 00:00:00      9720128
poses.npy                                      1980-01-01 00:00:00         5312
times.npy                                      1980-01-01 00:00:00          452
focal.npy                                      1980-01-01 00:00:00          136


In [6]:
data = np.load('../data/tiny_dnerf_data.npz')

n_training = 79
n_training = 79
testimg_idx = 80

images = torch.from_numpy(data['images'][:n_training,:,:,:3])
poses = torch.from_numpy(data['poses'][:n_training])
focal = torch.from_numpy(data['focal'])
times = torch.from_numpy(data['times'][:n_training])

# Test
testimg = torch.from_numpy(data['images'][testimg_idx,:,:,:3])
testpose = torch.from_numpy(data['poses'][testimg_idx])
testtime = torch.from_numpy(data['times'][testimg_idx:])

height, width = images.shape[1:3]


# Rays and Ray Sampling

In [None]:
def get_rays(
	height: int,
	width: int,
	focal_length: float,
	c2w: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
	r"""
	Find origin and direction of rays through every pixel and camera origin.
	"""

	# Apply pinhole camera model to gather directions at each pixel
	i, j = torch.meshgrid(
			torch.arange(width, dtype=torch.float32).to(c2w),
			torch.arange(height, dtype=torch.float32).to(c2w),
			indexing='ij')
	i, j = i.transpose(-1, -2), j.transpose(-1, -2)
	directions = torch.stack([(i - width * .5) / focal_length,
														-(j - height * .5) / focal_length,
														-torch.ones_like(i)
													 ], dim=-1)

	# Apply camera pose to directions
	rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1)

	# Origin is same for all directions (the optical center)
	rays_o = c2w[:3, -1].expand(rays_d.shape)
	return rays_o, rays_d
	
def sample_stratified(
	rays_o: torch.Tensor,
	rays_d: torch.Tensor,
	near: float,
	far: float,
	n_samples: int,
	perturb: Optional[bool] = True,
	inverse_depth: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
	r"""
	Sample along ray from regularly-spaced bins.
	"""

	# Grab samples for space integration along ray
	t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
	if not inverse_depth:
		# Sample linearly between `near` and `far`
		z_vals = near * (1.-t_vals) + far * (t_vals)
	else:
		# Sample linearly in inverse depth (disparity)
		z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

	# Draw uniform samples from bins along ray
	if perturb:
		mids = .5 * (z_vals[1:] + z_vals[:-1])
		upper = torch.concat([mids, z_vals[-1:]], dim=-1)
		lower = torch.concat([z_vals[:1], mids], dim=-1)
		t_rand = torch.rand([n_samples], device=z_vals.device)
		z_vals = lower + (upper - lower) * t_rand
	z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples])

	# Apply scale from `rays_d` and offset from `rays_o` to samples
	# pts: (width, height, n_samples, 3)
	pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
	return pts, z_vals


In [None]:
def sample_pdf(
	bins: torch.Tensor,
	weights: torch.Tensor,
	n_samples: int,
	perturb: bool = False
) -> torch.Tensor:
	r"""
	Apply inverse transform sampling to a weighted set of points.
	"""

	# Normalize weights to get PDF.
	pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]]

	# Convert PDF to CDF.
	cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]]
	cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # [n_rays, weights.shape[-1] + 1]

	# Take sample positions to grab from CDF. Linear when perturb == 0.
	if not perturb:
		u = torch.linspace(0., 1., n_samples, device=cdf.device)
		u = u.expand(list(cdf.shape[:-1]) + [n_samples]) # [n_rays, n_samples]
	else:
		u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) # [n_rays, n_samples]

	# Find indices along CDF where values in u would be placed.
	u = u.contiguous() # Returns contiguous tensor with same values.
	inds = torch.searchsorted(cdf, u, right=True) # [n_rays, n_samples]

	# Clamp indices that are out of bounds.
	below = torch.clamp(inds - 1, min=0)
	above = torch.clamp(inds, max=cdf.shape[-1] - 1)
	inds_g = torch.stack([below, above], dim=-1) # [n_rays, n_samples, 2]

	# Sample from cdf and the corresponding bin centers.
	matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]]
	cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,
											 index=inds_g)
	bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1,
												index=inds_g)

	# Convert samples to ray length.
	denom = (cdf_g[..., 1] - cdf_g[..., 0])
	denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
	t = (u - cdf_g[..., 0]) / denom
	samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

	return samples # [n_rays, n_samples]

def sample_hierarchical(
	rays_o: torch.Tensor,
	rays_d: torch.Tensor,
	z_vals: torch.Tensor,
	weights: torch.Tensor,
	n_samples: int,
	perturb: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
	r"""
	Apply hierarchical sampling to the rays.
	"""

	# Draw samples from PDF using z_vals as bins and weights as probabilities.
	z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
	new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples,
													perturb=perturb)
	new_z_samples = new_z_samples.detach()

	# Resample points from ray based on PDF.
	z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
	pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None]  # [N_rays, N_samples + n_samples, 3]
	return pts, z_vals_combined, new_z_samples

def cumprod_exclusive(
	tensor: torch.Tensor
) -> torch.Tensor:
	r"""
	(Courtesy of https://github.com/krrish94/nerf-pytorch)

	Mimick functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch.

	Args:
	tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1
		is to be computed.
	Returns:
	cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of
		tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).
	"""

	# Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
	cumprod = torch.cumprod(tensor, -1)
	# "Roll" the elements along dimension 'dim' by 1 element.
	cumprod = torch.roll(cumprod, 1, -1)
	# Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
	cumprod[..., 0] = 1.

	return cumprod


# Encodings

## Frequency encoder

In [None]:
class PositionalEncoder(nn.Module):
	r"""
	Sine-cosine positional encoder for input points.
	"""
	def __init__(
		self,
		d_input: int,
		n_freqs: int,
		log_space: bool = False
	):
		super().__init__()
		self.d_input = d_input
		self.n_freqs = n_freqs
		self.log_space = log_space
		self.d_output = d_input * (1 + 2 * self.n_freqs)
		self.embed_fns = [lambda x: x]

		# Define frequencies in either linear or log scale
		if self.log_space:
			freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
		else:
			freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)

		# Alternate sin and cos
		for freq in freq_bands:
			self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
			self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))

	def forward(
		self,
		x
	) -> torch.Tensor:
		r"""
		Apply positional encoding to input.
		"""
		return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)


## Hash encoding

In [None]:
class Grid(nn.Module):
    """
    Neural network grid class. The input x needs to be within [0, 1].
    """
    def __init__(self,
                 feature_dim: int,
                 grid_dim: int,
                 num_lvl: int,
                 max_res: int,
                 min_res: int,
                 hashtable_power: int,
                 device='cpu'
                 ):
        super().__init__()

        # Determine the device to use (CPU or CUDA)
        self.device = device

        # Initialize the attributes of the grid
        self.feature_dim = feature_dim  # Dimensionality of the feature vectors
        self.grid_dim = grid_dim  # Dimensionality of the grid (e.g., 2D, 3D)
        self.num_lvl = num_lvl  # Number of levels in the grid hierarchy
        self.max_res = max_res  # Maximum resolution of the grid
        self.min_res = min_res  # Minimum resolution of the grid
        self.hashtable_power = hashtable_power  # Power of the hashtable size (number of entries is 2^hashtable_power)

        self.d_output = num_lvl * grid_dim

        # Constants for hash calculations
        self.prime = [3367900313, 2654435761, 805459861]  # Prime numbers for hashing
        self.max_entry = 2 ** self.hashtable_power  # Maximum number of entries in the hashtable
        # Factor to scale resolutions logarithmically
        self.factor_b = np.exp((np.log(self.max_res) - np.log(self.min_res)) / (self.num_lvl - 1))

        # Compute the resolutions for each level
        self.resolutions = []
        for i in range(self.num_lvl):
            # Calculate resolution for level i
            self.resolutions.append(np.floor(self.min_res * (self.factor_b ** i)))

        # Initialize the hashtable for each resolution
        self.hashtable = nn.ParameterList([])  # List of hashtables for each resolution
        for res in self.resolutions:
            total_res = res ** self.grid_dim  # Total number of cells at this resolution
            table_size = min(total_res, self.max_entry)  # Size of the hashtable (limited by max_entry)
            # Initialize table with random values, scaled as per InstantNGP paper
            table = torch.randn(int(table_size), self.feature_dim, device=self.device) *0.001 #* 0.0001 #+ torch.rand(1).to(self.device)
            table = nn.Parameter(table)  # Convert to a learnable parameter
            self.hashtable.append(table)  # Add to the hashtable list

    def forward(self, x):
        # Normalization
        _min = torch.tensor([-3.5552, -2.1935, -2.8307]).to(self.device)
        _max = torch.tensor([1.8859, 3.1777, 2.0705]).to(self.device)
        x = (x - _min ) / (_max - _min)

        # print(x.min(0)[0], x.max(0)[0])
        
        out_feature = []
        for lvl in range(self.num_lvl):
            # Transform coordinates to hash space
            coord = self.to_hash_space(x, self.resolutions[lvl])
            floor_corner = torch.floor(coord)  # Find the floor corner for interpolation
            # Get the corners for interpolation
            corners = self.get_corner(floor_corner).to(torch.long)
            # Hash the corners to get feature indices
            feature_index = self.hash(corners, self.hashtable[lvl].shape[0], self.resolutions[lvl])
            flat_feature_index = feature_index.to(torch.long).flatten()  # Flatten the indices
            # Retrieve corner features from the hashtable
            corner_feature = torch.reshape(self.hashtable[lvl][flat_feature_index],
                                           (corners.shape[0], corners.shape[1], self.feature_dim))
            # Calculate interpolation weights
            weights = self.interpolation_weights(coord - floor_corner)
            weights = torch.stack([weights, weights, weights], -1)  # Stack weights for each feature
            # Perform weighted interpolation of corner features
            weighted_feature = corner_feature * weights
            summed_feature = weighted_feature.sum(-2)  # Sum the weighted features
            out_feature.append(summed_feature)  # Append the result to the output feature list
        return torch.cat(out_feature, -1)  # Concatenate features from all levels

    def to_hash_space(self, x, resolution):
        """
        Transform input coordinates to hash space, ensuring they are within the grid's resolution.
        """
        return torch.clip(x * (resolution - 1), 0, resolution - 1.0001)  # Scale and clip coordinates

    def interpolation_weights(self, diff):
        """
        Calculate the interpolation weights based on the differences from the floor corner.
        """
        ones = torch.ones_like(diff, device=self.device)  # Create a tensor of ones with the same shape as diff
        minus_x = (ones - diff)[..., 0]  # Calculate 1 - x for each dimension
        x = diff[..., 0]  # Get the x difference
        minus_y = (ones - diff)[..., 1]  # Calculate 1 - y for each dimension
        y = diff[..., 1]  # Get the y difference

        if self.grid_dim == 2:
            # For 2D, calculate weights for the four corners
            stacks = torch.stack([minus_x * minus_y, x * minus_y, minus_x * y, x * y], -1)
            return stacks
        else:
            # For 3D, calculate weights for the eight corners
            minus_z = (ones - diff)[..., 2]  # Calculate 1 - z for each dimension
            z = diff[..., 2]  # Get the z difference
            stacks = torch.stack([minus_x * minus_y * minus_z,
                                  x * minus_y * minus_z,
                                  minus_x * y * minus_z,
                                  x * y * minus_z,
                                  minus_x * minus_y * z,
                                  x * minus_y * z,
                                  minus_x * y * z,
                                  x * y * z], -1)
            return stacks

    def alt_weights(self, corner, coord):
        """
        Alternative method for calculating weights based on the distance to the corners.
        """
        diag_length = torch.full_like(coord[:, 0], 2. ** (1 / 2), device=self.device)  # Diagonal length for normalization
        w = torch.empty(corner.shape[0], corner.shape[1], device=self.device)  # Initialize weight tensor
        for c in range(corner.shape[1]):
            dist = torch.norm(corner[:, c, :] - coord, dim=1)  # Calculate distance to each corner
            w[:, c] = diag_length - dist  # Calculate weight based on distance
        normed_w = torch.nn.functional.normalize(w, p=1)  # Normalize the weights
        return normed_w

    def hash(self, x, num_entry, res):
        """
        Hash function to map coordinates to hashtable indices.
        """
        if num_entry != self.max_entry:
            # For smaller hashtables, use a simple linear hash
            index = 0
            for i in range(self.grid_dim):
                index += x[..., i] * res ** i
            return index
        else:
            # For larger hashtables, use a more complex hash with primes
            _sum = 0
            for i in range(self.grid_dim):
                _sum = _sum ^ (x[..., i] * self.prime[i])
            index = _sum % num_entry  # Modulo operation to keep within table size
            return index

    def get_corner(self, floor_corner):
        """
        Get the corner points for interpolation based on the floor corner.
        """
        num_entry = floor_corner.shape[0]  # Number of entries

        if self.grid_dim == 2:
            # Calculate corners for 2D grids
            c000 = floor_corner
            c001 = floor_corner + torch.tensor([0, 1], device=self.device).repeat(num_entry, 1)
            c010 = floor_corner + torch.tensor([1, 0], device=self.device).repeat(num_entry, 1)
            c011 = floor_corner + torch.ones_like(floor_corner, device=self.device)
            stacks = torch.stack([c000, c010, c001, c011], -2)
            return stacks
        else:
            # Calculate corners for 3D grids
            c000 = floor_corner
            c001 = floor_corner + torch.tensor([0, 0, 1], device=self.device).repeat(num_entry, 1)
            c010 = floor_corner + torch.tensor([0, 1, 0], device=self.device).repeat(num_entry, 1)
            c011 = floor_corner + torch.tensor([0, 1, 1], device=self.device).repeat(num_entry, 1)
            c100 = floor_corner + torch.tensor([1, 0, 0], device=self.device).repeat(num_entry, 1)
            c101 = floor_corner + torch.tensor([1, 0, 1], device=self.device).repeat(num_entry, 1)
            c110 = floor_corner + torch.tensor([1, 1, 0], device=self.device).repeat(num_entry, 1)
            c111 = floor_corner + torch.ones_like(floor_corner, device=self.device)
            stacks = torch.stack([c000, c010, c001, c011, c100, c101, c110, c111], -2)
            return stacks


# Model and Forward pass

In [None]:
class NeRF(nn.Module):
	r"""
	Neural radiance fields module.
	"""
	def __init__(
		self,
		d_input: int = 3,
		n_layers: int = 8,
		d_filter: int = 256,
		skip: Tuple[int] = (4,),
		d_viewdirs: Optional[int] = None
	):
		super().__init__()
		self.d_input = d_input
		self.skip = skip
		self.act = nn.functional.relu
		self.d_viewdirs = d_viewdirs

		# Create model layers
		self.layers = nn.ModuleList(
			[nn.Linear(self.d_input, d_filter)] +
			[nn.Linear(d_filter + self.d_input, d_filter) if i in skip \
			 else nn.Linear(d_filter, d_filter) for i in range(n_layers - 1)]
		)

		# Bottleneck layers
		if self.d_viewdirs is not None:
			# If using viewdirs, split alpha and RGB
			self.alpha_out = nn.Linear(d_filter, 1)
			self.rgb_filters = nn.Linear(d_filter, d_filter)
			self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2)
			self.output = nn.Linear(d_filter // 2, 3)
		else:
			# If no viewdirs, use simpler output
			self.output = nn.Linear(d_filter, 4)

	def forward(
		self,
		x: torch.Tensor,
		viewdirs: Optional[torch.Tensor] = None
	) -> torch.Tensor:
		r"""
		Forward pass with optional view direction.
		"""

		# Cannot use viewdirs if instantiated with d_viewdirs = None
		if self.d_viewdirs is None and viewdirs is not None:
			raise ValueError('Cannot input x_direction if d_viewdirs was not given.')

		# Apply forward pass up to bottleneck
		x_input = x
		for i, layer in enumerate(self.layers):
			x = self.act(layer(x))
			if i in self.skip:
				x = torch.cat([x, x_input], dim=-1)

		# Apply bottleneck
		if self.d_viewdirs is not None:
			# Split alpha from network output
			alpha = self.alpha_out(x)

			# Pass through bottleneck to get RGB
			x = self.rgb_filters(x)
			x = torch.concat([x, viewdirs], dim=-1)
			x = self.act(self.branch(x))
			x = self.output(x)

			# Concatenate alphas to output
			x = torch.concat([x, alpha], dim=-1)
		else:
			# Simple output
			x = self.output(x)
		return x

In [None]:
class DNeRF(nn.Module):
	r"""Dynamic Neural radiance fields module.
	"""
	def __init__(self,
		d_input: int = 3,
		n_layers: int = 8,
		d_filter: int = 256,
		skip: Tuple[int] = (4,),
		d_viewdirs: Optional[int] = None,

		d_time: int = 1,
		encode: Callable = None,
		zero_canonical: bool = True
	):
		super().__init__()
		self.d_input = d_input
		self.n_layers = n_layers
		self.d_filter = d_filter
		self.skip = skip
		self.act = nn.ReLU()
		# Time network
		self.d_time = d_time
		self.encode = encode
		self.zero_canonical = zero_canonical

		# Defining original nerf nodel
		self.nerf_model = NeRF(d_input, n_layers, d_filter, skip, d_viewdirs)
		# Defining time deformation network
		self.deformation_model = self._create()
		

	def _create(self):
		input_dimention = self.d_input+self.d_time
		
		layers = [nn.Linear(input_dimention, self.d_filter)]
		for i in range(1, self.n_layers):
			if i in self.skip:
				layers += [nn.Linear(self.d_filter+input_dimention, self.d_filter)]
			else:
				layers += [nn.Linear(self.d_filter, self.d_filter)]
		layers += [nn.Linear(self.d_filter, 3)]

		layers = nn.ModuleList(layers)
		return layers 
	
	def _query_time(self, x: torch.Tensor) -> torch.Tensor:
		original_input = x
		for i, layer in enumerate(self.deformation_model):
			if i in self.skip:
				x = torch.cat([original_input, x], -1)
			
			x = layer(x)
			
			if i < len(self.deformation_model)-1:
				x = self.act(x)
		return x

	def forward(
		self,
		x: torch.Tensor,
		timesteps: torch.Tensor,
		viewdirs: Optional[torch.Tensor] = None,
	) -> torch.Tensor:
		r"""
		Forward pass through time deformation network and NeRF
		"""
		
		debug('DNERF INPUTS: '+str(x.shape) +' ; '+ str(timesteps.shape))
		inputs = torch.cat([x, timesteps], -1)
		dx = self._query_time(inputs)
		if self.zero_canonical:
			cond = timesteps[:,0] == 0
			dx[cond] = 0.0

		original_points = x[:,:3]
		points_dx = original_points + dx
		points_dx = self.encode(points_dx)

		output = self.nerf_model(points_dx, viewdirs)

		return output

In [None]:
def get_chunks(
	inputs: torch.Tensor,
	chunksize: int = 2**15
) -> List[torch.Tensor]:
	r"""
	Divide an input into chunks.
	"""
	return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]

def prepare_chunks(
	points: torch.Tensor,
	encoding_function: Callable[[torch.Tensor], torch.Tensor],
	chunksize: int = 2**15
) -> List[torch.Tensor]:
	r"""
	Encode and chunkify points to prepare for NeRF model.
	"""
	points = points.reshape((-1, 3))
	points = encoding_function(points)
	points = get_chunks(points, chunksize=chunksize)
	return points

def prepare_viewdirs_chunks(
	points: torch.Tensor,
	rays_d: torch.Tensor,
	encoding_function: Callable[[torch.Tensor], torch.Tensor],
	chunksize: int = 2**15
) -> List[torch.Tensor]:
	r"""
	Encode and chunkify viewdirs to prepare for NeRF model.
	"""
	# Prepare the viewdirs
	viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
	viewdirs = viewdirs[:, None, ...].expand(points.shape).reshape((-1, 3))
	viewdirs = encoding_function(viewdirs)
	viewdirs = get_chunks(viewdirs, chunksize=chunksize)
	return viewdirs

def prepare_time(
	timesteps, 
	time_encoding_fn: Callable[[torch.Tensor], torch.Tensor], 
	chunksize: int = 2**15
) -> List[torch.Tensor]:
	r"""
	Encode and chunkify timepoints to prepare for NeRF model.
	"""
	timesteps = time_encoding_fn(timesteps)
	timesteps = get_chunks(timesteps, chunksize=chunksize)
	return timesteps
    
def nerf_forward(
	rays_o: torch.Tensor,
	rays_d: torch.Tensor,
	timesteps: torch.Tensor,
	near: float,
	far: float,
	encoding_fn: Callable[[torch.Tensor], torch.Tensor],
	coarse_model: nn.Module,
	kwargs_sample_stratified: dict = None,
	n_samples_hierarchical: int = 0,
	kwargs_sample_hierarchical: dict = None,
	fine_model = None,
	viewdirs_encoding_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
	time_encoding_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
	chunksize: int = 2**15
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
	r"""
	Compute forward pass through model(s).
	"""
	# Set no kwargs if none are given.
	if kwargs_sample_stratified is None:
		kwargs_sample_stratified = {}
	if kwargs_sample_hierarchical is None:
		kwargs_sample_hierarchical = {}

	# Sample query points along each ray.
	query_points, z_vals = sample_stratified(
			rays_o, rays_d, near, far, **kwargs_sample_stratified)
	# Prepare batches.
	batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize)
	if viewdirs_encoding_fn is not None:
		batches_viewdirs = prepare_viewdirs_chunks(
			query_points, rays_d,
			viewdirs_encoding_fn,
			chunksize=chunksize)
	else:
		batches_viewdirs = [None] * len(batches)

	# Time
	if timesteps.shape[0] == 1: # only one image and one timestep
		# Expanding to match other data
		n_points = query_points.reshape(-1,3).shape[0]
		expanded_timesteps = timesteps.expand(n_points).reshape(-1,1)
	else: # For each ray repeating timestep along sampling dimention
		n_rays = query_points.shape[1]
		expanded_timesteps = timesteps.repeat(1,n_rays).reshape(-1,1)
	batches_times = prepare_time(expanded_timesteps, time_encoding_fn, chunksize)
	
	###
	# Coarse model pass.
	# Split the encoded points into "chunks", run the model on all chunks, and
	# concatenate the results (to avoid out-of-memory issues).
	predictions = []
	for batch, batch_viewdirs, batch_time in zip(batches, batches_viewdirs, batches_times):
		one_batch_prediction = coarse_model(batch, viewdirs=batch_viewdirs, timesteps=batch_time)
		predictions.append(one_batch_prediction)

	debug('Predictions length: '+str(len(predictions)))
	raw = torch.cat(predictions, dim=0)
	
	raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]])
	raw = raw.to('cpu')

	# Perform differentiable volume rendering to re-synthesize the RGB image.
	rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals, rays_d)
	# rgb_map, depth_map, acc_map, weights = render_volume_density(raw, rays_o, z_vals)

	outputs = {
			'z_vals_stratified': z_vals
	}

	###
	# Fine model pass.
	debug('Fine model pass\n')
	if n_samples_hierarchical > 0:
		# Save previous outputs to return.
		rgb_map_0, depth_map_0, acc_map_0 = rgb_map, depth_map, acc_map

		# Apply hierarchical sampling for fine query points.
		query_points, z_vals_combined, z_hierarch = sample_hierarchical(
			rays_o, rays_d, z_vals, weights, n_samples_hierarchical,
			**kwargs_sample_hierarchical)

		# Prepare inputs as before.
		batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize)
		if viewdirs_encoding_fn is not None:
			batches_viewdirs = prepare_viewdirs_chunks(
				query_points, rays_d,
				viewdirs_encoding_fn,
				chunksize=chunksize)
		else:
			batches_viewdirs = [None] * len(batches)

		# Time
		if timesteps.shape[0] == 1: # only one image and one timestep
			# Expanding to match other data
			n_points = query_points.reshape(-1,3).shape[0]
			expanded_timesteps = timesteps.expand(n_points).reshape(-1,1)
		else: # For each ray repeating timestep along sampling dimention
			n_rays = query_points.shape[1]
			expanded_timesteps = timesteps.repeat(1,n_rays).reshape(-1,1)
		batches_times = prepare_time(expanded_timesteps, time_encoding_fn, chunksize)

		# Forward pass new samples through fine model.
		fine_model = fine_model if fine_model is not None else coarse_model
		predictions = []
		for batch, batch_viewdirs, batch_time in zip(batches, batches_viewdirs, batches_times):
			one_batch_predictions = fine_model(batch, viewdirs=batch_viewdirs, timesteps=batch_time)
			predictions.append(one_batch_predictions)
		raw = torch.cat(predictions, dim=0)
		raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]])
		raw = raw.to('cpu')

		# Perform differentiable volume rendering to re-synthesize the RGB image.
		rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals_combined, rays_d)

		# Store outputs.
		outputs['z_vals_hierarchical'] = z_hierarch
		outputs['rgb_map_0'] = rgb_map_0
		outputs['depth_map_0'] = depth_map_0
		outputs['acc_map_0'] = acc_map_0

	# Store outputs.
	outputs['rgb_map'] = rgb_map
	outputs['depth_map'] = depth_map
	outputs['acc_map'] = acc_map
	outputs['weights'] = weights
	return outputs

# Raw data to outputs

In [None]:
def raw2outputs(
	raw: torch.Tensor,
	z_vals: torch.Tensor,
	rays_d: torch.Tensor,
	raw_noise_std: float = 0.0,
	white_bkgd: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
	r"""
	Convert the raw NeRF output into RGB and other maps.
	"""
	device = raw.device
	# Difference between consecutive elements of `z_vals`. [n_rays, n_samples]
	dists = z_vals[..., 1:] - z_vals[..., :-1]
	dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1)

	# Multiply each distance by the norm of its corresponding direction ray
	# to convert to real world distance (accounts for non-unit directions).
	dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

	# Add noise to model's predictions for density. Can be used to
	# regularize network during training (prevents floater artifacts).
	noise = 0.
	if raw_noise_std > 0.:
		noise = torch.randn(raw[..., 3].shape) * raw_noise_std

	# Predict density of each sample along each ray. Higher values imply
	# higher likelihood of being absorbed at this point. [n_rays, n_samples]
	alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists.to(device))

	# Compute weight for RGB of each sample along each ray. [n_rays, n_samples]
	# The higher the alpha, the lower subsequent weights are driven.
	weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)

	# Compute weighted RGB map.
	rgb = torch.sigmoid(raw[..., :3])  # [n_rays, n_samples, 3]
	rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # [n_rays, 3]

	# Estimated depth map is predicted distance.
	depth_map = torch.sum(weights * z_vals.to(device), dim=-1)

	# Disparity map is inverse depth.
	disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
														depth_map / torch.sum(weights, -1))

	# Sum of weights along each ray. In [0, 1] up to numerical error.
	acc_map = torch.sum(weights, dim=-1)

	# To composite onto a white background, use the accumulated alpha map.
	if white_bkgd:
		rgb_map = rgb_map + (1. - acc_map[..., None])

	return rgb_map, depth_map, acc_map, weights


# Evaluation tools

In [None]:
def plot_samples(
	z_vals: torch.Tensor,
	z_hierarch: Optional[torch.Tensor] = None,
	ax: Optional[np.ndarray] = None):
	r"""
	Plot stratified and (optional) hierarchical samples.
	"""
	y_vals = 1 + np.zeros_like(z_vals)

	if ax is None:
		ax = plt.subplot()
	ax.plot(z_vals, y_vals, 'b-o')
	if z_hierarch is not None:
		y_hierarch = np.zeros_like(z_hierarch)
		ax.plot(z_hierarch, y_hierarch, 'r-o')
	ax.set_ylim([-1, 2])
	ax.set_title('Stratified  Samples (blue) and Hierarchical Samples (red)')
	ax.axes.yaxis.set_visible(False)
	ax.grid(True)
	return ax


In [None]:
def render(i, rgb_predicted, iternums, outputs,train_psnrs,val_psnrs, is_save=False):
    fig, ax = plt.subplots(1, 4, figsize=(24,4), gridspec_kw={'width_ratios': [1, 1, 1, 3]})
    # Plot example outputs
    
    ax[0].imshow(rgb_predicted.reshape([height, width, 3]).detach().cpu().numpy())
    ax[0].set_title(f'Iteration: {i}')
    ax[1].imshow(testimg.detach().cpu().numpy())
    ax[1].set_title(f'Target')
    ax[2].plot(range(0, i + 1), train_psnrs, 'r')
    ax[2].plot(iternums, val_psnrs, 'b')
    ax[2].set_title('PSNR (train=red, val=blue')
    z_vals_strat = outputs['z_vals_stratified'].view((-1, n_samples))
    z_sample_strat = z_vals_strat[z_vals_strat.shape[0] // 2].detach().cpu().numpy()
    if 'z_vals_hierarchical' in outputs:
        z_vals_hierarch = outputs['z_vals_hierarchical'].view((-1, n_samples_hierarchical))
        z_sample_hierarch = z_vals_hierarch[z_vals_hierarch.shape[0] // 2].detach().cpu().numpy()
    else:
        z_sample_hierarch = None
    _ = plot_samples(z_sample_strat, z_sample_hierarch, ax=ax[3])
    ax[3].margins(0)
    plt.show()

    if is_save:
        rgb_predicted = rgb_predicted.reshape([height, width, 3]).detach().cpu().numpy()
        rgb_predicted = (rgb_predicted*255).astype(np.uint8)
        Image.fromarray(rgb_predicted).save('predicted_image.png')

        ti = testimg.detach().cpu().numpy()
        ti = (ti*255).astype(np.uint8)
        Image.fromarray(ti).save('GT_image.png')


def eval_one_time(i, iternums, train_psnrs, val_psnrs, model, encode, fine_model, encode_viewdirs, is_save=False):
	model.eval()
	height, width = testimg.shape[:2]
	rays_o, rays_d = get_rays(height, width, focal, testpose)
	rays_o = rays_o.reshape([-1, 3]).to(device)
	rays_d = rays_d.reshape([-1, 3]).to(device)
	outputs = nerf_forward(rays_o, rays_d,
							near, far, encode, model,
							kwargs_sample_stratified=kwargs_sample_stratified,
							n_samples_hierarchical=n_samples_hierarchical,
							kwargs_sample_hierarchical=kwargs_sample_hierarchical,
							fine_model=fine_model,
							viewdirs_encoding_fn=encode_viewdirs,
							chunksize=chunksize)

	rgb_predicted = outputs['rgb_map']
	testimg_flat = testimg.reshape(-1, 3)
	loss = torch.nn.functional.mse_loss(rgb_predicted, testimg_flat.to(device))
	val_psnr = -10. * torch.log10(loss)
	val_psnrs.append(val_psnr.item())
	iternums.append(i)

	print("Val Loss: ", loss.item(), " Val PSRN: ", val_psnr.item())
	# Plot example outputs outside
	render(i, rgb_predicted, iternums, outputs,train_psnrs,val_psnrs, is_save)



In [None]:
# Evaluation rendering
def pose_spherical(theta, phi, radius):
    trans_t = lambda t : torch.Tensor([
        [1,0,0,0],
        [0,1,0,0],
        [0,0,1,t],
        [0,0,0,1]]).float()
    
    rot_phi = lambda phi : torch.Tensor([
        [1,0,0,0],
        [0,np.cos(phi),-np.sin(phi),0],
        [0,np.sin(phi), np.cos(phi),0],
        [0,0,0,1]]).float()
    
    rot_theta = lambda th : torch.Tensor([
        [np.cos(th),0,-np.sin(th),0],
        [0,1,0,0],
        [np.sin(th),0, np.cos(th),0],
        [0,0,0,1]]).float()
    
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
    return c2w


def save_training_progress(
    model,fine_model,encode,encode_viewdirs, chunksize, 
    near, far, height, width, focal, epoch, device
):
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
    pose = render_poses[epoch%40]

    rays_o, rays_d = get_rays(height, width, focal, pose)
    rays_o = rays_o.reshape([-1, 3]).to(device)
    rays_d = rays_d.reshape([-1, 3]).to(device)
    outputs = nerf_forward(rays_o, rays_d,
							near, far, encode, model,
							kwargs_sample_stratified=kwargs_sample_stratified,
							n_samples_hierarchical=n_samples_hierarchical,
							kwargs_sample_hierarchical=kwargs_sample_hierarchical,
							fine_model=fine_model,
							viewdirs_encoding_fn=encode_viewdirs,
							chunksize=chunksize)
    rgb_predicted = outputs['rgb_map']
    rgb_predicted = rgb_predicted.reshape([height, width, 3]).detach().cpu().numpy()
    rgb_predicted = (rgb_predicted*255).astype(np.uint8)
    img = Image.fromarray(rgb_predicted)
    _zeros = '0' * (5 - len(str(epoch)))
    img.save(test_save_dir+f'{_zeros}{epoch}.png')

# Training preparation 

In [None]:
class EarlyStopping:
	r"""
	Early stopping helper based on fitness criterion.
	"""
	def __init__(
		self,
		patience: int = 30,
		margin: float = 1e-4
	):
		self.best_fitness = 0.0  # In our case PSNR
		self.best_iter = 0
		self.margin = margin
		self.patience = patience or float('inf')  # epochs to wait after fitness stops improving to stop

	def __call__(
		self,
		iter: int,
		fitness: float
	):
		r"""
		Check if criterion for stopping is met.
		"""
		if (fitness - self.best_fitness) > self.margin:
			self.best_iter = iter
			self.best_fitness = fitness
		delta = iter - self.best_iter
		stop = delta >= self.patience  # stop training if patience exceeded
		return stop

def crop_center(
	img: torch.Tensor,
	frac: float = 0.5
) -> torch.Tensor:
	r"""
	Crop center square from image.
	"""
	h_offset = round(img.shape[0] * (frac / 2))
	w_offset = round(img.shape[1] * (frac / 2))
	return img[h_offset:-h_offset, w_offset:-w_offset]



In [None]:
# Encoders

if encoder_used is Encoders.FREQ:
    encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space)
    encode = lambda x: encoder(x)
    
elif encoder_used is Encoders.HASH:
    encoder = Grid(
        feature_dim=3,
        grid_dim=3,
        num_lvl=17,
        max_res=2**14, 
        min_res=16,
        hashtable_power=19,
    	device=device,
    ).to(device)
    encode = lambda x: encoder(x)

# View direction encoders
if use_viewdirs:
    encoder_viewdirs = PositionalEncoder(
        d_input, n_freqs_views, log_space=log_space)
    encode_viewdirs = lambda x: encoder_viewdirs(x)
    d_viewdirs = encoder_viewdirs.d_output
else:
    encode_viewdirs = None
    d_viewdirs = None

# Time encoder
encoder_time = PositionalEncoder(1, n_freqs_time, log_space=log_space)
encode_time = lambda x: encoder_time(x)
d_timesteps = encoder_time.d_output

print(f'Encode input points into {encoder.d_output} dimentions!', )

# Models
model = DNeRF(
			d_input = encoder.d_output, 
			n_layers = n_layers, 
			d_filter = d_filter, 
			skip = skip,
			d_viewdirs = d_viewdirs,
			d_time = d_timesteps,
			encode = encode,
			zero_canonical=zero_canonical)
model.to(device)
# model_params = list(model.parameters())
if use_fine_model:
    fine_model = DNeRF(
        d_input = encoder.d_output, 
        n_layers=n_layers_fine, 
        d_filter=d_filter_fine, 
        skip=skip_fine,
        d_viewdirs=d_viewdirs,
        d_time=d_timesteps,
        encode=encode,
        zero_canonical=zero_canonical)
    fine_model.to(device)
    # model_params = model_params + list(fine_model.parameters())
else:
    fine_model = None

# Optimizer
if encoder_used is Encoders.HASH:
    optimizer = torch.optim.Adam([
        {'params': model.parameters(), 'lr': lr},
        {'params': fine_model.parameters(), 'lr': lr},
        {'params': encoder.parameters(), 'lr': lr_hash_enc},
    ])
else:
    optimizer = torch.optim.Adam([
        {'params': model.parameters(), 'lr': lr},
        {'params': fine_model.parameters(), 'lr': lr},
    ])

# Scheduler 
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, 
    start_factor=scheduler_start_end_factors[0], 
    end_factor=scheduler_start_end_factors[1], 
    total_iters=n_epochs)
    
# Early Stopping
warmup_stopper = EarlyStopping(patience=50)

# Training

In [None]:
model.train()


train_psnrs = []
val_psnrs = []
iternums = []


pbar = tqdm(range(n_epochs), desc="Training Epochs")
for i in pbar:
    ###
    ### Train one time
    ###
    # One image per step

    # Randomly pick an image as the target.
    # Target Image
    target_img_idx = np.random.randint(images.shape[0])
    target_img = images[target_img_idx].to(device)
    if center_crop and i < center_crop_iters:
        target_img = crop_center(target_img)
    height, width = target_img.shape[:2]
    target_img = target_img.reshape([-1, 3])
    # Pose
    target_pose = poses[target_img_idx].to(device)
    # Rays
    rays_o, rays_d = get_rays(height, width, focal, target_pose)
    rays_o = rays_o.reshape([-1, 3]).to(device)
    rays_d = rays_d.reshape([-1, 3]).to(device)
    # Handle Time
    timesteps = times[target_img_idx].reshape(1).to(device)


    # Run one iteration of TinyNeRF and get the rendered RGB image.
	outputs = nerf_forward(
		rays_o, rays_d, timesteps,
        near, far, encode, model,
        kwargs_sample_stratified=kwargs_sample_stratified,
        n_samples_hierarchical=n_samples_hierarchical,
        kwargs_sample_hierarchical=kwargs_sample_hierarchical,
        fine_model=fine_model,
        viewdirs_encoding_fn=encode_viewdirs,
		time_encoding_fn=encode_time,
        chunksize=chunksize)

    # Check for any numerical issues.
    for k, v in outputs.items():
        if torch.isnan(v).any():
            print(f"! [Numerical Alert] {k} contains NaN.")
        if torch.isinf(v).any():
            print(f"! [Numerical Alert] {k} contains Inf.")

    # Backprop!
    rgb_predicted = outputs['rgb_map']

    loss = torch.nn.functional.mse_loss(
        rgb_predicted.to(device), target_img.to(device))
    loss.backward()

    # After the backward pass
    # for name, param in encoder.named_parameters():
    # 	if param.grad is not None:
    # 		print(f"{name} gradient after backward: {param.grad}")
    # 	else:
    # 		print(f"{name} has no gradients after backward")

    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()

    psnr = -10. * torch.log10(loss)
    psnr = psnr.item()
    
    ###
    # Evaluation and logging

    epoch = i
    pbar.set_description(f"Epoch {epoch + 1}/{n_epochs}, PSNR: {psnr:.2f}, Loss: {loss.item():.5f}")

    train_psnrs.append(psnr)

    # Evaluate testimg at given display rate.
    if i % display_rate == 0:
        eval_one_time(i, iternums, train_psnrs, val_psnrs, model, encode, fine_model, encode_viewdirs)

        save_training_progress(
            model,fine_model,encode,encode_viewdirs, chunksize, 
            near, far, height, width, focal, epoch//display_rate, device
        )


In [None]:
import glob
from PIL import Image, ImageDraw, ImageFont

def add_frame_number(image, frame_number,fontsize=12):
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()  # You can choose a different font if you prefer
    text = str(frame_number)
    text_width = draw.textlength(text, font=font)
    text_size = (text_width, fontsize)
    # Position the text at the bottom right corner
    text_position = (image.width - text_size[0] - 10, image.height - text_size[1] - 10)
    
    draw.text(text_position, text, font=font, fill="white")
    return image

def make_gif(frame_folder, save_path="./my_awesome.gif"):
    im_paths = sorted(glob.glob(f"{frame_folder}/*.png"))
    frames = [Image.open(image) for image in im_paths]
    
    # Add frame numbers
    frame = [add_frame_number(frame, i) for i, frame in enumerate(frames)]
    
    # Save gif
    frame_one = frames[0]
    frame_one.save(save_path, format="GIF", append_images=frames,optimize=True,
               save_all=True, duration=100, loop=0)


make_gif(test_save_dir, f'./xyz{encoder_used.name}_2RGBA_animation.gif')


In [None]:
# Saving one image for paper
eval_one_time(i, iternums, train_psnrs, val_psnrs, model, encode, fine_model, encode_viewdirs, True)

In [None]:

torch.save(model.state_dict(), f'./results/nerf_{encoder_used.name}.tensor')
torch.save(fine_model.state_dict(), f'./results/nerf_fine_{encoder_used.name}.tensor')


In [None]:
# encoder_used_name = 'HASH'
# model.load_state_dict(torch.load(f'nerf_{encoder_used_name}.tensor'))
# fine_model.load_state_dict(torch.load(f'nerf_fine_{encoder_used_name}.tensor'))