Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Camera position optimization with differentiable rendering using depth image #207

Closed
guanming001 opened this issue May 22, 2020 · 5 comments
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@guanming001
Copy link
Contributor

Hi thanks for creating this amazing library!

I tried to modify the tutorial example
camera_position_optimization_with_differentiable_rendering.ipynb to optimize the camera position using rendered depth image instead of the silhouette image.

But the results seem weird as the camera position does not converge to the desired location but keeps zooming into the object. Did I miss out anything?

teapot_optimization_demo

Thank you.

The modified jupyter notebook is attached below for reference:

camera_position_optimization_with_differentiable_rendering_depth.ipynb
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-P3OUvJirQdR"
   },
   "outputs": [],
   "source": [
    "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "44lB2sH-rQdW"
   },
   "source": [
    "# Camera position optimization using differentiable rendering\n",
    "\n",
    "In this tutorial we will learn the [x, y, z] position of a camera given a reference image using differentiable rendering. \n",
    "\n",
    "We will first initialize a renderer with a starting position for the camera. We will then use this to generate an image, compute a loss with the reference image, and finally backpropagate through the entire pipeline to update the position of the camera. \n",
    "\n",
    "This tutorial shows how to:\n",
    "- load a mesh from an `.obj` file\n",
    "- initialize a `Camera`, `Shader` and `Renderer`,\n",
    "- render a mesh\n",
    "- set up an optimization loop with a loss function and optimizer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AZGmIlmWrQdX"
   },
   "source": [
    "##  0. Install and import modules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qkX7DiM6rmeM"
   },
   "source": [
    "If `torch`, `torchvision` and `pytorch3d` are not installed, run the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 717
    },
    "colab_type": "code",
    "id": "sEVdNGFwripM",
    "outputId": "27047061-a29b-4562-c164-c1288e24c266"
   },
   "outputs": [],
   "source": [
    "!pip install torch torchvision\n",
    "!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "w9mH5iVprQdZ"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm_notebook\n",
    "import imageio\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage import img_as_ubyte\n",
    "\n",
    "# io utils\n",
    "from pytorch3d.io import load_obj\n",
    "\n",
    "# datastructures\n",
    "from pytorch3d.structures import Meshes, Textures\n",
    "\n",
    "# 3D transformations functions\n",
    "from pytorch3d.transforms import Rotate, Translate\n",
    "\n",
    "# rendering components\n",
    "from pytorch3d.renderer import (\n",
    "    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation, \n",
    "    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,\n",
    "    SoftSilhouetteShader, HardPhongShader, PointLights\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "cpUf2UvirQdc"
   },
   "source": [
    "## 1. Load the Obj\n",
    "\n",
    "We will load an obj file and create a **Meshes** object. **Meshes** is a unique datastructure provided in PyTorch3D for working with **batches of meshes of different sizes**. It has several useful class methods which are used in the rendering pipeline. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8d-oREfkrt_Z"
   },
   "source": [
    "If you are running this notebook locally after cloning the PyTorch3D repository, the mesh will already be available. **If using Google Colab, fetch the mesh and save it at the path `data/`**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "sD5KcLuJr0PL",
    "outputId": "e65061fa-dbd5-4c06-b559-3592632983ee"
   },
   "outputs": [],
   "source": [
    "!mkdir -p data\n",
    "!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VWiPKnEIrQdd"
   },
   "outputs": [],
   "source": [
    "# Set the cuda device \n",
    "device = torch.device(\"cuda:0\")\n",
    "torch.cuda.set_device(device)\n",
    "\n",
    "# Load the obj and ignore the textures and materials.\n",
    "verts, faces_idx, _ = load_obj(\"./data/teapot.obj\")\n",
    "faces = faces_idx.verts_idx\n",
    "\n",
    "# Initialize each vertex to be white in color.\n",
    "verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)\n",
    "textures = Textures(verts_rgb=verts_rgb.to(device))\n",
    "\n",
    "# Create a Meshes object for the teapot. Here we have only one mesh in the batch.\n",
    "teapot_mesh = Meshes(\n",
    "    verts=[verts.to(device)],   \n",
    "    faces=[faces.to(device)], \n",
    "    textures=textures\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mgtGbQktrQdh"
   },
   "source": [
    "\n",
    "\n",
    "## 2. Optimization setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Q6PzKD_NrQdi"
   },
   "source": [
    "### Create a renderer\n",
    "\n",
    "A **renderer** in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a number of subcomponents such as a **camera** (orthgraphic/perspective). Here we initialize some of these components and use default values for the rest. \n",
    "\n",
    "For optimizing the camera position we will use a renderer which produces a **silhouette** of the object only and does not apply any **lighting** or **shading**. We will also initialize another renderer which applies full **phong shading** and use this for visualizing the outputs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "KPlby75GrQdj"
   },
   "outputs": [],
   "source": [
    "# Initialize an OpenGL perspective camera.\n",
    "cameras = OpenGLPerspectiveCameras(device=device)\n",
    "\n",
    "# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of \n",
    "# edges. Refer to blending.py for more details. \n",
    "blend_params = BlendParams(sigma=1e-4, gamma=1e-4)\n",
    "\n",
    "# Define the settings for rasterization and shading. Here we set the output image to be of size\n",
    "# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that \n",
    "# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for \n",
    "# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of \n",
    "# the difference between naive and coarse-to-fine rasterization. \n",
    "raster_settings = RasterizationSettings(\n",
    "    image_size=256, \n",
    "    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, \n",
    "    faces_per_pixel=100, \n",
    ")\n",
    "\n",
    "# Create a silhouette mesh renderer by composing a rasterizer and a shader. \n",
    "silhouette_renderer = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=cameras, \n",
    "        raster_settings=raster_settings\n",
    "    ),\n",
    "    shader=SoftSilhouetteShader(blend_params=blend_params)\n",
    ")\n",
    "\n",
    "\n",
    "# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.\n",
    "raster_settings = RasterizationSettings(\n",
    "    image_size=256, \n",
    "    blur_radius=0.0, \n",
    "    faces_per_pixel=1, \n",
    ")\n",
    "# We can add a point light in front of the object. \n",
    "lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))\n",
    "phong_renderer = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=cameras, \n",
    "        raster_settings=raster_settings\n",
    "    ),\n",
    "    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n",
    ")\n",
    "\n",
    "# NEW added to generate depth image \n",
    "# using the same raster settings as the above phong renderer\n",
    "depth_renderer = MeshRasterizer(\n",
    "    cameras=cameras, \n",
    "    raster_settings=raster_settings\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "osOy2OIJrQdn"
   },
   "source": [
    "### Create a reference image\n",
    "\n",
    "We will first position the teapot and generate an image. We use helper functions to rotate the teapot to a desired viewpoint. Then we can use the renderers to produce an image. Here we will use both renderers and visualize the silhouette and full shaded image. \n",
    "\n",
    "The world coordinate system is defined as +Y up, +X left and +Z in. The teapot in world coordinates has the spout pointing to the left. \n",
    "\n",
    "We defined a camera which is positioned on the positive z axis hence sees the spout to the right. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 305
    },
    "colab_type": "code",
    "id": "EjJrW7qerQdo",
    "outputId": "93545b65-269e-4719-f4a2-52cbc6c9c974"
   },
   "outputs": [],
   "source": [
    "# Select the viewpoint using spherical angles  \n",
    "distance = 3   # distance from camera to the object\n",
    "elevation = 50.0   # angle of elevation in degrees\n",
    "azimuth = 0.0  # No rotation so the camera is positioned on the +Z axis. \n",
    "\n",
    "# Get the position of the camera based on the spherical angles\n",
    "R, T = look_at_view_transform(distance, elevation, azimuth, device=device)\n",
    "\n",
    "# Render the teapot providing the values of R and T. \n",
    "silhouete = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
    "image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
    "# NEW added to generate depth image \n",
    "depth_ref = depth_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
    "depth_ref = depth_ref.zbuf\n",
    "\n",
    "silhouete = silhouete.cpu().numpy()\n",
    "image_ref = image_ref.cpu().numpy()\n",
    "# NEW added to generate depth image\n",
    "depth_ref = depth_ref.cpu().numpy()\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.imshow(silhouete.squeeze()[..., 3])  # only plot the alpha channel of the RGBA image\n",
    "plt.grid(False)\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.imshow(image_ref.squeeze())\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.imshow(depth_ref.squeeze())\n",
    "plt.grid(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "plBJwEslrQdt"
   },
   "source": [
    "### Set up a basic model \n",
    "\n",
    "Here we create a simple model class and initialize a parameter for the camera position. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YBbP1-EDrQdu"
   },
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self, meshes, renderer, image_ref):\n",
    "        super().__init__()\n",
    "        self.meshes = meshes\n",
    "        self.device = meshes.device\n",
    "        self.renderer = renderer\n",
    "        \n",
    "        # Get the silhouette of the reference RGB image by finding all the non zero values. \n",
    "        # image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 0).astype(np.float32))\n",
    "        # self.register_buffer('image_ref', image_ref)\n",
    "        \n",
    "        # NEW added to get depth image\n",
    "        image_ref = torch.from_numpy((image_ref).astype(np.float32))\n",
    "        self.register_buffer('image_ref', image_ref)\n",
    "        \n",
    "        # Create an optimizable parameter for the x, y, z position of the camera. \n",
    "        self.camera_position = nn.Parameter(\n",
    "            torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))\n",
    "\n",
    "    def forward(self):\n",
    "        \n",
    "        # Render the image using the updated camera position. Based on the new position of the \n",
    "        # camer we calculate the rotation and translation matrices\n",
    "        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)\n",
    "        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)\n",
    "        \n",
    "        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)\n",
    "        image = image.zbuf\n",
    "        \n",
    "        # Calculate the silhouette loss\n",
    "        # loss = torch.sum((image[..., 3] - self.image_ref) ** 2)\n",
    "        \n",
    "        # NEW Calculate the depth image loss\n",
    "        loss = torch.sum((image - self.image_ref) ** 2)\n",
    "        \n",
    "        return loss, image\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qCGLSJtfrQdy"
   },
   "source": [
    "## 3. Initialize the model and optimizer\n",
    "\n",
    "Now we can create an instance of the **model** above and set up an **optimizer** for the camera position parameter. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "srZPBU7_rQdz"
   },
   "outputs": [],
   "source": [
    "# We will save images periodically and compose them into a GIF.\n",
    "filename_output = \"./teapot_optimization_demo.gif\"\n",
    "writer = imageio.get_writer(filename_output, mode='I', duration=0.3)\n",
    "\n",
    "# Initialize a model using the renderer, mesh and reference image\n",
    "# model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)\n",
    "\n",
    "# NEW use depth_ref instead of image_ref\n",
    "model = Model(meshes=teapot_mesh, renderer=depth_renderer, image_ref=depth_ref).to(device)\n",
    "\n",
    "# Create an optimizer. Here we are using Adam and we pass in the parameters of the model\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dvTLnrWorQd2"
   },
   "source": [
    "### Visualize the starting position and the reference position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 335
    },
    "colab_type": "code",
    "id": "qyRXpP3mrQd3",
    "outputId": "47ecb12a-e68c-47f5-92fc-821a7a9bd661"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "\n",
    "_, image_init = model()\n",
    "plt.subplot(1, 2, 1)\n",
    "# plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])\n",
    "plt.imshow(image_init.detach().squeeze().cpu().numpy()) # NEW for plotting depth image\n",
    "plt.grid(False)\n",
    "plt.title(\"Starting position\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.imshow(model.image_ref.cpu().numpy().squeeze())\n",
    "plt.grid(False)\n",
    "# plt.title(\"Reference silhouette\")\n",
    "plt.title(\"Reference depth\") # NEW for plotting depth image\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aGJu7h-lrQd5"
   },
   "source": [
    "## 4. Run the optimization \n",
    "\n",
    "We run several iterations of the forward and backward pass and save outputs every 10 iterations. When this has finished take a look at `./teapot_optimization_demo.gif` for a cool gif of the optimization process!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "79d7fc84b5564206ab64b2759474da04",
      "02acadb61c3949fcaeab177fd184c388",
      "efd9860908c64bfe9d47118be4734648",
      "f8df7c6efb7d47f5be760a39b4bdbcf8",
      "d8a109658c364a00ab4d298112dac6db",
      "2d05db82cc99482bb3d62b6d4e5b1a98",
      "c621d425e2c8426c8cd4f9136d392af1",
      "3df8063f307040ebb8ff8e2f26ccf729"
     ]
    },
    "colab_type": "code",
    "id": "HvnK5VI5rQd6",
    "outputId": "4019c697-3fc6-4c7b-cdfe-225633cc0d60"
   },
   "outputs": [],
   "source": [
    "loop = tqdm_notebook(range(200))\n",
    "for i in loop:\n",
    "    optimizer.zero_grad()\n",
    "    loss, _ = model()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    loop.set_description('Optimizing (loss %.4f)' % loss.data)\n",
    "    \n",
    "    if loss.item() < 200:\n",
    "        break\n",
    "    \n",
    "    # Save outputs to create a GIF. \n",
    "    if i % 10 == 0:\n",
    "        R = look_at_rotation(model.camera_position[None, :], device=model.device)\n",
    "        T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)\n",
    "        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)\n",
    "        image = image[0, ..., :3].detach().squeeze().cpu().numpy()\n",
    "        image = img_as_ubyte(image)\n",
    "        writer.append_data(image)\n",
    "        \n",
    "        plt.figure()\n",
    "        plt.imshow(image[..., :3])\n",
    "        plt.title(\"iter: %d, loss: %0.2f\" % (i, loss.data))\n",
    "        plt.grid(\"off\")\n",
    "        plt.axis(\"off\")\n",
    "    \n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mWj80P_SsPTN"
   },
   "source": [
    "## 5. Conclusion \n",
    "\n",
    "In this tutorial we learnt how to **load** a mesh from an obj file, initialize a PyTorch3D datastructure called **Meshes**, set up an **Renderer** consisting of a **Rasterizer** and a **Shader**, set up an optimization loop including a **Model** and a **loss function**, and run  the optimization. "
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "colab": {
   "name": "camera_position_optimization_with_differentiable_rendering.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "02acadb61c3949fcaeab177fd184c388": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2d05db82cc99482bb3d62b6d4e5b1a98": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "3df8063f307040ebb8ff8e2f26ccf729": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "79d7fc84b5564206ab64b2759474da04": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_efd9860908c64bfe9d47118be4734648",
       "IPY_MODEL_f8df7c6efb7d47f5be760a39b4bdbcf8"
      ],
      "layout": "IPY_MODEL_02acadb61c3949fcaeab177fd184c388"
     }
    },
    "c621d425e2c8426c8cd4f9136d392af1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "d8a109658c364a00ab4d298112dac6db": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "efd9860908c64bfe9d47118be4734648": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "Optimizing (loss 327.5365)",
      "description_tooltip": null,
      "layout": "IPY_MODEL_2d05db82cc99482bb3d62b6d4e5b1a98",
      "max": 200,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d8a109658c364a00ab4d298112dac6db",
      "value": 200
     }
    },
    "f8df7c6efb7d47f5be760a39b4bdbcf8": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3df8063f307040ebb8ff8e2f26ccf729",
      "placeholder": "​",
      "style": "IPY_MODEL_c621d425e2c8426c8cd4f9136d392af1",
      "value": "100% 200/200 [00:06&lt;00:00, 29.48it/s]"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
@gkioxari gkioxari self-assigned this May 22, 2020
@gkioxari gkioxari added the how to How to use PyTorch3D in my project label May 22, 2020
@gkioxari
Copy link
Contributor

gkioxari commented May 22, 2020

@guanming001 your notebook is not readable.

For us to be most effective with resolving issues, it'd be best if you could provide concrete errors from the codebase. We can't be consultants for people's projects or help them make their ideas work. However, if you find that a PyTorch3D op should output something but it outputs something else, then this is something we can help you with and we will debug this and try to fix it.

@guanming001
Copy link
Contributor Author

Hi @gkioxari sorry for the inconvenience caused, I have extracted the code portion from the raw format of jupyter notebook and the full version of the python code is attached at the end.

Below are a quick overview of the main changes I have made to test optimization with differentiable rendering using depth image:

I added a depth_renderer function generate depth image using the following raster_settings:

# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

...

# NEW added to generate depth image 
# using the same raster settings as the above phong renderer
depth_renderer = MeshRasterizer(
    cameras=cameras, 
    raster_settings=raster_settings
)

To extract the depth reference image, I used the .zbuf:

# NEW added to generate depth image 
depth_ref = depth_renderer(meshes_world=teapot_mesh, R=R, T=T)
depth_ref = depth_ref.zbuf

Inside the basic model class, I have modified the reference image to depth and set the initial camera position closer to the reference depth image (hoping for a simpler optimization):

class Model(nn.Module):

        ...

        # Get the silhouette of the reference RGB image by finding all the non zero values. 
        # image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 0).astype(np.float32))
        # self.register_buffer('image_ref', image_ref)
        
        # NEW added to get depth image
        depth_ref = torch.from_numpy((depth_ref).astype(np.float32))
        self.register_buffer('depth_ref', depth_ref)
        
        # Create an optimizable parameter for the x, y, z position of the camera. 
        self.camera_position = nn.Parameter(
            # Original starting point
            # torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device)) 
            # Set to a starting point closer to the reference depth image
            torch.from_numpy(np.array([0.0114, 2.3306, 2.0206], dtype=np.float32)).to(meshes.device)) 

When the forward method is called, the loss is computed based on the difference of the depth images:

    def forward(self):
        
        # Render the image using the updated camera position. Based on the new position of the 
        # camer we calculate the rotation and translation matrices
        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        
        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        # NEW added to generate depth image 
        image = image.zbuf
        
        # Calculate the silhouette loss
        # loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
        
        # NEW Calculate the depth image loss
        loss = torch.sum((image - self.depth_ref) ** 2)
        
        return loss, image

