-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2V>>Scalar representation is all zero? #14
Comments
Hi @topoliu, However don't despair, there are many nonlinear equivariant maps and EMLP will find them by mapping to feature tensors of higher order as well as using equivariant nonlinearities. To see how this works, let's imagine the first feature representation of EMLP is (2*W>>30*Scalar+10*W+3*W**2+W**3).equivariant_basis().shape Out: (684, 32) And calling vis to visualize this basis, the layer looks like this: Showing that with these higher order representations and nonlinearities, EMLP can in fact fit your functions, import emlp
from emlp.reps import V,T,Scalar
from emlp.groups import SO
import objax
W = V(SO(3))
repin = 2*W
repout = Scalar
model = emlp.nn.EMLP(repin,repout,SO(3))
import numpy as np
import jax.numpy as jnp
from jax import device_put
u,v = np.random.randn(2,100,3)
X = np.concatenate([u,v],-1)
Y = ((u*v).sum(-1) - jnp.linalg.norm(v,axis=1)**3)[:,None]
Y /= Y.std()
Y = device_put(Y)
X = device_put(X) opt = objax.optimizer.Adam(model.vars())
@objax.Jit
@objax.Function.with_vars(model.vars()+opt.vars())
def loss(x, y):
yhat = model(x)
return ((yhat-y)**2).mean()
grad_and_val = objax.GradValues(loss, model.vars())
@objax.Jit
@objax.Function.with_vars(model.vars()+opt.vars())
def train_op(x, y, lr):
g, v = grad_and_val(x, y)
opt(lr=lr, grads=g)
return v
import matplotlib.pyplot as plt
losses = [train_op(X,Y,3e-3) for _ in range(200)]
plt.plot(losses)
plt.yscale('log')
plt.ylabel("MSE") |
Hello Finzi: Thanks for you quick and detailed response. At a more realistic scenario, molecular property prediction, the atom coordinates are 3 dim vectors, and the property is a scalar. Thanks |
Hi @topoliu, Again, (nV>>Scalar) is zero, but EMLP(nV,Scalar) is not, since it heavily leverages the bilinear layer to compute inner products (which for O(3) and scalar outputs is essentially all you need, see Scalars are universal: Gauge-equivariant machine I hope this clears up the confusion. The key point is that there is a difference between the equivariant linear maps (found with |
import jax
import jax.numpy as jnp
from emlp.reps import V,Scalar
from emlp.groups import SO
import numpy as np
W =V(SO(3))
rep = 2*W
P = (rep>>Scalar).equivariant_projector()
applyP = lambda v: P@v
P.to_dense()
DeviceArray([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], dtype=float32)
I am running a toy example, two vector as input, a scalar as output
if I rotate the vector, the output scalar should not change.
But I get P matrix is all zero!
The text was updated successfully, but these errors were encountered: