In [1]:
import torch
import torch.nn as nn

In [40]:
target = torch.tensor([1, 2, 3, 4, 5])
mask = torch.tensor([False, True, False, True, False])
source = torch.tensor([10, 20])

result = target.masked_scatter(mask, source)
result

tensor([ 1, 10,  3, 20,  5])

In [44]:
attention_mask = torch.full(
    (4, 12), fill_value=1
)
attention_mask.shape

torch.Size([4, 12])

In [49]:
attention_mask.cumsum(-1)[:, -1].dim()

1

In [48]:
(attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

In [39]:
vocab_size, hidden_state = 10000, 256
lm_head = nn.Linear(hidden_state, vocab_size)
embed_tokens = nn.Embedding(vocab_size, hidden_state)
lm_head.weight.shape, embed_tokens.weight.shape

(torch.Size([10000, 256]), torch.Size([10000, 256]))

In [36]:
torch.equal(lm_head.weight, embed_tokens.weight)

False

In [37]:
lm_head.weight = embed_tokens.weight

In [38]:
torch.equal(lm_head.weight, embed_tokens.weight)

True

In [2]:
base, dim = 1000, 128
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq.shape, inv_freq.dtype

(torch.Size([64]), torch.float32)

In [3]:
inv_freq_expanded = inv_freq[None, :, None]
inv_freq_expanded.shape

torch.Size([1, 64, 1])

In [4]:
bs, seq_len, n_heads = 2, 10, 4
position_ids = torch.cat([torch.arange(seq_len)[None, :] for _ in range(bs)], dim=0)
position_ids

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [5]:
position_ids_expanded = position_ids[:, None, :].float()
position_ids_expanded.shape

torch.Size([2, 1, 10])

In [6]:
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
freqs.shape

torch.Size([2, 10, 64])

In [7]:
freqs

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 8.9769e-01, 8.0584e-01,  ..., 1.3824e-03,
          1.2409e-03, 1.1140e-03],
         [2.0000e+00, 1.7954e+00, 1.6117e+00,  ..., 2.7647e-03,
          2.4819e-03, 2.2279e-03],
         ...,
         [7.0000e+00, 6.2838e+00, 5.6409e+00,  ..., 9.6766e-03,
          8.6866e-03, 7.7978e-03],
         [8.0000e+00, 7.1815e+00, 6.4467e+00,  ..., 1.1059e-02,
          9.9275e-03, 8.9118e-03],
         [9.0000e+00, 8.0792e+00, 7.2526e+00,  ..., 1.2441e-02,
          1.1168e-02, 1.0026e-02]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 8.9769e-01, 8.0584e-01,  ..., 1.3824e-03,
          1.2409e-03, 1.1140e-03],
         [2.0000e+00, 1.7954e+00, 1.6117e+00,  ..., 2.7647e-03,
          2.4819e-03, 2.2279e-03],
         ...,
         [7.0000e+00, 6.2838e+00, 5.6409e+00,  ..., 9.6766e-03,
          8.686

In [8]:
q = torch.ones(bs, n_heads, seq_len, dim)
q.shape

torch.Size([2, 4, 10, 128])

In [9]:
x1 = q[..., :q.shape[-1] // 2]
x2 = q[..., q.shape[-1] // 2:]
x1.shape, x2.shape

(torch.Size([2, 4, 10, 64]), torch.Size([2, 4, 10, 64]))

In [10]:
a = torch.cat((-x2, x1), dim=-1)
a.shape

torch.Size([2, 4, 10, 128])

In [11]:
a

tensor([[[[-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          ...,
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.]],

         [[-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          ...,
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.]],

         [[-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          ...,
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1.,  1.,  1.]],

         [[-1., -1., -1.,  ...,  1.,  1.,  1.],
          [-1., -1., -1.,  ...,  1., 

In [12]:
freqs.shape

torch.Size([2, 10, 64])

In [13]:
emb = torch.cat((freqs, freqs), dim=-1)
emb.shape

torch.Size([2, 10, 128])

In [14]:
cos = emb.cos()
cos.shape

torch.Size([2, 10, 128])

In [15]:
cos = cos.unsqueeze(1)
cos.shape

torch.Size([2, 1, 10, 128])

In [16]:
a.shape

torch.Size([2, 4, 10, 128])

In [17]:
def rotate_half(x):
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

In [18]:
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
x

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

In [19]:
x.shape

torch.Size([2, 4])

In [20]:
rotate_half(x)

tensor([[-3, -4,  1,  2],
        [-7, -8,  5,  6]])

In [21]:
a = torch.arange(10).expand(1, -1)
a.shape

torch.Size([1, 10])

In [22]:
num_positions = 100
emb_dim = 128
B = 4

In [23]:
emb_layer = nn.Embedding(num_positions, emb_dim)

In [24]:
conv_output = torch.randn(B, num_positions, emb_dim)

In [25]:
position_ids = torch.arange(num_positions)
position_emb = emb_layer(position_ids)
position_emb.shape

torch.Size([100, 128])

In [26]:
out_1 = conv_output + position_emb

In [27]:
out_2 = conv_output + emb_layer(position_ids.expand(1, -1))
out_2.shape

torch.Size([4, 100, 128])

In [28]:
torch.equal(out_1, out_2)

True

In [29]:
B, num_heads, seq_len, emb_dim = 4, 2, 10, 128

In [30]:
attn_scores = torch.randn(B, num_heads, seq_len, seq_len)
values = torch.randn(B, num_heads, seq_len, emb_dim // num_heads)
attn_scores.shape, values.shape

(torch.Size([4, 2, 10, 10]), torch.Size([4, 2, 10, 64]))

In [31]:
out = attn_scores @ values
out.shape

torch.Size([4, 2, 10, 64])

In [32]:
act = torch.ones(B, seq_len, emb_dim)
act.shape

torch.Size([4, 10, 128])

In [33]:
act.mean(-1, keepdim=True).shape

torch.Size([4, 10, 1])