In [30]:
import torch

In [31]:
def absmax_quantize_i8(X: torch.Tensor):
    absmax = torch.max(torch.abs(X))
    X_i8 = ((X * 127) / absmax).to(torch.int8)
    return X_i8, X_i8.to(torch.float32) * absmax / 127

def zeropoint_quantize_i8(X: torch.Tensor):
    r = torch.max(X) - torch.min(X)
    r = 1 if r == 0 else r
    scale = 255 / r

    zeropoint = (-scale * torch.min(X) - 128)
    X_i8 =  (X * scale + zeropoint).round().to(torch.int8)
    
    return X_i8, (X_i8 - zeropoint) / scale

def zeropoint_quantize(X):
    # Calculate value range (denominator)
    x_range = torch.max(X) - torch.min(X)
    x_range = 1 if x_range == 0 else x_range

    # Calculate scale
    scale = 255 / x_range

    # Shift by zero-point
    zeropoint = (-scale * torch.min(X) - 128).round()
    # Scale and round the inputs
    X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)

    # Dequantize
    X_dequant = (X_quant - zeropoint) / scale

    return X_quant.to(torch.int8), X_dequant

def zp_mul(A, B):
    # Calculate value range (denominator)
    a_range = torch.max(A) - torch.min(A)
    b_range = torch.max(B) - torch.min(B)
    a_range = 1 if a_range == 0 else a_range
    b_range = 1 if b_range == 0 else b_range
    
    # Calculate scale
    a_scale = 255 / a_range
    b_scale = 255 / b_range
    c_scale = a_scale * b_scale

    # Shift by zero-point
    a_zp = (-a_scale * torch.min(A) - 128).round()
    b_zp = (-b_scale * torch.min(B) - 128).round()
    c_zp = a_zp * b_zp
    
    # Scale and round the inputs
    A_quant = torch.clip((A * a_scale + a_zp).round(), -128, 127).to(torch.int8)
    B_quant = torch.clip((B * b_scale + b_zp).round(), -128, 127).to(torch.int8)    
    
    # print(f'c_scale:{c_scale}, c_zp:{c_zp}')
    # print(f'A_quant:{A_quant}, B_quant:{B_quant}')
    # Multiply
    C_quant = (A_quant.to(torch.int16) * B_quant.to(torch.int16)) + c_zp #- A_quant * b_zp.to(torch.float32) - B_quant * a_zp.to(torch.float32) 
    # print(f'c_quant:{C_quant}')
    C = C_quant / c_scale

    return C

In [32]:
def matmul_vector_abs_i8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    A_scale = 127 / torch.max(torch.abs(A), dim=1).values
    B_scale = 127 / torch.max(torch.abs(B), dim=0).values
    C_scale = torch.matmul(A_scale.unsqueeze(1), B_scale.unsqueeze(0))

    A_i8 = torch.clip((A  * A_scale.unsqueeze(1)).round(), -128, 127).to(torch.int8)
    B_i8 = torch.clip((B  * B_scale.unsqueeze(0)).round(), -128, 127).to(torch.int8)

    return torch.matmul(A_i8.to(torch.int32), B_i8.to(torch.int32)) / C_scale

def LLM_matmul_abs_i8(X: torch.Tensor, W: torch.Tensor, alpha = 5) -> torch.Tensor:
    X_col_filter = torch.max(torch.abs(X), dim = 0).values > alpha
    X1 = X[:, X_col_filter]
    W1 = W[X_col_filter, :]
    X2 = X[:, ~X_col_filter]
    W2 = W[~X_col_filter, :]
    
    O1 = torch.matmul(X1, W1)
    print(f'Reserved {(X1.shape[1] / X.shape[1] * 100):.1f}%')
    X2_scale = 127 / torch.max(torch.abs(X2), dim=1).values
    W2_scale = 127 / torch.max(torch.abs(W2), dim=0).values
    O2_scale = torch.matmul(X2_scale.unsqueeze(1), W2_scale.unsqueeze(0))

    X2_i8 = torch.clip((X2  * X2_scale.unsqueeze(1)).round(), -128, 127).to(torch.int8)
    W2_i8 = torch.clip((W2  * W2_scale.unsqueeze(0)).round(), -128, 127).to(torch.int8)

    O2 = torch.matmul(X2_i8.to(torch.int32), W2_i8.to(torch.int32)) / O2_scale
    
    return O1 + O2.to(O1)

