In [242]:
from types import MethodDescriptorType
from typing import *
from functools import partial
from dataclasses import dataclass

import torch
from torch import Tensor, tensor, nn
from icecream import ic

def numeric_info(dtype):
  return torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)

def dtype_max(dtype):
  return numeric_info(dtype).max

def _round(x : Tensor, dtype : torch.dtype) -> Tensor:
  if dtype.is_floating_point:
    return x.to(dtype)
  else:
    return torch.round(x).to(dtype)

def function_str(func):
  # https://stackoverflow.com/questions/251464/how-to-get-a-function-name-as-a-string
  # for future expansion e.g. properties
  if hasattr(func, '__module__'):
    return func.__module__ + '.' + func.__qualname__
  else:
    return func.__qualname__


## 1. Core functionality, no syntactic sugar

Define a `ScaleTensorData` object, with the minimal information, and with operations
defined cleanly, but without PyTorch Tensor integration

In [243]:
# Core
GPU = False
def _uptype(dtype) -> torch.dtype:
    map = {
        torch.int8: torch.int16,
        torch.int16: torch.int32,
        torch.float8_e4m3fn: torch.float16,
        torch.float8_e5m2: torch.float16,
        torch.float16: torch.float16 if GPU else torch.float32,
        torch.float32: torch.float32,
    }
    return map[dtype]

def _to_uptype(t : Tensor) -> Tensor:
    return torch.as_tensor(t, dtype=_uptype(t.dtype))


def _maxval(t : Tensor):
    return _to_uptype(t).abs().max().to(torch.float32)


@dataclass
class ScaledTensorData:
    data: Tensor
    scale: Tensor

    def __post_init__(self) -> None:
        if not isinstance(self.scale, Tensor):
            self.scale = tensor(self.scale)
        assert self.scale.dtype == torch.float32
        assert self.scale.shape == () 
        # Possible future expansion to e.g. row-scaled, column-scaled, etc, but
        # for now, insist st_scale is a single-element tensor

    def __repr__(self) -> str:
        return f"ScaledTensor({self.scale} * {self.data})"

    def to_tensor(self, dtype: torch.dtype) -> Tensor:
        return self.data.to(dtype) * self.scale.to(dtype)

    @property
    def shape(self) -> torch.Size:
        return self.data.shape
    
    @property
    def dtype_max(self):
        return dtype_max(self.data.dtype)

def st_quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensorData:
    """
    Rescale so that max(|data|) == maxFinite(data.dtype)

    """
    maxval = _maxval(x)
    if maxval == 0:
        # Tensor is all zeros - set scale to 1
        scale = Tensor(1.0, dtype=torch.float32)
    else:
        # Scale so that largest element is the largest finite value of dtype
        scale = maxval / dtype_max(dtype)

    return ScaledTensorData(_round(x / scale, dtype), scale)


def st_requantise(st: ScaledTensorData) -> ScaledTensorData:
    """
    Rescale so that max(|data|) == maxFinite(data.dtype)

    Equivalent to quantise(st.to_tensor(torch.float32)) but avoids the conversion

    Returned tensor may share its data with input tensor
    """
    maxdataval = _maxval(st.data)
    if maxdataval == 0:
        # All zero, reset scale to 1
        return ScaledTensorData(st.data, st.scale)
    else:
         rescale = maxdataval / st.dtype_max
    return ScaledTensorData((_to_uptype(st.data) * (1 / rescale)).to(st.data.dtype), st.scale * rescale)


def wasted_bits(st_data, maxval = None) -> float:
    """
    By how much is tensor `st_data` not using the full dynamic range of its dtype?

    E.g.
       t = torch.tensor([1,2,-16], dtype=torch.int8)

    Is using only 5 (4 + sign) of the available 8 bits.
    Therefore 
       wasted_bits(t) == 3 == 8-3

    Optional argument maxval, if the maximum value in the tensor has already 
    been computed, perhaps with a higher-accuracy method (e.g. pre-rounding)
    """
    if maxval is None:
        maxval = _maxval(st_data)

    maxval = maxval.to(st_data.dtype)
    dtype_bits = numeric_info(st_data.dtype).bits
    if maxval == 0:
        # All values zero -> all bits are wasted
        return dtype_bits
    
    # Otherwise, how many bits is maxval using.
    if st_data.dtype.is_floating_point:
      # Convert maxval to integer of the same bitwidth
      ints = {
          8: torch.int8,
          16: torch.int16,
          32: torch.int32
      }
      maxval = maxval.view(ints[dtype_bits])

    # Assuming a signed type, max usable bits are dtype_bits-1
    return dtype_bits-1 - torch.log2(maxval)

