In [54]:
import torch
from torch import nn

In [69]:
class _SparseJumpingSquaredReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = x.float()
        nonzeros = (x > 0).to_sparse_csr()
        nonzero_x_shifted = torch.masked.masked_tensor(x * nonzeros, nonzeros).add_(1)
        ctx.save_for_backward(nonzero_x_shifted)

        masked_jsrelu = nonzero_x_shifted.clone().square_().add_(-1).div_(2)

        res: torch.Tensor =  masked_jsrelu.get_data()
        assert res.is_sparse or res.is_sparse_csr, res

        return res

    @staticmethod
    def backward(ctx, grad_output):
        nonzero_x_shifted, = ctx.saved_tensors
        return grad_output * nonzero_x_shifted.get_data()
    
SparseJumpingSquaredReLU = _SparseJumpingSquaredReLU.apply

In [9]:
x = torch.randn([128, 128], device='cuda', requires_grad=True)

In [52]:
with torch.no_grad():
    y = SparseJumpingSquaredReLU(x)

In [72]:
class MLP(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.wi = nn.Linear(128, 128)
        self.wo = nn.Linear(128, 128)
    def forward(self, x):
        x = self.wi(x)
        x = SparseJumpingSquaredReLU(x)
        x = self.wo(x)
        return x

In [73]:
mlp = MLP().to('cuda')
with torch.no_grad():
    z = mlp(x)
z

tensor([[-0.2390, -0.0738, -0.4273,  ...,  0.0761, -0.2038,  0.1887],
        [ 0.0366, -0.3345,  0.0172,  ...,  0.4337,  0.2052, -0.2865],
        [ 0.1740,  0.0359, -0.2689,  ...,  0.0223, -0.4306,  0.1324],
        ...,
        [ 0.5264, -0.1293, -0.2503,  ...,  0.1437, -0.2875, -0.2538],
        [ 0.1357,  0.1196, -0.0615,  ...,  0.2817,  0.3430, -0.2425],
        [ 0.3831,  0.3054,  0.1942,  ...,  0.2678,  0.4195,  0.4509]],
       device='cuda:0')

In [61]:
x.requires_grad = True

z = mlp(x)
loss = (z**2).sum()
loss.backward()

tensor([[-0.1903,  0.1982,  0.6741,  ..., -0.2116,  0.2345,  0.3611],
        [-0.3962,  0.6045, -0.5680,  ...,  0.3701, -1.4278, -0.6893],
        [ 0.1720,  0.1340, -0.0762,  ..., -0.1806, -0.4827,  0.0148],
        ...,
        [-0.3116,  0.0321, -0.2780,  ...,  0.3731, -0.6803, -0.3226],
        [ 0.2022, -0.0880,  0.1228,  ...,  0.1000, -0.1772, -0.2030],
        [ 0.0146,  0.6470,  0.2027,  ..., -0.4298, -0.1915,  0.6884]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


RuntimeError: Function _SparseJumpingSquaredReLUBackward returned an invalid gradient at index 0 - expected layout Strided but got SparseCsr