In [23]:
from types import MethodWrapperType, GetSetDescriptorType
from typing import *
from warnings import warn
from functools import partial
from dataclasses import dataclass
from icecream import ic

import torch
from torch import Tensor, tensor, nn
from torch.utils._pytree import tree_map

import numpy as np

# Make warn=print for notebooks, as otherwise outputs are not interleaved correctly
from termcolor import colored

aten = torch.ops.aten


def warn(msg):
    print(colored("WARNING", "yellow"), f": {msg}")


warn("Nothing to worry about")


def numeric_info(dtype):
    return torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)


def dtype_max(dtype):
    return numeric_info(dtype).max


def _round(x: Tensor, dtype: torch.dtype) -> Tensor:
    """
    Convert to dtype, rounding if the destination is integer
    """
    if dtype.is_floating_point:
        return x.to(dtype)
    else:
        return torch.round(x).to(dtype)


def function_str(func: Callable) -> str:
    return torch.overrides.resolve_name(func)

    # https://stackoverflow.com/questions/251464/how-to-get-a-function-name-as-a-string
    # for future expansion e.g. properties
    if hasattr(func, "__module__"):
        return func.__module__ + "." + func.__qualname__
    else:
        return func.__qualname__


def type_str(x: type) -> str:
    return x.__name__


def _uptype(dtype) -> torch.dtype:
    """
    For DTYPE, what is the type in which most arithmetic (e.g. max, abs) is defined?
    """
    GPU = False  # Check more accurately, and choose bf16 as appropriate
    f16_t = torch.float16 if GPU else torch.float32
    map = {
        torch.int8: torch.int16,
        torch.int16: torch.int16,
        torch.float8_e4m3fn: f16_t,
        torch.float8_e5m2: f16_t,
        torch.float16: f16_t,
        torch.float32: torch.float32,
    }
    return map[dtype]


def _to_uptype(t: Tensor) -> Tensor:
    """
    Convert t to its _uptype
    """
    return torch.as_tensor(t, dtype=_uptype(t.dtype))


def _maxval(t: Tensor):
    """
    Max absolute value of tensor, returned in its `_uptype`
    """
    return _to_uptype(t).abs().max()


torch.set_printoptions(precision=3, threshold=32)

# From https://github.com/albanD/subclass_zoo/blob/main/utils.py
import contextlib


@contextlib.contextmanager
def _no_dispatch():
    guard = torch._C._DisableTorchDispatch()
    try:
        yield
    finally:
        del guard



In [24]:
def tensor_oneline_str(t):
    shape_str = "x".join(map(str, t.shape))
    quantiles = torch.tensor([0, 0.05, 0.25, 0.5, 0.75, 1.0])
    vals = torch.quantile(torch.flatten(t).to(torch.float32), quantiles, interpolation="nearest")
    if t.dtype.is_floating_point:
        # scale down vals
        finite_vals = vals[torch.isfinite(vals)]
        max = finite_vals.abs().max()
        if max > 0:
            logmax = torch.floor(torch.log10(max))
            if -2 <= logmax <= 3:
                logmax = 0
            max_scale = 10**-logmax
            max_scale_str = f"10^{int(logmax)} x " if logmax != 0 else ""
        else:
            max_scale = 1
            max_scale_str = ""
        vals_str = max_scale_str + "Quants{" + "|".join(f"{v.item():.3f}" for v in vals * max_scale) + "}"
    else:
        # Assume integer, print as integers
        vals_str = "Quants{" + "|".join(f"{int(v)}" for v in vals) + "}"

    classname = type(t).__name__

    dtype_str = f"{t.dtype}".replace("torch.float", "f").replace("torch.int", "i").replace("torch.uint", "u")

    return f"{classname}({shape_str},{dtype_str}) {vals_str}"


for scale in [-10, -3, -2, -1, 0, 1, 2, 3, 10]:
    kurt = 3
    print(tensor_oneline_str(torch.randn(100, 300) ** kurt * (10**scale)))

print(tensor_oneline_str((torch.randn(100, 300) * 10000).to(torch.uint8)))

Tensor(100x300,f32) 10^-9 x Quants{-5.080|-0.442|-0.031|-0.000|0.031|7.563}
Tensor(100x300,f32) Quants{-0.093|-0.004|-0.000|-0.000|0.000|0.058}
Tensor(100x300,f32) Quants{-0.477|-0.043|-0.003|-0.000|0.003|0.537}
Tensor(100x300,f32) Quants{-7.255|-0.447|-0.030|-0.000|0.031|7.686}
Tensor(100x300,f32) Quants{-54.827|-4.489|-0.285|0.000|0.308|62.219}
Tensor(100x300,f32) Quants{-1006.063|-44.348|-2.883|0.000|3.168|564.190}
Tensor(100x300,f32) Quants{-8416.035|-432.668|-27.389|0.000|31.403|7809.401}
Tensor(100x300,f32) 10^4 x Quants{-5.812|-0.428|-0.030|-0.000|0.031|9.379}
Tensor(100x300,f32) 10^11 x Quants{-7.704|-0.444|-0.029|0.000|0.031|6.558}
Tensor(100x300,u8) Quants{0|12|64|128|192|255}


