In [33]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import BertLMHeadModel, AutoConfig

##### Self-Attention
$Attention \left(Q, K, V\right) = softmax \left( \frac{QK^T}{\sqrt{d_k}} \right) \cdot V$

<!-- ![alternatvie text](attention.png) -->
<div>
<center>
<img src="attention.png" width="300"/>
</center>
</div>

In [34]:
def scaled_dot_product(q, k, v):
    # (bs, head, seq, hs // head)
    d_k = q.shape[-1]
    attn_score = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(d_k)
    # (bs, head, seq, seq)
    attn_probs = F.softmax(attn_score, dim=-1)
    attn_probs = F.dropout(attn_probs, 0.1)
    # (bs, head, seq, hs // head)
    attn = torch.matmul(attn_probs, v)
    return attn

In [35]:
class SelfAttention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.n_heads = n_heads

    def permute_for_scores(self, x):
        # x: (batch_size, seq_len, hidden_size)
        new_shape = x.shape[:-1] + (self.n_heads, -1)
        x = x.view(new_shape)
        # output: (bs, head, seq, hs // head)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # hidden_states: (batch_size, seq_len, hidden_size)
        # qkv layers
        q = self.permute_for_scores(self.q_proj(hidden_states))
        k = self.permute_for_scores(self.k_proj(hidden_states))
        v = self.permute_for_scores(self.v_proj(hidden_states))
        # core attention
        output = scaled_dot_product(q, k, v)
        # output: (bs, seq, head, hs // head)
        output.permute(0, 2, 1, 3)
        output.view(output.shape[0], output.shape[1], -1)
        return output

##### Attention Layer
<div>
<center>
<img src="transformer.png" width="400"/>
</center>
</div>

In [36]:
class Projection(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(hidden_size)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layer_norm(hidden_states + input_tensor)
        return hidden_states

class MLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(hidden_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        out = self.activation(out)
        out = self.linear2(out)
        return out

class Attention(nn.Module):
    def __init__(self, hidden_size, n_heads):
        super().__init__()
        self.self_attn = SelfAttention(hidden_size, n_heads)
        self.proj1 = Projection(hidden_size)
        self.linear_net = MLP(hidden_size)
        self.proj2 = Projection(hidden_size)

    def forward(self, hidden_states):
        self_output = self.self_attn(hidden_states)
        attention_output = self.proj1(self_output, hidden_states)
        linear_out = self.linear_net(attention_output)
        linear_out = attention_output + self.dropout(linear_out)
        out = self.proj2(linear_out)
        return out

In [37]:
model = Attention(hidden_size=1024, n_heads=16)
print(model)

Attention(
  (self_attn): SelfAttention(
    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (proj1): Projection(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (linear_net): MLP(
    (linear1): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): GELU(approximate=none)
    (linear2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (proj2): Projection(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
)


In [48]:


config = AutoConfig.from_pretrained("bert-large-uncased")
bert_model = BertLMHeadModel(config)

print(list(bert_model.named_modules())[4])
bert_model.bert.encoder.layer[0]

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


('bert.embeddings.position_embeddings', Embedding(512, 1024))


BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=1024, out_features=1024, bias=True)
      (key): Linear(in_features=1024, out_features=1024, bias=True)
      (value): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=1024, out_features=1024, bias=True)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=1024, out_features=4096, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=4096, out_features=1024, bias=True)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [47]:
print(bert_model.bert.encoder.layer[0])

TypeError: 'generator' object is not subscriptable

In [39]:
def train(model, device="cuda", bs=8, seq_length=512):
    input_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    attention_mask = torch.ones(bs, seq_length, dtype=torch.float16, device=device)
    token_type_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    labels = input_ids.clone()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    for step in range(100):
        inputs = (input_ids, attention_mask, token_type_ids)
        loss = model(*inputs, labels=labels).loss
        loss.backward()
        optimizer.step()

        if step % 10 == 0:
            print(f"step {step} loss: {loss.item()}")

In [40]:
#train(model, device="cpu")