Test Planar Cube data.

In [None]:
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
import einops

from state_encoder_3d.models import (
    LatentNeRF, VolumeRenderer, init_weights_normal
)
from state_encoder_3d.dataset import PlanarCubeDataset

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
print(f"Using device {device}")

In [None]:
batch_size = 2
dataset = PlanarCubeDataset(
    data_store_path="../data/planar_cube_grid.zarr",
    max_num_instances=1,
    num_views=1,
)
dataloader = iter(torch.utils.data.DataLoader(dataset, batch_size=batch_size))

In [None]:
def plot_output_ground_truth(img, depth, gt_img, resolution):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), squeeze=False)
    axes[0, 0].imshow(img.cpu().view(*resolution).detach().numpy())
    axes[0, 0].set_title("Trained MLP")
    axes[0, 1].imshow(gt_img.cpu().view(*resolution).detach().numpy())
    axes[0, 1].set_title("Ground Truth")
    
    depth = depth.cpu().view(*resolution[:2]).detach().numpy()
    axes[0, 2].imshow(depth, cmap='Greys')
    axes[0, 2].set_title("Depth")
    
    for i in range(3):
        axes[0, i].set_axis_off()

    plt.show()

In [None]:
latent_dim = 256
nerf = LatentNeRF(latent_ch=latent_dim).to(device)
nerf.apply(init_weights_normal)
# Near and far based on z_distances in PlanarCubeEnvironment
renderer = VolumeRenderer(near=4, far=13, n_samples=100, white_back=False).to(device)

In [None]:
# Loss
img2mse = lambda x, y: torch.mean((x - y) ** 2)

lr = 5e-4
optim = torch.optim.Adam(nerf.parameters(), lr=lr, betas=(0.9, 0.999))

In [None]:
# Constant latent as we have a single scene
latent = 0.1*torch.rand((1,latent_dim), device=device)
latent = einops.repeat(latent, "b ... -> (repeat b) ...", repeat=batch_size)

num_steps = 10001
steps_til_summary = 100
for step in tqdm(range(num_steps)):
    model_input, gt_image = next(dataloader)
    xy_pix = model_input['x_pix'].to(device)
    intrinsics = model_input['intrinsics'].to(device)
    c2w = model_input['cam2world'].to(device)
    
    rgb, depth = renderer(c2w, intrinsics, xy_pix, nerf, latent)

    loss = img2mse(rgb, gt_image.to(device))

    optim.zero_grad()
    loss.backward()
    optim.step()

    # Every so often, we want to show what our model has learned.
    # It would be boring otherwise!
    if not step % steps_til_summary:
        print(f"Step {step}: loss = {float(loss.detach().cpu()):.5f}")

        plot_output_ground_truth(rgb[0], depth[0], gt_image[0], resolution=(64,64,3))