In [59]:
from types import MethodDescriptorType
from typing import *
from functools import partial
from dataclasses import dataclass

import torch
from torch import Tensor, tensor, nn

def numeric_info(dtype):
  return torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)

def _maxval(dtype):
  return numeric_info(dtype).max

## First, core functionality

Define a `ScaleTensorData` object, with the minimal information, and with operations
defined cleanly, but without PyTorch Tensor integration

In [71]:
# Core
@dataclass
class ScaledTensorData:
    data: Tensor
    scale: Tensor

    def __post_init__(self) -> None:
        if not isinstance(self.scale, Tensor):
            self.scale = tensor(self.scale)
        assert self.scale.dtype == torch.float32
        assert self.scale.shape == () 
        # Possible future expansion to e.g. row-scaled, column-scaled, etc, but
        # for now, insist st_scale is a single-element tensor

    def __repr__(self) -> str:
        return f"ScaledTensor({self.scale} * {self.data})"

    def to_tensor(self, dtype: torch.dtype) -> Tensor:
        return self.data.to(dtype) * self.scale.to(dtype)

    @property
    def shape(self) -> torch.Size:
        return self.data.shape
    
    @property
    def dtype_max(self):
        return numeric_info(self.data.dtype).max

GPU = False
def _uptype(dtype) -> torch.dtype:
    if numeric_info(dtype).bits < 16:
        return torch.float16 if GPU else torch.float32
    else:
        return dtype
    
def _to_uptype(t : Tensor) -> Tensor:
    return torch.as_tensor(t, dtype=_uptype(t.dtype))

def _maxabs(t : Tensor):
    return _to_uptype(t).abs().max().to(torch.float32)

def st_quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensorData:
    """
    Rescale so that max(|data|) == maxFinite(data.dtype)

    """
    maxval = _maxabs(x)
    if maxval == 0:
        # Tensor is all zeros - set scale to 1
        scale = Tensor(1.0, dtype=torch.float32)
    else:
        # Scale so that largest element is the largest finite value of dtype
        scale = maxval / _maxval(dtype)

    return ScaledTensorData((x / scale).to(dtype), scale)

def st_requantise(st: ScaledTensorData) -> ScaledTensorData:
    """
    Rescale so that max(|data|) == maxFinite(data.dtype)

    Equivalent to quantise(st.to_tensor(torch.float32)) but avoids the conversion

    Returned tensor may share its data with input tensor
    """
    maxdataval = _maxabs(st.data)
    if maxdataval == 0:
        # All zero, reset scale to 1
        return ScaledTensorData(st.data, st.scale)
    else:
         rescale = maxdataval / st.dtype_max
    return ScaledTensorData((_to_uptype(st.data) * (1 / rescale)).to(st.data.dtype), st.scale * rescale)

### Quick testing
f16 = tensor([1, 2, 3], dtype=torch.float16)
f8_t = torch.float8_e4m3fn
print(f16)
print(f16.to(f8_t))

st_data = st_quantise(f16, f8_t)

print(st_data)
print(st_requantise(st_data))
print(st_data.to_tensor(f16.dtype), "# <- rounding errors at the high end of the f8 range")


tensor([1., 2., 3.], dtype=torch.float16)
tensor([1., 2., 3.], dtype=torch.float8_e4m3fn)
ScaledTensor(0.0066964286379516125 * tensor([144., 288., 448.], dtype=torch.float8_e4m3fn))
ScaledTensor(0.0066964286379516125 * tensor([144., 288., 448.], dtype=torch.float8_e4m3fn))
tensor([0.9639, 1.9277, 3.0000], dtype=torch.float16) # <- rounding errors at the high end of the f8 range


In [98]:
# Ops
# - Use worst-case scaling rules (no overflow!)
# - Placeholder impl (fast impl requires custom kernels)
# - No autograd support

def st_add(a: ScaledTensorData, b: ScaledTensorData) -> ScaledTensorData:
    out_dtype = a.data.dtype
    scale = a.scale + b.scale
    data = (a.data * (a.scale / scale) + b.data * (b.scale / scale)).to(out_dtype)
    return ScaledTensorData(data, scale)