## 1. Core functionality, no syntactic sugar

Define a `ScaleTensorData` object, with the minimal information, and with operations
defined cleanly, but without PyTorch Tensor integration

In [25]:
@dataclass
class ScaledTensorData:
    data: Tensor
    scale: Tensor

    def __post_init__(self) -> None:
        if not isinstance(self.scale, Tensor):
            self.scale = tensor(self.scale)
        assert self.scale.dtype == torch.float32
        assert self.scale.shape == ()
        # Possible future expansion to e.g. row-scaled, column-scaled, etc, but
        # for now, insist st_scale is a single-element tensor

    def _contents_str(self) -> str:
        return f"{self.scale} * {tensor_oneline_str(self.data)}"

    def __repr__(self) -> str:
        return f"ScaledTensorData({self._contents_str()})"

    def to_tensor(self, dtype: torch.dtype = None) -> Tensor:
        dtype = dtype or self.scale.dtype
        return self.data.to(dtype) * self.scale.to(dtype)

    @property
    def shape(self) -> torch.Size:
        return self.data.shape

    @property
    def dtype_max(self):
        return dtype_max(self.data.dtype)


def st_quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensorData:
    """
    Rescale so other max(|data|) == maxFinite(data.dtype)
    """
    maxval = _maxval(x)
    if maxval == 0:
        # Tensor is all zeros - set scale to 1
        scale = torch.tensor(1.0, dtype=torch.float32)
    else:
        # Scale so other largest element is the largest finite value of dtype
        scale = maxval / dtype_max(dtype)

    return ScaledTensorData(_round(x / scale, dtype), scale)


def st_requantise(st: ScaledTensorData) -> ScaledTensorData:
    """
    Rescale so other max(|data|) == maxFinite(data.dtype)

    Equivalent to quantise(st.to_tensor(torch.float32)) but avoids the conversion

    Returned tensor may share its data with input tensor
    """
    maxdataval = _maxval(st.data)
    if maxdataval == 0:
        # All zero, reset scale to 1
        return ScaledTensorData(st.data, st.scale)
    else:
        rescale = maxdataval / st.dtype_max
    return ScaledTensorData((_to_uptype(st.data) * (1 / rescale)).to(st.data.dtype), st.scale * rescale)


def wasted_bits(st_data, maxval=None) -> float:
    """
    By how much is tensor `st_data` not using the full dynamic range of its dtype?

    E.g.
       t = torch.tensor([1,2,-16], dtype=torch.int8)

    Is using only 5 (4 + sign) of the available 8 bits.
    Therefore
       wasted_bits(t) == 3 == 8-3

    Optional argument maxval, if the maximum value in the tensor has already
    been computed, perhaps with a higher-accuracy method (e.g. pre-rounding)
    """
    if maxval is None:
        maxval = _maxval(st_data)

    maxval = maxval.to(st_data.dtype)
    dtype_bits = numeric_info(st_data.dtype).bits
    if maxval == 0:
        # All values zero -> all bits are wasted
        return dtype_bits

    # Otherwise, how many bits is maxval using.
    if st_data.dtype.is_floating_point:
        # Convert maxval to integer of the same bitwidth
        ints = {8: torch.int8, 16: torch.int16, 32: torch.int32}
        maxval = maxval.view(ints[dtype_bits])

    # Assuming a signed type, max usable bits are dtype_bits-1
    return dtype_bits - 1 - torch.log2(maxval)


def test_wasted_bits():
    t = torch.tensor([1, 2, -16], dtype=torch.int8)
    assert wasted_bits(t) == 3


test_wasted_bits()

### Quick testing
f16 = tensor([1, 2, 3], dtype=torch.float16)
f8_t = torch.float8_e4m3fn
print(f16)
print(f16.to(f8_t))

st_data = st_quantise(f16, f8_t)

# TODO: wasting a lot more bits at the low end here -- range of 144->448.
print(st_data)
print(f"{wasted_bits(st_data.data)=} <-- of a max of {numeric_info(f8_t).bits} bits, should be wasting nearly zero")
print(st_requantise(st_data))
print(st_data.to_tensor(f16.dtype), "# <- rounding errors at the high end of the f8 range")

tensor([1., 2., 3.], dtype=torch.float16)
tensor([1., 2., 3.], dtype=torch.float8_e4m3fn)
ScaledTensorData(0.0066964286379516125 * Tensor(3,f8_e4m3fn) Quants{144.000|144.000|144.000|288.000|448.000|448.000})
wasted_bits(st_data.data)=tensor(0.023) <-- of a max of 8 bits, should be wasting nearly zero
ScaledTensorData(0.0066964286379516125 * Tensor(3,f8_e4m3fn) Quants{144.000|144.000|144.000|288.000|448.000|448.000})
tensor([0.964, 1.928, 3.000], dtype=torch.float16) # <- rounding errors at the high end of the f8 range


