In [126]:
import torch
import torch.nn as nn
from math import sqrt
import torch.nn.functional as F

# Encoder

In [106]:
class Configuration():
  dim_token_emb= 768
  attention_probs_dropout_prob= 0.1
  classifier_dropout= None
  gradient_checkpointing= False
  hidden_act= "gelu"
  hidden_dropout_prob= 0.1
  hidden_size= 768
  initializer_range= 0.02
  intermediate_size= 3072
  layer_norm_eps= 1e-12
  max_position_embeddings= 512
  model_type= "encoder"
  num_attention_heads= 12
  num_hidden_layers= 12
  pad_token_id= 0
  position_embedding_type= "absolute"
  type_vocab_size= 2
  use_cache= True
  vocab_size= 30522

In [107]:
config = Configuration()

In [108]:
config.dim_token_emb

768

In [109]:
def scaled_dot_product_attention(query, key, value):
  dim_k = key.size(1)
  score_matrix = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
  weight = F.softmax(score_matrix, dim=-1)
  return torch.bmm(weight, value)

In [110]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(
            self.q(hidden_state), self.k(hidden_state), self.v(hidden_state))
        return attn_outputs

In [111]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        x = self.output_linear(x)
        return x

In [112]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

In [113]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(self, x):
        # Apply layer normalization and then copy input into query, key, value
        hidden_state = self.layer_norm_1(x)
        # Apply attention with a skip connection
        x = x + self.attention(hidden_state)
        # Apply feed-forward layer with a skip connection
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x

In [114]:
encoder_layer = TransformerEncoderLayer(config)

In [115]:
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size,
                                             config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
        # Create token and position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [116]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config)
                                     for _ in range(config.num_hidden_layers)])

    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [117]:
encoder = TransformerEncoder(config)

# Decoder

In [118]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [119]:
y = torch.LongTensor([[0, 1, 2, 3, 4, 5], [3, 4, 5, 6, 7, 8]])
y.size()

torch.Size([2, 6])

In [120]:
def scaled_dot_product_attention(query, key, value, mask=None) :
  dim_k = key.size(1)
  score_matrix = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
  if mask is not None :
    score_matrix = score_matrix.masked_fill(mask==0, float('-inf'))
  weight = F.softmax(score_matrix, dim=-1)
  return torch.bmm(weight, value)

In [121]:
class DecoderEmbeddings(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
    self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
    self.dropout = nn.Dropout()

  def forward(self, input_ids):
    # Create position IDs for input sequence
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.int).unsqueeze(0)
    # Create token and position embeddings
    token_embeddings = self.token_embedding(input_ids)
    position_embeddings = self.position_embedding(position_ids)
    # Combine token and position embeddings
    embeddings = token_embeddings + position_embeddings
    embeddings = self.layer_norm(embeddings)
    embeddings = self.dropout(embeddings)
    return embeddings

In [122]:
embedding_layer = DecoderEmbeddings(config)
embeddings = embedding_layer(y)
embeddings.size()

torch.Size([2, 6, 768])

In [123]:
class MaskedAttentionHead(nn.Module):
  def __init__(self, embedding_dim, head_dim):
    super().__init__()
    self.q = nn.Linear(embedding_dim, head_dim)
    self.k = nn.Linear(embedding_dim, head_dim)
    self.v = nn.Linear(embedding_dim, head_dim)

  def forward(self, hidden_state):
    query = self.q(hidden_state)
    key = self.k(hidden_state)
    value = self.v(hidden_state)
    mask_ = torch.tril(torch.ones(query.size(1), query.size(1), dtype=torch.int)).unsqueeze(0)
    attn_outputs = scaled_dot_product_attention(query, key, value, mask=mask_)
    return attn_outputs