def test_wasted_bits():
    t = torch.tensor([1,2,-16], dtype=torch.int8)
    assert wasted_bits(t) == 3
test_wasted_bits()

### Quick testing
f16 = tensor([1, 2, 3], dtype=torch.float16)
f8_t = torch.float8_e4m3fn
print(f16)
print(f16.to(f8_t))

st_data = st_quantise(f16, f8_t)

print(st_data)
print(f'{wasted_bits(st_data.data)=} <-- of a max of {numeric_info(f8_t).bits} bits, should be wasting nearly zero')
print(st_requantise(st_data))
print(st_data.to_tensor(f16.dtype), "# <- rounding errors at the high end of the f8 range")


tensor([1., 2., 3.], dtype=torch.float16)
tensor([1., 2., 3.], dtype=torch.float8_e4m3fn)
ScaledTensor(0.0066964286379516125 * tensor([144., 288., 448.], dtype=torch.float8_e4m3fn))
wasted_bits(st_data.data)=tensor(0.023) <-- of a max of 8 bits, should be wasting nearly zero
ScaledTensor(0.0066964286379516125 * tensor([144., 288., 448.], dtype=torch.float8_e4m3fn))
tensor([0.964, 1.928, 3.000], dtype=torch.float16) # <- rounding errors at the high end of the f8 range


## 2. Operators, without syntactic sugar

In [244]:
# Ops
# - Use worst-case scaling rules (no overflow!)
# - Placeholder impl (fast impl requires custom kernels)
# - No autograd support

def st_add(a: ScaledTensorData, b: ScaledTensorData) -> ScaledTensorData:
    out_dtype = a.data.dtype
    scale = a.scale + b.scale
    data = (_to_uptype(a.data) * (a.scale / scale) + _to_uptype(b.data) * (b.scale / scale)).to(out_dtype)
    return ScaledTensorData(data, scale)


def st_matmul(a: ScaledTensorData, b: ScaledTensorData, debug = True) -> ScaledTensorData:
    assert a.data.dtype == b.data.dtype
    in_dtype = a.data.dtype
    out_dtype = a.data.dtype

    a_maxval = a.scale * a.dtype_max
    b_maxval = b.scale * b.dtype_max

    # Predicted maxval for NxK @ KxM
    K = a.shape[-1]
    out_maxval_estimate = a_maxval * b_maxval * K

    out_scale = out_maxval_estimate / dtype_max(out_dtype) 

    # Derivation of matmul scale factors:
    # (ad * as) @ (bd * bs) = (ad @ bd) * (as * bs)
    #                       = (ad @ bd) * (as * bs / os * os)
    #                       = (ad @ bd * as * bs / os) * os
    #                       = (ad @ bd * rat) * os
    #                         where rat = as * bs / os
    #                       = (ad * sqrt(rat)) @ (bd * sqrt(rat)) * os

    rat = a.scale * b.scale / out_scale

    if numeric_info(in_dtype).bits < numeric_info(_uptype(in_dtype)).bits:
        # Assume low-precision muls will accumulate to uptype, so won't overflow
        # to simulate this on cpu, uptype before the matmul;
        # on appropriate hardware (e.g. graphcore, h100), call the special matmul
        adbd = _to_uptype(a.data) @ _to_uptype(b.data)
        if debug:
            out_maxval = _maxval(adbd) * (a.scale * b.scale)
        out_data = adbd * rat
        out_data = out_data.to(out_dtype)
    else:
        # Inputs are in 16+ bits, and we know the products will certainly 
        # overflow, as they are scaled to dtype_max, so downscale before multiplying 
        sqrt_rat = torch.sqrt(rat)
        a_down = _to_uptype(a.data) * sqrt_rat
        b_down = _to_uptype(b.data) * sqrt_rat
        out_data = a_down @ b_down
        if debug:
            out_maxval = _maxval(out_data) * out_scale
        out_data = out_data.to(out_dtype)

    # debug check how bad out_maxval_estimate was
    if debug:
        assert out_maxval_estimate > out_maxval # Should always be an upper bound
        wasted = wasted_bits(out_data)
        if wasted > numeric_info(out_dtype).bits/2:
            print(f'st_matmul: WARNING: Very bad maxval estimate {out_maxval_estimate} vs {out_maxval}, {out_maxval_estimate/out_maxval}x too large - will lose at least {wasted} bits of precision')

        if _maxval(out_data) == 0:
            raise ValueError("All-data zero - rerun with debug and view st_matmul: WARNING above")

    return ScaledTensorData(out_data, out_scale)


