In [1]:
import torch
from torch import nn
from torch.masked import MaskedTensor
from torch.masked.maskedtensor.core import _tensors_match

In [2]:
from torch.overrides import get_default_nowrap_functions
def _validate_members(self):
    data = self._masked_data if self._masked_data.is_sparse else self._masked_data
    mask = self.get_mask()
    if type(data) != type(mask):
        raise TypeError(f"data and mask must have the same type. Got {type(data)} and {type(mask)}")
    if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
        raise TypeError(f"data layout of {data.layout} is not supported.")
    if data.layout == torch.sparse_coo:
        if not _tensors_match(data.indices(), mask.indices(), exact=True):
            raise ValueError("data and mask are both sparse COO tensors but do not have the same indices.")
    elif data.layout == torch.sparse_csr:
        if not _tensors_match(
            data.crow_indices(), mask.crow_indices(), exact=True
        ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
            raise ValueError("data and mask are both sparse CSR tensors but do not share either crow or col indices.")
    if mask.dtype != torch.bool:
        raise TypeError("mask must have dtype bool.")
    if not (
        data.dtype == torch.float16
        or data.dtype == torch.float32
        or data.dtype == torch.float64
        or data.dtype == torch.bool
        or data.dtype == torch.int8
        or data.dtype == torch.int16
        or data.dtype == torch.int32
        or data.dtype == torch.int64
    ):
        raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
    if data.dim() != mask.dim():
        raise ValueError("data.dim() must equal mask.dim()")
    if data.size() != mask.size():
        raise ValueError("data.size() must equal mask.size()")

def _set_data_mask(self, data, mask):
    self._masked_data = data.coalesce() if data.is_sparse else data
    self._masked_mask = mask
    self._validate_members()

_old_preprocess_data = MaskedTensor._preprocess_data

def _preprocess_data(self, data, mask):
    _old_preprocess_data(self, data, mask)
    if self._masked_data.is_sparse:
        self._masked_data.coalesce()

MaskedTensor._validate_members = _validate_members
MaskedTensor._set_data_mask = _set_data_mask
MaskedTensor._preprocess_data = _preprocess_data

    # @classmethod
    # def __torch_function__(cls, func, types, args=(), kwargs=None):
        # print(func)
        # return MaskedTensor.__torch_function__(func, types, args, kwargs)

# COO

In [3]:
class _SparseCOOJumpingSquaredReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.x_shape = x.shape
        x = x.float().flatten(0, 1)
        nonzeros = (x > 0).to_sparse_coo()
        nonzero_x_shifted = MaskedTensor(x * nonzeros, nonzeros).add_(1)
        ctx.save_for_backward(nonzero_x_shifted)

        masked_jsrelu = nonzero_x_shifted.clone().square_().add_(-1).div_(2)

        res: torch.Tensor =  masked_jsrelu.get_data()
        assert res.is_sparse or res.is_sparse_csr, res

        return res

    @staticmethod
    def backward(ctx, grad_output):
        nonzero_x_shifted, = ctx.saved_tensors
        return (grad_output * nonzero_x_shifted.get_data()).to_dense().reshape(ctx.x_shape)
    
SparseCOOJumpingSquaredReLU = _SparseCOOJumpingSquaredReLU.apply

In [4]:
from torch import Tensor


class SparseCOOLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, n_token=None) -> None:
        super().__init__(in_features, out_features, bias, device, dtype)
        self.n_token = n_token
    def forward(self, input: Tensor, n_token=None) -> Tensor:
        if not input.is_sparse and not input.is_sparse_csr:
            return super().forward(input)
        assert len(input.shape) <= 2
        output: torch.Tensor = torch.sparse.addmm(self.bias, input, self.weight.transpose(-1, -2))
        if n_token is None:
            n_token = self.n_token
        return output.unflatten(0, [-1, n_token])

In [5]:
class SparseCOOMLP(nn.Module):
    def __init__(self, dim=768, hidden_dim=3072, n_token=None) -> None:
        super().__init__()
        self.key = nn.Linear(dim, hidden_dim)
        self.act = SparseCOOJumpingSquaredReLU
        self.value = SparseCOOLinear(hidden_dim, dim, n_token=n_token)
    def forward(self, x: torch.Tensor):
        activation = self.act(self.key(x))
        return self.value(activation, n_token=x.shape[1])
        
class MLP(nn.Module):
    def __init__(self, dim=768, hidden_dim=3072) -> None:
        super().__init__()
        self.key = nn.Linear(dim, hidden_dim)
        self.act = nn.ReLU()
        self.value = nn.Linear(hidden_dim, dim)
    def forward(self, x: torch.Tensor):
        activation = self.act(self.key(x))
        return self.value(activation)
        

In [6]:
x = torch.randn([64, 197, 768], device='cuda')

In [7]:
smlp = SparseCOOMLP().to('cuda')
for i in range(10):
    y = smlp(x)
    loss = (y**2).sum()
    loss.backward()



In [8]:
mlp = MLP().to('cuda')
for i in range(100):
    y_non_sparse = mlp(x)
    loss = (y_non_sparse**2).sum()
    loss.backward()

# CSR

In [9]:

class _SparseCSRJumpingSquaredReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        ctx.x_shape = x.shape
        x = x.float().flatten(0, 1)
        nonzeros = (x > 0).to_sparse_csr()
        nonzero_x_shifted = MaskedTensor(x * nonzeros, nonzeros).add_(1)
        ctx.save_for_backward(nonzero_x_shifted)

        masked_jsrelu = nonzero_x_shifted.clone().square_().add_(-1).div_(2)

        res: torch.Tensor =  masked_jsrelu.get_data()
        assert res.is_sparse or res.is_sparse_csr, res

        return res

    @staticmethod
    def backward(ctx, grad_output):
        nonzero_x_shifted, = ctx.saved_tensors
        return (grad_output * nonzero_x_shifted.get_data()).to_dense().reshape(ctx.x_shape)
    
SparseCSRJumpingSquaredReLU = _SparseCSRJumpingSquaredReLU.apply

In [10]:
class SparseCSRLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, n_token=None) -> None:
        super().__init__(in_features, out_features, bias, device, dtype)
        self.n_token = n_token
    def forward(self, input: Tensor, n_token=None) -> Tensor:
        if not input.is_sparse and not input.is_sparse_csr:
            return super().forward(input)
        assert len(input.shape) <= 2
        output: torch.Tensor = torch.sparse.addmm(self.bias, input, self.weight.transpose(-1, -2))
        if n_token is None:
            n_token = self.n_token
        return output.unflatten(0, [-1, n_token])

In [11]:
class SparseCSRMLP(nn.Module):
    def __init__(self, dim=768, hidden_dim=3072, n_token=None) -> None:
        super().__init__()
        self.key = nn.Linear(dim, hidden_dim)
        self.act = SparseCSRJumpingSquaredReLU
        self.value = SparseCSRLinear(hidden_dim, dim, n_token=n_token)
    def forward(self, x: torch.Tensor):
        activation = self.act(self.key(x))
        return self.value(activation, n_token=x.shape[1])

In [14]:
csr_mlp = SparseCSRMLP().to('cuda')
for i in range(10):
    y = csr_mlp(x)
    loss = (y**2).sum()
    loss.backward()