Buffers in Pytorch are useful when dealing with GPUs. Unlike parameters, buffers do not require gradient computation, but they still need to be on the correct device


In [1]:
# setup input
import torch
import torch.nn as nn


torch.manual_seed(211)

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
context_length = batch.shape[1]
input_embedding_dim = inputs.shape[1]
output_embedding_dim = 5

print(batch.shape)

torch.Size([2, 6, 3])


# Causal Self-Attention Class Without Buffers

In [2]:
class CausalAttentionWithoutBuffers(nn.Module):
  def __init__(self,
               input_embedding_dim,
               output_embedding_dim,
               context_length,
               dropout,
               qkv_bias=False):
    super().__init__()
    self.output_embedding_dim = output_embedding_dim
    self.W_query = nn.Linear(input_embedding_dim,
                             output_embedding_dim,
                             bias=qkv_bias)
    self.W_key = nn.Linear(input_embedding_dim,
                           output_embedding_dim,
                           bias=qkv_bias)
    self.W_value = nn.Linear(input_embedding_dim,
                              output_embedding_dim,
                              bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.mask = torch.triu(
        torch.ones(context_length, context_length),
        diagonal=1)

  def forward(self,inputs):
    batch, num_tokens, input_embedding_dim = inputs.shape
    keys = self.W_key(inputs)
    queries = self.W_query(inputs)
    values = self.W_value(inputs)

    attention_scores = queries @ keys.transpose(1, 2)
    attention_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], - torch.inf)
    masked_attention_weight = torch.softmax(
        attention_scores / (keys.shape[-1]**0.5),
        dim=-1)
    masked_attention_dropout_weight = self.dropout(masked_attention_weight)

    context_vector = masked_attention_dropout_weight @ values
    return context_vector

In [3]:
# create an instance
ca_wo_buffers = CausalAttentionWithoutBuffers(
    input_embedding_dim=input_embedding_dim,
    output_embedding_dim=output_embedding_dim,
    context_length=context_length,
    dropout=0.1
)

with torch.no_grad():
  context_vectors = ca_wo_buffers(batch)

print(context_vectors)

tensor([[[ 0.6402,  0.5157, -0.2856,  0.0129,  0.2789],
         [ 0.4965,  0.6147, -0.5197, -0.0234,  0.3758],
         [ 0.4503,  0.6470, -0.5922, -0.0287,  0.4065],
         [ 0.3693,  0.5715, -0.5502, -0.0445,  0.3656],
         [ 0.3410,  0.5678, -0.5272,  0.0224,  0.3588],
         [ 0.3110,  0.5293, -0.5249, -0.0352,  0.3423]],

        [[ 0.6402,  0.5157, -0.2856,  0.0129,  0.2789],
         [ 0.4965,  0.6147, -0.5197, -0.0234,  0.3758],
         [ 0.4503,  0.6470, -0.5922, -0.0287,  0.4065],
         [ 0.3693,  0.5715, -0.5502, -0.0445,  0.3656],
         [ 0.2671,  0.4230, -0.3759,  0.0342,  0.2633],
         [ 0.2814,  0.4623, -0.4287,  0.0150,  0.2920]]])


Everything workds fine.

Now let's transfer the `CausalAttentionWithoutBuffers` module to a GPU device

In [5]:
print("Machine has GPU:", torch.cuda.is_available())

batch = batch.to("cuda")
ca_wo_buffers.to("cuda")

Machine has GPU: True


