<a href="https://colab.research.google.com/github/jingz666/NeuF/blob/main/NeuS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.

In [4]:
class Embedder:
  def __init__(self, **kwargs):
    self.kwargs = kwargs
    self.create_embedding_fn()

  def create_embedding_fn(self):
    embed_fns = []
    d = self.kwargs['input_dims']
    out_dim = 0
    if self.kwargs['include_input']:
      embed_fns.append(lambda x:x)
      out_dim += d
    
    max_freq = self.kwargs['max_freq_log2']
    N_freqs = self.kwargs['num_freqs']

    if self.kwargs['log_sampling']:
      freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) # tensor([  1.,   2.,   4.,   8.,  16.,  32.,  64., 128., 256., 512.])
    else:
      freq_bands = torch.linspace(2.**0, 2.**max_freq, N_freqs) # tensor([  1.0000,  57.7778, 114.5556, 171.3333, 228.1111, 284.8889, 341.6667, 398.4445, 455.2222, 512.0000])

    for freq in freq_bands:
      for p_fn in self.kwargs['periodic_fns']:
        embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
        out_dim += d

    self.embed_fns = embed_fns
    self.out_dim = out_dim

    def embed(self, inputs):
      return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(multires, input_dims=3):
  embed_kwargs = {
      'include_input': True,
      'input_dims': input_dims,
      'max_freq_log2': multires-1,
      'num_freqs': multires,
      'log_sampling': True,
      'periodic_fns': [torch.sin, torch.cos],
  }

  embedder_obj = Embedder(**embed_kwargs)
  def embed(x, eo=embedder_obj): return eo.embed(x)
  return embed, embedder_obj.out_dim 

### This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch

In [5]:
import torch.nn.functional as F
import numpy as np

In [18]:
class NeRF(nn.Module):
  def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4], use_viewdirs=False):
    super(NeRF, self).__init__()
    self.D = D
    self.W = W
    self.d_in = d_in
    self.d_in_view = d_in_view
    self.input_ch = 3
    self.input_ch_view = 3
    self.embed_fn = None
    self.embed_fn_view = None

    if multires > 0:
      embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
      self.embed_fn = embed_fn
      self.input_ch = input_ch

    if multires_view > 0:
      embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view)
      self.embed_fn_view = embed_fn_view
      self.input_ch_view = input_ch_view

    self.skips = skips
    self.use_viewdirs = use_viewdirs

#     ModuleList(
#   (0): Linear(in_features=36, out_features=256, bias=True)
#   (1): Linear(in_features=256, out_features=256, bias=True)
#   (2): Linear(in_features=256, out_features=256, bias=True)
#   (3): Linear(in_features=256, out_features=256, bias=True)
#   (4): Linear(in_features=256, out_features=256, bias=True)
#   (5): Linear(in_features=292, out_features=256, bias=True)
#   (6): Linear(in_features=256, out_features=256, bias=True)
#   (7): Linear(in_features=256, out_features=256, bias=True)
# )
    self.pts_linears = nn.ModuleList(
        [nn.Linear(self.input_ch, W)]+
        [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)] # i=0,1,...,7
    )

    self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])

    if use_viewdirs:
      self.feature_linear = nn.Linear(W, W)
      self.alpha_linear = nn.Linear(W, 1)
      self.rgb_linear = nn.Linear(W // 2, 3)
    else:
      self.output_linear = nn.Linear(W, output_ch)
  
  def forward(self, input_pts, input_views):
    if self.embed_fn is not None:
      input_pts = self.embed_fn(input_pts)
    if self.embed_fn_view is not None:
      input_views = self.embed_fn_view(input_views)
    
    h = input_pts
    for i, l in enumerate(self.pts_linears):
      h = self.pts_linears[i](h)
      h = F.relu(h)
      if i in self.skips:
        h = torch.cat([input_pts, h], -1)
    
    if self.use_viewdirs:
      alpha = self.alpha_linear(h)
      feature = self.feature_linear(h)
      h = torch.cat([feature, input_views], -1)

      for i,l in enumerate(self.views_linears):
        h = self.views_linears[i](h)
        h = F.relu(h)

      rgb = self.rgb_linear(h)
      return alpha, rgb
    else:
      assert False

# This implementation is borrowed from IDR: https://github.com/lioryariv/idr
class

In [17]:
nn.ModuleList([nn.Linear(36 + 256, 256 // 2)])

ModuleList(
  (0): Linear(in_features=292, out_features=128, bias=True)
)

In [16]:
print(i) if i not in [4] else nn.Linear(256 + 36, 256) for i in range(8 - 1)]

SyntaxError: ignored