Skip to content

Commit

Permalink
simplify point kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jul 24, 2019
1 parent 80c82c6 commit 5d79b25
Showing 1 changed file with 25 additions and 29 deletions.
54 changes: 25 additions & 29 deletions se3cnn/point/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import se3cnn.SO3 as SO3


class SE3PointKernel(torch.nn.Module):
class Kernel(torch.nn.Module):
def __init__(self, Rs_in, Rs_out, RadialModel, get_l_filters=None, sh=None, normalization='norm'):
'''
:param Rs_in: list of couple (multiplicity, representation order)
Expand Down Expand Up @@ -66,27 +66,24 @@ def __repr__(self):
Rs_out=self.Rs_out,
)

def forward(self, difference_matrix):
def forward(self, r):
"""
:param difference_matrix: tensor [[batch,] N_out, N_in, 3]
:return: tensor [l_out * mul_out * m_out, l_in * mul_in * m_in, [batch,] N_out, N_in]
:param r: tensor [batch, 3]
:return: tensor [batch, l_out * mul_out * m_out, l_in * mul_in * m_in]
"""
has_batch = difference_matrix.dim() == 4
if not has_batch:
difference_matrix = difference_matrix.unsqueeze(0)
batch, xyz = r.size()
assert xyz == 3

batch, N_out, N_in, _ = difference_matrix.size()

kernel = difference_matrix.new_zeros(self.n_out, self.n_in, batch, N_out, N_in)
kernel = r.new_zeros(batch, self.n_out, self.n_in)

# precompute all needed spherical harmonics
Ys = self.sh(self.set_of_l_filters, difference_matrix) # [l_filter * m_filter, batch, N_out, N_in]
Ys = self.sh(self.set_of_l_filters, r) # [l_filter * m_filter, batch]

# use the radial model to fix all the degrees of freedom
radii = difference_matrix.norm(2, dim=-1).view(-1) # [batch * N_out * N_in]
# note: for the normalization we assume that the variance of weights[i] is one
weights = self.R(radii).view(batch, N_out, N_in, -1) # [batch, N_out, N_in, l_out * l_in * mul_out * mul_in * l_filter]
begin_w = 0
radii = r.norm(2, dim=1) # [batch]
# note: for the normalization we assume that the variance of coefficients[i] is one
coefficients = self.R(radii) # [batch, l_out * l_in * mul_out * mul_in * l_filter]
begin_c = 0

begin_out = 0
for i, (mul_out, l_out) in enumerate(self.Rs_out):
Expand All @@ -98,51 +95,50 @@ def forward(self, difference_matrix):
for mul_in, l_in in self.Rs_in:
l_filters = self.get_l_filters(l_in, l_out)
num_summed_elements += mul_in * len(l_filters)
num_summed_elements *= N_in # note: idealy the number of neighbours

begin_in = 0
for j, (mul_in, l_in) in enumerate(self.Rs_in):
s_in = slice(begin_in, begin_in + mul_in * (2 * l_in + 1))

l_filters = self.get_l_filters(l_in, l_out)

# extract the subset of the `weights` that corresponds to the couple (l_out, l_in)
# extract the subset of the `coefficients` that corresponds to the couple (l_out, l_in)
n = mul_out * mul_in * len(l_filters)
w = weights[:, :, :, begin_w: begin_w + n].contiguous().view(batch, N_out, N_in, mul_out, mul_in, -1) # [batch, N_out, N_in, mul_out, mul_in, l_filter]
begin_w += n
c = coefficients[:, begin_c: begin_c + n].contiguous().view(batch, mul_out, mul_in, -1) # [batch, mul_out, mul_in, l_filter]
begin_c += n

Qs = getattr(self, "Q_{}_{}".format(i, j)) # [m_out, m_in, l_filter * m_filter]

# note: I don't know if we can vectorize this for loop because [l_filter * m_filter] cannot be put into [l_filter, m_filter]
K = 0
for k, l_filter in enumerate(l_filters):
tmp = sum(2 * l + 1 for l in self.set_of_l_filters if l < l_filter)
Y = Ys[tmp: tmp + 2 * l_filter + 1] # [m, batch, N_out, N_in]
Y = Ys[tmp: tmp + 2 * l_filter + 1] # [m, batch]

tmp = sum(2 * l + 1 for l in l_filters if l < l_filter)
Q = Qs[:, :, tmp: tmp + 2 * l_filter + 1] # [m_out, m_in, m]

# note: The multiplication with `w` could also be done outside of the for loop
K += torch.einsum("ijr,rknm,knmuv->uivjknm", (Q, Y, w[..., k])) # [mul_out, m_out, mul_in, m_in, batch, N_out, N_in]
# note: The multiplication with `c` could also be done outside of the for loop
K += torch.einsum("ijk,kz,zuv->zuivj", (Q, Y, c[..., k])) # [batch, mul_out, m_out, mul_in, m_in]

# put 2l_in+1 to keep the norm of the m vector constant
# put 2l_ou+1 to keep the variance of each m componant constant
# sum_m Y_m^2 = (2l+1)/(4pi) and norm(Q) = 1 implies that norm(QY) = sqrt(1/4pi)
if self.normalization == 'norm':
K *= math.sqrt(2 * l_in + 1) * math.sqrt(4 * math.pi)
x = math.sqrt(2 * l_in + 1) * math.sqrt(4 * math.pi)
if self.normalization == 'component':
K *= math.sqrt(2 * l_out + 1) * math.sqrt(4 * math.pi)
x = math.sqrt(2 * l_out + 1) * math.sqrt(4 * math.pi)

# normalization assuming that each terms are of order 1 and uncorrelated
K /= num_summed_elements ** 0.5
x /= num_summed_elements ** 0.5

# TODO create tests for these normalizations
K.mul_(x)

if K is not 0:
kernel[s_out, s_in] = K.contiguous().view_as(kernel[s_out, s_in])
kernel[:, s_out, s_in] = K.contiguous().view_as(kernel[:, s_out, s_in])

begin_in += mul_in * (2 * l_in + 1)
begin_out += mul_out * (2 * l_out + 1)

if not has_batch:
kernel = kernel.squeeze(2)

return kernel

0 comments on commit 5d79b25

Please sign in to comment.