## 2. Operations, without syntactic sugar

In [26]:
# Ops
# - Use worst-case scaling rules (no overflow!)
# - Placeholder impl (fast impl requires custom kernels)
# - No autograd support


def st_add(a: ScaledTensorData, b: ScaledTensorData, alpha=1.0) -> ScaledTensorData:
    b_scale = alpha * b.scale
    out_dtype = a.data.dtype
    scale = a.scale + b_scale
    data = (_to_uptype(a.data) * (a.scale / scale) + _to_uptype(b.data) * (b_scale / scale)).to(out_dtype)
    return ScaledTensorData(data, scale)


def st_matmul(a: ScaledTensorData, b: ScaledTensorData, debug=True) -> ScaledTensorData:
    a_uptype = _uptype(a.data.dtype)
    b_uptype = _uptype(b.data.dtype)
    assert a_uptype == b_uptype
    out_dtype = a_uptype

    # Do this in f32 until we are debugged
    return st_quantise(a.to_tensor() @ b.to_tensor(), out_dtype)

    # a_maxval = a.scale * a.dtype_max
    # b_maxval = b.scale * b.dtype_max

    # # Predicted maxval for NxK @ KxM
    # K = a.shape[-1]
    # out_maxval_estimate = a_maxval * b_maxval * np.sqrt(K)*2

    # out_scale = out_maxval_estimate / dtype_max(out_dtype)

    # # Derivation of matmul scale factors:
    # # (ad * as) @ (bd * bs) = (ad @ bd) * (as * bs)
    # #                       = (ad @ bd) * (as * bs / os * os)
    # #                       = (ad @ bd * as * bs / os) * os
    # #                       = (ad @ bd * rat) * os
    # #                         where rat = as * bs / os
    # #                       = (ad * sqrt(rat)) @ (bd * sqrt(rat)) * os

    # rat = a.scale * b.scale / out_scale

    # a_bits = numeric_info(a.data.dtype).bits
    # b_bits = numeric_info(b.data.dtype).bits

    # if max(a_bits, b_bits) < numeric_info(out_dtype).bits:
    #     # Assume low-precision muls will accumulate to uptype, so won't overflow
    #     # to simulate this on cpu, uptype before the matmul;
    #     # on appropriate hardware (e.g. graphcore, h100), call the special matmul
    #     adbd = _to_uptype(a.data) @ _to_uptype(b.data)
    #     if debug:
    #         out_maxval = _maxval(adbd) * (a.scale * b.scale)
    #     out_data = adbd * rat
    #     out_data = out_data.to(out_dtype)
    # else:
    #     # Inputs are in 16+ bits, and we know the products will certainly
    #     # overflow, as they are scaled to dtype_max, so downscale before multiplying
    #     sqrt_rat = torch.sqrt(rat)
    #     a_down = _to_uptype(a.data) * sqrt_rat
    #     b_down = _to_uptype(b.data) * sqrt_rat
    #     out_data = a_down @ b_down
    #     if debug:
    #         out_maxval = _maxval(out_data) * out_scale
    #     out_data = out_data.to(out_dtype)

    # # debug check how bad out_maxval_estimate was
    # if debug:
    #     assert out_maxval_estimate > out_maxval  # Should always be an upper bound
    #     wasted = wasted_bits(out_data)
    #     if wasted > numeric_info(out_dtype).bits / 2:
    #         warn(
    #             f"st_matmul: Very bad maxval estimate {out_maxval_estimate} vs {out_maxval}, {out_maxval_estimate/out_maxval:.1f}x too large - will lose at least {wasted} bits of precision"
    #         )

    #     if _maxval(out_data) == 0:
    #         raise ValueError("All-data zero - rerun with debug and view st_matmul: WARNING above")

    return ScaledTensorData(out_data, out_scale)


def st_relu(a: ScaledTensorData) -> ScaledTensorData:
    data = nn.functional.relu(a.data).to(a.data.dtype)
    return ScaledTensorData(data, a.scale)


# Check operators behave sensibly
st1 = ScaledTensorData(tensor(32, dtype=torch.int8), 0.5)
st2 = ScaledTensorData(tensor(64, dtype=torch.int8), 0.25)

f32 = lambda x: x.to_tensor(torch.float32) if isinstance(x, ScaledTensorData) else x.to(dtype=torch.float32)


ic(st1)
ic(st2)
ic(st_add(st1, st2))
ic(f32(st_add(st1, st2)))
ic(f32(st1) + f32(st2))

