/
jacobians.py
102 lines (89 loc) · 3.12 KB
/
jacobians.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/06_jacobians.ipynb.
# %% auto 0
__all__ = ['JacobianDRR', 'gradient_matching', 'plot_img_jacobian']
# %% ../notebooks/api/06_jacobians.ipynb 3
import torch
class JacobianDRR(torch.nn.Module):
"""Computes the Jacobian of a DRR wrt pose parameters."""
def __init__(self, drr, rotation, translation, parameterization, convention=None):
super().__init__()
self.drr = drr
self.rotation = torch.nn.Parameter(rotation.clone())
self.translation = torch.nn.Parameter(translation.clone())
self.parameterization = parameterization
self.convention = convention
def forward(self):
I = self.cast(self.rotation, self.translation)
J = torch.autograd.functional.jacobian(
self.cast,
(self.rotation, self.translation),
vectorize=True,
strategy="forward-mode",
)
J = torch.concat([self.permute(j) for j in J], dim=0)
return I, J
def cast(self, rotation, translation):
return self.drr(rotation, translation, self.parameterization, self.convention)
def permute(self, x):
return x.permute(-1, 0, 2, 3, 1, 4)[..., 0, 0]
# %% ../notebooks/api/06_jacobians.ipynb 4
def gradient_matching(J0, J1):
J0 /= J0.norm(dim=[-1, -2], keepdim=True)
J1 /= J1.norm(dim=[-1, -2], keepdim=True)
return (J0 - J1).norm()
# %% ../notebooks/api/06_jacobians.ipynb 5
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
def plot_img_jacobian(I, J, **kwargs):
def fmt(x, pos):
a, b = f"{x:.0e}".split("e")
a = float(a)
b = int(b)
if a == 0:
return "0"
elif b == 0:
if a < 0:
return "-1"
else:
return "1"
elif a < 0:
return rf"$-10^{{{b}}}$"
else:
return rf"$10^{{{b}}}$"
plt.figure(figsize=(10, 4), dpi=300, constrained_layout=True)
plt.subplot(2, 4, 2)
plt.title("J(yaw)")
plt.imshow(J[0].squeeze().cpu().detach(), **kwargs)
plt.colorbar(format=FuncFormatter(fmt))
plt.axis("off")
plt.subplot(2, 4, 3)
plt.title("J(pitch)")
plt.imshow(J[1].squeeze().cpu().detach(), **kwargs)
plt.colorbar(format=FuncFormatter(fmt))
plt.axis("off")
plt.subplot(2, 4, 4)
plt.title("J(roll)")
plt.imshow(J[2].squeeze().cpu().detach(), **kwargs)
plt.colorbar(format=FuncFormatter(fmt))
plt.axis("off")
plt.subplot(2, 4, 6)
plt.title("J(x)")
plt.imshow(J[3].squeeze().cpu().detach(), **kwargs)
plt.colorbar(format=FuncFormatter(fmt))
plt.axis("off")
plt.subplot(2, 4, 7)
plt.title("J(y)")
plt.imshow(J[4].squeeze().cpu().detach(), **kwargs)
plt.colorbar(format=FuncFormatter(fmt))
plt.axis("off")
plt.subplot(2, 4, 8)
plt.title("J(z)")
plt.imshow(J[5].squeeze().cpu().detach(), **kwargs)
plt.colorbar(format=FuncFormatter(fmt))
plt.axis("off")
plt.subplot(2, 4, 1)
plt.title("img")
plt.imshow(I.cpu().detach().squeeze(), cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show()