In [None]:
import os

import torch
from torch import nn
from torch.utils.data import DataLoader, IterableDataset
import numpy as np
import skimage
import skimage.transform
from mpl_toolkits.axes_grid1 import make_axes_locatable

import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

from dataset import diff_rendering_dataset
from rendering import VolumeRenderer
from fields import RadianceField
from training import fit_inverse_graphics_representation

In [None]:
# Set device

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
print(f"Device: {device}")

In [None]:
# Load data

data_path = "./data/bunny"
cam2world = np.load(os.path.join(data_path, "cam2world.npy"))
images = np.load(os.path.join(data_path, "images.npy"))

cam2world = torch.Tensor(cam2world).to(device)
images = torch.tensor(images).to(device)
intrinsics = torch.tensor([[0.7, 0., 0.5],
                            [0., 0.7, 0.5],
                            [0., 0., 1.]]).to(device)
print(cam2world.shape, images.shape)

In [None]:
# Generate dataset

bunny_dataset = diff_rendering_dataset(images, cam2world, device=device)
model_input, gt = next(bunny_dataset)

plt.imshow(gt.view(images.shape[1], images.shape[2], 3).detach().cpu())
plt.show()

In [None]:
# Train

radiance_field = RadianceField(scene_rep_name="HybridVoxelNeuralField", device=device).to(device)
renderer = VolumeRenderer(near=1.5, far=4.5, n_samples=128, white_back=True, rand=False).to(device)
fit_inverse_graphics_representation(
    representation=radiance_field,
    renderer=renderer,
    data_iter=bunny_dataset,
    img_resolution=(128, 128, 3),
    lr=1e-3,
    total_steps=2_001
)