In [1]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        print(f"key shape : {keys.shape}")
        print(f"queries shape : {queries.shape}")
        print(f"values shape : {values.shape}")
        attn_scores = queries @ keys.T # omega
        print(f"attn_scores shape : {attn_scores.shape}")
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        print(f"attn_weights shape : {attn_weights.shape}")
        context_vec = attn_weights @ values
        print(f"context_vec shape : {context_vec.shape}")
        return context_vec

In [9]:
x_2 = inputs[1]                                                   #A
d_in = inputs.shape[1]                                    #B
d_out = 2
print(d_in, d_out)
print(inputs[1])
print(inputs)
print(x_2.shape)

3 2
tensor([0.5500, 0.8700, 0.6600])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])
torch.Size([3])


In [12]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

key shape : torch.Size([6, 2])
queries shape : torch.Size([6, 2])
values shape : torch.Size([6, 2])
attn_scores shape : torch.Size([6, 6])
attn_weights shape : torch.Size([6, 6])
context_vec shape : torch.Size([6, 2])
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [26]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        print(f"dim_in={d_in}, dim_out={d_out}, context_len={context_length}, dropout rate={dropout}, qkv_bias={qkv_bias}")
        self.query_w = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.key_w = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.value_w = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
        # use buffer，因为mask不是parameters，所以当模型model.to("cuda")时不会加载到gpu中，且用buffer可以让mask出现在stat_dict中保存起来，后续load时，若mask有改动可以直接把mask load出来
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    
    def forward(self, x):
        b, num_token, dim = x.shape
        print(f"batch_size = {b}, num_token： {num_token}, dim={dim}")
        query = self.query_w(x)
        key = self.key_w(x)
        value = self.value_w(x)
        print(f"key shape: {key.shape}")

        attention_scores = query @ key.transpose(1,2)

        attention_scores.masked_fill_(self.mask.bool(), -torch.inf)

        attention_weight = torch.softmax(attention_scores/key.shape[-1]**0.5, dim=-1)
        attention_weight = self.dropout(attention_weight)

        context_vec = attention_weight @ value
        return context_vec


In [28]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
print(context_vecs)

torch.Size([2, 6, 3])
dim_in=3, dim_out=2, context_len=6, dropout rate=0.0, qkv_bias=False
batch_size = 2, num_token： 6, dim=3
key shape: torch.Size([2, 6, 2])
context_vecs.shape: torch.Size([2, 6, 2])
tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


In [48]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim_in, dim_out, heads_num, context_length, dropout, qkv_bias=False):
        super().__init__()

        assert(dim_out % heads_num == 0)
        self.dim_out = dim_out
        self.head_dim = dim_out // heads_num
        self.heads_num = heads_num
        self.query_w = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.key_w = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.value_w = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        self.proj_out = nn.Linear(dim_out, dim_out)

    def forward(self, x):
        b, tokens_num, dim = x.shape
        print(f"input shape: {x.shape}")
        # 把最后dim_out维度打散成head_num x head_dim，然后交换tokens_num和head_num位置便于后面的attention计算
        query = self.query_w(x).view(b, tokens_num, self.heads_num, self.head_dim).transpose(1, 2)
        key = self.key_w(x)
        print(f"key shape: {key.shape}")
        key = key.view(b, tokens_num, self.heads_num, self.head_dim)
        print(f"key shape after unsqueeze: {key.shape}")
        key = key.transpose(1, 2)
        print(f"key shape after transpose: {key.shape}")
        value = self.value_w(x).view(b, tokens_num, self.heads_num, self.head_dim).transpose(1, 2)

        attention_scores = query @ key.transpose(2, 3)
        print(f"attention_scores shape: {attention_scores.shape}")
        # 后面带下划线_的method会改变本身的值
        attention_scores.masked_fill_(self.mask.bool()[:tokens_num, :tokens_num], -torch.inf)

        attention_weight = torch.softmax(attention_scores / key.shape[-1] ** 0.5, dim=-1)

        attention_weight = self.dropout(attention_weight)

        context_vecs = attention_weight @ value
        # print(f"context_vecs shape: {context_vecs.shape}")
        context_vecs = context_vecs.transpose(1, 2)
        # print(f"context_vecs shape after transpose: {context_vecs.shape}")
        context_vecs = context_vecs.contiguous().view(b, tokens_num, self.dim_out)
        print(f"context_vecs shape after squeeze: {context_vecs.shape}")
        context_vecs = self.proj_out(context_vecs)
        return context_vecs


In [51]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
print(f"d_in {d_in}, d_out {d_out}, batch_size {batch_size}, context_length {context_length}")
mha = MultiHeadAttention(d_in, d_out, heads_num=2, context_length=context_length, dropout=0.0 )
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

d_in 3, d_out 2, batch_size 2, context_length 6
input shape: torch.Size([2, 6, 3])
key shape: torch.Size([2, 6, 2])
key shape after unsqueeze: torch.Size([2, 6, 2, 1])
key shape after transpose: torch.Size([2, 2, 6, 1])
attention_scores shape: torch.Size([2, 2, 6, 6])
context_vecs shape: torch.Size([2, 2, 6, 1])
context_vecs shape after transpose: torch.Size([2, 6, 2, 1])
context_vecs shape after squeeze: torch.Size([2, 6, 2])
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [55]:
context_length = 1024
d_in, d_out = 768, 768
num_heads = 12

mha = MultiHeadAttention(d_in, d_out,num_heads, context_length, 0.0 )
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(mha)

2360064

In [1]:
import torch

torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")

batch_size = 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)

PyTorch version: 2.6.0+cu124


In [2]:
print(embeddings.shape)
print(embeddings)

torch.Size([8, 1024, 768])
tensor([[[ 1.3391e+00,  2.0517e-01, -1.6879e+00,  ..., -4.2419e-01,
          -5.8824e-02,  7.8626e-01],
         [ 4.0166e-01, -2.8328e-01, -7.3094e-01,  ...,  5.2304e-01,
           2.2982e+00,  6.3116e-01],
         [ 5.2773e-01,  6.7984e-02, -3.2776e-01,  ..., -2.8288e-01,
          -1.5578e+00, -8.6155e-01],
         ...,
         [ 2.9147e+00, -3.2614e-02, -6.2381e-01,  ...,  9.1058e-01,
          -1.2182e+00, -4.7430e-02],
         [ 4.2607e-01, -3.5098e-01, -1.3139e+00,  ...,  1.1188e+00,
           1.6521e+00,  1.0859e+00],
         [ 2.6405e-01,  8.3405e-01,  1.4404e+00,  ..., -8.5109e-01,
          -1.4092e+00, -1.7833e-01]],

        [[-1.4454e+00, -2.7590e+00,  3.8863e-01,  ...,  2.4145e-01,
           3.2685e-02, -7.5191e-02],
         [-8.1008e-01,  8.0733e-01,  1.0608e-01,  ..., -3.7774e-01,
          -9.7854e-01,  7.4685e-01],
         [-2.1887e-01,  2.5251e-01, -2.1778e-01,  ..., -1.4378e+00,
           1.7645e-01, -1.0513e+00],
         ...