In [1]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

from jax.config import config
config.enable_omnistaging() # Linen requires enabling omnistaging
import logging
logging.getLogger().setLevel(logging.INFO)

In [2]:
from RL.emlp_flax import EMLPBlock,EMLP,MLP
from emlp.reps import T,V
from emlp.groups import S,SO

In [3]:
G = SO(2)
repin = T(1)(G)+T(0)(G)
repout = T(0)(G) + T(2)(G)

In [4]:
model = EMLP(repin,repout,G)

INFO:root:Initing EMLP (flax)
INFO:root:Reps: [V⁰+V, 64V⁰+32V+16V²+8V³+4V⁴+2V⁵, 64V⁰+32V+16V²+8V³+4V⁴+2V⁵, 64V⁰+32V+16V²+8V³+4V⁴+2V⁵]
INFO:root:V cache miss
INFO:root:Solving basis for V, for G=SO(2)
INFO:root:V² cache miss
INFO:root:Solving basis for V², for G=SO(2)
INFO:root:V³ cache miss
INFO:root:Solving basis for V³, for G=SO(2)
INFO:root:V⁴ cache miss
INFO:root:Solving basis for V⁴, for G=SO(2)
INFO:root:V⁵ cache miss
INFO:root:Solving basis for V⁵, for G=SO(2)
INFO:root:V⁶ cache miss
INFO:root:Solving basis for V⁶, for G=SO(2)
INFO:root:Linear W components:1338 rep:126V⁰+158V+48V²+24V³+12V⁴+6V⁵+2V⁶
INFO:root:BiW components: dim:69808
INFO:root:V⁷ cache miss
INFO:root:Solving basis for V⁷, for G=SO(2)
INFO:root:V⁸ cache miss
INFO:root:Solving basis for V⁸, for G=SO(2)
INFO:root:V⁹ cache miss
INFO:root:Solving basis for V⁹, for G=SO(2)
INFO:root:V¹⁰ cache miss
INFO:root:Solving basis for V¹⁰, for G=SO(2)
INFO:root:Linear W components:171264 rep:8064V⁰+6080V+4064V²+2544V³+1528V⁴+89

In [5]:
repin>>repout

V⁰+V+V²+V³

In [6]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (repin.size(),)) # Dummy input
params = model.init(key2, x) # Initialization call
jax.tree_map(lambda x: x.shape, params) # Checking output shapes

FrozenDict({
    params: {
        modules_0: {
            bilinear: {
                w: (69808,),
            },
            linear: {
                b: (446,),
                w: (3, 446),
            },
        },
        modules_1: {
            bilinear: {
                w: (69808,),
            },
            linear: {
                b: (446,),
                w: (384, 446),
            },
        },
        modules_2: {
            bilinear: {
                w: (69808,),
            },
            linear: {
                b: (446,),
                w: (384, 446),
            },
        },
        modules_3: {
            b: (5,),
            w: (384, 5),
        },
    },
})

In [7]:
model.apply(params,x)

DeviceArray([ 0.00313691, -0.00691255, -0.00108144,  0.00487514,
             -0.00263916], dtype=float32)

In [8]:
model.apply(params,x)

DeviceArray([ 0.00313691, -0.00691255, -0.00108144,  0.00487514,
             -0.00263916], dtype=float32)

In [9]:
model.apply(params,x)

DeviceArray([ 0.00313691, -0.00691255, -0.00108144,  0.00487514,
             -0.00263916], dtype=float32)