I have also lowered the learning rate of Adam optimizer:

# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
# optimizer = torch.optim.Adam(model.parameters(), lr=0.05) # Original
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Reduced learning rate

Even with the simpler starting position and lower learning rate, the optimization process seems to be unable to converge correctly (as shown in the .gif where the camera is continuously moving closer to the object)

Figure_1

teapot_optimization_demo

Full version of the python code:

camera_position_optimization_with_differentiable_rendering_depth.py
######################################################################
### Testing camera position opt with diff rendering using depth image 
### Modified from 
### https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb
######################################################################

######################################
###  0. Install and import modules ###
######################################
import os
import torch
import numpy as np
from tqdm import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte

# io utils
from pytorch3d.io import load_obj

# datastructures
from pytorch3d.structures import Meshes, Textures

# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate

# rendering components
from pytorch3d.renderer import (
    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation, 
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights
)


#######################
### 1. Load the Obj ###
#######################
# !mkdir -p data
# !wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj


#############################
### 2. Optimization setup ###
#############################
# Set the cuda device 
device = torch.device("cuda:0")
torch.cuda.set_device(device)

# Load the obj and ignore the textures and materials.
verts, faces_idx, _ = load_obj("./data/teapot.obj")
faces = faces_idx.verts_idx

# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = Textures(verts_rgb=verts_rgb.to(device))