In [124]:
class MaskedMultiAttentionHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    embedding_dim = config.hidden_size
    num_head = config.num_attention_heads
    head_dim = embedding_dim // num_head
    self.attn_heads = nn.ModuleList(
        [MaskedAttentionHead(embedding_dim, head_dim) for _ in range(num_head)]
    )
    self.output_layer = nn.Linear(embedding_dim, embedding_dim)

  def forward(self, hidden_state):
    maskedattn_outputs = torch.cat([h(hidden_state) for h in self.attn_heads], dim=-1)
    maskedattn_outputs = self.output_layer(maskedattn_outputs)
    return maskedattn_outputs

In [127]:
masked_multi_attn = MaskedMultiAttentionHead(config)
attn_outputs = masked_multi_attn(embeddings)
attn_outputs.size()

torch.Size([2, 6, 768])

In [134]:
class EncoderDecoderAttentionHead(nn.Module):
  def __init__(self, embedding_dim, head_dim):
    super().__init__()
    self.q = nn.Linear(embedding_dim, head_dim)
    self.k = nn.Linear(embedding_dim, head_dim)
    self.v = nn.Linear(embedding_dim, head_dim)

  def forward(self, encoder_hidden_state, decoder_hidden_state):
    attn_outputs = scaled_dot_product_attention(
        self.q(decoder_hidden_state), self.k(encoder_hidden_state), self.v(encoder_hidden_state)
    )
    return attn_outputs

In [135]:
class MultiEncoderDecoderAttentionHead(nn.Module):
  def __init__(self, config):
    super().__init__()
    embedding_dim = config.hidden_size
    num_head = config.num_attention_heads
    head_dim = embedding_dim // num_head
    self.attn_heads = nn.ModuleList(
        [EncoderDecoderAttentionHead(embedding_dim, head_dim) for _ in range(num_head)]
    )
    self.output_layer = nn.Linear(embedding_dim, embedding_dim)

  def forward(self, encoder_hidden_state, decoder_hidden_state):
    attn_outputs = torch.cat([h(encoder_hidden_state, decoder_hidden_state) for h in self.attn_heads], dim=-1)
    attn_outputs = self.output_layer(attn_outputs)
    return attn_outputs

In [136]:
multi_encoder_decoder_attn = MultiEncoderDecoderAttentionHead(config)
attn_outputs = multi_encoder_decoder_attn(attn_outputs, attn_outputs)
attn_outputs.size()

torch.Size([2, 6, 768])

In [137]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

In [142]:
class TransformerDecoderLayer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(config.hidden_size)
    self.layer_norm2 = nn.LayerNorm(config.hidden_size)
    self.layer_norm3 = nn.LayerNorm(config.hidden_size)

    self.Maskedattention = MaskedMultiAttentionHead(config)
    self.EncoderDecoderattention = MultiEncoderDecoderAttentionHead(config)

    self.feed_forward = FeedForward(config)

  def forward(self, encoder_hidden_state, decoder_hidden_state):
    hidden_state1 = self.layer_norm1(decoder_hidden_state)
    masked_outputs = decoder_hidden_state + self.Maskedattention(hidden_state1)

    hidden_state2 = self.layer_norm2(masked_outputs)
    final_hidden = masked_outputs + self.EncoderDecoderattention(encoder_hidden_state, hidden_state2)

    final_hidden = final_hidden + self.feed_forward(self.layer_norm3(final_hidden))
    return final_hidden

In [143]:
class TransformerDecoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.embeddings = DecoderEmbeddings(config)

    num_layer = config.num_hidden_layers
    self.decoder_layer = nn.ModuleList(
        [TransformerDecoderLayer(config) for _ in range(num_layer)]
    )

  def forward(self, encoder_final_hidden, y):
    y = self.embeddings(y)
    for d in self.decoder_layer :
      y = d(encoder_final_hidden, y)
    return y

In [144]:
decoder = TransformerDecoder(config)

In [147]:
decoder_output = decoder(attn_outputs, y)

In [148]:
decoder_output.size()

torch.Size([2, 6, 768])