In [0]:
import torch
import math
from torch import nn
import torch.nn.functional as F

In [0]:
class EncoderLayer(nn.Module):
  def __init__(self, word_depth, num_heads, head_depth,
               ff_activation = nn.ReLU, dropout_prob=0.5):
    super().__init__()

    self.num_heads = num_heads
    self.word_depth = word_depth
    self.head_depth = head_depth
    self.output_depth = num_heads * head_depth

    self.query = nn.Linear(word_depth, self.output_depth)
    self.key = nn.Linear(word_depth, self.output_depth)
    self.value = nn.Linear(word_depth, self.output_depth)

    self.head_combination = nn.Linear(self.output_depth, self.word_depth)
    self.feedforwards = nn.Sequential(
        nn.Linear(self.word_depth, self.word_depth),
        ff_activation(),
        nn.Linear(self.word_depth, self.word_depth)
    )

    self.norm = nn.LayerNorm(self.word_depth)
    self.dropout = nn.Dropout(dropout_prob)
  
  def forward(self, inp):

    # Applying the query, key, and value layers.
    q = self.query(inp)
    k = self.key(inp)
    v = self.value(inp)

    # Shape <- (batch, sequence, num_heads, head_depth)
    def reshape_per_head(x):
      initial_shape = x.shape
      final_shape = initial_shape[:-1] + (self.num_heads, self.head_depth)
      return x.view(final_shape)
    
    q = reshape_per_head(q)
    k = reshape_per_head(k)
    v = reshape_per_head(v)

    # (batch, sequence, num_heads)
    attention_values = q @ k.transpose(-1, -2)
    print(f"att val", attention_values.shape)
    scaled_attention_values = attention_values / math.sqrt(self.head_depth)

    normalized_attention_values = torch.softmax(scaled_attention_values, dim = -1)

    values = normalized_attention_values @ v
    new_value_shape = values.shape[:-2] + (self.output_depth,) 
    values_reshaped = values.view(*new_value_shape).contiguous()

    combined_heads = self.head_combination(values_reshaped)
    final_attention = self.norm(self.dropout(combined_heads))

    feedforward = self.feedforwards(final_attention)

    skip_connection = self.norm(inp + self.dropout(feedforward))

    final = self.norm(skip_connection)

    return final
    

In [0]:
class Embeddings(nn.Module):
  def __init__(self, vocab_size, embedding_depth, max_length=512):
    super().__init__()

    self.word_embedding = nn.Embedding(vocab_size, embedding_depth)
    self.positional_embedding = nn.Embedding(max_length, embedding_depth)
    self.bitext_embedding = nn.Embedding(2, embedding_depth)

    self.norm = nn.LayerNorm(embedding_depth)

  def forward(self, inp, token_types=None):
    
    if token_types is None:
      token_types = torch.zeros(inp.shape).long()
    
    we = self.word_embedding(inp)
    pe = self.positional_embedding(token_types)
    be = self.bitext_embedding(inp)

    embedding = we + pe + be

    return self.norm(embedding)

In [0]:
class BERT(nn.Module):
  def __init__(self, num_layers, heads_per_layer, head_depth, vocab_size, embedding_depth, max_length):
    super().__init__()

    self.embeddings = Embeddings(vocab_size, embedding_depth, max_length=max_length)
    layers = [EncoderLayer(embedding_depth, heads_per_layer, head_depth=head_depth) for _ in range(num_layers)]
    self.layers = nn.Sequential(*layers)
  
  def forward(self, inp, token_types=None):

    embedded = self.embeddings(inp, token_types=token_types)
    out = self.layers(embedded)

    return out
  

In [0]:
b = BERT(2, 3, 4, 10, 12, 5)

In [0]:
b(torch.zeros(7, 2).long())

att val torch.Size([7, 2, 3, 3])
att val torch.Size([7, 2, 3, 3])


tensor([[[-1.9141, -0.0635, -0.5667, -0.2885, -0.3414,  2.0442, -1.0042,
           1.3162, -0.4922,  0.2312,  0.6234,  0.4554],
         [-1.9143, -0.3391,  0.0848, -0.1083,  0.6554,  1.9112, -1.2758,
           0.3930, -1.1434,  0.4879,  0.7014,  0.5473]],

        [[-1.1996, -0.9819, -1.0433,  0.3721,  0.6000,  1.6952, -0.8568,
           0.0877, -1.3681,  0.7744,  1.1593,  0.7608],
         [-1.6925, -0.5312, -0.1901, -0.1470,  0.4121,  1.5486, -1.2408,
           0.6970, -1.3836,  0.6325,  1.1619,  0.7332]],

        [[-1.7565,  0.0051, -0.1654,  0.4266, -0.3064,  1.8467, -1.3330,
           0.2359, -1.2515,  0.6251,  0.9122,  0.7612],
         [-1.1979, -0.5619, -0.4143, -0.1075,  0.5098,  2.1466, -1.4576,
           1.0072, -1.2004,  0.4544,  0.6030,  0.2185]],

        [[-0.9966,  0.3235, -0.8787,  0.4582, -0.6787,  1.7770, -1.3124,
          -0.0932, -1.0960,  0.2931,  1.7558,  0.4479],
         [-1.9144, -0.3934, -1.1846,  0.0758,  1.0159,  1.2933, -0.4631,
          -0.0276,