def LLM_matmul_zp_i8(X: torch.Tensor, W: torch.Tensor, alpha = 5) -> torch.Tensor:
    X_col_filter = torch.max(torch.abs(X), dim = 0).values > alpha
    X1 = X[:, X_col_filter]
    W1 = W[X_col_filter, :]
    X2 = X[:, ~X_col_filter]
    W2 = W[~X_col_filter, :]
    
    O1 = torch.matmul(X1, W1)
    print(f'Reserved {(X1.shape[1] / X.shape[1] * 100):.1f}%')
    # Calculate value range (denominator)
    X2_range = torch.max(X2, dim=1).values - torch.min(X2, dim=1).values
    W2_range = torch.max(W2, dim=0).values - torch.min(W2, dim=0).values
    
    # Calculate scale
    X2_scale = 255 / X2_range
    W2_scale = 255 / W2_range
    O2_scale = torch.matmul(X2_scale.unsqueeze(1), W2_scale.unsqueeze(0))

    # Shift by zero-point
    X2_zp = (-X2_scale * torch.min(X2, dim = 1).values - 128).round()
    W2_zp = (-W2_scale * torch.min(W2, dim = 0).values - 128).round()
    O_zp = torch.matmul(X2_zp.unsqueeze(1), W2_zp.unsqueeze(0))    
    
    # Scale and round the inputs
    X2_quant = torch.clip((X2 * X2_scale.unsqueeze(1) + X2_zp.unsqueeze(1)).round(), -128, 127).to(torch.int8)
    W2_quant = torch.clip((W2 * W2_scale.unsqueeze(0) + W2_zp.unsqueeze(0)).round(), -128, 127).to(torch.int8)   
    O2_quant = (X2_quant.to(torch.int32) @ W2_quant.to(torch.int32)) \
                - X2_quant.to(X2) @ W2_zp.unsqueeze(0).expand(X2.shape[1], -1) \
                - X2_zp.unsqueeze(1).expand(-1, W2.shape[0]) @ W2_quant.to(W2) \
                + O_zp * X2.shape[1]
    O2 = O2_quant / O2_scale
    
    return O1 + O2.to(O1)

## Test

In [34]:
X = torch.randn(500, 1000, dtype=torch.bfloat16)
W = torch.randn(1000, 500, dtype=torch.bfloat16)
X[0, 0: X.shape[1] // 10] = 6

error = torch.abs(LLM_matmul_abs_i8(X, W) - X @ W)
filter = error > 1
print(torch.sum(error), torch.sum(error) / (X.shape[0] * W.shape[1]))

error = torch.abs(LLM_matmul_zp_i8(X, W) - X @ W)
filter = error > 1
print(torch.sum(error), torch.sum(error) / (X.shape[0] * W.shape[1]))

error = torch.abs(matmul_vector_abs_i8(X, W) - X @ W)
filter = error > 1
print(torch.sum(error), torch.sum(error) / (X.shape[0] * W.shape[1]))

Reserved 10.0%
tensor(66560., dtype=torch.bfloat16) tensor(0.2656, dtype=torch.bfloat16)
Reserved 10.0%
tensor(68608., dtype=torch.bfloat16) tensor(0.2754, dtype=torch.bfloat16)
tensor(69632., dtype=torch.bfloat16) tensor(0.2793, dtype=torch.bfloat16)
