In [1]:
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This example demonstrates scene optimization with the PyTorch3D
pulsar interface. For this, a reference image has been pre-generated
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`).
The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful
scene representation.
"""
import logging
import math

import cv2
import imageio
import numpy as np
import torch

# Import `look_at_view_transform` as needed in the suggestion later in the
# example.
from pytorch3d.renderer.cameras import PerspectiveCameras  # , look_at_view_transform
from pytorch3d.renderer.points import (
    PointsRasterizationSettings,
    PointsRasterizer,
    PulsarPointsRenderer,
)
from pytorch3d.structures.pointclouds import Pointclouds
from torch import nn, optim

In [4]:
LOGGER = logging.getLogger(__name__)
N_POINTS = 10_000
WIDTH = 1_000
HEIGHT = 1_000
DEVICE = torch.device("cuda")

class SceneModel(nn.Module):
    """
    A simple scene model to demonstrate use of pulsar in PyTorch modules.

    The scene model is parameterized with sphere locations (vert_pos),
    channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
    camera rotation (cam_rot) and sensor focal length and width (cam_sensor).

    The forward method of the model renders this scene description. Any
    of these parameters could instead be passed as inputs to the forward
    method and come from a different model.
    """

    def __init__(self):
        super(SceneModel, self).__init__()
        self.gamma = 1.0
        # Points.
        torch.manual_seed(1)
        vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 10.0
        vert_pos[:, 2] += 25.0
        vert_pos[:, :2] -= 5.0
        self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
        self.register_parameter(
            "vert_col",
            nn.Parameter(
                torch.ones(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 0.5,
                requires_grad=True,
            ),
        )
        self.register_parameter(
            "vert_rad",
            nn.Parameter(
                torch.ones(N_POINTS, dtype=torch.float32) * 0.3, requires_grad=True
            ),
        )
        self.register_buffer(
            "cam_params",
            torch.tensor(
                [0.0, 0.0, 0.0, 0.0, math.pi, 0.0, 5.0, 2.0], dtype=torch.float32
            ),
        )
        self.cameras = PerspectiveCameras(
            # The focal length must be double the size for PyTorch3D because of the NDC
            # coordinates spanning a range of two - and they must be normalized by the
            # sensor width (see the pulsar example). This means we need here
            # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar.
            focal_length=5.0,
            R=torch.eye(3, dtype=torch.float32, device=DEVICE)[None, ...],
            T=torch.zeros((1, 3), dtype=torch.float32, device=DEVICE),
            image_size=((WIDTH, HEIGHT),),
            device=DEVICE,
        )
        raster_settings = PointsRasterizationSettings(
            image_size=(WIDTH, HEIGHT),
            radius=self.vert_rad,
        )
        rasterizer = PointsRasterizer(
            cameras=self.cameras, raster_settings=raster_settings
        )
        self.renderer = PulsarPointsRenderer(rasterizer=rasterizer, n_track=32)

    def forward(self):
        # The Pointclouds object creates copies of it's arguments - that's why
        # we have to create a new object in every forward step.
        pcl = Pointclouds(
            points=self.vert_pos[None, ...], features=self.vert_col[None, ...]
        )
        return self.renderer(
            pcl,
            gamma=(self.gamma,),
            zfar=(45.0,),
            znear=(1.0,),
            radius_world=True,
            bg_col=torch.ones((3,), dtype=torch.float32, device=DEVICE),
        )[0]


def cli():
    """
    Scene optimization example using pulsar and the unified PyTorch3D interface.
    """
    LOGGER.info("Loading reference...")
    # Load reference.
    ref = (
        torch.from_numpy(
            imageio.imread(
                "./data/examples_TestRenderer_test_smallopt.png"
            )[:, ::-1, :].copy()
        ).to(torch.float32)
        / 255.0
    ).to(DEVICE)
    # Set up model.
    model = SceneModel().to(DEVICE)
    # Optimizer.
    optimizer = optim.SGD(
        [
            {"params": [model.vert_col], "lr": 1e0},
            {"params": [model.vert_rad], "lr": 5e-3},
            {"params": [model.vert_pos], "lr": 1e-2},
        ]
    )
    LOGGER.info("Optimizing...")
    # Optimize.
    for i in range(500):
        optimizer.zero_grad()
        result = model()
        # Visualize.
        result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
        cv2.imshow("opt", result_im[:, :, ::-1])
        overlay_img = np.ascontiguousarray(
            ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
                :, :, ::-1
            ]
        )
        overlay_img = cv2.putText(
            overlay_img,
            "Step %d" % (i),
            (10, 40),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 0, 0),
            2,
            cv2.LINE_AA,
            False,
        )
        cv2.imshow("overlay", overlay_img)
        cv2.waitKey(1)
        # Update.
        loss = ((result - ref) ** 2).sum()
        LOGGER.info("loss %d: %f", i, loss.item())
        loss.backward()
        optimizer.step()
        # Cleanup.
        with torch.no_grad():
            model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
            # Remove points.
            model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
            model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
            vd = (
                (model.vert_col - torch.ones(3, dtype=torch.float32).to(DEVICE))
                .abs()
                .sum(dim=1)
            )
            model.vert_pos.data[vd <= 0.2] = -1000.0
    LOGGER.info("Done.")


In [5]:
logging.basicConfig(level=logging.INFO)
cli()

INFO:__main__:Loading reference...
INFO:__main__:Optimizing...
INFO:__main__:loss 0: 488075.500000
INFO:__main__:loss 1: 409452.125000
INFO:__main__:loss 2: 346267.125000
INFO:__main__:loss 3: 294852.500000
INFO:__main__:loss 4: 252508.687500
INFO:__main__:loss 5: 217286.156250
INFO:__main__:loss 6: 187805.812500
INFO:__main__:loss 7: 162923.531250
INFO:__main__:loss 8: 141823.625000
INFO:__main__:loss 9: 123827.796875
INFO:__main__:loss 10: 108440.523438
INFO:__main__:loss 11: 95148.750000
INFO:__main__:loss 12: 83743.515625
INFO:__main__:loss 13: 73848.468750
INFO:__main__:loss 14: 65293.097656
INFO:__main__:loss 15: 57912.062500
INFO:__main__:loss 16: 51409.648438
INFO:__main__:loss 17: 45727.000000
INFO:__main__:loss 18: 40828.148438
INFO:__main__:loss 19: 36610.023438
INFO:__main__:loss 20: 32953.355469
INFO:__main__:loss 21: 29807.783203
INFO:__main__:loss 22: 27107.919922
INFO:__main__:loss 23: 24652.171875
INFO:__main__:loss 24: 22547.078125
INFO:__main__:loss 25: 20667.951172


INFO:__main__:loss 230: 690.683289
INFO:__main__:loss 231: 687.790771
INFO:__main__:loss 232: 684.717773
INFO:__main__:loss 233: 681.991821
INFO:__main__:loss 234: 679.000488
INFO:__main__:loss 235: 676.400024
INFO:__main__:loss 236: 674.494324
INFO:__main__:loss 237: 670.569580
INFO:__main__:loss 238: 667.609253
INFO:__main__:loss 239: 665.298584
INFO:__main__:loss 240: 662.967529
INFO:__main__:loss 241: 659.449463
INFO:__main__:loss 242: 657.318970
INFO:__main__:loss 243: 653.644287
INFO:__main__:loss 244: 651.926392
INFO:__main__:loss 245: 648.921570
INFO:__main__:loss 246: 646.026672
INFO:__main__:loss 247: 636.805054
INFO:__main__:loss 248: 634.394531
INFO:__main__:loss 249: 630.962341
INFO:__main__:loss 250: 628.255371
INFO:__main__:loss 251: 626.324402
INFO:__main__:loss 252: 622.607178
INFO:__main__:loss 253: 619.860779
INFO:__main__:loss 254: 617.742981
INFO:__main__:loss 255: 614.695618
INFO:__main__:loss 256: 612.342529
INFO:__main__:loss 257: 611.322754
INFO:__main__:loss 2

INFO:__main__:loss 465: 396.448456
INFO:__main__:loss 466: 395.784729
INFO:__main__:loss 467: 396.074951
INFO:__main__:loss 468: 395.710205
INFO:__main__:loss 469: 395.277100
INFO:__main__:loss 470: 395.182892
INFO:__main__:loss 471: 394.314941
INFO:__main__:loss 472: 394.156921
INFO:__main__:loss 473: 393.524506
INFO:__main__:loss 474: 394.229065
INFO:__main__:loss 475: 393.715820
INFO:__main__:loss 476: 393.433472
INFO:__main__:loss 477: 392.815613
INFO:__main__:loss 478: 392.717438
INFO:__main__:loss 479: 392.244629
INFO:__main__:loss 480: 391.991943
INFO:__main__:loss 481: 391.667542
INFO:__main__:loss 482: 390.612122
INFO:__main__:loss 483: 391.378540
INFO:__main__:loss 484: 390.847931
INFO:__main__:loss 485: 390.661285
INFO:__main__:loss 486: 390.580261
INFO:__main__:loss 487: 390.332764
INFO:__main__:loss 488: 389.850189
INFO:__main__:loss 489: 389.406219
INFO:__main__:loss 490: 389.482422
INFO:__main__:loss 491: 389.108398
INFO:__main__:loss 492: 388.757416
INFO:__main__:loss 4