# Tutorial: Equivariant Layers from LieGroups

In [1]:
import os
import sys
import warnings

sys.path.append(os.path.dirname(os.getcwd()))
warnings.filterwarnings('ignore')

import geomstats.backend as gs

INFO: Using numpy backend


# Wrapper class to instantiate an EMLP group from a Geomstats group

In [2]:
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from emlp.groups import Group

class GeomstatsGroup(Group):
    def __init__(self, geomstats_group):
        self.d = geomstats_group.dim
        
        # continuous generators only
        self.lie_algebra = geomstats_group.lie_algebra.basis
    
        super().__init__()

INFO: Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
INFO: Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO: Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [3]:
test_geomstats_group = SpecialOrthogonal(n=3)
test_group = GeomstatsGroup(test_geomstats_group)

print(type(test_group.lie_algebra))
print(test_group.lie_algebra.shape)
print(test_group.lie_algebra)

<class 'jaxlib.xla_extension.DeviceArray'>
(3, 3, 3)
[[[ 0.  0.  0.]
  [ 0.  0. -1.]
  [ 0.  1.  0.]]

 [[ 0.  0.  1.]
  [ 0.  0.  0.]
  [-1.  0.  0.]]

 [[ 0. -1.  0.]
  [ 1.  0.  0.]
  [ 0.  0.  0.]]]


# Run EMLP unit tests

In [4]:
from emlp_tests.equivariance_tests import \
    test_sum, test_prod, test_high_rank_representations, test_large_representations

In [5]:
test_sum(test_group)

INFO: V cache miss
INFO: Solving basis for V, for G=GeomstatsGroup
INFO: V² cache miss
INFO: Solving basis for V², for G=GeomstatsGroup


In [6]:
test_prod(test_group)

INFO: V⁵ cache miss
INFO: Solving basis for V⁵, for G=GeomstatsGroup


In [7]:
test_high_rank_representations(test_group)

INFO: Success with T(0, 0) and G=GeomstatsGroup
INFO: Success with T(1, 0) and G=GeomstatsGroup
INFO: Success with T(2, 0) and G=GeomstatsGroup
INFO: V³ cache miss
INFO: Solving basis for V³, for G=GeomstatsGroup
INFO: Success with T(3, 0) and G=GeomstatsGroup
INFO: V⁴ cache miss
INFO: Solving basis for V⁴, for G=GeomstatsGroup
INFO: Success with T(4, 0) and G=GeomstatsGroup
INFO: Success with T(5, 0) and G=GeomstatsGroup
INFO: V⁶ cache miss
INFO: Solving basis for V⁶, for G=GeomstatsGroup
INFO: Success with T(6, 0) and G=GeomstatsGroup
INFO: V⁷ cache miss
INFO: Solving basis for V⁷, for G=GeomstatsGroup
INFO: Success with T(7, 0) and G=GeomstatsGroup


In [8]:
test_large_representations(test_group)

INFO: Success with G=GeomstatsGroup


In [9]:
from emlp.reps import T
from emlp_tests.equivariance_tests import test_equivariant_matrix, test_bilinear_layer

In [10]:
test_equivariant_matrix(test_group, T(1) + 2 * T(0), T(1) + T(2) + 2 * T(0) + T(1))

In [11]:
test_equivariant_matrix(test_group, 5 * T(0) + 5 * T(1), 3 * T(0) + T(2)+ 2 * T(1))

In [12]:
test_equivariant_matrix(test_group, 5 * (T(0) + T(1)), 2 * (T(0) + T(1)) + T(2) + T(1))

In [13]:
test_bilinear_layer(test_group, 5 * T(0) + 5 * T(1), 3 * T(0) + T(2) + 2 * T(1))

# Build EMLP Layer with Geomstats Group

In [14]:
import numpy as np

import emlp

repin = 5 * T(0) + 5 * T(1)
repout = 3 * T(0) # 3 output logits for the 3 classes of collisions
group = test_group
model = emlp.nn.EMLP(repin, repout, group=group, num_layers=3, ch=384)

x = np.random.randn(32, repin(group).size()) # Create a minibatch of data
y = model(x) # Outputs the 3 class logits

INFO: Initing EMLP (objax)
INFO: Reps: [5V⁰+5V, 102V⁰+34V+11V²+3V³, 102V⁰+34V+11V²+3V³, 102V⁰+34V+11V²+3V³]
INFO: Linear W components:8640 rep:750V⁰+920V+225V²+70V³+15V⁴
INFO: BiW components: dim:75648
INFO: Linear W components:165888 rep:15300V⁰+8568V+3928V²+1504V³+325V⁴+66V⁵+9V⁶
INFO: BiW components: dim:75648
INFO: Linear W components:165888 rep:15300V⁰+8568V+3928V²+1504V³+325V⁴+66V⁵+9V⁶
INFO: BiW components: dim:75648
INFO: Linear W components:1152 rep:306V⁰+102V+33V²+9V³
