In [1]:
from functools import partial
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from typing import *
_torch_tensor_to = torch.Tensor.to
_torch_tensor = torch.tensor

def numeric_info(dtype):
  return torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)



In [10]:
# Core

# https://pytorch.org/docs/stable/notes/extending.html#extending-torch-with-a-tensor-like-type
@dataclass
class ScaledTensor:
    data: Tensor
    scale: Tensor

    def __init__(self, data, scale = torch.tensor(1.0, dtype=torch.float32)):
        self.data = data
        self.scale = scale

    def __post_init__(self) -> None:
        assert self.scale.dtype == torch.float32
        assert self.scale.shape == (1,) 
        # Possible future expansion to e.g. row-scaled, column-scaled, etc

    def __repr__(self) -> str:
        return f"ScaledTensor({self.scale} * {self.data})"

    def to_tensor(self, dtype: torch.dtype) -> Tensor:
        return self.data.to(dtype) * self.scale.to(dtype)

    @property
    def shape(self) -> torch.Size:
        return self.data.shape

def _uptype(dtype) -> torch.dtype:
    if torch.finfo(dtype).bits < 16:
        return torch.float16
    else:
        return dtype
    
def _to_uptype(t : Tensor) -> Tensor:
    return t.to(_uptype(t.dtype))

def _maxabs(t : Tensor):
    return _to_uptype(t).abs().max().to(torch.float32)

def quantise(x: Tensor, dtype: torch.dtype) -> ScaledTensor:
    maxval = _maxabs(x)
    if maxval == 0:
        # Tensor is all zeros - set scale to 1
        scale = torch.tensor(1.0, dtype=torch.float32)
    else:
        # Scale so that largest element is the largest finite value of dtype
        scale = maxval / torch.finfo(dtype).max

    return ScaledTensor(data=(x / scale).to(dtype), scale=scale)

def requantise(st: ScaledTensor) -> ScaledTensor:
    """
    Rescale so that max(|data|) == maxFinite(data.dtype)

    Equivalent to quantise(st.to_tensor(torch.float32)) but avoids the conversion

    Returned tensor may share its data with input tensor
    """
    maxdataval = _maxabs(st.data)
    if maxdataval == 0:
        # All zero, reset scale to 1
        return ScaledTensor(st.data, st.scale)
    else:
         rescale = maxdataval / torch.finfo(st.data.dtype).max
    return ScaledTensor((_to_uptype(st.data) * (1 / rescale)).to(st.data.dtype), st.scale * rescale)

f16 = torch.tensor([1, 2, 3], dtype=torch.float16)

st = quantise(f16, torch.float8_e4m3fn)

print(st)
print(requantise(st))


ScaledTensor(0.0066964286379516125 * tensor([144., 288., 448.], dtype=torch.float8_e4m3fn))


TypeError: unsupported operand type(s) for +: 'ScaledTensor' and 'int'

In [12]:
# Ops
# - Use worst-case scaling rules (no overflow!)
# - Placeholder impl (fast impl requires custom kernls)
# - No autograd support


def add(a: ScaledTensor, b: ScaledTensor) -> ScaledTensor:
    assert a.dtype == b.dtype
    scale = a.scale + b.scale
    data = (a.data * (a.scale / scale) + b.data * (b.scale / scale)).to(a.data.dtype)
    return ScaledTensor(data, scale, a.dtype)


def matmul(a: ScaledTensor, b: ScaledTensor) -> ScaledTensor:
    assert a.dtype == b.dtype
    scale = a.scale * b.scale * a.shape[-1]
    downscale = (a.shape[-1] / a.dtype.base_scale) ** -.5
    return ScaledTensor(
        (a.data * downscale).to(a.dtype.dtype) @ (b.data * downscale).to(b.dtype.dtype),
        scale, a.dtype)


def relu(a: ScaledTensor) -> ScaledTensor:
    return ScaledTensor(nn.functional.relu(a.data), a.scale, a.dtype)


ScaledTensor.__add__ = add
ScaledTensor.__matmul__ = matmul


print(torch.tensor(100, dtype=sint8) + torch.tensor(200, dtype=sint8))
print(torch.full((2, 20), 100).to(sint16) @ torch.full((20, 3), 200).to(sint16))

stensor(300.0, dtype=sint8)
stensor([[390636.9 390636.9 390636.9]
 [390636.9 390636.9 390636.9]], dtype=sint16)


In [13]:
# Example

class FFN(nn.Module):
    def __init__(self, hidden_size: int, dtype: Union[torch.dtype, ScaledDType]):
        super().__init__()
        self.W0 = torch.randn(hidden_size, 4*hidden_size).to(dtype)
        self.W1 = torch.randn(4*hidden_size, hidden_size).to(dtype)
        self.relu = relu if isinstance(dtype, ScaledDType) else nn.functional.relu

    def forward(self, x: Union[Tensor, ScaledTensor]) -> Union[Tensor, ScaledTensor]:
        y = self.relu(x @ self.W0)

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #1: {wasted_bits(y)}")
            y = requantise(y)

        y = y @ self.W1

        if isinstance(y, ScaledTensor):
            print(f"wasted bits #2: {wasted_bits(y)}")

        return y

hidden_size = 1024
dtype = sint16
module = FFN(hidden_size, dtype)
result = module(torch.randn(10, hidden_size).to(dtype))
print()
print(result)

wasted bits #1: 8.258488655090332
wasted bits #2: 10.414993286132812

stensor([[   0.        -41.124104  164.49641  ...  287.8687     41.124104
   205.62051 ]
 [ -82.24821  -123.37231   -82.24821  ... -164.49641  -164.49641
  -164.49641 ]
 [ 123.37231  -205.62051   411.24103  ...  -82.24821     0.
     0.      ]
 ...
 [ -41.124104  123.37231   205.62051  ... -287.8687   -287.8687
  -411.24103 ]
 [ -82.24821  -164.49641    82.24821  ...   82.24821    41.124104
   -41.124104]
 [  41.124104  205.62051    82.24821  ...   82.24821  -164.49641
  -123.37231 ]], dtype=sint16)