hidden = 32
t3 = torch.randn(2, hidden)
t4 = torch.randn(hidden, 3)
st3 = st_quantise(t3, torch.int8)
st4 = st_quantise(t4, torch.int8)
ic(st3)
ic(st4)
ic(st_matmul(st3, st4))
print(f"{f32(st_matmul(st3, st4)) = } <-- quantized")
print(f"{f32(st3) @ f32(st4) = } <-- intermediate")
print(f"{t3 @ t4 = } <-- exact")

# st5 = st_quantise(tensor([-2, 3, -0.06, 4]), torch.int8)
# ic(st5, f32(st5))
# rt5 = st_relu(st5)
# ic(rt5, f32(rt5))

ic| st1: ScaledTensorData(0.5 * Tensor(,i8) Quants{32|32|32|32|32|32})
ic| st2: ScaledTensorData(0.25 * Tensor(,i8) Quants{64|64|64|64|64|64})


ic| st_add(st1, st2): ScaledTensorData(0.75 * Tensor(,i8) Quants{42|42|42|42|42|42})
ic| f32(st_add(st1, st2)): tensor(31.500)
ic| f32(st1) + f32(st2): tensor(32.)
ic| st3: ScaledTensorData(0.01775376871228218 * Tensor(2x32,i8) Quants{-127|-86|-52|-9|25|108})
ic| st4: ScaledTensorData(0.021980037912726402 * Tensor(32x3,i8) Quants{-127|-80|-26|-1|23|96})
ic| st_matmul(st3, st4): ScaledTensorData(0.00042443175334483385 * Tensor(2x3,i16) Quants{-9714|-9714|-7678|-5303|26302|32767})


f32(st_matmul(st3, st4)) = tensor([[11.163, -4.123, -3.259],
        [-2.251,  4.451, 13.907]]) <-- quantized
f32(st3) @ f32(st4) = tensor([[11.163, -4.123, -3.259],
        [-2.251,  4.451, 13.907]]) <-- intermediate
t3 @ t4 = tensor([[11.197, -4.210, -3.202],
        [-2.218,  4.478, 13.879]]) <-- exact


In [27]:
### [Aside: Possibly surprising rounding]

# print('Starting point: ', tensor([-2, 0.09, 4]))

# st5 = st_quantise(tensor([-2, 0.09, 4]), torch.int8)
# print(f'{st5=} {f32(st5)=} <-- 0.0900 input rounds to 0.0630')

# print('So, put in 0.0630 to begin with')
# st5 = st_quantise(tensor([-2, 0.0630, 4]), torch.int8)
# print(f'{st5=} {f32(st5)=} <-- great, 0.0630 rounds to 0.0630')

# print('But now, put in 0.06')
# st5 = st_quantise(tensor([-2, 0.06, 4]), torch.int8)
# print(f'{st5=} {f32(st5)=} <-- Eh, 0.06 rounds down to 0.0315?')

## 3. Sugar hit: override Tensor operations in subclass

We define a `ScaledTensor` object other behaves like a torch Tensor.

In [28]:
# See https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-like-type
# TODO: check
#   - _make_wrapper_subclass THPVariable_make_wrapper_subclass https://github.com/pytorch/pytorch/pull/65340
#   - _make_subclass
#   - See [Note] at https://github.com/albanD/subclass_zoo/blob/276d2f005484d80ebbcd9e274d79685adb6a1da2/negative_tensor.py#L24
#     - Doesn't apply in this case as we are composing, not deriving?
# Looking at https://github.com/albanD/subclass_zoo
#   - trivial_tensor doesn't do autograd
#   - inner_autograd_tensor explicitly defers to its `elem`, which is incorrect
# In PyTorch core
#   - MaskedTensor
# See FP8Tensor in subclass_zoo: https://github.com/albanD/subclass_zoo/pull/44/files


class ScaledTensor(Tensor):
    @staticmethod
    def __new__(cls, st: ScaledTensorData, *, requires_grad=False):
        assert not st.data.requires_grad
        ret = torch.Tensor._make_wrapper_subclass(
            cls,
            size=st.data.size(),
            strides=st.data.stride(),
            storage_offset=st.data.storage_offset(),
            dtype=st.scale.dtype,
            layout=st.data.layout,
            requires_grad=requires_grad,
            device=st.data.device,
        )
        ret.st = st
        return ret

    def __repr__(self) -> str:
        # See https://github.com/pytorch/pytorch/issues/73665
        with _no_dispatch():
            return super().__repr__(tensor_contents=str(self.st._contents_str()))

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        kwargs = kwargs or {}

        func_str = f'{function_str(func)}@({",".join(type_str(type(x)) for x in args)})'
        # print(':', func_str)

        # 1. Is this one we handle?
        extra_msg = ""
        if func in cls.HANDLED_FUNCTIONS:
            # Call the handler
            handler = cls.HANDLED_FUNCTIONS[func]
            try:  # TODO: just for debug
                ret = handler(func, *args, **kwargs)
            except Exception as e:
                warn(f"dispatch handler exception, e={repr(e)}")
                raise

            if ret != NotImplemented:
                return ret

            # handler may have returned "NotImplemented" to tell us to run the fallback
            extra_msg = f" -- [Handler {function_str(handler)} returned NotImplemented]"

        # Not handled, convert all ScaledTensors to Tensors, and run
        func_str = f'{function_str(func)}@({",".join(type_str(type(x)) for x in args)})'
        warn(f"ScaledTensor.__torch_dispatch__: Upcasting to float32 for {func_str}" + extra_msg)

        return upcast_args_and_redispatch(func, *args, **kwargs)

    __torch_function__ = torch._C._disabled_torch_function_impl


