In [1]:
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from typing import *

_torch_tensor_to = torch.Tensor.to
_torch_tensor = torch.tensor

In [2]:
# Core


@dataclass
class ScaledDType:
    name: str
    dtype: torch.dtype
    base_scale: float

    def __init__(self, name: str, dtype: torch.dtype):
        self.name = name
        self.dtype = dtype
        dmax = torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max
        self.base_scale = 1 / dmax

    def __hash__(self) -> int:
        return hash((self.name, self.dtype))

    def __repr__(self) -> str:
        return self.name

    @property
    def bits(self) -> float:
        return torch.finfo(self.dtype).bits if self.dtype.is_floating_point else torch.iinfo(self.dtype).bits


sfloat16 = ScaledDType("sfloat16", torch.float16)
sfloat8_e4m3fn = ScaledDType("sfloat8_e4m3fn", torch.float8_e4m3fn)
sfloat8_e5m2 = ScaledDType("sfloat8_e5m2", torch.float8_e5m2)
sint32 = ScaledDType("sint32", torch.int32)
sint16 = ScaledDType("sint16", torch.int16)
sint8 = ScaledDType("sint8", torch.int8)


@dataclass
class ScaledTensor:
    data: Tensor
    scale: Tensor
    dtype: ScaledDType

    def __post_init__(self) -> None:
        assert self.data.dtype == self.dtype.dtype
        assert self.scale.dtype == torch.float32

    def __repr__(self) -> str:
        return f"stensor({self.to(torch.float32).numpy()}, dtype={self.dtype})"

    def to(self, dtype: torch.dtype) -> Tensor:
        return self.data.to(dtype) * self.scale.to(dtype) * self.dtype.base_scale

    @property
    def shape(self) -> torch.Size:
        return self.data.shape

    @classmethod
    def quantise(cls, x: Tensor, dtype: ScaledDType) -> "ScaledTensor":
        scale = x.abs().max().to(torch.float32)
        if scale == 0:
            scale = torch.tensor(1, dtype=torch.float32)
        return cls(data=(x / scale / dtype.base_scale).to(dtype.dtype), scale=scale, dtype=dtype)

    @property
    def wasted_bits(self) -> float:
        range = (self.data.abs().max() * self.dtype.base_scale).to(torch.float32)
        if range == 0:
            return self.dtype.bits
        logrange = torch.log2(1 / range)
        if self.dtype.dtype.is_floating_point:
            return float(torch.log2(logrange))
        return float(logrange)

    def requantise(self) -> "ScaledTensor":
        rescale = self.data.abs().max().to(torch.float32) * self.dtype.base_scale
        return ScaledTensor((self.data / rescale).to(self.dtype.dtype), scale=self.scale * rescale, dtype=self.dtype)


# Monkeypatch for convenience
def _tensor_to(self: Tensor, *args: Any, **kwargs: Any) -> Tensor:
    for arg in args + (kwargs.get("dtype"),):
        if isinstance(arg, ScaledDType):
            assert len(args) + len(kwargs) == 1
            return ScaledTensor.quantise(self, arg)
    return _torch_tensor_to(self, *args, **kwargs)


torch.Tensor.to = _tensor_to


def _tensor(
    data: Any, *, dtype: Union[None, torch.dtype, ScaledDType] = None, **kwargs: Any
) -> Union[Tensor, ScaledTensor]:
    if isinstance(dtype, ScaledDType):
        return _torch_tensor(data, **kwargs).to(dtype)
    return _torch_tensor(data, dtype=dtype, **kwargs)


torch.tensor = _tensor

d = torch.tensor([1, 2, 3], dtype=sfloat16)
print(d)
d.data[...] /= 2.0**16
print(f"Wasted {d.wasted_bits} bits")
print(d)

stensor([1.0002443 2.0004885 3.       ], dtype=sfloat16)
Wasted 4.0 bits
stensor([1.5262516e-05 3.0525032e-05 4.5776367e-05], dtype=sfloat16)


In [3]:
# Ops
# - Use worst-case scaling rules (no overflow!)
# - Placeholder impl (fast impl requires custom kernls)
# - No autograd support


def add(a: ScaledTensor, b: ScaledTensor) -> ScaledTensor:
    assert a.dtype == b.dtype
    scale = a.scale + b.scale
    data = (a.data * (a.scale / scale) + b.data * (b.scale / scale)).to(a.data.dtype)
    return ScaledTensor(data, scale, a.dtype)


def matmul(a: ScaledTensor, b: ScaledTensor) -> ScaledTensor:
    assert a.dtype == b.dtype
    scale = a.scale * b.scale * a.shape[-1]
    downscale = (a.shape[-1] / a.dtype.base_scale) ** -0.5
    return ScaledTensor((a.data * downscale).to(a.dtype.dtype) @ (b.data * downscale).to(b.dtype.dtype), scale, a.dtype)


def relu(a: ScaledTensor) -> ScaledTensor:
    return ScaledTensor(nn.functional.relu(a.data), a.scale, a.dtype)


ScaledTensor.__add__ = add
ScaledTensor.__matmul__ = matmul


print(torch.tensor(100, dtype=sint8) + torch.tensor(200, dtype=sint8))
print(torch.full((2, 20), 100).to(sint16) @ torch.full((20, 3), 200).to(sint16))

stensor(300.0, dtype=sint8)
stensor([[390636.9 390636.9 390636.9]
 [390636.9 390636.9 390636.9]], dtype=sint16)


In [4]:
# Example


class FFN(nn.Module):
    def __init__(self, hidden_size: int, dtype: Union[torch.dtype, ScaledDType]):
        super().__init__()
        self.W0 = torch.randn(hidden_size, 4 * hidden_size).to(dtype)
        self.W1 = torch.randn(4 * hidden_size, hidden_size).to(dtype)
        self.relu = relu if isinstance(dtype, ScaledDType) else nn.functional.relu

    def forward(self, x: Union[Tensor, ScaledTensor]) -> Union[Tensor, ScaledTensor]:
        y = self.relu(x @ self.W0)

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #1: {y.wasted_bits}")
            y = y.requantise()

        y = y @ self.W1

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #2: {y.wasted_bits}")

        return y


hidden_size = 1024
dtype = sint16
module = FFN(hidden_size, dtype)
result = module(torch.randn(10, hidden_size).to(dtype))
print()
print(result)

wasted bits #1: 8.167065620422363
wasted bits #2: 10.912492752075195

stensor([[ -94.06189   -47.030945 -470.30945  ...  -47.030945 -329.2166
     0.      ]
 [  47.030945 -329.2166      0.       ...  235.15472   235.15472
    94.06189 ]
 [  47.030945 -141.09283   -94.06189  ...    0.         47.030945
   423.2785  ]
 ...
 [ -47.030945    0.         47.030945 ...  141.09283   188.12378
   188.12378 ]
 [ 188.12378   188.12378  -235.15472  ...  -47.030945   47.030945
   -94.06189 ]
 [  47.030945    0.       -141.09283  ...   47.030945 -376.24756
     0.      ]], dtype=sint16)