# Create a Meshes object for the teapot. Here we have only one mesh in the batch.
teapot_mesh = Meshes(
    verts=[verts.to(device)],   
    faces=[faces.to(device)], 
    textures=textures
)

# Initialize an OpenGL perspective camera.
cameras = OpenGLPerspectiveCameras(device=device)

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of 
# edges. Refer to blending.py for more details. 
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader. 
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)
# We can add a point light in front of the object. 
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)

# NEW added to generate depth image 
# using the same raster settings as the above phong renderer
depth_renderer = MeshRasterizer(
    cameras=cameras, 
    raster_settings=raster_settings
)


################################
### Create a reference image ###
################################
# Select the viewpoint using spherical angles  
distance = 3   # distance from camera to the object
elevation = 50.0   # angle of elevation in degrees
azimuth = 0.0  # No rotation so the camera is positioned on the +Z axis. 

# Get the position of the camera based on the spherical angles
R, T = look_at_view_transform(distance, elevation, azimuth, device=device)

# Render the teapot providing the values of R and T. 
silhouete = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)
# NEW added to generate depth image 
depth_ref = depth_renderer(meshes_world=teapot_mesh, R=R, T=T)
depth_ref = depth_ref.zbuf

silhouete = silhouete.cpu().numpy()
image_ref = image_ref.cpu().numpy()
# NEW added to generate depth image
depth_ref = depth_ref.cpu().numpy()

plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(silhouete.squeeze()[..., 3])  # only plot the alpha channel of the RGBA image
plt.title("Silhouete")
plt.grid(False)
plt.subplot(1, 3, 2)
plt.imshow(image_ref.squeeze())
plt.title("Image")
plt.subplot(1, 3, 3)
plt.imshow(depth_ref.squeeze())
plt.title("Depth")
plt.grid(False)
plt.show()


