In [3]:
import io
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
# Use smaller model if you have limited resources
# checkpoint = "HuggingFaceTB/SmolLM2-360M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, padding_side="left")

def get_serialized_model_size_in_mb(model):
    """Calculate the size of the serialized model in megabytes."""
    buffer = io.BytesIO()
    torch.save(model, buffer)
    size_in_bytes = buffer.tell()
    buffer.close()
    size_in_mb = size_in_bytes / (1024**2)  # Convert bytes to megabytes
    return size_in_mb

def run_model(model):
    text = "Tell me a fun fact in a short sentence."
    messages = [{"role": "user", "content": text}]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    input_length = input_ids.shape[1]
    output_ids = model.generate(input_ids, max_length=200, do_sample=True)[0]
    response_ids = output_ids[input_length:]
    return tokenizer.decode(response_ids, ignore_special_tokens=True)

In [80]:
model = AutoModelForCausalLM.from_pretrained(checkpoint)
size_before_quantization = get_serialized_model_size_in_mb(model)
print(f"Serialized model size before quantization: {size_before_quantization:.2f} MB")
print(run_model(model))

Serialized model size before quantization: 6528.52 MB
<|im_start|>assistant
Did you know ancient Egyptians often buried food alongside their mummies to ensure a full belly in the afterlife?<|im_end|>


In [82]:
# This quantizes the weights from bfloat32 to bfloat16
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.to(torch.bfloat16)
size_after_quantization = get_serialized_model_size_in_mb(model)
print(f"Serialized model size after quantization1: {size_after_quantization:.2f} MB")
print(run_model(model))

Serialized model size after quantization1: 3264.33 MB
<|im_start|>assistant
Did you know that penguins can only fly underwater, not in the sky like humans?<|im_end|>


In [None]:
from typing import Tuple

def int8_symmetric_quantize(
    fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    quant_min = -128
    quant_max = 127
    min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
    max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scale = max_val_pos / (float(quant_max - quant_min) / 2)
    scale = scale.view(fp32_tensor.shape[0], -1)
    out = torch.round(fp32_tensor * (1.0 / scale))
    out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
    return out, scale

def quantize_linear_layers_to_int8(model):
    for _, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            weight = module.weight.data
            w_int8, scale = int8_symmetric_quantize(weight)
            
            # Store the int8 weights and scale directly in the module
            module.register_buffer('weight_int8', w_int8)
            module.register_buffer('weight_scale', scale)
            # Remove the original weight to save space
            delattr(module, 'weight')
            
            # Create a new forward method that uses the quantized weights
            def new_forward(self, x):
                # Dequantize during inference
                dequantized_weight = self.weight_int8.to(x.dtype) * self.weight_scale
                return torch.nn.functional.linear(x, dequantized_weight, self.bias)
            import types
            module.forward = types.MethodType(new_forward, module)
    return model

model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = model.to(torch.bfloat16)
model = quantize_linear_layers_to_int8(model)
size_after_quantization2 = get_serialized_model_size_in_mb(model)
print(f"Serialized model size after quantization2: {size_after_quantization2:.2f} MB")
print(run_model(model))

Serialized model size after quantization2: 1825.69 MB
<|im_start|>assistant
Did you know that the longest-living creature in the animal kingdom can grow up to 21ft long and has the average lifespan of 100 years?<|im_end|>


In [84]:
model = AutoModelForCausalLM.from_pretrained(checkpoint)
# Apply dynamic quantization - targets only the linear layers
quantized_model = torch.quantization.quantize_dynamic(
    model, 
    {torch.nn.Linear},  # Only quantize linear layers
    dtype=torch.qint8
)

# Check size and run inference
size_before = get_serialized_model_size_in_mb(model)
size_after = get_serialized_model_size_in_mb(quantized_model)
print(f"Original size: {size_before:.2f} MB")
print(f"Quantized size: {size_after:.2f} MB")
print(f"Size reduction: {100 * (1 - size_after / size_before):.2f}%")
print(run_model(quantized_model))

Original size: 6528.52 MB
Quantized size: 2016.57 MB
Size reduction: 69.11%
assistant A Your I system is what is it says. Do You want a question Your Your Question might be “Is Rome a great city for ladies If You ever get out of town in Italy for tourism or even business. yes the Romans I mean the ancient romans Are we talking about any city A city could be built near Rome if A city A city is What do you want to do You are doing The City Well The Romans built Rome A huge city city So you Can do whatever you want now In Rome Because the Roman government And The Romans Are Famous A whole for building stuff Rome The Romans built more then almost anything now Rome is it was What they called for Rome The ancient roman Rome is also called The Eternal Rome Which means The old Rome Is always Rome The Romans built Rome a very tall tower Which we


In [None]:
# This works in floating point
x = torch.tensor([0.5, 1.5, 2.5])

x_int8 = x.to(torch.int8)
print(x_int8)  # [0, 2, 3] - already losing precision
y_int8 = torch.softmax(x_int8, dim=0)  # Error: softmax not implemented for 'Int8'

tensor([0, 1, 2], dtype=torch.int8)


RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Char'