In [1]:
from functools import partial
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from typing import *
_torch_tensor_to = torch.Tensor.to
_torch_tensor = torch.tensor

def numeric_info(dtype):
  return torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)


In [11]:
# Core

@dataclass
class ScaledDType:
    name: str
    dtype: torch.dtype
    base_scale: float

    def __init__(self, name: str, dtype: torch.dtype):
        self.name = name
        self.dtype = dtype
        dmax = numeric_info(dtype).max
        self.base_scale = 1 / dmax

    def __hash__(self) -> int:
        return hash((self.name, self.dtype))

    def __repr__(self) -> str:
        return self.name

    @property
    def bits(self) -> float:
        return numeric_info(self.dtype).bits


sfloat16 = ScaledDType("sfloat16", torch.float16)
sfloat8_e4m3fn = ScaledDType("sfloat8_e4m3fn", torch.float8_e4m3fn)
sfloat8_e5m2 = ScaledDType("sfloat8_e5m2", torch.float8_e5m2)
sint32 = ScaledDType("sint32", torch.int32)
sint16 = ScaledDType("sint16", torch.int16)
sint8 = ScaledDType("sint8", torch.int8)


@dataclass
class ScaledTensor:
    data: Tensor
    scale: Tensor
    dtype: ScaledDType

    def __post_init__(self) -> None:
        assert self.data.dtype == self.dtype.dtype
        assert self.scale.dtype == torch.float32

    def __repr__(self) -> str:
        return f"stensor({self.to(torch.float32).numpy()}, dtype={self.dtype})"

    def to(self, dtype: torch.dtype) -> Tensor:
        return self.data.to(dtype) * self.scale.to(dtype) * self.dtype.base_scale

    @property
    def shape(self) -> torch.Size:
        return self.data.shape

def quantise(x: Tensor, dtype: ScaledDType) -> ScaledTensor:
    scale = x.abs().max().to(torch.float32)
    if scale == 0:
        scale = torch.tensor(1, dtype=torch.float32)
    return ScaledTensor(data=(x / scale / dtype.base_scale).to(dtype.dtype), scale=scale, dtype=dtype)

def wasted_bits(st: ScaledTensor) -> float:
    range = (st.data.abs().max() * st.dtype.base_scale).to(torch.float32)
    if range == 0:
        return st.dtype.bits
    logrange = torch.log2(1/range)
    if st.dtype.dtype.is_floating_point:
        return float(torch.log2(logrange))
    return float(logrange)

def requantise(st: ScaledTensor) -> ScaledTensor:
    rescale = st.data.abs().max().to(torch.float32) * st.dtype.base_scale
    return ScaledTensor((st.data / rescale).to(st.dtype.dtype), scale=st.scale * rescale, dtype=st.dtype)


# Monkeypatch for convenience
@partial(setattr, torch.Tensor, 'to')
def _(self: Tensor, *args: Any, **kwargs: Any) -> Tensor:
    for arg in args + (kwargs.get("dtype"),):
        if isinstance(arg, ScaledDType):
            assert len(args) + len(kwargs) == 1
            return quantise(self, arg)
    return _torch_tensor_to(self, *args, **kwargs)

@partial(setattr, torch, 'tensor')
def _(data: Any, *, dtype: Union[None, torch.dtype, ScaledDType] = None, **kwargs: Any) -> Union[Tensor, ScaledTensor]:
    if isinstance(dtype, ScaledDType):
        return _torch_tensor(data, **kwargs).to(dtype)
    return _torch_tensor(data, dtype=dtype, **kwargs)

d = torch.tensor([1, 2, 3], dtype=sfloat16)
print(d)
d.data[...] /= 2.0**16
print(f"Wasted {wasted_bits(d)} bits")
print(d)

stensor([1.0002443 2.0004885 3.       ], dtype=sfloat16)
Wasted 4.0 bits
stensor([1.5262516e-05 3.0525032e-05 4.5776367e-05], dtype=sfloat16)


In [12]:
# Ops
# - Use worst-case scaling rules (no overflow!)
# - Placeholder impl (fast impl requires custom kernls)
# - No autograd support


def add(a: ScaledTensor, b: ScaledTensor) -> ScaledTensor:
    assert a.dtype == b.dtype
    scale = a.scale + b.scale
    data = (a.data * (a.scale / scale) + b.data * (b.scale / scale)).to(a.data.dtype)
    return ScaledTensor(data, scale, a.dtype)


def matmul(a: ScaledTensor, b: ScaledTensor) -> ScaledTensor:
    assert a.dtype == b.dtype
    scale = a.scale * b.scale * a.shape[-1]
    downscale = (a.shape[-1] / a.dtype.base_scale) ** -.5
    return ScaledTensor(
        (a.data * downscale).to(a.dtype.dtype) @ (b.data * downscale).to(b.dtype.dtype),
        scale, a.dtype)


def relu(a: ScaledTensor) -> ScaledTensor:
    return ScaledTensor(nn.functional.relu(a.data), a.scale, a.dtype)


ScaledTensor.__add__ = add
ScaledTensor.__matmul__ = matmul


print(torch.tensor(100, dtype=sint8) + torch.tensor(200, dtype=sint8))
print(torch.full((2, 20), 100).to(sint16) @ torch.full((20, 3), 200).to(sint16))

stensor(300.0, dtype=sint8)
stensor([[390636.9 390636.9 390636.9]
 [390636.9 390636.9 390636.9]], dtype=sint16)


In [13]:
# Example

class FFN(nn.Module):
    def __init__(self, hidden_size: int, dtype: Union[torch.dtype, ScaledDType]):
        super().__init__()
        self.W0 = torch.randn(hidden_size, 4*hidden_size).to(dtype)
        self.W1 = torch.randn(4*hidden_size, hidden_size).to(dtype)
        self.relu = relu if isinstance(dtype, ScaledDType) else nn.functional.relu

    def forward(self, x: Union[Tensor, ScaledTensor]) -> Union[Tensor, ScaledTensor]:
        y = self.relu(x @ self.W0)

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #1: {wasted_bits(y)}")
            y = requantise(y)

        y = y @ self.W1

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #2: {wasted_bits(y)}")

        return y

hidden_size = 1024
dtype = sint16
module = FFN(hidden_size, dtype)
result = module(torch.randn(10, hidden_size).to(dtype))
print()
print(result)

wasted bits #1: 8.258488655090332
wasted bits #2: 10.414993286132812

stensor([[   0.        -41.124104  164.49641  ...  287.8687     41.124104
   205.62051 ]
 [ -82.24821  -123.37231   -82.24821  ... -164.49641  -164.49641
  -164.49641 ]
 [ 123.37231  -205.62051   411.24103  ...  -82.24821     0.
     0.      ]
 ...
 [ -41.124104  123.37231   205.62051  ... -287.8687   -287.8687
  -411.24103 ]
 [ -82.24821  -164.49641    82.24821  ...   82.24821    41.124104
   -41.124104]
 [  41.124104  205.62051    82.24821  ...   82.24821  -164.49641
  -123.37231 ]], dtype=sint16)