def st_matmul(a: ScaledTensorData, b: ScaledTensorData) -> ScaledTensorData:
    # simplified version: assume low-precision a.data and b.data are multiplied to 
    # higher precision (e.g. f16), then downscaled.
    assert a.data.dtype == b.data.dtype
    out_dtype = a.data.dtype

    a_maxval = a.scale * a.dtype_max
    b_maxval = b.scale * b.dtype_max

    # Predicted maxval for NxK @ KxM
    K = a.shape[-1]
    out_maxval = a_maxval * b_maxval * K

    out_scale = out_maxval / _maxval(out_dtype) 

    # Assume low-precision muls will accumulate to uptype
    # to simulate this on cpu, uptype before the matmul;
    # on appropriate hardware (e.g. graphcore, h100), call the special matmul
    out_data = (_to_uptype(a.data) @ _to_uptype(b.data)) * (1 / out_scale)
    out_data = out_data.to(out_dtype)

    return ScaledTensorData(out_data, out_scale)


def st_relu(a: ScaledTensorData) -> ScaledTensorData:
    return ScaledTensorData(nn.functional.relu(a.data), a.scale, a.dtype)

st1 = ScaledTensorData(tensor(32, dtype=torch.int8), 0.5)
st2 = ScaledTensorData(tensor(64, dtype=torch.int8), 0.25)

f32 = lambda x: x.to_tensor(torch.float32) if isinstance(x, ScaledTensorData) else x.to(dtype=torch.float32)

print(f'{st1=}')
print(f'{st2=}')
print(f'{st_add(st1, st2)=}')
print(f'{f32(st_add(st1, st2)) = }')
print(f'{f32(st1) + f32(st2)   = }')

st3 = st_quantise(torch.full((2, 3), 100.0), torch.int8)
st4 = st_quantise(torch.full((3, 4), 200.0), torch.int8)
print(f'{st3=}')
print(f'{st4=}')
print(f'{st_matmul(st3, st4) = }')
print(f'{f32(st_matmul(st3, st4)) = } <-- ~48200 quantized')
print(f'{f32(st3) @ f32(st4)      = } <-- 60000 exact')


st1=ScaledTensor(0.5 * 32)
st2=ScaledTensor(0.25 * 64)
st_add(st1, st2)=ScaledTensor(0.75 * 42)
f32(st_add(st1, st2)) = tensor(31.5000)
f32(st1) + f32(st2)   = tensor(32.)
st3=ScaledTensor(0.787401556968689 * tensor([[127, 127, 127],
        [127, 127, 127]], dtype=torch.int8))
st4=ScaledTensor(1.574803113937378 * tensor([[127, 127, 127, 127],
        [127, 127, 127, 127],
        [127, 127, 127, 127]], dtype=torch.int8))
st_matmul(st3, st4) = ScaledTensor(472.4409484863281 * tensor([[102, 102, 102, 102],
        [102, 102, 102, 102]], dtype=torch.int8))
f32(st_matmul(st3, st4)) = tensor([[48188.9766, 48188.9766, 48188.9766, 48188.9766],
        [48188.9766, 48188.9766, 48188.9766, 48188.9766]]) <-- ~48200 quantized
f32(st3) @ f32(st4)      = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- 60000 exact


## Now plug into Tensor subclass for convenience

In [105]:
# https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-like-type
class ScaledTensor(Tensor):
    def __init__(self, st_data : ScaledTensorData):
        super().__init__()
        self.st = st_data
        
    def __new__(cls, st_data, *args, **kwargs):
        # Ensure that if the tensor is ever viewed as the base class, it is NaN
        return super().__new__(cls, tensor(torch.nan), *args, **kwargs)

    def __repr__(self) -> str:
        return f"{self.st}"

    def to_tensor(self, dtype: torch.dtype) -> Tensor:
        return self.st.to_tensor(dtype)

    @property
    def shape(self) -> torch.Size:
        return self.st.shape

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        if func in cls.HANDLED_FUNCTIONS:
          ret = cls.HANDLED_FUNCTIONS[func](*args, **kwargs)
          if ret != NotImplemented:
              return ret
          # Otherwise drop through to the fallback

        # Convert to float32 and call func
        print(f"ScaledTensor.__torch_function__: WARNING: Upcasting to float32 for {func}@{types}")

        def to_tensor(t) -> Tensor:
            if isinstance(t, ScaledTensor):
                return t.to_tensor(torch.float32)
            else:
                return t

        new_args = tuple(to_tensor(a) for a in args)
        new_kwargs = { k:to_tensor(v) for (k,v) in kwargs.items() }

        # We don't want the auto-downcast of 
        #    ret = super().__torch_function__(func, types, new_args, new_kwargs)
        # because the super()'s handler will construct other types of tensor which
        # don't simply reinterpret to a ScaledTensor.  If the handler's *do* construct
        # ScaledTensors, of course, they should pass through, so it would also be 
        # incorrect to simply upcast the result.

        with torch._C.DisableTorchFunctionSubclass():
            return func(*new_args, **new_kwargs)