def st_relu(a: ScaledTensorData) -> ScaledTensorData:
    data = nn.functional.relu(a.data).to(a.data.dtype)
    return ScaledTensorData(data, a.scale)

# Check operators behae sensibly
st1 = ScaledTensorData(tensor(32, dtype=torch.int8), 0.5)
st2 = ScaledTensorData(tensor(64, dtype=torch.int8), 0.25)

f32 = lambda x: x.to_tensor(torch.float32) if isinstance(x, ScaledTensorData) else x.to(dtype=torch.float32)

torch.set_printoptions(precision=3, threshold=32)

ic(st1)
ic(st2)
ic(st_add(st1, st2))
ic(f32(st_add(st1, st2)) )
ic(f32(st1) + f32(st2))

hidden = 32
t3 = torch.randn(2, hidden)
t4 = torch.randn(hidden, 3)
st3 = st_quantise(t3, torch.int8)
st4 = st_quantise(t4, torch.int8)
ic(st3)
ic(st4)
ic(st_matmul(st3, st4))
print(f'{f32(st_matmul(st3, st4)) = } <-- quantized')
print(f'{f32(st3) @ f32(st4) = } <-- intermediate')
print(f'{t3 @ t4 = } <-- exact')

# st5 = st_quantise(tensor([-2, 3, -0.06, 4]), torch.int8)
# ic(st5, f32(st5))
# rt5 = st_relu(st5)
# ic(rt5, f32(rt5))


ic| st1: ScaledTensor(0.5 * 32)
ic| st2: ScaledTensor(0.25 * 64)
ic| st_add(st1, st2): ScaledTensor(0.75 * 42)
ic| f32(st_add(st1, st2)): tensor(31.500)
ic| f32(st1) + f32(st2): tensor(32.)
ic| st3: ScaledTensor(0.019876088947057724 * tensor([[-21,  78,  17,  ...,   8, -53,  41],
                 [ 40, -30, -38,  ..., -75,  52,   0]], dtype=torch.int8))
ic| st4: ScaledTensor(0.019780419766902924 * tensor([[-28,   3,  -9],
                 [-72,   8, -42],
                 [ 13, -67,  25],
                 ...,
                 [ 60, -50, -25],
                 [-51,   8, -54],
                 [-64, -37,  10]], dtype=torch.int8))
ic| st_matmul(st3, st4): ScaledTensor(1.5977916717529297 * tensor([[-4, -6,  0],
                                 [ 3,  2, -6]], dtype=torch.int8))


f32(st_matmul(st3, st4)) = tensor([[-6.391, -9.587,  0.000],
        [ 4.793,  3.196, -9.587]]) <-- quantized
f32(st3) @ f32(st4) = tensor([[ -6.410, -10.810,  -0.515],
        [  6.196,   4.459, -10.164]]) <-- intermediate
t3 @ t4 = tensor([[ -6.381, -10.807,  -0.489],
        [  6.171,   4.462, -10.129]]) <-- exact


In [245]:
### [Aside: Possibly surprising rounding]

# print('Starting point: ', tensor([-2, 0.09, 4]))

# st5 = st_quantise(tensor([-2, 0.09, 4]), torch.int8)
# print(f'{st5=} {f32(st5)=} <-- 0.0900 input rounds to 0.0630')

# print('So, put in 0.0630 to begin with')
# st5 = st_quantise(tensor([-2, 0.0630, 4]), torch.int8) 
# print(f'{st5=} {f32(st5)=} <-- great, 0.0630 rounds to 0.0630')

