In [1]:
import torch
from torch import nn
import tinycudann as tcnn
import torch.nn.functional as F
import vren
from einops import rearrange
import numpy as np



In [2]:
L = 4
F_ = 2
log2_T = 10
N_min = 4
b = 1.25
rgb_act = "Sigmoid"

xyz_encoder = \
    tcnn.NetworkWithInputEncoding(
        n_input_dims=3, n_output_dims=16,
        encoding_config={
            "otype": "Grid",
            "type": "Hash",
            "n_levels": L,
            "n_features_per_level": F_,
            "log2_hashmap_size": log2_T,
            "base_resolution": N_min,
            "per_level_scale": b,
            "interpolation": "Linear"
        },
        network_config={
            "otype": "FullyFusedMLP",
            "activation": "ReLU",
            "output_activation": "None",
            "n_neurons": 64,
            "n_hidden_layers": 1,
        }
    )

dir_encoder = \
    tcnn.Encoding(
        n_input_dims=3,
        encoding_config={
            "otype": "SphericalHarmonics",
            "degree": 4,
        },
    )

rgb_net = \
    tcnn.Network(
        n_input_dims=32, n_output_dims=3,
        network_config={
            "otype": "FullyFusedMLP",
            "activation": "ReLU",
            "output_activation": rgb_act,
            "n_neurons": 64,
            "n_hidden_layers": 2,
        }
    )

In [3]:
encoding = tcnn.Encoding(3, encoding_config={
            "otype": "Frequency",
            "n_frequencies": 10,
        })
# encoding = tcnn.Encoding(3, encoding_config={
#             "otype": "Grid",
#             "type": "Hash",
#             "n_levels": L,
#             "n_features_per_level": F_,
#             "log2_hashmap_size": log2_T,
#             "base_resolution": N_min,
#             "per_level_scale": b,
#             "interpolation": "Linear"
#         })
network = tcnn.Network(encoding.n_output_dims, 16, network_config={
            "otype": "FullyFusedMLP",
            "activation": "ReLU",
            "output_activation": "None",
            "n_neurons": 64,
            "n_hidden_layers": 1,
        })
model = torch.nn.Sequential(encoding, network)

In [4]:
encoding(torch.tensor([[1.0, 2.0, 3.0]]).cuda())

tensor([[ 0.0000e+00, -1.0000e+00, -1.7881e-07,  1.0000e+00, -5.3644e-07,
          1.0000e+00, -1.3113e-06,  1.0000e+00, -2.8014e-06,  1.0000e+00,
         -5.7817e-06,  1.0000e+00, -1.1802e-05,  1.0000e+00, -2.3782e-05,
          1.0000e+00, -4.7743e-05,  1.0000e+00, -9.5665e-05,  1.0000e+00,
         -1.7881e-07,  1.0000e+00, -5.3644e-07,  1.0000e+00, -1.3113e-06,
          1.0000e+00, -2.8014e-06,  1.0000e+00, -5.7817e-06,  1.0000e+00,
         -1.1802e-05,  1.0000e+00, -2.3782e-05,  1.0000e+00, -4.7743e-05,
          1.0000e+00, -9.5665e-05,  1.0000e+00, -1.9157e-04,  1.0000e+00,
          5.3644e-07, -1.0000e+00, -1.3113e-06,  1.0000e+00, -2.8014e-06,
          1.0000e+00, -5.7817e-06,  1.0000e+00, -1.1802e-05,  1.0000e+00,
         -2.3782e-05,  1.0000e+00, -4.7743e-05,  1.0000e+00, -9.5665e-05,
          1.0000e+00, -1.9157e-04,  1.0000e+00, -3.8338e-04,  1.0000e+00]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>)

In [5]:
encoding.n_output_dims

60

In [6]:
x = torch.randn((32, 3)).cuda()
gamma_x = xyz_encoder(x)
gamma_x.shape

torch.Size([32, 16])

In [7]:
d = torch.randn((32, 3)).cuda()
d_normed = F.normalize(d)
gamma_d = dir_encoder(d_normed)
gamma_d.shape

torch.Size([32, 16])

In [8]:
gamma_d[0]

tensor([  0.2820,   0.7134,  -0.9604,   1.3145,   4.2891,  -3.1348,   3.3418,
         -5.7773,   2.7891,  16.8594, -22.3125,  12.2266, -11.9844,  22.5312,
        -14.5078,   1.3379], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)

In [9]:
rgb = rgb_net(torch.cat([gamma_x, gamma_d], 1))

In [10]:
rgb.shape

torch.Size([32, 3])

In [11]:
torch.save(xyz_encoder.state_dict(), "xyz_sd.pth")
torch.save(xyz_encoder, "xyz.pth")

In [12]:
torch.load("xyz.pth")

NetworkWithInputEncoding(n_input_dims=3, n_output_dims=16, seed=1337, dtype=torch.float16, hyperparams={'encoding': {'base_resolution': 4, 'hash': 'CoherentPrime', 'interpolation': 'Linear', 'log2_hashmap_size': 10, 'n_features_per_level': 2, 'n_levels': 4, 'otype': 'Grid', 'per_level_scale': 1.25, 'type': 'Hash'}, 'network': {'activation': 'ReLU', 'n_hidden_layers': 1, 'n_neurons': 64, 'otype': 'FullyFusedMLP', 'output_activation': 'None'}, 'otype': 'NetworkWithInputEncoding'})

In [13]:
torch.load("xyz_sd.pth")

OrderedDict([('params',
              tensor([-1.9296e-01, -1.5868e-02, -1.7857e-01,  ...,  8.1820e-05,
                      -7.1032e-05,  8.4406e-05], device='cuda:0'))])

In [14]:
for name, params in encoding.named_parameters():
    print(name)
    print(params.shape)

params
torch.Size([0])


In [15]:
for name, params in network.named_parameters():
    print(name)
    print(params.shape)

params
torch.Size([5120])


In [16]:
for name, params in xyz_encoder.named_parameters():
    print(name)
    print(params.shape)

params
torch.Size([4144])


In [17]:
xyz_enc = nn.Sequential(encoding, network)
xyz_enc

Sequential(
  (0): Encoding(n_input_dims=3, n_output_dims=60, seed=1337, dtype=torch.float16, hyperparams={'n_frequencies': 10, 'otype': 'Frequency'})
  (1): Network(n_input_dims=60, n_output_dims=16, seed=1337, dtype=torch.float16, hyperparams={'encoding': {'offset': 0.0, 'otype': 'Identity', 'scale': 1.0}, 'network': {'activation': 'ReLU', 'n_hidden_layers': 1, 'n_neurons': 64, 'otype': 'FullyFusedMLP', 'output_activation': 'None'}, 'otype': 'NetworkWithInputEncoding'})
)

In [18]:
for name, params in xyz_enc.named_parameters():
    print(name)
    print(params.shape)

0.params
torch.Size([0])
1.params
torch.Size([5120])
