In [1]:
import torch

class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._sussy_metadata = metadata

    def __repr__(self):
        return f"{self._t.shape} metadata {self._sussy_metadata}"

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        print()
        print(f'{cls = }')
        print(f'{func = }')
        print(f'{func.__name__ = }')
        if kwargs is None:
            kwargs = {}
        metadata = None
        all_args = list(args) + list(kwargs.values())
        for i, arg in enumerate(all_args):
            print(f'{i = }, {arg = }, {arg.__class__.__name__ = }')
            if hasattr(arg, '_sussy_metadata'):
                metadata = arg._sussy_metadata
                break
        args = [a._t if hasattr(a, '_t') else a for a in args]
        kwargs = {k: v._t if isinstance(v, MetadataTensor) else v for k, v in kwargs.items()}
        ret = func(*args, **kwargs)
        return cls(ret, metadata=metadata)


In [2]:
sus = MetadataTensor(torch.randn(1,4), 'big chungus')
torch.randn(1,4).clamp(min=sus)


cls = <class '__main__.MetadataTensor'>
func = <method 'clamp' of 'torch._C.TensorBase' objects>
func.__name__ = 'clamp'
i = 0, arg = tensor([[ 1.2620, -0.0875, -0.9023, -1.1289]]), arg.__class__.__name__ = 'Tensor'
i = 1, arg = torch.Size([1, 4]) metadata big chungus, arg.__class__.__name__ = 'MetadataTensor'


torch.Size([1, 4]) metadata big chungus

In [13]:
# pyright: reportIncompatibleMethodOverride=false
# pylint: disable=abstract-method
import math

from typing import Any, Literal
import torch

NumberOrTensor = torch.Tensor
class Algebra:
    def add(self, x: torch.Tensor, y: "NumberOrTensor") -> torch.Tensor:
        raise NotImplementedError

    def neg(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def sub(self, x: torch.Tensor, y: "NumberOrTensor") -> torch.Tensor:
        return self.neg(self.add(self.neg(x), y)) # can be overridden

    def mul(self, x: torch.Tensor, y: "NumberOrTensor") -> torch.Tensor:
        raise NotImplementedError

    def div(self, x: torch.Tensor, y: "NumberOrTensor") -> torch.Tensor:
        raise NotImplementedError

    def reciprocal(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def pow(self, base: torch.Tensor, exponent: "NumberOrTensor") -> torch.Tensor:
        raise NotImplementedError

    def sum(self, x: torch.Tensor, dim: int | None = None, keepdim=False) -> torch.Tensor:
        raise NotImplementedError

    def min(self, x: torch.Tensor, dim: int | None = None, keepdim=False) -> torch.Tensor:
        if dim is None: return torch.min(x)
        return x.amin(dim, keepdim)

    def max(self, x: torch.Tensor, dim: int | None = None, keepdim=False) -> torch.Tensor:
        if dim is None: return torch.max(x)
        return x.amax(dim, keepdim)

    def matmul(self, x: torch.Tensor, y: torch.Tensor):
        # this imlements matmul by calling add and mul

        x_squeeze = False
        y_squeeze = False

        if x.ndim == 1:
            x_squeeze = True
            x = x.unsqueeze(0)

        if y.ndim == 1:
            y_squeeze = True
            y = y.unsqueeze(1)

        res = self.sum(self.mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim = -2)

        if x_squeeze: res = res.squeeze(-2)
        if y_squeeze: res = res.squeeze(-1)

        return res

    def mm(self, x:torch.Tensor, y:torch.Tensor):
        return self.matmul(x, y)
    
[v for v in dir(Algebra) if not v.startswith('_')]

['add',
 'div',
 'matmul',
 'max',
 'min',
 'mm',
 'mul',
 'neg',
 'pow',
 'reciprocal',
 'sub',
 'sum']