In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from datasets.nerf_synthetic import SubjectLoader
from nerfacc.estimators.occ_grid import OccGridEstimator
from radiance_fields.ngp import NGPRadianceField
from radiance_fields.ngp_single_mlp import NGPRadianceFieldSingleMlp

from torch import Tensor
from tqdm import tqdm
from typing import Any, Dict, List
from utils import render_image_with_occgrid_test

In [None]:
# dataset parameters
scene = "chair"
data_root = "/media/data7/fballerini/datasets/nerf_synthetic"
test_dataset_kwargs = {}
# scene parameters
aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5]).cuda()
near_plane = 0.0
# model parameters
grid_resolution = 128
grid_nlvl = 1
# render parameters
render_step_size = 5e-3
alpha_thre = 0.0
cone_angle = 0.0

test_dataset = SubjectLoader(
    subject_id=scene,
    root_fp=data_root,
    split="test",
    num_rays=None,
    device="cuda:0",
    **test_dataset_kwargs,
)

estimator = OccGridEstimator(
    roi_aabb=aabb, 
    resolution=grid_resolution, 
    levels=grid_nlvl
).cuda()

radiance_field = NGPRadianceFieldSingleMlp(
    aabb=estimator.aabbs[-1],
    use_viewdirs=False,
    base_resolution=32,
    n_levels=1,
    n_features_per_level=8,
    encoding_type="torch",
    mlp_activation="Sine",
    n_neurons=64,
    n_hidden_layers=3
).cuda()

In [3]:
sd_A = torch.load("/media/data7/fballerini/nerfacc/examples/ckpts/torch_sine_single_1_8_32/chair_A.pt")

In [38]:
vox = sd_A["radiance_field"]["encoding.levels.0.embedding.weight"]
vox.shape

torch.Size([32768, 8])

In [39]:
vox[567]

tensor([ 4.2767e-04,  6.2468e-05,  1.8006e-03, -2.1123e-03,  4.2656e-04,
        -8.5651e-04, -4.7835e-04,  6.7730e-04], device='cuda:0')

In [40]:
from itertools import permutations

perm = list(permutations(range(8)))[7]
perm

(0, 1, 2, 3, 5, 4, 7, 6)

In [41]:
vox = torch.index_select(vox, dim=1, index=torch.IntTensor(perm).cuda())

In [42]:
vox[567]

tensor([ 4.2767e-04,  6.2468e-05,  1.8006e-03, -2.1123e-03, -8.5651e-04,
         4.2656e-04,  6.7730e-04, -4.7835e-04], device='cuda:0')

In [4]:
estimator.load_state_dict(sd_A["estimator"])
radiance_field.load_state_dict(sd_A["radiance_field"])

<All keys matched successfully>