In [12]:
import torch

In [13]:
def absmax_quantize_i8(X: torch.Tensor):
    absmax = torch.max(torch.abs(X))
    X_i8 = ((X * 127) / absmax).to(torch.int8)
    return X_i8, X_i8.to(torch.float32) * absmax / 127

def zeropoint_quantize_i8(X: torch.Tensor):
    r = torch.max(X) - torch.min(X)
    r = 1 if r == 0 else r
    scale = 255 / r

    zeropoint = (-scale * torch.min(X) - 128)
    X_i8 =  (X * scale + zeropoint).round().to(torch.int8)
    
    return X_i8, (X_i8 - zeropoint) / scale

# def absmax_quantize(X):
#     # Calculate scale
#     scale = 127 / torch.max(torch.abs(X))

#     # Quantize
#     X_quant = (scale * X).round()

#     # Dequantize
#     X_dequant = X_quant / scale

#     return X_quant.to(torch.int8), X_dequant

# def zeropoint_quantize(X):
#     # Calculate value range (denominator)
#     x_range = torch.max(X) - torch.min(X)
#     x_range = 1 if x_range == 0 else x_range

#     # Calculate scale
#     scale = 255 / x_range

#     # Shift by zero-point
#     zeropoint = (-scale * torch.min(X) - 128).round()
#     # Scale and round the inputs
#     X_quant = torch.clip((X * scale + zeropoint).round(), -128, 127)

#     # Dequantize
#     X_dequant = (X_quant - zeropoint) / scale

#     return X_quant.to(torch.int8), X_dequant

def zp_mul(A, B):
    # Calculate value range (denominator)
    a_range = torch.max(A) - torch.min(A)
    b_range = torch.max(B) - torch.min(B)
    a_range = 1 if a_range == 0 else a_range
    b_range = 1 if b_range == 0 else b_range
    
    # Calculate scale
    a_scale = 255 / a_range
    b_scale = 255 / b_range
    c_scale = a_scale * b_scale

    # Shift by zero-point
    a_zp = (-a_scale * torch.min(A) - 128).round()
    b_zp = (-b_scale * torch.min(B) - 128).round()
    c_zp = a_zp * b_zp
    
    # Scale and round the inputs
    A_quant = torch.clip((A * a_scale + a_zp).round(), -128, 127).to(torch.int8)
    B_quant = torch.clip((B * b_scale + b_zp).round(), -128, 127).to(torch.int8)    
    
    # print(f'c_scale:{c_scale}, c_zp:{c_zp}')
    # print(f'A_quant:{A_quant}, B_quant:{B_quant}')
    # Multiply
    C_quant = (A_quant.to(torch.int16) * B_quant.to(torch.int16)) + c_zp #- A_quant * b_zp.to(torch.float32) - B_quant * a_zp.to(torch.float32) 
    # print(f'c_quant:{C_quant}')
    C = C_quant / c_scale

    return C

In [14]:
def matmul_vector_abs_i8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    A_scale = 127 / torch.max(torch.abs(A), dim=1).values
    B_scale = 127 / torch.max(torch.abs(B), dim=0).values
    C_scale = torch.matmul(A_scale.unsqueeze(1), B_scale.unsqueeze(0))

    A_i8 = torch.clip((A  * A_scale.unsqueeze(1)).round(), -128, 127).to(torch.int8)
    B_i8 = torch.clip((B  * B_scale.unsqueeze(0)).round(), -128, 127).to(torch.int8)

    return torch.matmul(A_i8.to(torch.int32), B_i8.to(torch.int32)) / C_scale

def LLM_matmul_abs_i8(X: torch.Tensor, W: torch.Tensor, alpha = 5) -> torch.Tensor:
    X_col_filter = torch.max(torch.abs(X), dim = 0).values > alpha
    X1 = X[:, X_col_filter]
    W1 = W[X_col_filter, :]
    X2 = X[:, ~X_col_filter]
    W2 = W[~X_col_filter, :]
    
    O1 = torch.matmul(X1, W1)
    print(f'Reserved {(X1.shape[1] / X.shape[1] * 100):.1f}%')
    X2_scale = 127 / torch.max(torch.abs(X2), dim=1).values
    W2_scale = 127 / torch.max(torch.abs(W2), dim=0).values
    O2_scale = torch.matmul(X2_scale.unsqueeze(1), W2_scale.unsqueeze(0))

    X2_i8 = torch.clip((X2  * X2_scale.unsqueeze(1)).round(), -128, 127).to(torch.int8)
    W2_i8 = torch.clip((W2  * W2_scale.unsqueeze(0)).round(), -128, 127).to(torch.int8)

    O2 = torch.matmul(X2_i8.to(torch.int32), W2_i8.to(torch.int32)) / O2_scale
    
    return O1 + O2.to(O1)

