# Quantization Fundamentals – **Lesson 3**
# Loading Models in Different Data Types

> **Goal:** Explore how casting a model’s parameters to lower-precision dtypes (FP16, BF16) affects inference on CPU/GPU, and learn safe workflows to inspect & convert models.


## 0 · Setup

In [None]:
# Make sure we’re using a torch version that supports BF16 on CPU
#!pip install --quiet torch==2.1.1
#!pip install --quiet transformers pillow requests

In [None]:
import torch
import torch.nn as nn
from copy import deepcopy

import requests, math, gc
from transformers import BlipForConditionalGeneration, BlipProcessor
from PIL import Image
from pathlib import Path
from itertools import islice

## 1 · Dummy Model Definition  
*Replace / modify this section with your own architecture as needed.*


In [None]:

# Paste or edit the DummyModel definition here
class DummyModel(nn.Module):
  """
  A dummy model that consists of an embedding layer
  with two blocks of a linear layer followed by a layer
  norm layer.
  """
  def __init__(self):
    super().__init__()

    torch.manual_seed(123)

    self.token_embedding = nn.Embedding(2, 2)

    # Block 1
    self.linear_1 = nn.Linear(2, 2)
    self.layernorm_1 = nn.LayerNorm(2)

    # Block 2
    self.linear_2 = nn.Linear(2, 2)
    self.layernorm_2 = nn.LayerNorm(2)

    self.head = nn.Linear(2, 2)

  def forward(self, x):
    hidden_states = self.token_embedding(x)

    # Block 1
    hidden_states = self.linear_1(hidden_states)
    hidden_states = self.layernorm_1(hidden_states)

    # Block 2
    hidden_states = self.linear_2(hidden_states)
    hidden_states = self.layernorm_2(hidden_states)

    logits = self.head(hidden_states)
    return logits


In [None]:
model = DummyModel()  # baseline FP32 model
display(model)

## 2 · Utility: Inspect Parameter dtypes 
**Observation:** By default, PyTorch creates all learnable parameters in **`torch.float32`** (32‑bit single precision), which offers ~7 decimal digits of precision.


In [None]:
def print_param_dtype(m):
    """Print each parameter name and its torch.dtype."""
    for name, param in m.named_parameters():
        print(f"{name:<25}  →  {param.dtype}")

In [None]:
print("FP32 (default) model parameter dtypes:\n")
print_param_dtype(model)

## 3 · Casting to **FP16**

In [None]:
model_fp16 = DummyModel().half()  # quick cast helper
print("\nAfter .half():")
print_param_dtype(model_fp16)

In [None]:
### 3.1 · Inference Test
dummy_input = torch.LongTensor([[1, 0], [0, 1]])

# FP32 reference output
logits_fp32 = model(dummy_input)
print("FP32 logits:\n", logits_fp32)

# FP16 forward pass (CPU)
try:
    logits_fp16 = model_fp16(dummy_input)
except Exception as e:
    print("\033[91m", type(e).__name__, ":", e, "\033[0m")


**Why it fails:** Many CPU kernels (e.g. `aten::embedding`) don’t have **Half** (FP16) support. FP16 is primarily intended for **NVIDIA GPUs** (Tensor Cores) or recent chips.


## 4 · Casting to **BF16** instead
BF16 keeps the *same 8‑bit exponent* as FP32 → **wide numeric range**, but only 7‑bit mantissa → lower precision. Crucially, recent CPUs (AVX‑512 / AMX) and GPUs (Ampere+) often support BF16 kernels.


In [None]:
### 4.1 · Deep‑copy then cast
model_bf16 = deepcopy(model).to(torch.bfloat16)
print("BF16 parameter dtypes:\n")
print_param_dtype(model_bf16)


### 4.2 · BF16 inference
logits_bf16 = model_bf16(dummy_input)
print("BF16 logits:\n", logits_bf16)

## 5 · Numerical Difference vs FP32

In [None]:
mean_diff = torch.abs(logits_bf16 - logits_fp32).mean().item()
max_diff  = torch.abs(logits_bf16 - logits_fp32).max().item()
print(f"Mean diff : {mean_diff:.3e} | Max diff : {max_diff:.3e}")

In practice the error is usually **< 1e‑3**, negligible for many tasks – making BF16 a *drop‑in replacement* on supported hardware.

#### Pros & Cons of Model Down‑Casting

|  | **FP32** | **BF16** | **FP16** |
|---|---|---|---|
| Memory/Speed | Baseline | ≈50 % smaller, faster on BF16‑capable HW | ≈50 % smaller, *much* faster on GPUs with TensorCores |
| Range | 8‑bit exponent | *Same* 8‑bit exponent (≈3.4e38) | 5‑bit exponent (≈6.5e4) |
| Precision | 23‑bit mantissa | 7‑bit mantissa | 10‑bit mantissa |
| CPU support | ✅ | ✅ (new CPUs) | ❌ many ops missing |
| GPU support | ✅ | ✅ (Ampere+) | ✅ (Pascal+ with TensorCores) |

**Guideline:**
- **Use BF16** for quick wins on modern CPUs / GPUs without rewriting code.
- **Use FP16** mainly on NVIDIA GPUs when kernels exist.


1. Casting a model to a lower‑precision dtype is *one line*, but **hardware kernels must exist** for all ops.
2. **FP16 on CPU will crash/raise** for many layers → prefer BF16 or keep FP32.
3. Numerical error between BF16 and FP32 is often minimal (see diff above).
4. Always verify model accuracy after down‑casting, especially on tasks with tight error budgets.


## 6 · Using Generative Models in Different Data Types

**Focus:** Measure memory savings and qualitative accuracy when loading a large vision–language model from Hugging Face (BLIP image captioner) in **FP32 vs BF16**. Learn how `torch_dtype` and `torch.set_default_dtype()` control precision at load time.

In [None]:
def print_param_dtype_lines(model, max_lines=None):
    """
    Print the dtype of each parameter.
    If max_lines is set, show only the first `max_lines`.
    """
    iterator = model.named_parameters()
    if max_lines is not None:
        iterator = islice(iterator, max_lines)

    shown = 0
    for name, param in iterator:
        print(f"{name:<35} → {param.dtype}")
        shown += 1

    total = sum(1 for _ in model.parameters())
    if max_lines is not None and total > shown:
        print(f"... ({total - shown} more parameters not shown)")

In [None]:
## Choose & Load the Model (FP32)
MODEL_ID = "Salesforce/blip-image-captioning-base"

# Load full‑precision model & processor
auto_processor = BlipProcessor.from_pretrained(MODEL_ID)
model_fp32     = BlipForConditionalGeneration.from_pretrained(MODEL_ID)

### Inspect parameter dtypes & memory footprint
print("FP32 parameter dtypes:\n")
print_param_dtype_lines(model_fp32, max_lines=10)

fp32_bytes = model_fp32.get_memory_footprint()
print(f"\nMemory footprint (FP32): {fp32_bytes/1e6:.1f} MB")

*Hugging Face models default to **float32** for maximum accuracy.*

In [None]:
## Load the Same Model Directly in **BF16**
model_bf16 = BlipForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
)

print("BF16 parameter dtypes:\n")
print_param_dtype_lines(model_bf16, max_lines=10)

bf16_bytes = model_bf16.get_memory_footprint()
print(f"\nMemory footprint (BF16): {bf16_bytes/1e6:.1f} MB")
print(f"Relative size vs FP32 : {bf16_bytes/fp32_bytes:.2f}×")

BF16 halves the memory requirement (one byte per value saved) while keeping the *same numeric range* as FP32.


## 7 · Caption an Image – Qualitative Comparison

In [None]:
# helper to load an image from URL (or local file) - provided in course assets
def load_image(img_url):
    image = Image.open(requests.get(
        img_url, stream=True).raw).convert('RGB')

    return image

img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
image   = load_image(img_url).convert("RGB")
image.resize((500,350))

In [None]:
### Generate with FP32
inputs = auto_processor(images=image, text="a picture of", return_tensors="pt")
output_ids = model_fp32.generate(**inputs, max_new_tokens=20)
caption_fp32 = auto_processor.decode(output_ids[0], skip_special_tokens=True)
print("FP32 caption →", caption_fp32)

### Generate with BF16
# same inputs, but forward on bf16 model (weights already bf16; inputs stay fp32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    output_ids_bf16 = model_bf16.generate(**inputs, max_new_tokens=20)
caption_bf16 = auto_processor.decode(output_ids_bf16[0], skip_special_tokens=True)
print("BF16 caption →", caption_bf16)

**Observation:** Captions are identical or almost identical! Autoregressive accumulation error is minimal because BF16 keeps the 8‑bit exponent; only mantissa precision drops.


## 8 · Controlling the *Global* Default dtype
Sometimes you’d like *every module you instantiate* to come up in lower precision without passing `torch_dtype=` everywhere.



In [None]:
torch.set_default_dtype(torch.bfloat16)

# any new layers / models created *after* this line inherit BF16
bf16_dummy = DummyModel()
print_param_dtype_lines(bf16_dummy, max_lines=10)

In [None]:
# Always restore default to FP32 afterwards to avoid surprises
torch.set_default_dtype(torch.float32)

**Best practice:** Set‑and‑reset inside a `with`‑context or script section; don’t leave global default at BF16 for unrelated code.

**Takeaways**
1. **torch_dtype= parameter** lets you load Hugging Face models directly in lower precision — faster and lighter than casting afterwards.
2. **BF16 ≈ 50 % memory cut** with negligible accuracy loss for vision‑language tasks, thanks to its full FP32 exponent range.
3. FP16 kernels are plentiful on GPUs; BF16 kernels increasingly so on both CPUs & GPUs (Ampere+ / Sapphire Rapids). On unsupported hardware, PyTorch falls back to FP32.
4. **torch.set_default_dtype() changes the dtype only for *new tensors and layers you create afterward*. It doesn’t shrink the checkpoint you download; pretrained weights are still cast locally. Use it when *initializing* fresh models, but prefer the **torch_dtype=** flag (or dtype‑specific checkpoints) when loading Hugging Face models to save RAM right away.

**Practical recommendations**
1. Loading pretrained models: pass **torch_dtype=torch.bfloat16** (or float16) to from_pretrained.
2. Building new modules from scratch: temporarily call **torch.set_default_dtype(desired_dtype)** then reset so the rest of your code isn’t surprised.
3. Casting after load **(model.half())**: fine for small models or quick tests, but avoid for multi-GB checkpoints on limited RAM/VRAM.


