In [1]:
import fastcore.all as fc
import torch
import torch.nn as nn

In [4]:
a = torch.triu(torch.ones(5, 5), diagonal=1)
a

tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

In [5]:
torch.softmax(a, dim=-1)

tensor([[0.0842, 0.2289, 0.2289, 0.2289, 0.2289],
        [0.0985, 0.0985, 0.2677, 0.2677, 0.2677],
        [0.1185, 0.1185, 0.1185, 0.3222, 0.3222],
        [0.1488, 0.1488, 0.1488, 0.1488, 0.4046],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

In [6]:
class CausalAttentionWithoutBuffers(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    def forward(self, x):
        b, seq_len, emb_dim = x.shape
        # (b, seq_len, d_out)
        q = self.W_query(x)
        k = self.W_key(x)
        v = self.W_value(x)
        attn_scores = torch.matmul(q, k.transpose(1, 2))
        attn_scores.masked_fill_(self.mask.bool()[:seq_len, :seq_len], -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ v
        return context_vec

In [9]:
torch.manual_seed(123)

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],  # step     (x^6)
    ]
)
inputs.shape

torch.Size([6, 3])

In [10]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [12]:
ca_without_buffer = CausalAttentionWithoutBuffers(inputs.shape[1], 2, 6, 0.0)
with torch.no_grad():
    context_vecs = ca_without_buffer(batch)

context_vecs.shape

torch.Size([2, 6, 2])

In [13]:
torch.cuda.is_available()

True

In [14]:
batch = batch.to("cuda")
ca_without_buffer.to("cuda")

CausalAttentionWithoutBuffers(
  (W_query): Linear(in_features=3, out_features=2, bias=False)
  (W_key): Linear(in_features=3, out_features=2, bias=False)
  (W_value): Linear(in_features=3, out_features=2, bias=False)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [15]:
with torch.no_grad():
    context_vecs = ca_without_buffer(batch)

context_vecs.shape

RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

In [16]:
ca_without_buffer.W_key.weight.device

device(type='cuda', index=0)

In [17]:
ca_without_buffer.mask.device

device(type='cpu')

In [21]:
ca_without_buffer.mask = ca_without_buffer.mask.to("cuda")
ca_without_buffer.mask.device

device(type='cuda', index=0)

In [22]:
with torch.no_grad():
    context_vecs = ca_without_buffer(batch)

context_vecs.shape

torch.Size([2, 6, 2])

In [23]:
class CausalAttentionWithBuffers(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, seq_len, emb_dim = x.shape
        # (b, seq_len, d_out)
        q = self.W_query(x)
        k = self.W_key(x)
        v = self.W_value(x)
        attn_scores = torch.matmul(q, k.transpose(1, 2))
        attn_scores.masked_fill_(self.mask.bool()[:seq_len, :seq_len], -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)

        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ v
        return context_vec

In [24]:
ca_with_buffer = CausalAttentionWithBuffers(inputs.shape[1], 2, 6, 0.0)
ca_with_buffer.to("cuda")

CausalAttentionWithBuffers(
  (W_query): Linear(in_features=3, out_features=2, bias=False)
  (W_key): Linear(in_features=3, out_features=2, bias=False)
  (W_value): Linear(in_features=3, out_features=2, bias=False)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [26]:
print(ca_with_buffer.mask.device)

cuda:0


In [27]:
with torch.no_grad():
    context_vecs = ca_with_buffer(batch)

context_vecs.shape

torch.Size([2, 6, 2])

In [28]:
ca_without_buffer.state_dict()

OrderedDict([('W_query.weight',
              tensor([[-0.2354,  0.0191, -0.2867],
                      [ 0.2177, -0.4919,  0.4232]], device='cuda:0')),
             ('W_key.weight',
              tensor([[-0.4196, -0.4590, -0.3648],
                      [ 0.2615, -0.2133,  0.2161]], device='cuda:0')),
             ('W_value.weight',
              tensor([[-0.4900, -0.3503, -0.2120],
                      [-0.1135, -0.4404,  0.3780]], device='cuda:0'))])

In [29]:
ca_with_buffer.state_dict()

OrderedDict([('mask',
              tensor([[0., 1., 1., 1., 1., 1.],
                      [0., 0., 1., 1., 1., 1.],
                      [0., 0., 0., 1., 1., 1.],
                      [0., 0., 0., 0., 1., 1.],
                      [0., 0., 0., 0., 0., 1.],
                      [0., 0., 0., 0., 0., 0.]], device='cuda:0')),
             ('W_query.weight',
              tensor([[-0.1362,  0.1853,  0.4083],
                      [ 0.1076,  0.1579,  0.5573]], device='cuda:0')),
             ('W_key.weight',
              tensor([[-0.2604,  0.1829, -0.2569],
                      [ 0.4126,  0.4611, -0.5323]], device='cuda:0')),
             ('W_value.weight',
              tensor([[ 0.4929,  0.2757,  0.2516],
                      [ 0.2377,  0.4800, -0.0762]], device='cuda:0'))])