def LLM_matmul_zp_i8(X: torch.Tensor, W: torch.Tensor, alpha = 5) -> torch.Tensor:
    X_col_filter = torch.max(torch.abs(X), dim = 0).values > alpha
    X1 = X[:, X_col_filter]
    W1 = W[X_col_filter, :]
    X2 = X[:, ~X_col_filter]
    W2 = W[~X_col_filter, :]
    
    O1 = torch.matmul(X1, W1)
    print(f'Reserved {(X1.shape[1] / X.shape[1] * 100):.1f}%')
    # Calculate value range (denominator)
    X2_range = torch.max(X2, dim=1).values - torch.min(X2, dim=1).values
    W2_range = torch.max(W2, dim=0).values - torch.min(W2, dim=0).values
    
    # Calculate scale
    X2_scale = 255 / X2_range
    W2_scale = 255 / W2_range
    O2_scale = torch.matmul(X2_scale.unsqueeze(1), W2_scale.unsqueeze(0))

    # Shift by zero-point
    X2_zp = (-X2_scale * torch.min(X2, dim = 1).values - 128).round()
    W2_zp = (-W2_scale * torch.min(W2, dim = 0).values - 128).round()
    O_zp = torch.matmul(X2_zp.unsqueeze(1), W2_zp.unsqueeze(0))    
    
    # Scale and round the inputs
    X2_quant = torch.clip((X2 * X2_scale.unsqueeze(1) + X2_zp.unsqueeze(1)).round(), -128, 127).to(torch.int8)
    W2_quant = torch.clip((W2 * W2_scale.unsqueeze(0) + W2_zp.unsqueeze(0)).round(), -128, 127).to(torch.int8)   
    O2_quant = (X2_quant.to(torch.int32) @ W2_quant.to(torch.int32)) \
                - X2_quant.to(X2) @ W2_zp.unsqueeze(0).expand(X2.shape[1], -1) \
                - X2_zp.unsqueeze(1).expand(-1, W2.shape[0]) @ W2_quant.to(W2) \
                + O_zp * X2.shape[1]
    O2 = O2_quant / O2_scale
    
    return O1 + O2.to(O1)

## Test

In [15]:
X = torch.randn(500, 1000, dtype=torch.bfloat16)
W = torch.randn(1000, 500, dtype=torch.bfloat16)
X[0, 0: X.shape[1] // 10] = 6

error = torch.abs(LLM_matmul_abs_i8(X, W) - X @ W)
filter = error > 1
print(f'LLM.int8() absmax -> Acc: {torch.sum(error)}, Avg: {torch.sum(error) / (X.shape[0] * W.shape[1])}')

error = torch.abs(LLM_matmul_zp_i8(X, W) - X @ W)
filter = error > 1
print(f'LLM.int8() zero-point -> Acc: {torch.sum(error)}, Avg: {torch.sum(error) / (X.shape[0] * W.shape[1])}')
print(torch.sum(error), torch.sum(error) / (X.shape[0] * W.shape[1]))

error = torch.abs(matmul_vector_abs_i8(X, W) - X @ W)
filter = error > 1
print(f'int8 abs -> Acc: {torch.sum(error)}, Avg: {torch.sum(error) / (X.shape[0] * W.shape[1])}')


Reserved 10.0%
LLM.int8() absmax -> Acc: 66560.0, Avg: 0.265625
Reserved 10.0%
LLM.int8() zero-point -> Acc: 68096.0, Avg: 0.271484375
tensor(68096., dtype=torch.bfloat16) tensor(0.2715, dtype=torch.bfloat16)
int8 abs -> Acc: 69632.0, Avg: 0.279296875


## Modle Quantization Test

In [16]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
torch.manual_seed(0)

# Set device to CPU for now
device = 'cpu'

# Load model and tokenizer
model_id = 'gpt2'
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Print model size
print(f"Model size: {model.get_memory_footprint():,} bytes")

Model size: 510,342,192 bytes


In [17]:
# Extract weights of the first layer
weights = model.transformer.h[0].attn.c_attn.weight.data
print("Original weights:")
print(weights)

# Quantize layer using absmax quantization
weights_abs_quant, weights_abs_dequant = absmax_quantize(weights)
print("\nAbsmax quantized weights:")
print(weights_abs_quant, '\n', weights_abs_dequant)

# Quantize layer using absmax quantization
weights_zp_quant, weights_zp_dequant = zeropoint_quantize(weights)
print("\nZero-point quantized weights:")
print(weights_zp_quant, '\n', weights_zp_dequant)

Original weights:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]])

