First install the repo and requirements.

In [None]:
%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git

# Limited Pytorch Support

We strongly recommend that users of our libary write native Jax code. However, we understand that due to existing code and/or constraints from the employer, it is sometimes unavoidable to use other frameworks like PyTorch. 

To service these requirements, we have added a way that PyTorch users can make use of the equivariant bases $Q\in \mathbb{R}^{n\times r}$ and projection matrices $P = QQ^\top$ that are computed by our solver. Since these objects are implicitly defined through `LinearOperators`, it is not as straightforward as simply calling `torch.from_numpy(Q)`. However, there is a way to use these operators within PyTorch code while preserving any gradients of the operation. We provide the function `emlp.reps.pytorch_support.torchify_fn` to do this.

In [1]:
import torch
import jax
import jax.numpy as jnp
from emlp.reps import V
from emlp.groups import S

W =V(S(4))
rep = 3*W+W**2

In [2]:
Q = (rep>>rep).equivariant_basis()
P = (rep>>rep).equivariant_projector()

In [3]:
applyQ = lambda v: Q@v
applyP = lambda v: P@v

The key is to wrap the desired operations as a function, and then we can apply `torchify_fn`. Now instead of taking jax objects as inputs and outputing jax objects, these functions take in PyTorch objects and output PyTorch objects.

In [4]:
from emlp.reps.pytorch_support import torchify_fn
applyQ_torch = torchify_fn(applyQ)
applyP_torch = torchify_fn(applyP)

In [5]:
x_torch = torch.arange(Q.shape[-1]).float().cuda()
x_torch.requires_grad=True
x_jax  = jnp.asarray(x_torch.cpu().data.numpy()) 

In [6]:
Qx1 = applyQ(x_jax)
Qx2 = applyQ_torch(x_torch)
print("jax output: ",Qx1[:5])
print("torch output: ",Qx2[:5])

jax output:  [0.48484263 0.07053992 0.07053989 0.07053995 1.6988853 ]
torch output:  tensor([0.4848, 0.0705, 0.0705, 0.0705, 1.6989], device='cuda:0',
       grad_fn=<SliceBackward>)


The outputs match, and note that the torch outputs will be on whichever is the default jax device. Similarly, the gradients of the two objects also match:

In [7]:
torch.autograd.grad(Qx2.sum(),x_torch)[0][:5]

tensor([-2.8704,  2.7858, -2.8704,  2.7858, -2.8704], device='cuda:0')

In [8]:
jax.grad(lambda x: (Q@x).sum())(x_jax)[:5]

DeviceArray([-2.8703732,  2.7858496, -2.8703732,  2.7858496, -2.8703732],            dtype=float32)

So you can safely use these torchified functions within your model, and still compute the gradients correctly.