Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2V>>Scalar representation is all zero? #14

Closed
topoliu opened this issue Aug 12, 2021 · 3 comments
Closed

2V>>Scalar representation is all zero? #14

topoliu opened this issue Aug 12, 2021 · 3 comments

Comments

@topoliu
Copy link

topoliu commented Aug 12, 2021

import jax
import jax.numpy as jnp
from emlp.reps import V,Scalar
from emlp.groups import SO
import numpy as np

W =V(SO(3))
rep = 2*W
P = (rep>>Scalar).equivariant_projector()
applyP = lambda v: P@v

P.to_dense()

DeviceArray([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], dtype=float32)

I am running a toy example, two vector as input, a scalar as output
if I rotate the vector, the output scalar should not change.

But I get P matrix is all zero!

@mfinzi
Copy link
Owner

mfinzi commented Aug 12, 2021

Hi @topoliu,
This is as expected since there are no rotation equivariant linear maps directly from a vector (or a pair of them) to a scalar. You can think of maps from vectors to scalars as mere vectors, and there are no rotation invariant vectors in 3D (except for the 0 vector).

However don't despair, there are many nonlinear equivariant maps and EMLP will find them by mapping to feature tensors of higher order as well as using equivariant nonlinearities.

To see how this works, let's imagine the first feature representation of EMLP is 30*Scalar+10*W+3*W**2+W**3
The equivariant layer from the first 2*W to this feature map has 32 independent degrees of freedom

(2*W>>30*Scalar+10*W+3*W**2+W**3).equivariant_basis().shape

Out: (684, 32)

And calling vis to visualize this basis, the layer looks like this:
image

Showing that with these higher order representations and nonlinearities, EMLP can in fact fit your functions,
let's take the example target function is f: 2V->Scalar, f(u,v) = u^Tv-||v||^3, EMLP can fit this function without trouble.

import emlp
from emlp.reps import V,T,Scalar
from emlp.groups import SO
import objax

W = V(SO(3))
repin = 2*W
repout = Scalar

model = emlp.nn.EMLP(repin,repout,SO(3))

import numpy as np
import jax.numpy as jnp
from jax import device_put

u,v = np.random.randn(2,100,3)
X = np.concatenate([u,v],-1)
Y = ((u*v).sum(-1) - jnp.linalg.norm(v,axis=1)**3)[:,None]
Y /= Y.std()
Y = device_put(Y)
X = device_put(X)
opt = objax.optimizer.Adam(model.vars())

@objax.Jit
@objax.Function.with_vars(model.vars()+opt.vars())
def loss(x, y):
    yhat = model(x)
    return ((yhat-y)**2).mean()

grad_and_val = objax.GradValues(loss, model.vars())

@objax.Jit
@objax.Function.with_vars(model.vars()+opt.vars())
def train_op(x, y, lr):
    g, v = grad_and_val(x, y)
    opt(lr=lr, grads=g)
    return v

import matplotlib.pyplot as plt
losses = [train_op(X,Y,3e-3) for _ in range(200)]
plt.plot(losses)
plt.yscale('log')
plt.ylabel("MSE")

image

@topoliu
Copy link
Author

topoliu commented Aug 13, 2021

Hello Finzi:

Thanks for you quick and detailed response.
I got a little confused about the scenario: I have n 3 dim vectors as features and a scalar as targets.
at least, the angle and distance of two vectors should not change under SO(3).

At a more realistic scenario, molecular property prediction, the atom coordinates are 3 dim vectors, and the property is a scalar.
(nV>>Scalar) is always zero.

Thanks

@mfinzi
Copy link
Owner

mfinzi commented Aug 13, 2021

Hi @topoliu,
As you mention, angles and distance between two vectors are invariants under SO(3). EMLP can fit these functions just fine, to see for yourself just rerun the above example with the targets Y = jnp.linalg.norm(u-v) or Y = (u*v).sum(-1)/np.sqrt((u*u).sum(-1)*(v*v).sum(-1)).

Again, (nV>>Scalar) is zero, but EMLP(nV,Scalar) is not, since it heavily leverages the bilinear layer to compute inner products (which for O(3) and scalar outputs is essentially all you need, see Scalars are universal: Gauge-equivariant machine
learning, structured like classical physics
). You can form these maps with EMLP, although if you also want permutation equivariance between the vectors then E(n)-GNN is the practical way to go (even if it can be replicated with some effort via EMLP).

I hope this clears up the confusion. The key point is that there is a difference between the equivariant linear maps (found with (nV>>Scalar).equivariant_basis()) and the nonlinear equivariant multilayer perceptron that you get with EMLP(nV,Scalar,SO(3)) even though EMLP makes use of the linear maps within its layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants