**Appendix B – Mixed Precision and Quantization**

_This notebook contains all the sample code for appendix B._

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/ageron/handson-mlp/blob/main/Appendix_B_mixed_precision_and_quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
  </td>
  <td>
    <a target="_blank" href="https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-mlp/blob/main/Appendix_B_mixed_precision_and_quantization.ipynb"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" /></a>
  </td>
</table>

# Setup

This project requires Python 3.10 or above:

In [1]:
import sys

assert sys.version_info >= (3, 10)

Are we using Colab or Kaggle?

In [2]:
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules

And PyTorch ≥ 2.6.0:

In [3]:
from packaging.version import Version
import torch

assert Version(torch.__version__) >= Version("2.6.0")

As we did in earlier chapters, let's define the default font sizes to make the figures prettier:

In [4]:
import matplotlib.pyplot as plt

plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)

This chapter can be very slow without a hardware accelerator, so if we can find one, let's use it:

In [5]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

device

'cuda'

Let's issue a warning if there's no hardware accelerator available:

In [6]:
if device == "cpu":
    print("Neural nets can be very slow without a hardware accelerator.")
    if IS_COLAB:
        print("Go to Runtime > Change runtime and select a GPU hardware "
              "accelerator.")
    if IS_KAGGLE:
        print("Go to Settings > Accelerator and select GPU.")

# Common Number Representations

PyTorch supports several types for floats:

In [7]:
import torch

fp64 = torch.tensor(1.234e56, dtype=torch.float64)

fp32 = torch.tensor(1.234e56, dtype=torch.float32)

fp16 = torch.tensor(1.234e56, dtype=torch.float16)
bf16 = torch.tensor(1.234e56, dtype=torch.bfloat16)

fp8 = torch.tensor(1.234e56, dtype=torch.float8_e5m2)
bf8 = torch.tensor(1.234e56, dtype=torch.float8_e4m3fn)

And several types of integers:

In [8]:
i64 = torch.tensor(123456, dtype=torch.int64)
u64 = torch.tensor(123456, dtype=torch.uint64)

i32 = torch.tensor(123456, dtype=torch.int32)
u32 = torch.tensor(123456, dtype=torch.uint32)

i16 = torch.tensor(1234, dtype=torch.int16)
u16 = torch.tensor(1234, dtype=torch.uint16)

i8 = torch.tensor(123, dtype=torch.int8)
u8 = torch.tensor(123, dtype=torch.uint8)

PyTorch does not yet have direct support for 4-bit integers, but we can pack two 4-bit integers into one 8-bit integer using bitwise operations:

In [9]:
u4 = torch.tensor((12 << 4) | 5, dtype=torch.uint8)  # stores 12 and 5

We can unpack the values by reversing the bit operations:

In [10]:
u4 >> 4, u4 & 0xF

(tensor(12, dtype=torch.uint8), tensor(5, dtype=torch.uint8))

It's a bit more involved for signed integers, because they are represented using 2's complement, so it's best to store them using an unsigned byte:

In [11]:
hi, lo = -5, -8
i4 = torch.tensor(((hi & 0xF) << 4) | (lo & 0xF), dtype=torch.uint8)

Unpacking is also a bit trickier:

In [12]:
hi, lo = ((i4 >> 4).to(torch.int8) ^ 8) - 8, ((i4 & 0xF).to(torch.int8) ^ 8) - 8
hi.item(), lo.item()

(-5, -8)

A similar approach can be used to store four 2-bit integers in an unsigned byte (this is left as an exercise to the reader).

Lastly, here's an example showing how ternary values can be represented efficiently, packing 5 values per byte:

In [13]:
def pack_ternary(values):
    factors = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8)
    return ((values + 1).to(torch.uint8) * factors).sum(dim=-1).to(torch.uint8)

def unpack_ternary(packed):
    vals = torch.empty(packed.shape + (5,), dtype=torch.int8)
    for i in range(5):
        vals[..., i] = packed % 3
        packed //= 3
    return vals - 1

