In [2]:
import torch

In [205]:
def absmax_quantize_i8(X: torch.Tensor):
    absmax = torch.max(torch.abs(X))
    X_i8 = ((X * 127) / absmax).to(torch.int8)
    return X_i8, X_i8.to(torch.float32) * absmax / 127

def zeropoint_quantize_i8(X: torch.Tensor):
    r = torch.max(X) - torch.min(X)
    r = 1 if r == 0 else r
    scale = 255 / r

    zeropoint = (-scale * torch.min(X) - 128)
    X_i8 =  (X * scale + zeropoint).round().to(torch.int8)
    
    return X_i8, (X_i8 - zeropoint) / scale

def zeropoint_quantize(X):
    # Calculate value range (denominator)
    x_range = torch.max(X) - torch.min(X)
    x_range = 1 if x_range == 0 else x_range

    # Calculate scale
    scale = 255 / x_range

    # Shift by zero-point
    zeropoint = (-scale * torch.min(X) - 128).round()
    # Scale and round the inputs
    X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)

    # Dequantize
    X_dequant = (X_quant - zeropoint) / scale

    return X_quant.to(torch.int8), X_dequant

In [65]:
def matmul_vector_abs_i8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    A_scale = 127 / torch.max(torch.abs(A), dim=1).values
    B_scale = 127 / torch.max(torch.abs(B), dim=0).values
    C_scale = torch.matmul(A_scale.unsqueeze(1), B_scale.unsqueeze(0))

    A_i8 = (A  * A_scale.unsqueeze(1)).round().to(torch.int8)
    B_i8 = (B  * B_scale.unsqueeze(0)).round().to(torch.int8)

    return torch.matmul(A_i8.to(torch.int16), B_i8.to(torch.int16)) / C_scale

def matmul_i8(A: torch.Tensor, B: torch.Tensor, alpha = 5) -> torch.Tensor:
    A_scale = 127 / torch.max(torch.abs(A))
    B_scale = 127 / torch.max(torch.abs(B))
    C_scale = A_scale * B_scale

    A_i8 = (A  * A_scale).round().to(torch.int8)
    B_i8 = (B  * B_scale).round().to(torch.int8)

    return torch.matmul(A_i8.to(torch.int16), B_i8.to(torch.int16)) / C_scale

def LLM_matmul_i8(X: torch.Tensor, W: torch.Tensor, alpha = 5) -> torch.Tensor:
    X_col_filter = torch.max(torch.abs(X), dim = 0).values > alpha
    X1 = X[:, X_col_filter]
    W1 = W[X_col_filter, :]
    X2 = X[:, ~X_col_filter]
    W2 = W[~X_col_filter, :]
    
    C1 = torch.matmul(X1, W1)
    print(f'Reserved {(X1.shape[1] / X.shape[1] * 100):.1f}%')
    X2_scale = 127 / torch.max(torch.abs(X2), dim=1).values
    W2_scale = 127 / torch.max(torch.abs(W2), dim=0).values
    C2_scale = torch.matmul(X2_scale.unsqueeze(1), W2_scale.unsqueeze(0))

    X2_i8 = (X2  * X2_scale.unsqueeze(1)).round().to(torch.int8)
    W2_i8 = (W2  * W2_scale.unsqueeze(0)).round().to(torch.int8)

    C2 = torch.matmul(X2_i8.to(torch.int16), W2_i8.to(torch.int16)) / C2_scale
    
    return C1 + C2

In [68]:
X = torch.randn(50, 100, dtype=torch.bfloat16)
W = torch.randn(100, 50, dtype=torch.bfloat16)
X[0,0:X.shape[1] // 3] = 6

error = torch.sum(torch.abs(LLM_matmul_i8(X, W) - X @ W))
print(error, error/(X.shape[0] * W.shape[1]))
# print(LLM_matmul_i8(X, W, 3) - X@W)

error = torch.sum(torch.abs(matmul_vector_abs_i8(X, W) - X @ W))
print(error, error/(X.shape[0] * W.shape[1]))
# print(matmul_vector_abs_i8(X, W) - X@W)

Reserved 33.0%
tensor(8832., dtype=torch.bfloat16) tensor(3.5312, dtype=torch.bfloat16)
tensor(12928., dtype=torch.bfloat16) tensor(5.1562, dtype=torch.bfloat16)


In [329]:
A = torch.randn(5, 5)
A_i8, A_recon = absmax_quantize_i8(A)
torch.sum(torch.abs(A - A_recon))

tensor(0.1733)

In [330]:
A_i8, A_recon = zeropoint_quantize_i8(A)
torch.sum(torch.abs(A - A_recon))

tensor(0.0715)

In [331]:
A_i8, A_recon = zeropoint_quantize(A)
torch.sum(torch.abs(A - A_recon))

tensor(0.0871)

In [347]:
A = torch.randn(5, 5).to(torch.float16)
B = torch.randn(5, 5).to(torch.float16)

nd_A_f16 = (255 / (A.max() - A.min())).round().to(A)
nd_B_f16 = (255 / (B.max() - B.min())).round().to(A)
zp_A_i16 = (A * A.min()).round().to(torch.int16)
zp_B_i16 = (B * B.min()).round().to(torch.int16)

A_i8 = (A * nd_A_f16).round().to(torch.int16)
# print(A_i8 - (A * nd_A_f16).round().to(torch.int16))
B_i8 = (B * nd_B_f16).round().to(torch.int16)

C_i32 = (A_i8.to(torch.int32) + zp_A_i16) * (B_i8.to(torch.int32) + zp_B_i16)
torch.sum(torch.abs(C_i32 / (nd_A_f16 * nd_B_f16) - A * B)), C_i32 / (nd_A_f16 * nd_B_f16) - A * B

(tensor(0.9214, dtype=torch.float16),
 tensor([[ 1.9531e-03, -3.5645e-02, -2.9297e-03,  3.8574e-02,  1.3477e-01],
         [-6.4453e-02,  6.5918e-03,  7.2021e-03, -1.6992e-01,  4.5898e-02],
         [-7.2266e-02,  3.5400e-03, -1.0681e-04,  3.1738e-02, -4.2480e-02],
         [-5.1880e-04, -1.5625e-02, -4.8340e-02,  5.3711e-03,  3.6621e-03],
         [-2.5146e-02,  3.0273e-02,  1.1963e-02,  1.2109e-01,  1.2817e-03]],
        dtype=torch.float16))