CausalAttentionWithoutBuffers(
  (W_query): Linear(in_features=3, out_features=5, bias=False)
  (W_key): Linear(in_features=3, out_features=5, bias=False)
  (W_value): Linear(in_features=3, out_features=5, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [6]:
with torch.no_grad():
  context_vectors = ca_wo_buffers(batch)

print(context_vectors)

RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

It seems like we attempted a matrix multiplication between a tensor on a GPU and a tensor on a CPU. But we moved the module to the GPU!?

Let's double check the device locations of some of the tensors:

In [7]:
print("W_query.device:", ca_wo_buffers.W_query.weight.device)
print("mask.device:", ca_wo_buffers.mask.device)

W_query.device: cuda:0
mask.device: cpu


 `mask` was not moved onto the GPU. That's because it's not a PyTorch parameter like the weights

 Let's manually move it to the GPU via `.to("cuda")`:

In [8]:
ca_wo_buffers.mask = ca_wo_buffers.mask.to("cuda")
print("mask.device:", ca_wo_buffers.mask.device)

mask.device: cuda:0


In [10]:
with torch.no_grad():
    context_vecs = ca_wo_buffers(batch)

print(context_vecs)

tensor([[[ 0.6402,  0.5157, -0.2856,  0.0129,  0.2789],
         [ 0.3104,  0.2500, -0.1385,  0.0063,  0.1352],
         [ 0.4503,  0.6470, -0.5922, -0.0287,  0.4065],
         [ 0.2135,  0.4460, -0.4807, -0.0476,  0.2977],
         [ 0.3139,  0.4931, -0.4356,  0.0416,  0.3064],
         [ 0.3110,  0.5293, -0.5249, -0.0352,  0.3423]],

        [[ 0.6402,  0.5157, -0.2856,  0.0129,  0.2789],
         [ 0.4965,  0.6147, -0.5197, -0.0234,  0.3758],
         [ 0.4503,  0.6470, -0.5922, -0.0287,  0.4065],
         [ 0.2795,  0.3944, -0.3671, -0.0348,  0.2492],
         [ 0.3410,  0.5678, -0.5272,  0.0224,  0.3588],
         [ 0.3110,  0.5293, -0.5249, -0.0352,  0.3423]]], device='cuda:0')


It worked!!!

But remembering to move individual tensors to the GPU can be tedious.

Let's use `register_buffer` to register the `mask` as a buffer:

# Causal Self-Attention Class With Buffers

In [13]:
class CausalAttentionWithBuffer(nn.Module):
  def __init__(self,
               input_embedding_dim,
               output_embedding_dim,
               context_length,
               dropout,
               qkv_bias=False):
    super().__init__()
    self.output_embedding_dim = output_embedding_dim
    self.W_query = nn.Linear(input_embedding_dim,
                             output_embedding_dim,
                             bias=qkv_bias)
    self.W_key = nn.Linear(input_embedding_dim,
                           output_embedding_dim,
                           bias=qkv_bias)
    self.W_value = nn.Linear(input_embedding_dim,
                              output_embedding_dim,
                              bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    # old:
    # self.mask = torch.triu(
    #     torch.ones(context_length, context_length),
    #     diagonal=1)

    # new:
    self.register_buffer("mask",
                         torch.triu(
                             torch.ones(context_length, context_length),
                             diagonal=1))

  def forward(self,inputs):
    batch, num_tokens, input_embedding_dim = inputs.shape
    keys = self.W_key(inputs)
    queries = self.W_query(inputs)
    values = self.W_value(inputs)

    attention_scores = queries @ keys.transpose(1, 2)
    attention_scores.masked_fill_(
        self.mask.bool()[:num_tokens, :num_tokens], - torch.inf)
    masked_attention_weight = torch.softmax(
        attention_scores / (keys.shape[-1]**0.5),
        dim=-1)
    masked_attention_dropout_weight = self.dropout(masked_attention_weight)

    context_vector = masked_attention_dropout_weight @ values
    return context_vector

Now, conveniently, if we move the module to the GPU, the mask will be located on the GPU as well:

In [14]:
ca_with_buffer = CausalAttentionWithBuffer(input_embedding_dim,
                                           output_embedding_dim,
                                           context_length,
                                           0.1)
ca_with_buffer.to("cuda")

print("W_query.device:", ca_with_buffer.W_query.weight.device)
print("mask.device:", ca_with_buffer.mask.device)

W_query.device: cuda:0
mask.device: cuda:0


In [15]:
with torch.no_grad():
    context_vecs = ca_with_buffer(batch)

print(context_vecs)

tensor([[[ 0.6924, -0.1047,  0.2637, -0.5263,  0.0409],
         [ 0.7756, -0.0087,  0.5004, -0.2959, -0.0401],
         [ 0.8001,  0.0166,  0.5725, -0.2217, -0.0680],
         [ 0.7084,  0.0397,  0.5361, -0.1610, -0.0656],
         [ 0.6688, -0.0150,  0.4947, -0.1358, -0.0907],
         [ 0.4174,  0.0600,  0.3324, -0.0908, -0.0254]],

        [[ 0.6924, -0.1047,  0.2637, -0.5263,  0.0409],
         [ 0.7756, -0.0087,  0.5004, -0.2959, -0.0401],
         [ 0.5785,  0.0501,  0.4881, -0.0532, -0.0810],
         [ 0.4953,  0.0194,  0.3557, -0.1413, -0.0365],
         [ 0.3220, -0.0451,  0.2021, -0.1039, -0.0421],
         [ 0.6422,  0.0302,  0.5043, -0.1089, -0.0783]]], device='cuda:0')


# Buffers and `state_dict`
Another advantage of PyTorch buffers, over regular tensors, is that they get included in a model's `state_dict`:

In [17]:
# without buffer
ca_wo_buffers.state_dict()

OrderedDict([('W_query.weight',
              tensor([[-0.5424,  0.0719, -0.1568],
                      [-0.5603, -0.3087, -0.5218],
                      [-0.2370, -0.5260,  0.4572],
                      [-0.0705, -0.0621, -0.1545],
                      [ 0.3552, -0.3376, -0.2241]], device='cuda:0')),
             ('W_key.weight',
              tensor([[ 0.1450, -0.1837,  0.5103],
                      [ 0.5660, -0.1473, -0.2244],
                      [-0.2593, -0.1432,  0.1937],
                      [ 0.4595, -0.4400, -0.5255],
                      [ 0.4428, -0.0302, -0.0299]], device='cuda:0')),
             ('W_value.weight',
              tensor([[ 0.2712, -0.2171,  0.5530],
                      [ 0.5426,  0.2207,  0.2221],
                      [-0.3114, -0.5317, -0.0487],
                      [ 0.4521, -0.2174, -0.1687],
                      [ 0.2986,  0.2177,  0.1011]], device='cuda:0'))])

In [18]:
# with buffer
ca_with_buffer.state_dict()

OrderedDict([('mask',
              tensor([[0., 1., 1., 1., 1., 1.],
                      [0., 0., 1., 1., 1., 1.],
                      [0., 0., 0., 1., 1., 1.],
                      [0., 0., 0., 0., 1., 1.],
                      [0., 0., 0., 0., 0., 1.],
                      [0., 0., 0., 0., 0., 0.]], device='cuda:0')),
             ('W_query.weight',
              tensor([[-0.1119,  0.5082, -0.3537],
                      [ 0.2697, -0.0218, -0.5054],
                      [-0.2409, -0.1694, -0.1194],
                      [ 0.1612, -0.0203, -0.1356],
                      [-0.4135,  0.2456, -0.5693]], device='cuda:0')),
             ('W_key.weight',
              tensor([[-0.2781, -0.2590, -0.3284],
                      [-0.0922, -0.3013,  0.3599],
                      [ 0.3255,  0.2196, -0.0739],
                      [ 0.1620,  0.5560,  0.4008],
                      [-0.5715, -0.4526,  0.2410]], device='cuda:0')),
             ('W_value.weight',
              tensor([[ 0.

saving and loading the `mask` is maybe not useful, but in the case where it is modified, it will remain unchange even if we save and load the `state_dict`:

In [19]:
ca_with_buffer.mask[ca_with_buffer.mask == 1.] = 2.
ca_with_buffer.mask

tensor([[0., 2., 2., 2., 2., 2.],
        [0., 0., 2., 2., 2., 2.],
        [0., 0., 0., 2., 2., 2.],
        [0., 0., 0., 0., 2., 2.],
        [0., 0., 0., 0., 0., 2.],
        [0., 0., 0., 0., 0., 0.]], device='cuda:0')

In [21]:
torch.save(ca_with_buffer.state_dict(), "model.pth")

new_ca_with_buffer = CausalAttentionWithBuffer(input_embedding_dim,
                                               output_embedding_dim,
                                               context_length,
                                               0.1)
new_ca_with_buffer.load_state_dict(torch.load("model.pth"))

new_ca_with_buffer.mask

tensor([[0., 2., 2., 2., 2., 2.],
        [0., 0., 2., 2., 2., 2.],
        [0., 0., 0., 2., 2., 2.],
        [0., 0., 0., 0., 2., 2.],
        [0., 0., 0., 0., 0., 2.],
        [0., 0., 0., 0., 0., 0.]])

without buffers, this is not true:

In [22]:
ca_wo_buffers.mask[ca_wo_buffers.mask == 1.] = 2.

torch.save(ca_wo_buffers.state_dict(), "model.pth")

new_ca_wo_buffer = CausalAttentionWithoutBuffers(input_embedding_dim,
                                                 output_embedding_dim,
                                                 context_length,
                                                 0.1)
new_ca_wo_buffer.load_state_dict(torch.load("model.pth"))

new_ca_wo_buffer.mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])