In [14]:
values = torch.tensor([-1, 0, 0, 1, -1], dtype=torch.int8)
packed = pack_ternary(values)
packed

tensor(66, dtype=torch.uint8)

In [15]:
unpack_ternary(torch.tensor(0, dtype=torch.uint8))

tensor([-1, -1, -1, -1, -1], dtype=torch.int8)

## Reduced Precision Models

In [16]:
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 1))
# [...] pretend the 32-bit model is trained here
model.half()  # convert the model parameters to half precision (16 bits)

Sequential(
  (0): Linear(in_features=10, out_features=100, bias=True)
  (1): ReLU()
  (2): Linear(in_features=100, out_features=1, bias=True)
)

In [17]:
X = torch.rand(3, 10, dtype=torch.float16)  # some 16-bit input
y_pred = model(X)  # 16-bit output

In [18]:
model = nn.Sequential(nn.Linear(10, 100, dtype=torch.float16), nn.ReLU(),
                      nn.Linear(100, 1, dtype=torch.float16))

## Mixed Precision Training

In [19]:
# MPT is not fully supported yet on MPS devices
device2 = "cpu" if device == "mps" else device

In [20]:
from torch.amp import GradScaler

def train_mpt(model, optimizer, criterion, train_loader, n_epochs,
              dtype=torch.float16, init_scale=2.0**16):
    grad_scaler = GradScaler(device=device2, init_scale=init_scale)
    model.train()
    for epoch in range(n_epochs):
        total_loss = 0.
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device2), y_batch.to(device2)
            with torch.autocast(device_type=device2, dtype=dtype):
                y_pred = model(X_batch)
                loss = criterion(y_pred, y_batch)
            total_loss += loss.item()
            print(f"\rEpoch: {epoch + 1}, loss: {total_loss:.3f}", end="")
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
            optimizer.zero_grad()
        print()

In [21]:
from torch.utils.data import DataLoader, TensorDataset

torch.manual_seed(42)
model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 1)).to(device2)

X_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)
train_set = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_set, batch_size=32)
optimizer = torch.optim.NAdam(model.parameters(), lr=0.01)
mse = torch.nn.MSELoss()
train_mpt(model, optimizer, mse, train_loader, n_epochs=10)

Epoch: 1, loss: 32.218
Epoch: 2, loss: 30.540
Epoch: 3, loss: 29.854
Epoch: 4, loss: 29.382
Epoch: 5, loss: 29.013
Epoch: 6, loss: 28.641
Epoch: 7, loss: 28.321
Epoch: 8, loss: 28.004
Epoch: 9, loss: 27.736
Epoch: 10, loss: 27.444


# Quantization

## Asymmetric linear quantization

In [22]:
w = torch.tensor([0.1, -0.1, 0.6, 0.0])  # 32-bit floats
s = (w.max() - w.min()) / 255.  # compute the scale
z = -(w.min() / s).round()  # compute the zero point
qw = torch.quantize_per_tensor(w, scale=s, zero_point=z, dtype=torch.quint8)
qw  # this is a quantized tensor internally represented using integers

tensor([ 0.0988, -0.0988,  0.6012,  0.0000], size=(4,), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.002745098201557994,
       zero_point=36)

In [23]:
qw.dequantize()  # back to 32-bit floats (close to the original tensor)

tensor([ 0.0988, -0.0988,  0.6012,  0.0000])

## Symmetric linear quantization

In [24]:
w = torch.tensor([0.0, -0.94, 0.92, 0.93])  # 32-bit floats
s = w.abs().max() / 127.
qw = torch.quantize_per_tensor(w, scale=s, zero_point=0, dtype=torch.qint8)
qw

tensor([ 0.0000, -0.9400,  0.9178,  0.9326], size=(4,), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.007401574868708849,
       zero_point=0)

## Dynamic Quantization

In [25]:
import platform

