In [1]:
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

In [2]:
from emlp.mlp_jax import MLP,EMLP#,LinearBNSwish
from emlp.datasets_jax import Fr
import jax.numpy as jnp
import jax
import optax

In [3]:
from emlp.equivariant_subspaces_jax import T,Scalar,Matrix,Vector,Quad,repsize
from emlp.groups_jax import SO,O,Trivial,Lorentz,O13,SO13,SO13p
from emlp.mlp_jax import EMLPLearned,LieLinear
import itertools
import numpy as np
import torch
from emlp.datasets_jax import Inertia,Fr,ParticleInteraction
import objax

In [4]:
import torch
from torch.utils.data import DataLoader
from slax.utils.utils import LoaderTo,cosLr, islice, export,FixedNumpySeed
from slax.tuning.study import train_trial
from slax.datasetup.datasets import split_dataset
from slax.model_trainers.classifier import Regressor
from functools import partial
import torch.nn as nn
#repmiddle = 100*T(0)+30*T(1)+10*T(2)+3*T(3)#+1*T(4)

def makeTrainer(*,dataset=Fr,network=EMLP,num_epochs=500,ndata=1000+1000,seed=2020,aug=False,
                bs=500,lr=1e-2,device='cuda',split={'train':-1,'test':1000},
                net_config={'num_layers':3,'d':3,'ch':128},
                trainer_config={'log_dir':None,'log_args':{'minPeriod':.02}},save=False):

    # Prep the datasets splits, model, and dataloaders
    with FixedNumpySeed(seed):
        datasets = split_dataset(dataset(ndata),splits=split)
    device = torch.device(device)
    model = network(datasets['train'].rep_in,datasets['train'].rep_out,**net_config)
    dataloaders = {k:LoaderTo(DataLoader(v,batch_size=bs,shuffle=(k=='train'),
                num_workers=0,pin_memory=False)) for k,v in datasets.items()}
    dataloaders['Train'] = dataloaders['train']
    opt_constr = objax.optimizer.Adam
    lr_sched = lambda e: lr
    return Regressor(model,dataloaders,opt_constr,lr_sched,**trainer_config)

In [5]:
trainer = makeTrainer(network=EMLPLearned,lr=3e-3)
#trainer = makeTrainer(network=MLP,lr=1e-3,aug=False)
trainer.train(500)

INFO:root:BiW components:864 dim:0 shape:(144, 6) rep:108T(0, 1)+12T(0, 2)+12T(1, 1)+2T(0, 3)+2T(1, 2)+2T(2, 1)+2T(0, 4) @ d=3
INFO:root:BiW components:18432 dim:3782 shape:(144, 128) rep:2052T(0, 0)+552T(1, 0)+552T(0, 1)+128T(2, 0)+164T(1, 1)+128T(0, 2)+66T(3, 0)+24T(2, 1)+24T(1, 2)+50T(0, 3)+8T(3, 1)+7T(4, 0)+3T(2, 2)+8T(1, 3)+7T(0, 4)+T(3, 2)+T(4, 1)+T(5, 0)+T(2, 3)+T(1, 4)+T(0, 5)+T(3, 3) @ d=3
INFO:root:BiW components:18432 dim:3782 shape:(144, 128) rep:2052T(0, 0)+552T(1, 0)+552T(0, 1)+128T(2, 0)+164T(1, 1)+128T(0, 2)+66T(3, 0)+24T(2, 1)+24T(1, 2)+50T(0, 3)+8T(3, 1)+7T(4, 0)+3T(2, 2)+8T(1, 3)+7T(0, 4)+T(3, 2)+T(4, 1)+T(5, 0)+T(2, 3)+T(1, 4)+T(0, 5)+T(3, 3) @ d=3


