In [1]:
import torch
from model.model import Model

In [2]:
device = torch.device("cpu")

In [3]:
model = Model(
    image_width=640//8,
    image_height=352//8,
    num_time_steps_per_frame=7,
    num_frames=76,
    num_coarse_samples_per_ray=4,
    num_fine_samples_per_ray=4,
    near_depth=0.0,
    far_depth=1.0,
    num_scene_trajectory_basis_coefficients=4,
    num_camera_trajectory_basis_coefficients=4,
    num_voxels_per_axis=8,
    min_bound_per_axis=1.0,
    max_bound_per_axis=10.0,
    voxel_dim=4,
    color_model_hidden_dim=4,
    device=device,
)

In [4]:
sum(p.numel() for p in model.parameters())

1534292

In [5]:
print(model)

Model(
  (model): Render(
    (camera_motion_model): CameraMotionModel(
      (positional_embedding): PositionalEmbedding()
      (linear_layers): ModuleList(
        (0): Linear(in_features=33, out_features=256, bias=True)
        (1-4): 4 x Linear(in_features=256, out_features=256, bias=True)
        (5): Linear(in_features=289, out_features=256, bias=True)
        (6-8): 3 x Linear(in_features=256, out_features=256, bias=True)
        (9): Linear(in_features=289, out_features=256, bias=True)
        (10-12): 3 x Linear(in_features=256, out_features=256, bias=True)
        (13): Linear(in_features=289, out_features=256, bias=True)
        (14-15): 2 x Linear(in_features=256, out_features=256, bias=True)
      )
      (output_layer): Linear(in_features=256, out_features=24, bias=True)
    )
    (render_ray): RenderRay(
      (scene_motion_model): SceneMotionModel(
        (positional_embedding): PositionalEmbedding()
        (linear_layers): ModuleList(
          (0): Linear(in_featur

In [6]:
print(model.total_num_time_steps)

532


In [7]:
sample_output = model(
    torch.tensor([3])
)

In [8]:
print(sample_output)

tensor([[[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.1015, 1.0000,  ..., 0.1015, 1.0000, 1.0000],
           [1.0000, 1.0000, 0.0000,  ..., 1.0000, 0.0000, 0.0000],
           ...,
           [1.0000, 1.0000, 0.0000,  ..., 1.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0103,  ..., 0.0000, 0.0572, 1.0000],
           [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
           [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
           [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.1015, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
           [1.0000, 1.0000

In [11]:
print(sample_output.shape)

torch.Size([1, 7, 3, 80, 44])