machine = platform.machine().lower()
engine = "qnnpack" if ("arm" in machine or "aarch64" in machine) else "x86"

In [26]:
from torch.ao.quantization import quantize_dynamic

model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 1))
# [...] pretend the 32-bit model is trained here
torch.backends.quantized.engine = engine
qmodel = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
X = torch.randn(3, 10)
y_pred = qmodel(X)  # float inputs and outputs, but quantized internally

## Static Quantization

In [27]:
calibration_loader = train_loader

In [28]:

from torch.ao.quantization import get_default_qconfig, QuantStub, DeQuantStub

torch.manual_seed(42)
model = nn.Sequential(QuantStub(),
                      nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 1),
                      DeQuantStub())
# [...] pretend the 32-bit model is trained here
model.qconfig = get_default_qconfig(engine)
torch.ao.quantization.prepare(model, inplace=True)
for X_batch, _ in calibration_loader:
    model(X_batch)
torch.ao.quantization.convert(model, inplace=True)

Sequential(
  (0): Quantize(scale=tensor([0.0604]), zero_point=tensor([57]), dtype=torch.quint8)
  (1): QuantizedLinear(in_features=10, out_features=100, scale=0.040152497589588165, zero_point=65, qscheme=torch.per_channel_affine)
  (2): ReLU()
  (3): QuantizedLinear(in_features=100, out_features=1, scale=0.01066256407648325, zero_point=68, qscheme=torch.per_channel_affine)
  (4): DeQuantize()
)

In [29]:
get_default_qconfig(engine)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})

In [30]:
for name, module in model.named_modules():
    if hasattr(module, 'weight'):
        print(name, module.weight().shape, module.weight().dtype)

1 torch.Size([100, 10]) torch.qint8
3 torch.Size([1, 100]) torch.qint8


In [31]:
torch.manual_seed(42)
X_batch = torch.randn(3, 10)
y_pred = model(X_batch)
y_pred

tensor([[ 0.2133],
        [ 0.0000],
        [-0.0320]])

Some modules can be fused, such as fusing `Linear` and `ReLU` into a `LinearReLU` layer. For this, we can use the `fuse_modules()` function, and give it a list of modules to fuse. Since we reference these modules by name, we must name each module that we want to fuse. This is usually done by creating a custom module with one attribute for each module, but another option is to use an `OrderedDict` like this:

In [32]:
from collections import OrderedDict

torch.manual_seed(42)
model = nn.Sequential(OrderedDict([
    ("quantize", QuantStub()),
    ("linear1", nn.Linear(10, 100)),
    ("relu1", nn.ReLU()),
    ("linear2", nn.Linear(100, 1)),
    ("dequantize", DeQuantStub())]))
# [...] pretend the 32-bit model is trained here
torch.quantization.fuse_modules(model, [['linear1', 'relu1']], inplace=True)

Sequential(
  (quantize): QuantStub()
  (linear1): LinearReLU(
    (0): Linear(in_features=10, out_features=100, bias=True)
    (1): ReLU()
  )
  (relu1): Identity()
  (linear2): Linear(in_features=100, out_features=1, bias=True)
  (dequantize): DeQuantStub()
)

When we quantize the model, we get a `QuantizedLinearReLU` module:

In [33]:
model.qconfig = get_default_qconfig(engine)
torch.ao.quantization.prepare(model, inplace=True)
for X_batch, _ in calibration_loader:
    model(X_batch)
torch.ao.quantization.convert(model, inplace=True)

Sequential(
  (quantize): Quantize(scale=tensor([0.0604]), zero_point=tensor([57]), dtype=torch.quint8)
  (linear1): QuantizedLinearReLU(in_features=10, out_features=100, scale=0.019709425047039986, zero_point=0, qscheme=torch.per_channel_affine)
  (relu1): Identity()
  (linear2): QuantizedLinear(in_features=100, out_features=1, scale=0.01066256407648325, zero_point=68, qscheme=torch.per_channel_affine)
  (dequantize): DeQuantize()
)