ScaledTensor.HANDLED_FUNCTIONS = {}

def quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensor:
    return ScaledTensor(st_quantise(x, dtype))

def requantise(st) -> ScaledTensor:
    return ScaledTensor(st_requantise(st.st))

f16 = tensor([1, 2, 3], dtype=torch.float16)
f8_t = torch.int8
print(f16)
print(f16.to(f8_t))

st = quantise(f16, f8_t)

print(st)
print(requantise(st))
print('Rounding errors at the high end of the f8 range:', st.to_tensor(f16.dtype))
print('Viewing as a tensor should show NaN:', Tensor(st))


tensor([1., 2., 3.], dtype=torch.float16)
tensor([1, 2, 3], dtype=torch.int8)
ScaledTensor(0.023622047156095505 * tensor([ 42,  84, 127], dtype=torch.int8))
ScaledTensor(0.023622047156095505 * tensor([ 42,  84, 127], dtype=torch.int8))
Rounding errors at the high end of the f8 range: tensor([0.9922, 1.9844, 3.0000], dtype=torch.float16)
Viewing as a tensor should show NaN: tensor(nan)


## Now overrides work, but just punt up to f32 for all ops: 

In [106]:
print(st + 2)
print(2 * st)
print(st + st)

st3 = quantise(torch.full((2, 3), 100.0), torch.int8)
st4 = quantise(torch.full((3, 4), 200.0), torch.int8)
print(f'{st3=}')
print(f'{st4=}')
print(f'{st3 @ st4 = }')
print(f'{f32(st3 @ st4)       = } <-- ~48200 quantized')
print(f'{f32(st3) @ f32(st4)  = } <-- 60000 exact')


tensor([2.9921, 3.9843, 5.0000])
tensor([1.9843, 3.9685, 6.0000])
tensor([1.9843, 3.9685, 6.0000])
st3=ScaledTensor(0.787401556968689 * tensor([[127, 127, 127],
        [127, 127, 127]], dtype=torch.int8))
st4=ScaledTensor(1.574803113937378 * tensor([[127, 127, 127, 127],
        [127, 127, 127, 127],
        [127, 127, 127, 127]], dtype=torch.int8))
st3 @ st4 = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]])
f32(st3 @ st4)       = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- ~48200 quantized
f32(st3) @ f32(st4)  = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- 60000 exact


In [122]:
def torch_function_override(cls, func):
    def doit(impl):
        cls.HANDLED_FUNCTIONS[func] = impl
    return doit

@torch_function_override(ScaledTensor, Tensor.add)
def _(a:Tensor, b: Tensor) -> Tensor:
    if not (isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)):
        return NotImplemented

    return ScaledTensor(st_add(a.st, b.st))

@torch_function_override(ScaledTensor, Tensor.matmul)
def _(a:Tensor, b: Tensor) -> Tensor:
    if not (isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)):
        return NotImplemented

    return ScaledTensor(st_matmul(a.st, b.st))

@torch_function_override(ScaledTensor, Tensor.to)
def _(a:Tensor, dtype:torch.dtype) -> Tensor:
    assert isinstance(a, ScaledTensor)
    return a.st.to_tensor(dtype)

print(f'{st + st=}')
print(f'{f32(st + st)=}')
print(f'{f32(st) + f32(st)=}')

print(f'{st3=}')
print(f'{st4=}')
print(f'{st3 @ st4 = }')
print(f'{f32(st3 @ st4)       = } <-- ~48200 quantized')
print(f'{f32(st3) @ f32(st4)  = } <-- 60000 exact')


st + st=ScaledTensor(0.04724409431219101 * tensor([ 42,  84, 127], dtype=torch.int8))
f32(st + st)=tensor([1.9843, 3.9685, 6.0000])
f32(st) + f32(st)=tensor([1.9843, 3.9685, 6.0000])
st3=ScaledTensor(0.787401556968689 * tensor([[127, 127, 127],
        [127, 127, 127]], dtype=torch.int8))
st4=ScaledTensor(1.574803113937378 * tensor([[127, 127, 127, 127],
        [127, 127, 127, 127],
        [127, 127, 127, 127]], dtype=torch.int8))
st3 @ st4 = ScaledTensor(472.4409484863281 * tensor([[102, 102, 102, 102],
        [102, 102, 102, 102]], dtype=torch.int8))
f32(st3 @ st4)       = tensor([[48188.9766, 48188.9766, 48188.9766, 48188.9766],
        [48188.9766, 48188.9766, 48188.9766, 48188.9766]]) <-- ~48200 quantized
f32(st3) @ f32(st4)  = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- 60000 exact


# And in a network...

In [13]:
# Example

class FFN(nn.Module):
    def __init__(self, hidden_size: int, dtype: Union[torch.dtype, ScaledDType]):
        super().__init__()
        self.W0 = torch.randn(hidden_size, 4*hidden_size).to(dtype)
        self.W1 = torch.randn(4*hidden_size, hidden_size).to(dtype)
        self.relu = relu if isinstance(dtype, ScaledDType) else nn.functional.relu

    def forward(self, x: Union[Tensor, ScaledTensor]) -> Union[Tensor, ScaledTensor]:
        y = self.relu(x @ self.W0)

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #1: {wasted_bits(y)}")
            y = requantise(y)

        y = y @ self.W1

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #2: {wasted_bits(y)}")

        return y

hidden_size = 1024
dtype = sint16
module = FFN(hidden_size, dtype)
result = module(torch.randn(10, hidden_size).to(dtype))
print()
print(result)

wasted bits #1: 8.258488655090332
wasted bits #2: 10.414993286132812

stensor([[   0.        -41.124104  164.49641  ...  287.8687     41.124104
   205.62051 ]
 [ -82.24821  -123.37231   -82.24821  ... -164.49641  -164.49641
  -164.49641 ]
 [ 123.37231  -205.62051   411.24103  ...  -82.24821     0.
     0.      ]
 ...
 [ -41.124104  123.37231   205.62051  ... -287.8687   -287.8687
  -411.24103 ]
 [ -82.24821  -164.49641    82.24821  ...   82.24821    41.124104
   -41.124104]
 [  41.124104  205.62051    82.24821  ...   82.24821  -164.49641
  -123.37231 ]], dtype=sint16)


In [None]:
def st_matmul(a: ScaledTensorData, b: ScaledTensorData) -> ScaledTensorData:
    # NxK @ KxM
    K = a.shape[-1]
    
    assert a.data.dtype == b.data.dtype
    out_dtype = a.data.dtype

    a_maxval = a.scale * a.dtype_max
    b_maxval = b.scale * b.dtype_max

    out_maxval = a_maxval * b_maxval * K

    # Scale each multiplicand to sqrt(dtype_max * K)
    a_downscale = 1 / torch.sqrt(a.dtype_max * K)
    b_downscale = 1 / torch.sqrt(b.dtype_max * K)

    print(f'{a_downscale=} {b_downscale=}')

    a_scaled = a.data * a_downscale
    b_scaled = b.data * b_downscale

    # Assume low-precision muls will accumulate to uptype
    # to simulate this on cpu, uptype before the matmul;
    # on appropriate hardware (e.g. graphcore, h100), call the special matmul
    out_data = _uptype(a_scaled) @ _uptype(b_scaled)
    # out_data = (ad @ bd) * (a_downscale * b_downscale)

    # A = ad * as
    # B = bd * bs
    # O = A @ B
    #   = (ad @ bd) * (as * bs)
    #   = (ad @ bd) * (as * bs)

    # O = A @ B = ScaledTensor(A @ B * (out_dtype_max / out_maxval), 1 / (dtype_max / out_maxval))
    #  OD_final = A @ B * (out_dtype_max / out_maxval)
    #  OD_final = ad @ bd * (as * bs * out_dtype_max / out_maxval)
    #  OD_final = ad @ bd * (a_downscale * b_downscale) * (as * bs * out_dtype_max / out_maxval / (a_downscale * b_downscale))
    #  OD_final = out_data * (as * bs * out_dtype_max / out_maxval / (a_downscale * b_downscale))



    scale = a.scale * b.scale * a.shape[-1]
    downscale = (a.shape[-1] / _maxval(out_dtype)) ** -.5
    out_data = (a.data * downscale).to(out_dtype) @ (b.data * downscale).to(out_dtype)
    return ScaledTensorData(out_data, scale)