ScaledTensor.HANDLED_FUNCTIONS = {}


def tensor_subclass_override(cls, funcs):
    """
    Decorator to add an implementation of an operation to a Tensor subclass

    @tensor_subclass_override(MySubclass, aten.view.default)
    def _(func, *args, *kwargs):
      print(f'Calling wrapped {func} with {len(args)} args)
      with _no_dispatch():
        return func(*args, *kwargs)

    Calling the implementation "_" allows this decorator to overwrite the name with
    a more sensible one "@torch_function_override(MySubclass, aten.view.default)"
    but of course you can just call it "MySubclass_impl_view" or "foo42" if you prefer.
    This is mainly useful in backtraces to see the decorated function easily
    """
    funcs = funcs if isinstance(funcs, tuple) else (funcs,)
    funcs_str = ",".join(map(function_str, funcs))

    def doit(impl):
        # Override impl name if it was just "_"
        if hasattr(impl, "__name__") and impl.__name__ == "_":
            impl.__name__ = f"@torch_function_override({cls.__name__}, {funcs_str})"
            if impl.__qualname__ != "_":
                print(f"torch_function_override: NOTE: {impl.__qualname__} not overridden")

            if impl.__qualname__ == "_":
                impl.__qualname__ = f"torch_function_override({cls.__name__}, {funcs_str})"

        # Record handler in the dictionary for each func
        for func in funcs:
            cls.HANDLED_FUNCTIONS[func] = impl

        return impl

    return doit


def to_f32_if_scaled(t: Any) -> Tensor:
    if isinstance(t, ScaledTensor):
        return t.st.to_tensor(torch.float32)
    else:
        return t


def to_scaled_if_f32(x, q_dtype):
    if isinstance(x, Tensor) and x.dtype == torch.float32:
        return ScaledTensor(st_quantise(x, q_dtype))
    else:
        return x


@tensor_subclass_override(
    ScaledTensor,
    (
        aten.is_same_size.default,
        aten.gt.Scalar,
        aten.eq.Tensor,
        aten.isnan.default,
        aten.ne.Scalar,
        aten.ge.Scalar,
        aten.le.Scalar,
    ),
)
def upcast_args_and_redispatch(func, *args, **kwargs):
    new_args = tree_map(to_f32_if_scaled, args)
    new_kwargs = tree_map(to_f32_if_scaled, kwargs)

    with _no_dispatch():  # TODO: not needed?
        return func(*new_args, **new_kwargs)


# These ops punt silently to float32, but cast down to ScaledTensor - TODO: implement more efficiently
@tensor_subclass_override(
    ScaledTensor,
    (
        aten.convolution.default,
        aten.convolution_backward.default,
        aten.max_pool2d_with_indices.default,
        aten.addmm.default,
        aten.threshold_backward.default,
        aten._log_softmax.default,
        aten.nll_loss_forward.default,
        aten.nll_loss_backward.default,
        aten.sum.dim_IntList,
        aten._log_softmax_backward_data.default,
        aten.max_pool2d_with_indices_backward.default,
        aten._local_scalar_dense.default,
        aten.max.default,
        aten.log10.default,
        aten.floor.default
    ),
)
def to_float32_and_wrap(func, *args, **kwargs):
    ret = upcast_args_and_redispatch(func, *args, **kwargs)

    q_dtypes = {a.st.data.dtype for a in args if isinstance(a, ScaledTensor)}
    if len(q_dtypes) != 1:
        q_dtype = {_uptype(ty) for ty in q_dtypes}
        warn(f"{func}: Widening to {q_dtype}... from {q_dtypes}")
    else:
        q_dtype = q_dtypes
    assert len(q_dtype) == 1
    q_dtype = q_dtype.pop()

    return tree_map(lambda x: to_scaled_if_f32(x, q_dtype), ret)


@tensor_subclass_override(
    ScaledTensor,
    (
        aten.view.default,
        aten.permute.default,
        aten.t.default,
        aten.gather.default,
        aten.index.Tensor,
    ),
)
def redispatch_via_data(func, t_self, *args, **kwargs):
    new_data = func(t_self.st.data, *args, **kwargs)
    return ScaledTensor(ScaledTensorData(new_data, t_self.st.scale))


