In [1]:
import os
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from skimage import img_as_ubyte
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
FoVPerspectiveCameras, look_at_view_transform,
look_at_rotation, RasterizationSettings,
MeshRenderer, MeshRasterizer,
BlendParams, SoftSilhouetteShader, HardPhongShader,
PointLights, TexturesVertex,
)

from pathlib import Path
from tqdm import tqdm
import imageio, PIL

In [2]:
!curl https://raw.githubusercontent.com/PacktPublishing/3D-Deep-Learning-with-Python/refs/heads/main/chap4/data/teapot.obj -o ./data/teapot.obj

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  149k  100  149k    0     0   165k      0 --:--:-- --:--:-- --:--:--  165k


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
data_dir = Path('./data')
output_dir = Path('./result_teapot')

In [5]:
verts, faces_idx, _ = load_obj("./data/teapot.obj")
faces = faces_idx.verts_idx

verts_rgb = torch.ones_like(verts).unsqueeze(0)
textures = TexturesVertex(verts_features=verts_rgb.to(device))

teapot_mesh = Meshes(
    verts=[verts.to(device)],
    faces=[faces.to(device)],
    textures=textures
)



In [6]:
cameras = FoVPerspectiveCameras(device=device)

In [7]:
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

raster_settings = RasterizationSettings(
    image_size = 256,
    blur_radius = np.log(1. / 1e-4 - 1.) * blend_params.sigma,
    faces_per_pixel = 100,
)

silhouette_renderer = MeshRenderer(
    rasterizer = MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)

raster_settings = RasterizationSettings(
        image_size=256,
        blur_radius=0.0,
        faces_per_pixel=1,
    )

lights = PointLights(
        device=device,
        location=((2.0, 2.0, -2.0),)
    )

phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
    shader=HardPhongShader(
            device=device,
            cameras=cameras,
            lights=lights
        )
)

In [8]:
distance = 3
elevation = 50.0
azimuth = 0.0
R, T = look_at_view_transform(distance,
    elevation,
    azimuth,
    device=device)

In [9]:
silhouette = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)

silhouette = silhouette.cpu().numpy()
image_ref = image_ref.cpu().numpy()

In [10]:
silhouette.shape

(1, 256, 256, 4)

In [11]:
image_ref.shape

(1, 256, 256, 4)

In [12]:
(image_ref[..., :3].max(-1) != 1).astype(np.float32)

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

In [13]:
(silhouette[..., :3].max(-1) != 1).astype(np.float32)

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

In [14]:
plt.figure(figsize=(10, 10))
# Рисовать только альфа-канал RGBA-изображения
plt.imshow(silhouette.squeeze()[..., 3])
plt.grid(False)
plt.savefig(os.path.join(output_dir, 'target_silhouette.png'))
plt.close()
plt.figure(figsize=(10, 10))
plt.imshow(image_ref.squeeze())
plt.grid(False)
plt.savefig(os.path.join(output_dir, 'target_rgb.png'))
plt.close()

In [15]:
class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref):
        super().__init__()
        
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer

        image_ref = torch.from_numpy(
            (image_ref[..., :3].max(-1) != 1).astype(np.float32)
        )

        self.register_buffer('image_ref', image_ref)

        self.camera_position = nn.Parameter(
            torch.from_numpy(np.array([3.0, 6.9, +2.5], dtype=np.float32)).to(self.device)
        )

    def forward(self):
        R = look_at_rotation(self.camera_position.unsqueeze(0), device=self.device)
        T = -torch.bmm(
            R.transpose(1, 2),
            self.camera_position[None, :, None])[:, :, 0] # (1, 3)

        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        
        loss = torch.sum((image[..., 3] - self.image_ref) ** 2) # per pixel MSE
        
        return loss, image

In [16]:
model = Model(meshes=teapot_mesh,
             renderer=silhouette_renderer,
             image_ref=image_ref).to(device)
model

Model(
  (renderer): MeshRenderer(
    (rasterizer): MeshRasterizer(
      (cameras): FoVPerspectiveCameras()
    )
    (shader): SoftSilhouetteShader()
  )
)

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

In [18]:
_, image_init = model()
plt.figure(figsize=(10, 10))
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.title("Стартовый силуэт")
plt.savefig(os.path.join(output_dir, 'starting_silhouette.png'))
plt.close()

### Ура ура оптимизация

In [19]:
epochs = 200
images = []
for epoch in tqdm(range(epochs)):
    

    loss, _ = model()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if loss.item() < 500:
        print('ВСЁ!')
        break

    # render an image
    R = look_at_rotation(model.camera_position[None, :], device=model.device)
    T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0] # (1, 3)
    image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
    image = image[0, ..., :3].detach().squeeze().cpu().numpy()
    image = img_as_ubyte(image)
    images.append(image)


imageio.mimsave(output_dir / 'teapot.gif', images)

 36%|███████████████████████████████████████████████████████████████                                                                                                                | 72/200 [00:03<00:07, 18.25it/s]


ВСЁ!