Next, you can use ExecuTorch or TFLite or any other solution to deploy your model to the target device. For example, let's export this fused and quantized module to ONNX:

In [34]:
if IS_COLAB:
    %pip install -qU onnx

In [35]:
torch.onnx.export(
    model,
    X_batch,
    "my_quantized_model.onnx",
    opset_version=13,
    input_names=['float_input'],
    output_names=['float_output']
)

## Quantization-Aware Training (QAT)

In [36]:
from torch.ao.quantization import get_default_qat_qconfig

torch.manual_seed(42)
model = nn.Sequential(QuantStub(),
                      nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 1),
                      DeQuantStub())
model.qconfig = get_default_qat_qconfig(engine)
torch.ao.quantization.prepare_qat(model, inplace=True)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
X_train = torch.randn(128, 10)
y_train = torch.randn(128, 1)
for epoch in range(5):
    optimizer.zero_grad()
    y_pred = model(X_train)
    loss = criterion(y_pred, y_train)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

model.eval()
torch.ao.quantization.convert(model, inplace=True)

Epoch 0, Loss: 1.1153
Epoch 1, Loss: 1.1081
Epoch 2, Loss: 1.1034
Epoch 3, Loss: 1.0954
Epoch 4, Loss: 1.0909


Sequential(
  (0): Quantize(scale=tensor([0.0473]), zero_point=tensor([60]), dtype=torch.quint8)
  (1): QuantizedLinear(in_features=10, out_features=100, scale=0.03720283508300781, zero_point=62, qscheme=torch.per_channel_affine)
  (2): ReLU()
  (3): QuantizedLinear(in_features=100, out_features=1, scale=0.008233007043600082, zero_point=68, qscheme=torch.per_channel_affine)
  (4): DeQuantize()
)

In [37]:
torch.manual_seed(42)
X_batch = torch.randn(3, 10)
y_pred = model(X_batch)  # float inputs & outputs, but quantized internally
y_pred

tensor([[ 0.1811],
        [ 0.0412],
        [-0.0082]])

# Hugging Face BitsAndBytes (bnb)

In [38]:
if IS_COLAB or IS_KAGGLE:
    %pip install -qU bitsandbytes

In [39]:
if device == "cuda":
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig

  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
  bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
                                  bnb_4bit_compute_dtype=torch.bfloat16)
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",
                                               quantization_config=bnb_config)

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [40]:
if device == "cuda":
  from transformers import AutoTokenizer

  tokenizer = AutoTokenizer.from_pretrained(model_id)
  tokenizer.pad_token = tokenizer.eos_token
  input_ids = tokenizer("Wow!", return_tensors="pt").input_ids.to(model.device)
  output_ids = model.generate(input_ids, max_new_tokens=10)
  print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Wow! I love this! I'm going to try


In [41]:
if device == "cuda":
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

  bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
                                  bnb_4bit_compute_dtype=torch.bfloat16,
                                  bnb_4bit_use_double_quant=True)
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",
                                               quantization_config=bnb_config)
  model = prepare_model_for_kbit_training(model)
  lora_config = LoraConfig(r=16, lora_alpha=32,
                           target_modules=["q_proj", "v_proj"],
                           lora_dropout=0.05, bias="none",
                           task_type="CAUSAL_LM")
  peft_model = get_peft_model(model, lora_config)

In [42]:
if IS_COLAB or IS_KAGGLE:
    %pip install -qU gguf

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/96.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m96.2/96.2 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [43]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"

torch_dtype = torch.float32 # could be torch.float16 or torch.bfloat16 too
tokenizer = AutoTokenizer.from_pretrained(model_id, gguf_file=filename)
model = AutoModelForCausalLM.from_pretrained(model_id, gguf_file=filename,
                                             dtype=torch_dtype)

config.json:   0%|          | 0.00/33.0 [00:00<?, ?B/s]

Converting and de-quantizing GGUF tensors...:   0%|          | 0/201 [00:00<?, ?it/s]