@tensor_subclass_override(ScaledTensor, (aten.unbind.int))
def _(func, t_self, *args, **kwargs):
    ret = func(t_self.st.data, *args, **kwargs)
    return tuple(ScaledTensor(ScaledTensorData(new_data, t_self.st.scale)) for new_data in ret)


@tensor_subclass_override(
    ScaledTensor,
    (
        aten.detach.default,
        aten.ones_like.default,
        aten.clone.default,
        aten.abs.default,
    ),
)
def redispatch_via_data_and_scale(func, t_self, *args, **kwargs) -> Tensor:
    new_data = func(t_self.st.data, *args, **kwargs)
    new_scale = func(t_self.st.scale, *args, **kwargs)
    return ScaledTensor(ScaledTensorData(new_data, new_scale))


@tensor_subclass_override(ScaledTensor, aten.sort.default)
def _(func, t_self, *args, **kwargs):
    new_data, indices = func(t_self.st.data, *args, **kwargs)
    return ScaledTensor(ScaledTensorData(new_data, t_self.st.scale)), indices


# Forward the ScaledTensorData ops above on the ScaledTensor type
def quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensor:
    return ScaledTensor(st_quantise(x, dtype), requires_grad=x.requires_grad)


def requantise(st) -> ScaledTensor:
    return ScaledTensor(st_requantise(st.st), requires_grad=st.requires_grad)


# Check basic to/from Subclass
f16 = tensor([[1, 2, 3]], dtype=torch.float16)
f8_t = torch.int8
print(f16)

st = quantise(f16, f8_t)

print(f"{st.shape=}")
assert st.shape == f16.shape
s = str(st)
print(f"{st=}")

print(f"{requantise(st)=}")
print("Rounding errors at the high end of the f8 range:", st.st.to_tensor(f16.dtype))

print("Reshaped:", st.T)

print(colored("Expect a warning...", "yellow"))
print('Addition, but note warning above about "Upcasting to float32", so prints as a normal tensor:', st + 2)

st.requires_grad_(True)
torch.sort(st)

tensor([[1., 2., 3.]], dtype=torch.float16)
st.shape=torch.Size([1, 3])
st=ScaledTensor(0.023622047156095505 * Tensor(1x3,i8) Quants{42|42|42|85|127|127})
requantise(st)=ScaledTensor(0.023622047156095505 * Tensor(1x3,i8) Quants{42|42|42|85|127|127})
Rounding errors at the high end of the f8 range: tensor([[0.992, 2.008, 3.000]], dtype=torch.float16)
Reshaped: ScaledTensor(0.023622047156095505 * Tensor(3x1,i8) Quants{42|42|42|85|127|127})


torch.return_types.sort(
values=ScaledTensor(0.023622047156095505 * Tensor(1x3,i8) Quants{42|42|42|85|127|127},
             grad_fn=<SortBackward0>),
indices=tensor([[0, 1, 2]]))

### Now overrides work, but just punt up to f32 for all ops

We will do some adds/multiplies etc, and note other the torch function
implementation issues a sequence of "WARNING: Upcasting to float32"

In [29]:
print(colored("Expect four warnings...", "yellow"))

print(st + 2)
print(2 * st)
print(st + st)

st3 = quantise(torch.full((2, 3), 100.0), torch.int8)
st4 = quantise(torch.full((3, 4), 200.0), torch.int8)

print(f"{f32(st3 @ st4)       = } <-- should be ~21000 quantized, but was done in f32 so exact")
print(f"{f32(st3) @ f32(st4)  = } <-- 60000 exact")

tensor([[2.992, 4.008, 5.000]], grad_fn=<AddBackward0>)
tensor([[1.984, 4.016, 6.000]], grad_fn=<MulBackward0>)
tensor([[1.984, 4.016, 6.000]], grad_fn=<AddBackward0>)
f32(st3 @ st4)       = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- should be ~21000 quantized, but was done in f32 so exact
f32(st3) @ f32(st4)  = tensor([[60000., 60000., 60000., 60000.],
        [60000., 60000., 60000., 60000.]]) <-- 60000 exact


## Autograd

In [30]:
# class ScaledTensor_add(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, a, b):
#         print("st add fwd")
#         assert isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)
#         return ScaledTensor(st_add(a.st, b.st))

#     @staticmethod
#     def backward(ctx, dout):
#         print("st add bwd")
#         return dout, dout

# # Can't use

# ret = ScaledTensor_add.apply(a, b)
# print("add ret", ret)
# return ret


x = tensor([1.1, 2.2, 3.3])

qx = quantise(x, torch.int8)
qx.requires_grad_(True)

print(f"{qx=}")
qy = qx + qx
# qy = ScaledTensor_add.apply(qx, qx)
print(f"{qy=}")


dx = x * 0.001  # no actual need to do 0.001
qdx = quantise(dx, torch.int8)
print("calling backward")
qy.backward(qdx)
print(f"{qx.grad=}")

