In [1]:
import torch
import matplotlib.pyplot as plt

from pytorch3d.renderer import (
FoVPerspectiveCameras,
NDCMultinomialRaysampler,
MonteCarloRaysampler,
EmissionAbsorptionRaymarcher,
ImplicitRenderer,
)

from utils.helper_functions import (
generate_rotating_nerf,
huber,
show_full_render,
sample_images_at_mc_locs
)

from utils.plot_image_grid import image_grid
from utils.generate_cow_renders import generate_cow_renders

from nerf_model import NeuralRadianceField

from tqdm import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40, azimuth_range=180)
print(f'Сгенерировано {len(target_images)} изображений/силуэтов/камер.')

Сгенерировано 40 изображений/силуэтов/камер.


In [4]:
render_size = target_images.shape[1] # * 2

volume_extent_world = 3.0

raysample_mc = MonteCarloRaysampler(
    min_x = -1.0,
    max_x = 1.0,
    min_y = -1.0,
    max_y = 1.0,
    n_rays_per_image = 750,
    n_pts_per_ray = 128,
    min_depth = 0.1,
    max_depth = volume_extent_world
)

In [5]:
raymarcher = EmissionAbsorptionRaymarcher()

In [6]:
renderer_mc = ImplicitRenderer(raysampler=raysample_mc, raymarcher=raymarcher)

In [7]:
target_images.shape

torch.Size([40, 128, 128, 3])

In [8]:
render_size = target_images.shape[1] # * 2

volume_extent_world = 3.0

raysampler_grid = NDCMultinomialRaysampler(
    image_height = render_size,
    image_width = render_size,
    n_pts_per_ray = 128,
    min_depth = 0.1,
    max_depth = volume_extent_world
)

In [9]:
renderer_grid = ImplicitRenderer(
    raysampler = raysampler_grid,
    raymarcher=raymarcher
)

In [10]:
from utils.helper_functions import show_full_render # dis one is batching a rendered image and concatinating it into full render, also calculating loss

In [11]:
neural_radiance_field = NeuralRadianceField()

In [12]:
torch.manual_seed(1)

renderer_grid = renderer_grid.to(device)
renderer_mc = renderer_mc.to(device)

target_cameras = target_cameras.to(device)
target_images = target_images.to(device)
target_silhouettes = target_silhouettes.to(device)

neural_radiance_field = neural_radiance_field.to(device)

In [13]:
lr = 1e-3
optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=lr)
batch_size = 6
n_iter = 3000

In [14]:
loss_history_color, loss_history_sil = [], []
for iteration in tqdm(range(n_iter)):
    if iteration == round(n_iter * 0.75):
        optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=lr * 0.1)
    
    batch_idx = torch.randperm(len(target_cameras))[:batch_size]
    
    batch_cameras = FoVPerspectiveCameras(
        R = target_cameras.R[batch_idx],
        T = target_cameras.T[batch_idx],
        znear = target_cameras.znear[batch_idx],
        zfar = target_cameras.zfar[batch_idx],
        aspect_ratio = target_cameras.aspect_ratio[batch_idx],
        fov = target_cameras.fov[batch_idx],
        device = device)


    rendered_images_silhouettes, sampled_rays = renderer_mc( cameras=batch_cameras, volumetric_function=neural_radiance_field)
    rendered_images, rendered_silhouettes = (rendered_images_silhouettes.split([3, 1], dim=-1))
    
    silhouettes_at_rays = sample_images_at_mc_locs(target_silhouettes[batch_idx, ..., None], sampled_rays.xys)
    sil_err = huber(rendered_silhouettes, silhouettes_at_rays).abs().mean()

    colors_at_rays = sample_images_at_mc_locs(target_images[batch_idx], sampled_rays.xys)
    col_err = huber(rendered_images, colors_at_rays).abs().mean()

    loss = color_err + sil_err
    loss_history_color.append(float(color_err))
    loss_history_sil.append(float(sil_err))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if iteration % 100 == 0:
        show_idx = torch.randperm(len(target_cameras))[:1]
        fig = show_full_render(
            neural_radiance_field,
            FoVPerspectiveCameras(
                R = target_cameras.R[show_idx],
                T = target_cameras.T[show_idx],
                znear = target_cameras.znear[show_idx],
                zfar = target_cameras.zfar[show_idx],
                aspect_ratio = target_cameras.aspect_ratio[show_idx],
                fov = target_cameras.fov[show_idx],
                device = device,
                ),
            target_images[show_idx][0],
            target_silhouettes[show_idx][0],
            renderer_grid,
            loss_history_color,
            loss_history_sil,
            )
        fig.savefig(f'intermediate_{iteration}')

  0%|                                                                                                                                                                                       | 0/3000 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 564.00 MiB (GPU 0; 3.80 GiB total capacity; 1.89 GiB already allocated; 475.69 MiB free; 1.99 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF