In [32]:
# External imports
import torch
import torch.nn.functional as F
import torch.nn as nn

Sketch:
1. Fix some Lie group, whose Lie algebra has a v.s. basis $\{e_{i}\}_{i \in \{1, ..., n\}}$.
2. Fix any NN with output layer of length $n$, e.g. composition of arbitrary linear layers.
3. View outputs as coefficients for generators of lie algebra, in some fixed fashion. 
4. Use output to update a context vector by the result of applying exponential of the corresponding linear combo.
5. Profit.

In [2]:
# Fix a Basis for so(3) Lie algebra:
L_x = torch.tensor([[0, 0, 0],
                    [0, 0, -1],
                    [0, 1, 0]], dtype=float)

L_y = torch.tensor([[0, 0, 1],
                    [0, 0, -1],
                    [-1, 0, 0]], dtype=float)

L_z = torch.tensor([[0, -1, 0],
                    [1, 0, 0],
                    [0, 0, 0]], dtype=float)

In [7]:
from torch import linalg
torch.linalg.norm((.5*L_x + .5*L_y), ord=2)

tensor(1.0000, dtype=torch.float64)

In [95]:
class SO3Block(nn.Module):
    def __init__(self, euclidean_network):
        super(SO3Block, self).__init__()
        self.generators = [L_x, L_y, L_z]
        self.euclidean_network = euclidean_network
        
    def linear_combo(self, alg_coefs):
        terms = [alg_coef*generator for (alg_coef, generator) in zip(alg_coefs, self.generators)]
        linear_combo = torch.zeros(3, 3, dtype=float)
        for term in terms:
            linear_combo += term
        return linear_combo
        
    def exponential_map(self, linear_combo):
        # Using Rodriguez Formula
        two_norm = torch.linalg.norm(linear_combo)
        normalized = linear_combo / two_norm
        g = torch.eye(3) + normalized*torch.sin(two_norm) + torch.matrix_power(normalized, 2)*(1 - torch.cos(two_norm))
        return g   
        
        
    def forward(self, x, c):
        alg_coefs = self.euclidean_network(x)
        linear_combo = self.linear_combo(alg_coefs)
        g = self.exponential_map(linear_combo)
        output = torch.matmul(g, c)
        return output

In [96]:
class FF(nn.Module):
    def __init__(self):
        super(FF, self).__init__()
        self.l1 = nn.Linear(10, 3)
        self.l2 = nn.Linear(3, 3)
    def forward(self, x):
        out = self.l1(x)
        out = F.relu(out)
        out = self.l2(out)
        return out

In [97]:
input = torch.rand(1, 10)

In [98]:
euclidean_network = FF()
lie_block = SO3Block(euclidean_network)

In [114]:
lie_block(input, torch.tensor([.2, .1 , 0], dtype=float))

tensor([ 0.2000,  0.1020, -0.0092], dtype=torch.float64, grad_fn=<MvBackward>)