qx=ScaledTensor(0.025984251871705055 * Tensor(3,i8) Quants{42|42|42|85|127|127},
             requires_grad=True)
qy=tensor([2.183, 4.417, 6.600], grad_fn=<AddBackward0>)
calling backward
qx.grad=tensor([0.002, 0.004, 0.007])


In [31]:
@tensor_subclass_override(ScaledTensor, aten.add.Tensor)
def _(func_, a: Tensor, b: Tensor, alpha=1) -> Tensor:
    if not (isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)):
        return NotImplemented

    return ScaledTensor(st_add(a.st, b.st, alpha))


@tensor_subclass_override(ScaledTensor, aten.mm.default)
def _(func_, a: Tensor, b: Tensor, *args, **kwargs) -> Tensor:
    if not (isinstance(a, ScaledTensor) and isinstance(b, ScaledTensor)):
        # Just upcast both.
        print(end="u")
        return upcast_args_and_redispatch(func_, a, b, *args, **kwargs)

    assert not args and not kwargs
    return ScaledTensor(st_matmul(a.st, b.st, debug=True))


@tensor_subclass_override(ScaledTensor, aten.relu.default)
def _(func_, a: Tensor, inplace=False) -> Tensor:
    if not isinstance(a, ScaledTensor):
        return NotImplemented

    assert not inplace
    return ScaledTensor(st_relu(a.st))


@tensor_subclass_override(
    ScaledTensor,
    (
        aten.unsqueeze_.default,
        aten.transpose_.default,
        aten.squeeze_.dim,
    ),
)
def inplace_delegate_to_data(func, t_self, *args, **kwargs):
    ## inplace operation on first operand
    if not isinstance(t_self, ScaledTensor):
        # t_self is a normal tensor, just upcast
        return upcast_args_and_redispatch(func, t_self, *args, **kwargs)

    t_self.st.data = func(t_self.st.data, *args, **kwargs)

    return t_self


# - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
@tensor_subclass_override(ScaledTensor, aten.add_.Tensor)
def _(func, t_self, other, *, alpha=1):
    ## inplace
    if not isinstance(t_self, ScaledTensor):
        # t_self is a normal tensor, just upcast and add in place
        return t_self.add_(other.st.to_tensor(), alpha)

    if isinstance(other, ScaledTensor):
        t_self.st = st_add(t_self.st, other.st)
    else:
        # Do the add in high precision and quantize
        ret = quantise(t_self.st.to_tensor() + other, dtype=t_self.st.data.dtype)
        t_self.st = ret.st

    return t_self


# - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
@tensor_subclass_override(ScaledTensor, aten.mul_.Tensor)
def _(func, t_self, other, *args, **kwargs):
    ## inplace
    if not isinstance(t_self, ScaledTensor):
        # t_self is a normal tensor, just upcast and add in place
        return t_self.mul_(other.st.to_tensor())

    if np.isscalar(other) or torch.numel(other) == 1:
        t_self.st.scale *= other
        return t_self

    assert False


@tensor_subclass_override(ScaledTensor, aten.mul.Tensor)
def _(func, t_self, other, *args, **kwargs):
    if np.isscalar(other) or torch.numel(other) == 1:
        return ScaledTensor(ScaledTensorData(t_self.st.data, t_self.st.scale * other))

    return upcast_args_and_redispatch(t_self, other, *args, **kwargs)


st3 = quantise(torch.full((2, 3), 100.0), torch.int8)
st4 = quantise(torch.full((3, 4), 200.0), torch.int8)


def fred():
    print(f"{st3=}")
    print(st3 @ st4)


fred()

ic(st3)
ic(st3.T)
ic(st3.flatten())
ic(torch.flatten(st3))

# ic(st4)
print(f"{f32(st3 @ st4)       = } <-- ~21000 quantized")
print(f"{f32(st3) @ f32(st4)  = } <-- 60000 exact")

print(nn.functional.relu(st))

torch.addmm(st3 @ st3.T, st3, st3.T)
torch.isnan(st3)

ic| st3: ScaledTensor

st3=ScaledTensor(0.787401556968689 * Tensor(2x3,i8) Quants{127|127|127|127|127|127})
ScaledTensor(1.8311105966567993 * Tensor(2x4,i16) Quants{32767|32767|32767|32767|32767|32767})


(0.787401556968689 * Tensor(2x3,i8) Quants{127|127|127|127|127|127})
ic| st3.T: ScaledTensor(0.787401556968689 * Tensor(3x2,i8) Quants{127|127|127|127|127|127})
ic| st3.flatten(): ScaledTensor(0.787401556968689 * Tensor(6,i8) Quants{127|127|127|127|127|127})
ic| torch.flatten(st3): ScaledTensor(0.787401556968689 * Tensor(6,i8) Quants{127|127|127|127|127|127})