Absmax quantized weights:
tensor([[-21, -12,  -4,  ...,   2,  -3,   1],
        [  4,   7,  11,  ...,  -2,  -1,  -1],
        [  0,   3,  16,  ...,   5,   2,  -1],
        ...,
        [-12,  -1,   9,  ...,   0,  -2,   1],
        [  7,  10,   5,  ...,   1,  -2,  -2],
        [-18,  -9, -11,  ...,   0,   0,   1]], dtype=torch.int8) 
 tensor([[-0.4702, -0.2687, -0.0896,  ...,  0.0448, -0.0672,  0.0224],
        [ 0.0896,  0.1567,  0.2463,  ..., -0.0448, -0.0224, -0.0224],
        [ 0.0000,  0.0672,  0.3583,  ...,  0.1120,  0.0448, -0.0224],
 

In [18]:
import numpy as np
from copy import deepcopy

# Store original weights
weights = [param.data.clone() for param in model.parameters()]

# Create model to quantize
model_abs = deepcopy(model)

# Quantize all model weights
weights_abs = []
for param in model_abs.parameters():
    _, dequantized = absmax_quantize_i8(param.data)
    param.data = dequantized
    weights_abs.append(dequantized)

# Create model to quantize
model_zp = deepcopy(model)

# Quantize all model weights
weights_zp = []
for param in model_zp.parameters():
    _, dequantized = zeropoint_quantize_i8(param.data)
    param.data = dequantized
    weights_zp.append(dequantized)

In [28]:
def generate_text(model, input_text, max_length=50):
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    output = model.generate(inputs=input_ids,
                            max_length=max_length,
                            do_sample=True,
                            top_k=30,
                            pad_token_id=tokenizer.eos_token_id,
                            attention_mask=input_ids.new_ones(input_ids.shape))
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Generate text with original and quantized models
original_text = generate_text(model, "I have a dream")
absmax_text   = generate_text(model_abs, "I have a dream")
zp_text       = generate_text(model_zp, "I have a dream")

print(f"Original model:\n{original_text}")
print("-" * 50)
print(f"Absmax model:\n{absmax_text}")
print("-" * 50)
print(f"Zeropoint model:\n{zp_text}")

Original model:
I have a dream, it's all about getting a big break but that doesn't sound like a great start. I was really happy with the way things turned out for me and it got me through some of the tough times that people were dealing with
--------------------------------------------------
Absmax model:
I have a dream."
And it come I is like he that was he be come

It to that not all was he he to that is he
And in as said he was is she that him all be she that he it said
--------------------------------------------------
Zeropoint model:
I have a dream of becoming something."

"Well, so do I will not be that someday that I will find out."

It is all for the sake of her; but I shall prove her.

"So, my


In [29]:
def calculate_perplexity(model, text):
    # Encode the text
    encodings = tokenizer(text, return_tensors='pt').to(device)

    # Define input_ids and target_ids
    input_ids = encodings.input_ids
    target_ids = input_ids.clone()
    
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

    # Loss calculation    
    neg_log_likelihood = outputs.loss

    # Perplexity calculation
    ppl = torch.exp(neg_log_likelihood)

    return ppl

ppl     = calculate_perplexity(model, original_text)
ppl_abs = calculate_perplexity(model_abs, absmax_text)
ppl_zp  = calculate_perplexity(model_zp, zp_text)

print(f"Original perplexity:  {ppl.item():.2f}")
print(f"Absmax perplexity:    {ppl_abs.item():.2f}")
print(f"Zeropoint perplexity: {ppl_zp.item():.2f}")

Original perplexity:  13.95
Absmax perplexity:    74.65
Zeropoint perplexity: 19.24