# print('But now, put in 0.06')
# st5 = st_quantise(tensor([-2, 0.06, 4]), torch.int8) 
# print(f'{st5=} {f32(st5)=} <-- Eh, 0.06 rounds down to 0.0315?')


## 3. Override Tensor operations in subclass

We define a `ScaledTensor` object that behaves like a torch Tensor.

In [246]:
# See https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-like-type
class ScaledTensor(Tensor):
    def __init__(self, st_data : ScaledTensorData):
        super().__init__()
        self.st = st_data
        
    def __new__(cls, st_data, *args, **kwargs):
        # Ensure that if the tensor is ever viewed as the base class, it is NaN
        return super().__new__(cls, tensor(torch.nan), *args, **kwargs)

    def __repr__(self) -> str:
        return f"{self.st}"

    @property
    def shape(self) -> torch.Size:
        return self.st.shape

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}
        # Have we registered a handler?
        if func in cls.HANDLED_FUNCTIONS:
          ret = cls.HANDLED_FUNCTIONS[func](*args, **kwargs)
          if ret != NotImplemented:
              return ret
          # Otherwise drop through to the fallback

        # Fallback: Convert to float32 and call func
        print(f"ScaledTensor.__torch_function__: WARNING: Upcasting to float32 for {function_str(func)}@{types}")

        def to_tensor_if_scaled(t) -> Tensor:
            if isinstance(t, ScaledTensor):
                return t.st.to_tensor(torch.float32)
            else:
                return t

        new_args = tuple(to_tensor_if_scaled(a) for a in args)
        new_kwargs = { k:to_tensor_if_scaled(v) for (k,v) in kwargs.items() }

        # We don't want the auto-downcast of 
        #    ret = super().__torch_function__(func, types, new_args, new_kwargs)
        # because the super()'s handler will construct other types of tensor which
        # don't simply reinterpret to a ScaledTensor.

        with torch._C.DisableTorchFunctionSubclass():
            return func(*new_args, **new_kwargs)

ScaledTensor.HANDLED_FUNCTIONS = {}

# Forward the ScaledTensorData ops above on the ScaledTensor type
def quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensor:
    return ScaledTensor(st_quantise(x, dtype))

def requantise(st) -> ScaledTensor:
    return ScaledTensor(st_requantise(st.st))

# Check basic to/from Subclass
f16 = tensor([1, 2, 3], dtype=torch.float16)
f8_t = torch.int8
print(f16)
print(f16.to(f8_t))

st = quantise(f16, f8_t)

print(st)
print(requantise(st))
print('Rounding errors at the high end of the f8 range:', st.st.to_tensor(f16.dtype))
print('Viewing as a tensor should show NaN:', Tensor(st))


tensor([1., 2., 3.], dtype=torch.float16)
tensor([1, 2, 3], dtype=torch.int8)
ScaledTensor(0.023622047156095505 * tensor([ 42,  85, 127], dtype=torch.int8))
ScaledTensor(0.023622047156095505 * tensor([ 42,  85, 127], dtype=torch.int8))
Rounding errors at the high end of the f8 range: tensor([0.992, 2.008, 3.000], dtype=torch.float16)
Viewing as a tensor should show NaN: tensor(nan)


### Now overrides work, but just punt up to f32 for all ops

We will do some adds/multiplies etc, and note that the torch function
implementation issues a sequence of "WARNING: Upcasting to float32"

In [247]:
print(st + 2)
print(2 * st)
print(st + st)

st3 = quantise(torch.full((2, 3), 100.0), torch.int8)
st4 = quantise(torch.full((3, 4), 200.0), torch.int8)

print(f'{f32(st3 @ st4)       = } <-- should be ~48200 quantized, but was done in f32 so exact')
print(f'{f32(st3) @ f32(st4)  = } <-- 60000 exact')


tensor([2.992, 4.008, 5.000])
tensor([1.984, 4.016, 6.000])
tensor([1.984, 4.016, 6.000])
f32(st3 @ st4)       = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- should be ~48200 quantized, but was done in f32 so exact
f32(st3) @ f32(st4)  = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- 60000 exact