f32(st3 @ st4)       = ScaledTensor(1.8311105966567993 * Tensor(2x4,i16) Quants{32767|32767|32767|32767|32767|32767}) <-- ~21000 quantized
f32(st3) @ f32(st4)  = ScaledTensor(1.8311105966567993 * Tensor(2x4,i16) Quants{32767|32767|32767|32767|32767|32767}) <-- 60000 exact
ScaledTensor(0.023622047156095505 * Tensor(1x3,i8) Quants{42|42|42|85|127|127},
             grad_fn=<ReluBackward0>)


tensor([[False, False, False],
        [False, False, False]])

# And in a network...

In [32]:
# Example


class FFN(nn.Module):
    def __init__(self, hidden_size: int, dtype: torch.dtype):
        super().__init__()
        q = lambda x: quantise(x, dtype)
        self.W0 = q(torch.randn(hidden_size, 4 * hidden_size))
        self.W1 = q(torch.randn(4 * hidden_size, hidden_size))
        print(f"{tensor_oneline_str(self.W1)}")

    def forward(self, x: Union[Tensor, ScaledTensor]) -> Union[Tensor, ScaledTensor]:
        y = nn.functional.relu(x @ self.W0)

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #1: {wasted_bits(y)}")
            y = requantise(y)

        y = y @ self.W1

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #2: {wasted_bits(y)}")

        return y


hidden_size = 64
storage_type = torch.int8
module = FFN(hidden_size, storage_type)
x = quantise(torch.randn(10, hidden_size), storage_type)
result = module(x)
print()
print(result)

ScaledTensor(256x64,f32) Quants{-3.894|-1.594|-0.675|0.000|0.675|3.833}
wasted bits #1: 0.9575748443603516
wasted bits #2: 0.9221382141113281

ScaledTensor(0.008577282540500164 * Tensor(10x64,i16) Quants{-27665|-14893|-4625|2450|8658|32767})


## CIFAR

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 60

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

import matplotlib.pyplot as plt

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
# imshow(torchvision.utils.make_grid(images))
# print labels
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))

Files already downloaded and verified
Files already downloaded and verified
frog  truck bird  dog   bird  horse dog   ship  truck ship  horse dog   car   horse cat   cat   horse ship  truck ship  cat   deer  plane ship  ship  deer  ship  horse horse horse plane plane frog  cat   frog  horse cat   horse deer  deer  plane bird  deer  deer  deer  plane plane plane bird  frog  truck frog  horse bird  ship  ship  horse car   truck frog 


In [36]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

for n, p in net.named_parameters():
    print(n, tensor_oneline_str(p))

d = net.state_dict()
for k in d:
    d[k] = quantise(d[k], torch.int16)

net.load_state_dict(d, assign=True)

print()
print("Post quantisation")
print("=================")
for n, p in net.named_parameters():
    print(n, p)

conv1.weight Parameter(6x3x5x5,f32) Quants{-0.114|-0.105|-0.066|-0.006|0.062|0.115}
conv1.bias Parameter(6,f32) Quants{-0.112|-0.112|-0.110|-0.081|0.066|0.081}
conv2.weight Parameter(16x6x5x5,f32) Quants{-0.082|-0.074|-0.040|-0.000|0.041|0.082}
conv2.bias Parameter(16,f32) Quants{-0.067|-0.050|-0.006|0.015|0.027|0.079}
fc1.weight Parameter(120x400,f32) Quants{-0.050|-0.045|-0.025|-0.000|0.025|0.050}
fc1.bias Parameter(120,f32) Quants{-0.050|-0.045|-0.033|-0.005|0.021|0.046}
fc2.weight Parameter(84x120,f32) Quants{-0.091|-0.082|-0.046|-0.000|0.046|0.091}
fc2.bias Parameter(84,f32) Quants{-0.089|-0.081|-0.043|0.005|0.043|0.089}
fc3.weight Parameter(10x84,f32) Quants{-0.108|-0.095|-0.051|0.000|0.055|0.109}
fc3.bias Parameter(10,f32) Quants{-0.098|-0.098|-0.037|0.006|0.024|0.051}

Post quantisation
conv1.weight Parameter(ScaledTensor(3.516209289955441e-06 * Tensor(6x3x5x5,i16) Quants{-32335|-29836|-18705|-1838|17765|32767},
             requires_grad=True))
conv1.bias Parameter(ScaledTenso

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        with torch.autograd.detect_anomaly():
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 0:  # print every 2000 mini-batches
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
            running_loss = 0.0


print("Finished Training")

  with torch.autograd.detect_anomaly():


[1,     1] loss: 0.001


  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/home/awf/micromamba/envs/scaledarith/lib/python3.10/asyncio/base_events.py", line 595, in run_forever
    se

RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.

In [None]:
tuple(tensor_oneline_str(x) for x in data)

('Tensor(4x3x32x32,f32) Quants{-1.000|-0.804|-0.404|-0.137|0.122|1.000}',
 'Tensor(4,i64) Quants{1|1|2|8|8|9}')