# PyTorch Buffers

PyTorch buffers are tensor attributes of a `torch.nn.Module` that are not considered parameters, but are still part of the module's state. They are typically used to store tensors that should be saved and loaded with the model, but do not require gradients for optimization. Buffers can be registered using the `register_buffer` method of a `torch.nn.Module`.

Buffers in PyTorch are useful when dealing with GPU computations, as they need to be transferred between devices alongside the model's parameters.

## Example without Buffers

In [1]:
import torch
import torch.nn as nn

class CausalAttentionWithoutBuffers(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()

        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)
        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values

        return context_vec

In [2]:
torch.manual_seed(0)
# Prepare dummy inputs
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0) # (b, num_tokens, d_in)

In [3]:
context_length = batch.shape[1]
d_in = batch.shape[2]
d_out = 2

ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, dropout=0.0)

with torch.no_grad():
    context_vecs = ca_without_buffer(batch)

print("Context Vectors without Buffers:")
print(context_vecs)

Context Vectors without Buffers:
tensor([[[-0.5063,  0.3518],
         [-0.6503,  0.3955],
         [-0.6976,  0.4064],
         [-0.6289,  0.3677],
         [-0.6131,  0.3179],
         [-0.5870,  0.3259]],

        [[-0.5063,  0.3518],
         [-0.6503,  0.3955],
         [-0.6976,  0.4064],
         [-0.6289,  0.3677],
         [-0.6131,  0.3179],
         [-0.5870,  0.3259]]])


This is fine. However, if we transfer between devices, for example, moving from CPU to GPU:

In [None]:
has_cuda = torch.cuda.is_available()
has_mps = torch.backends.mps.is_available()

print("Machine has GPU:", has_cuda or has_mps)

if has_mps:
    device = torch.device("mps")   # Apple Silicon GPU (Metal)
elif has_cuda:
    device = torch.device("cuda")  # NVIDIA GPU
else:
    device = torch.device("cpu")   # CPU fallback

print(f"Using device: {device}")

batch = batch.to(device)
ca_without_buffer = ca_without_buffer.to(device)

In [None]:
with torch.no_grad():
    context_vecs = ca_without_buffer(batch)

print("Context Vectors without Buffers (on device):")
print(context_vecs)

In [None]:
print("W_query.device:", ca_without_buffer.W_query.weight.device)
print("mask.device:", ca_without_buffer.mask.device)

In [None]:

type(ca_without_buffer.mask)

The `mask` tensor will not be transferred to the GPU, because it is not a PyTorch parameter like the weights, so we have to manually move it to the GPU via `mask.to(device)`.

In [None]:
ca_without_buffer.mask = ca_without_buffer.mask.to(device)
print("mask.device:", ca_without_buffer.mask.device)

In [None]:
with torch.no_grad():
    context_vecs = ca_without_buffer(batch)

print("Context Vectors without Buffers (on device):")
print(context_vecs)

## Example with Buffers

In [6]:
import torch
import torch.nn as nn

class CausalAttentionWithBuffers(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()

        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)
        # Old way: self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
        # New way: register the mask as a buffer
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values

        return context_vec

In [8]:
ca_with_buffer = CausalAttentionWithBuffers(d_in, d_out, context_length, 0.0)
ca_with_buffer.to(device)

print("W_query.device:", ca_with_buffer.W_query.weight.device)
print("mask.device:", ca_with_buffer.mask.device)

W_query.device: cpu
mask.device: cpu


In [9]:
with torch.no_grad():
    context_vecs = ca_with_buffer(batch)

print(context_vecs)

tensor([[[-0.3550, -0.6560],
         [-0.1536, -0.7514],
         [-0.0853, -0.7803],
         [-0.0297, -0.7015],
         [-0.0417, -0.6247],
         [-0.0040, -0.6322]],

        [[-0.3550, -0.6560],
         [-0.1536, -0.7514],
         [-0.0853, -0.7803],
         [-0.0297, -0.7015],
         [-0.0417, -0.6247],
         [-0.0040, -0.6322]]])


## Buffers and `state_dict`

The PyTorch buffers get included in a model's `state_dict`, which is a dictionary containing all the parameters and buffers of the model. This means that when you save a model using `torch.save(model.state_dict(), 'model.pth')`, both the parameters and buffers will be saved. When you load the model using `model.load_state_dict(torch.load('model.pth'))`, both the parameters and buffers will be restored to their saved state.

In [10]:
ca_without_buffer.state_dict()

OrderedDict([('W_query.weight',
              tensor([[-0.0043,  0.3097, -0.4752],
                      [-0.4249, -0.2224,  0.1548]])),
             ('W_key.weight',
              tensor([[-0.0114,  0.4578, -0.0512],
                      [ 0.1528, -0.1745, -0.1135]])),
             ('W_value.weight',
              tensor([[-0.5516, -0.3824, -0.2380],
                      [ 0.0214,  0.2282,  0.3464]]))])

In [11]:
ca_with_buffer.state_dict()

OrderedDict([('mask',
              tensor([[0., 1., 1., 1., 1., 1.],
                      [0., 0., 1., 1., 1., 1.],
                      [0., 0., 0., 1., 1., 1.],
                      [0., 0., 0., 0., 1., 1.],
                      [0., 0., 0., 0., 0., 1.],
                      [0., 0., 0., 0., 0., 0.]])),
             ('W_query.weight',
              tensor([[-0.3914, -0.2514,  0.2097],
                      [ 0.4794, -0.1188,  0.4320]])),
             ('W_key.weight',
              tensor([[-0.0931,  0.0611,  0.5228],
                      [-0.5356, -0.3635, -0.1462]])),
             ('W_value.weight',
              tensor([[-0.2251,  0.4988, -0.3742],
                      [-0.2658, -0.4034, -0.5407]]))])