In [248]:
def torch_function_override(cls, funcs):
    funcs = funcs if isinstance(funcs, tuple) else (funcs,)

    def doit(impl):
        # This work is also done by functools.wraps, but it doesn't make sensible names,
        # and doesn't fix qualname
        if impl.__name__ == '_':
            impl.__name__ = f'@torch_function_override({cls.__name__}, {funcs})'
            if impl.__qualname__ != '_':
                print(f'torch_function_override: NOTE: {impl.__qualname__} not overridden')
        if impl.__qualname__ == '_':
            impl.__qualname__ = f'torch_function_override({cls.__name__}, {funcs})'

        # Record it in the dictionary
        for func in funcs:
          cls.HANDLED_FUNCTIONS[func] = impl

    return doit

@torch_function_override(ScaledTensor, Tensor.add)
def _(a:Tensor, b: Tensor) -> Tensor:
    if not (isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)):
        print(f'ScaledTensor: Punting on {type(a), type(b)=}')
        return NotImplemented

    return ScaledTensor(st_add(a.st, b.st))

@torch_function_override(ScaledTensor, Tensor.matmul)
def _(a:Tensor, b: Tensor) -> Tensor:
    if not (isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)):
        return NotImplemented

    return ScaledTensor(st_matmul(a.st, b.st))

@torch_function_override(ScaledTensor, Tensor.to)
def _(a:Tensor, dtype:torch.dtype) -> Tensor:
    assert isinstance(a, ScaledTensor)
    return a.st.to_tensor(dtype)

@torch_function_override(ScaledTensor, (Tensor.relu, nn.functional.relu))
def _(a:Tensor, inplace = False) -> Tensor:
    assert isinstance(a, ScaledTensor)
    assert not inplace
    return ScaledTensor(st_relu(a.st))

ic(st3)
ic(st4)
ic(st3 @ st4 )
print(f'{f32(st3 @ st4)       = } <-- ~48200 quantized')
print(f'{f32(st3) @ f32(st4)  = } <-- 60000 exact')

print(nn.functional.relu(st))


ic| st3: ScaledTensor(0.787401556968689 * tensor([[127, 127, 127],
                 [127, 127, 127]], dtype=torch.int8))
ic| st4: ScaledTensor(1.574803113937378 * tensor([[127, 127, 127, 127],
                 [127, 127, 127, 127],
                 [127, 127, 127, 127]], dtype=torch.int8))
ic| st3 @ st4: ScaledTensor(472.4409484863281 * tensor([[-45, -45, -45, -45],
                       [-45, -45, -45, -45]], dtype=torch.int8))


f32(st3 @ st4)       = tensor([[-21259.842, -21259.842, -21259.842, -21259.842],
        [-21259.842, -21259.842, -21259.842, -21259.842]]) <-- ~48200 quantized
f32(st3) @ f32(st4)  = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- 60000 exact
ScaledTensor(0.023622047156095505 * tensor([ 42,  85, 127], dtype=torch.int8))


# And in a network...

In [251]:
# Example
from icecream import ic

class FFN(nn.Module):
    def __init__(self, hidden_size: int, dtype: torch.dtype):
        super().__init__()
        q = lambda x: quantise(x, dtype)
        self.W0 = q(torch.randn(hidden_size, 4*hidden_size))
        self.W1 = q(torch.randn(4*hidden_size, hidden_size))
        ic(self.W1)

    def forward(self, x: Union[Tensor, ScaledTensor]) -> Union[Tensor, ScaledTensor]:
        y = nn.functional.relu(x @ self.W0)

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #1: {wasted_bits(y)}")
            y = requantise(y)

        y = y @ self.W1

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #2: {wasted_bits(y)}")

        return y

hidden_size = 1024
storage_type = torch.int8
module = FFN(hidden_size, storage_type)
x = quantise(torch.randn(10, hidden_size), storage_type)
result = module(x)
print()
print(result)

ic| self.W1: ScaledTensor(0.042678870260715485 * tensor([[-19, -18,  -2,  ...,  20,  12,  18],
                     [  7, -54, -29,  ..., -25, -14,  51],
                     [  2,   6,  -8,  ...,   7, -13,  -5],
                     ...,
                     [-11,  46,  -6,  ..., -51,  22,  10],
                     [ -9,   6, -29,  ..., -24,  49,   5],
                     [ -7, -36, -25,  ...,   3,  29, -18]], dtype=torch.int8))




ValueError: All-data zero - rerun with debug and view st_matmul: WARNING above