############################
### Set up a basic model ###
############################
class Model(nn.Module):
    # def __init__(self, meshes, renderer, image_ref):
    def __init__(self, meshes, renderer, depth_ref):
        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer
        
        # Get the silhouette of the reference RGB image by finding all the non zero values. 
        # image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 0).astype(np.float32))
        # self.register_buffer('image_ref', image_ref)
        
        # NEW added to get depth image
        depth_ref = torch.from_numpy((depth_ref).astype(np.float32))
        self.register_buffer('depth_ref', depth_ref)
        
        # Create an optimizable parameter for the x, y, z position of the camera. 
        self.camera_position = nn.Parameter(
        	# Original starting point
            # torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device)) 
            # Set to a starting point closer to the reference depth image
            torch.from_numpy(np.array([0.0114, 2.3306, 2.0206], dtype=np.float32)).to(meshes.device)) 

    def forward(self):
        
        # Render the image using the updated camera position. Based on the new position of the 
        # camer we calculate the rotation and translation matrices
        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        
        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        # NEW added to generate depth image 
        image = image.zbuf
        
        # Calculate the silhouette loss
        # loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
        
        # NEW Calculate the depth image loss
        loss = torch.sum((image - self.depth_ref) ** 2)
        
        return loss, image
  

#############################################
### 3. Initialize the model and optimizer ###
#############################################
# We will save images periodically and compose them into a GIF.
filename_output = "./teapot_optimization_demo.gif"
writer = imageio.get_writer(filename_output, mode='I', duration=0.3)

# Initialize a model using the renderer, mesh and reference image
# model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)

# NEW use depth_ref instead of image_ref
model = Model(meshes=teapot_mesh, renderer=depth_renderer, depth_ref=depth_ref).to(device)

# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
# optimizer = torch.optim.Adam(model.parameters(), lr=0.05) # Original
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Reduced learning rate


##################################################################
### Visualize the starting position and the reference position ###
##################################################################
plt.figure(figsize=(10, 10))

_, image_init = model()
plt.subplot(1, 2, 1)
# plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.imshow(image_init.detach().squeeze().cpu().numpy()) # NEW for plotting depth image
plt.grid(False)
plt.title("Starting position")

plt.subplot(1, 2, 2)
plt.imshow(model.depth_ref.cpu().numpy().squeeze())
plt.grid(False)
# plt.title("Reference silhouette")
plt.title("Reference depth") # NEW for plotting depth image
plt.show()


###############################
### 4. Run the optimization ###
###############################
plt.figure()
for i in tqdm(range(500)):
    optimizer.zero_grad()
    loss, depth_current = model()
    loss.backward()
    optimizer.step()
    
    # loop.set_description('Optimizing (loss %.4f)' % loss.data)
    
    if loss.item() < 200:
        break
    
    # Save outputs to create a GIF. 
    if i % 10 == 0:
        R = look_at_rotation(model.camera_position[None, :], device=model.device)
        T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = img_as_ubyte(image)
        writer.append_data(image)
        
        # plt.imshow(image[..., :3])
        plt.imshow(depth_current.detach().squeeze().cpu().numpy())
        plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
        plt.grid("off")
        plt.axis("off")
        # plt.show()
        plt.pause(0.01)
    
writer.close()

Thank you for your help!

@gkioxari
Copy link
Contributor

Hi @guanming001
Let me rephrase. Do you believe there is something wrong with a PyTorch3D operator? Is there a bug with a PyTorch3D implementation? These are the issues we are addressing here. From your description, there don't seem to be any issues with PyTorch3D. From what I read, your question is regarding your end-to-end model optimization.

I want to emphasize that we are not here to help you train your network for the task at hand. We provide the nails and hammers, but you are in charge of building your own house.

@guanming001
Copy link
Contributor Author

Hi @gkioxari thank you for the reply.

I have managed to get the optimization working by setting the original invalid depth values (negative ones) to an arbitrary large positive number

depth_ref = depth_ref.zbuf
depth_ref[depth_ref<0] = 10 # Set invalid depth to a large positive number

...

image = image.zbuf
image[image<0] = 10 # Set invalid depth to a large positive number

Example of using back the original camera starting position and learning rate of Adam optimizer (0.05), the optimization can converge most of the time.

Figure_1

teapot_optimization_demo

@Pinnh
Copy link

Pinnh commented Jul 1, 2020

I meet the same problem using phong_renderer is not stable for optimization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

3 participants