In [1]:
import time 
import math
import torch
import matplotlib.pyplot as plt
from renderer import Renderer
from tqdm import tqdm
from torchvision.io import read_image

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N = 5_000
W = H = 256

In [3]:
mu = (torch.rand((N, 3), device=device) - 0.5) * 8.
mu[:,2] = torch.rand((N,)) * 0.001
scales = torch.rand((N, 3), device=device) * 0.1
quats = torch.rand((N, 4), device=device)
cols = torch.rand((N, 3), device=device) 
opcs = torch.rand((N), device=device)

params = {
    'mu': mu, 'scales': scales, 'quats': quats, 'cols': cols, 'opcs': opcs
}    

In [4]:
renderer = Renderer(params=params, device=device)

In [5]:
# Create GT image
gt_image = read_image('./mikey_cropped.jpg').permute(1,2,0) / 255
fov_x = math.pi / 2.0 # Angle of the camera frustum 90°
focal = 0.5 * float(W) / math.tan(0.5 * fov_x) # Distance to Image Plane
viewmat = torch.eye(4, device=device)
viewmat[:3,3] = torch.tensor([0,0,-4])
camera = {'viewmat': viewmat, 'focal': focal, 'H': H, 'W': W}

In [6]:
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(params=renderer.parameters(), lr=1e-2)

In [7]:
%matplotlib notebook
from ipywidgets import *

In [12]:
pred = renderer(camera, gt_image)

plt.ion()
figure, ax = plt.subplots()
im1 = ax.matshow(pred.detach().cpu())

for iter in tqdm(range(1_000)):
    optimizer.zero_grad()

    pred = renderer(camera, gt_image)

    im1.set_data(pred.detach().cpu())
    figure.canvas.draw()
    figure.canvas.flush_events()
    time.sleep(0.1)

    loss = criterion(pred, gt_image)
    
    loss.backward()

    torch.nn.utils.clip_grad_value_(renderer.parameters(), clip_value=1.0)
    for param in renderer.parameters():
        param.grad[param.grad.isnan()] = 0.

    optimizer.step()

    print(f'Iter: {iter}, Loss: {loss.item()}, Grad. Norms: {[p.abs().norm().item() for p in renderer.parameters()]}')

plt.matshow(pred.detach().cpu())
plt.show()

100%|██████████| 16/16 [00:00<00:00, 45.24it/s]


<IPython.core.display.Javascript object>

100%|██████████| 16/16 [00:00<00:00, 31.71it/s]
  0%|          | 1/1000 [00:01<28:57,  1.74s/it]

Iter: 0, Loss: 0.06803075224161148, Grad. Norms: [72.81307983398438, 230.5654296875, 40.130210876464844, 81.34281158447266, 9.75297737121582]


100%|██████████| 16/16 [00:00<00:00, 48.37it/s]
  0%|          | 2/1000 [00:02<24:07,  1.45s/it]

Iter: 1, Loss: 0.06602101027965546, Grad. Norms: [72.89939880371094, 230.5977020263672, 40.05413818359375, 81.27366638183594, 9.887796401977539]


100%|██████████| 16/16 [00:00<00:00, 22.60it/s]
  0%|          | 3/1000 [00:04<25:11,  1.52s/it]

Iter: 2, Loss: 0.0634247437119484, Grad. Norms: [72.97757720947266, 230.62867736816406, 39.98175048828125, 81.20601654052734, 10.014514923095703]


100%|██████████| 16/16 [00:00<00:00, 22.30it/s]
  0%|          | 4/1000 [00:06<25:56,  1.56s/it]

Iter: 3, Loss: 0.060819149017333984, Grad. Norms: [73.04857635498047, 230.65794372558594, 39.91289520263672, 81.13818359375, 10.131562232971191]


100%|██████████| 16/16 [00:00<00:00, 36.95it/s]
  0%|          | 5/1000 [00:07<26:21,  1.59s/it]

Iter: 4, Loss: 0.05881102383136749, Grad. Norms: [73.11547088623047, 230.6859130859375, 39.84735107421875, 81.07036590576172, 10.23958683013916]


100%|██████████| 16/16 [00:00<00:00, 46.00it/s]
  1%|          | 6/1000 [00:09<24:59,  1.51s/it]

Iter: 5, Loss: 0.05721129849553108, Grad. Norms: [73.17970275878906, 230.71372985839844, 39.785213470458984, 81.00308990478516, 10.339388847351074]


  0%|          | 0/16 [00:00<?, ?it/s]
  1%|          | 6/1000 [00:09<26:10,  1.58s/it]


KeyboardInterrupt: 