HBox(children=(FloatProgress(value=0.0, description='train', max=500.0, style=ProgressStyle(description_width=â€¦

INFO:root:Learned Lie Algebra: Traced<ConcreteArray([[[ 0.6252229  -0.42994434 -1.0035503 ]
  [ 0.87450933  0.04434267 -0.34513336]
  [ 0.67070735  0.48131174 -1.2213402 ]]

 [[ 0.90400136 -1.5148104   1.6624688 ]
  [-0.57737005 -1.4149495  -0.20211951]
  [ 0.08933817 -2.5206723   1.3372997 ]]

 [[-0.30345318 -0.01401055 -0.8413807 ]
  [ 1.3572499  -2.0300913   0.38513002]
  [-1.1501877   0.60647106  0.05988278]]])>with<JVPTrace(level=2/0)>
  with primal = Traced<ConcreteArray([[[ 0.6252229  -0.42994434 -1.0035503 ]
                  [ 0.87450933  0.04434267 -0.34513336]
                  [ 0.67070735  0.48131174 -1.2213402 ]]
                
                 [[ 0.90400136 -1.5148104   1.6624688 ]
                  [-0.57737005 -1.4149495  -0.20211951]
                  [ 0.08933817 -2.5206723   1.3372997 ]]
                
                 [[-0.30345318 -0.01401055 -0.8413807 ]
                  [ 1.3572499  -2.0300913   0.38513002]
                  [-1.1501877   0.60647106  0.0598




NotImplementedError: Singular value decomposition JVP not implemented for full matrices

In [1]:
import jax.numpy as jnp
from emlp.groups_jax import LearnedGroup,SO
from emlp.equivariant_subspaces_jax import T,get_active_subspace,projection_matrix
from emlp.datasets_jax import Fr

In [2]:
#ds = Fr()
#model = EMLPLearned(ds.rep_in,ds.rep_out,**{'num_layers':3,'d':3,'ch':128})

In [3]:
Q = get_active_subspace(SO(3),(3,0))

In [4]:
P = projection_matrix(SO(3),(3,0))

In [5]:
U,S,VT = jnp.linalg.svd(P,full_matrices=True)

In [6]:
S

DeviceArray([3.4641020e+00, 3.4641018e+00, 3.4641018e+00, 3.4641016e+00,
             3.4641016e+00, 3.4641013e+00, 3.4641013e+00, 2.4494901e+00,
             2.4494901e+00, 2.4494898e+00, 2.4494898e+00, 2.4494898e+00,
             2.4494896e+00, 2.4494896e+00, 2.4494896e+00, 2.4494896e+00,
             2.4494894e+00, 1.4142139e+00, 1.4142138e+00, 1.4142137e+00,
             1.4142137e+00, 1.4142137e+00, 1.4142134e+00, 1.4142134e+00,
             1.4142133e+00, 1.4142132e+00, 1.4440218e-07], dtype=float32)

In [7]:
VT[(S>1e-6).sum():].reshape(3,3,3)

DeviceArray([[[-0.0000000e+00,  3.0858658e-08,  8.2125688e-08],
              [ 3.1295126e-08,  1.9933603e-08,  4.0824825e-01],
              [ 8.1934409e-08, -4.0824834e-01,  1.7629272e-08]],

             [[-1.2653363e-08,  4.0951225e-08, -4.0824825e-01],
              [-6.9838137e-08,  2.5179830e-08, -3.1598255e-08],
              [ 4.0824834e-01,  1.6228114e-08, -2.8592899e-08]],

             [[-1.6091979e-08,  4.0824831e-01, -5.3383498e-09],
              [-4.0824831e-01, -3.8538712e-08,  1.9951237e-08],
              [ 1.2598501e-08,  4.5214659e-09, -1.3233216e-08]]],            dtype=float32)

In [8]:
VT[(S>1e-9)

SyntaxError: unexpected EOF while parsing (<ipython-input-8-26a9c2bd0c9d>, line 1)

In [1]:
from emlp.equivariant_subspaces_jax import T
from emlp.utils import NoCache as nocache
from emlp.groups_jax import SO,O,C,D,Scaling,Parity,TimeReversal,Lorentz,SO13p,SO13,Symplectic,Permutation,Trivial,LearnedGroup
from emlp.groups_jax import DiscreteTranslation,SU,Permutation



In [None]:
for i in range(8):
    print(f"T({i}): {T(i)(SO(3)).symmetric_subspace()[0]}")

In [10]:
G = LearnedGroup(3,0,1)
for i in range(8):
    print(f"T({i}): {T(i)(G).symmetric_subspace()[0]}")

T(0): 1
T(1): 0
T(2): 0
T(3): 0
T(4): 0
T(5): 0
T(6): 0
T(7): 1


In [11]:
Q = get_active_subspace(SO(3),(1,1))
P = projection_matrix(SO(3),(1,1))
U,S,VT = jnp.linalg.svd(P,full_matrices=True)

In [3]:
G = LearnedGroup(3,0,3)
for i in range(5):
    for j in range(4):
        print(f"T({i,j}): {T(i,j)(G).symmetric_subspace()[0]}")

T((0, 0)): 1
T((0, 1)): 0
T((0, 2)): 0
T((0, 3)): 0
T((1, 0)): 0
T((1, 1)): 1
T((1, 2)): 0
T((1, 3)): 0
T((2, 0)): 0
T((2, 1)): 0
T((2, 2)): 0
T((2, 3)): 0
T((3, 0)): 0
T((3, 1)): 0
T((3, 2)): 0
T((3, 3)): 0
T((4, 0)): 0
T((4, 1)): 0
T((4, 2)): 0
T((4, 3)): 0


In [4]:
for i in range(5):
    for j in range(4):
        print(f"T({i,j}): {T(i,j)(G).show_subspace()}")

AttributeError: 'DeviceArray' object has no attribute 'abs'

In [8]:

for i in range(3):
    print(f"T({i}): {T(i)(DiscreteTranslation(32)).symmetric_subspace()[0]}")

T(0): 1
T(1): 1
T(2): 32


In [5]:
for i in range(4):
    print(f"T({i}):\n {T(i)(DiscreteTranslation(3)).show_subspace()}")

T(0):
 1.0
T(1):
 [0.5773504 0.5773503 0.5773503]
T(2):
 [[0.57735026 1.7320509  1.1547006 ]
 [1.1547008  0.57735026 1.7320511 ]
 [1.7320509  1.1547006  0.57735026]]
T(3):
 [[[5.196152  1.1547006 1.7320509]
  [2.3094013 3.4641018 0.5773503]
  [4.6188025 4.0414524 2.8867517]]

 [[2.8867517 4.618803  4.0414524]
  [1.7320509 5.1961527 1.1547006]
  [0.5773503 2.3094013 3.4641018]]

 [[3.4641018 0.5773504 2.3094013]
  [4.0414524 2.8867517 4.6188025]
  [1.1547008 1.7320511 5.196152 ]]]


In [8]:
bell_numbers = [1,1,2,5,15,52,203,877]
k_max= 8
n_max=9
printstr = ""
#with nocache():
printstr+="       "
for n in range(2,n_max):
    printstr+=" n={:1d}".format(n)
    printstr+="  "
printstr+="b(k)\n"
printstr+="     "+"______"*(n_max-2)+"\n"
for k in range(k_max):

    printstr+=" k={:1d}|".format(k)
    for n in range(2,n_max):
        printstr+="  "
        if n**k<7000: printstr+="{:4d}".format(T(k)(Permutation(n)).symmetric_subspace()[0])
        else: printstr+="    "
        
        #print(f"S{n} T({k}): {T(k)(Permutation(n)).symmetric_subspace()[0]}")
    printstr+="| {:3d}\n".format(bell_numbers[k])
print(printstr)

        n=2   n=3   n=4   n=5   n=6   n=7   n=8  b(k)
     __________________________________________
 k=0|     1     1     1     1     1     1     1|   1
 k=1|     1     1     1     1     1     1     1|   1
 k=2|     2     2     2     2     2     2     2|   2
 k=3|     4     5     5     5     5     5     5|   5
 k=4|     8    14    15    15    15    15    15|  15
 k=5|    16    41    51    52                  |  52
 k=6|    32   122   187                        | 203
 k=7|    64   365                              | 877



In [2]:

k_max= 8
n_max=9
printstr = ""
#with nocache():
printstr+="       "
for n in range(2,n_max):
    printstr+=" n={:1d}".format(n)
    printstr+="  "
printstr+="\n"
printstr+="     "+"______"*(n_max-2)+"\n"
for k in range(k_max):

    printstr+=" k={:1d}|".format(k)
    for n in range(2,n_max):
        printstr+="  "
        if n**k<7000: printstr+="{:4d}".format(T(k)(DiscreteTranslation(n)).symmetric_subspace()[0])
        else: printstr+="    "
        
        #print(f"S{n} T({k}): {T(k)(Permutation(n)).symmetric_subspace()[0]}")
    printstr+="|\n"
print(printstr)

        n=2   n=3   n=4   n=5   n=6   n=7   n=8  
     __________________________________________
 k=0|     1     1     1     1     1     1     1|
 k=1|     1     1     1     1     1     1     1|
 k=2|     2     3     4     5     6     7     8|
 k=3|     4     9    16    25    36    49    64|
 k=4|     8    27    64   125   216   343   512|
 k=5|    16    81   256   625                  |
 k=6|    32   243  1024                        |
 k=7|    64   729                              |



In [2]:
T(4)(DiscreteTranslation(9)).symmetric_subspace()[0]

729