<a href="https://colab.research.google.com/github/hsong-77/transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.nn.functional import log_softmax, pad
from torch import Tensor
import math
import warnings

warnings.filterwarnings("ignore")

In [None]:
def scaled_dot_product_attention(query, key, value, mask = None, drop_rate = None):
  d_k = query.size(-1)
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
  if mask is not None:
    scores = scores.masked_fill_(mask == 0, -1e9)
  p_attn = scores.softmax(dim = -1)
  if drop_rate is not None:
    p_attn = nn.Dropout(drop_rate)(p_attn)
  return torch.matmul(p_attn, value)

def padding_mask(x, pad = 2):
  return (x != pad).unsqueeze(-2)

def subsequent_mask(size):
  attn_shape = (1, size, size)
  subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal = 1).type(torch.uint8)
  return subsequent_mask == 0

def position_encoding(d_model, seq_len):
  pos = torch.arange(0, seq_len).reshape(-1, 1)
  dim = torch.arange(0, d_model, 2).reshape(1, -1)
  phase = pos / (1e4 ** (dim // d_model))

  pe = torch.zeros(seq_len, d_model)
  pe[:, 0::2] = torch.sin(phase)
  pe[:, 1::2] = torch.cos(phase)
  pe.unsqueeze(0)

  return pe


In [None]:
class LayerNorm(nn.Module):
  
  def __init__(self, features: int, eps: float = 1e-6):
    super().__init__()
    self.a_2 = nn.Parameter(torch.ones(features))
    self.b_2 = nn.Parameter(torch.zeros(features))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1, keepdim = True)
    std = x.std(-1, keepdim = True)
    return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
    #return (x - mean) / (std + self.eps)


class Residual(nn.Module):

  def __init__(self, sublayer: nn.Module, d_model: int, drop_rate: float = 0.1):
    super().__init__()
    self.sublayer = sublayer
    self.norm = LayerNorm(d_model)
    self.dropout = nn.Dropout(drop_rate)

  def forward(self, *x):
    return self.norm(x[0] + self.dropout(self.sublayer(*x)))


class AttentionHead(nn.Module):

  def __init__(self, d_model: int, d_query: int, d_value: int):
    super().__init__()
    self.q = nn.Linear(d_model, d_query)
    self.k = nn.Linear(d_model, d_query)
    self.v = nn.Linear(d_model, d_value)

  def forward(self, query, key, value, mask, drop_rate):
    return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value), mask, drop_rate)


class MultiHeadedAttention(nn.Module):

  def __init__(self, num_heads: int, d_model: int, d_query: int, d_value: int, drop_rate: float = 0.1):
    super().__init__()
    self.drop_rate = drop_rate
    self.heads = nn.ModuleList([AttentionHead(d_model, d_query, d_value) for _ in range(num_heads)])
    self.linear = nn.Linear(num_heads * d_value, d_model)

  def forward(self, query, key, value, mask = None):
    return self.linear(torch.cat([head(query, key, value, mask, self.drop_rate) for head in self.heads], dim = -1))


class PositionwiseFeedForward(nn.Module):

  def __init__(self, d_model: int, d_ff: int, drop_rate: float = 0.1):
    super().__init__()
    self.w_1 = nn.Linear(d_model, d_ff)
    self.w_2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(drop_rate)

  def forward(self, x):
    return self.w_2(self.dropout(self.w_1(x).relu()))


class Embeddings(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.lut = nn.Embedding(vocab_size, d_model)
    self.d_model = d_model

  def forward(self, x):
    return self.lut(x) * math.sqrt(self.d_model)



In [None]:
class EncoderLayer(nn.Module):

  def __init__(self, d_model: int, num_heads: int, d_ff: int, drop_rate: float = 0.1):
    super(EncoderLayer, self).__init__()
    assert d_model % num_heads == 0
    d_query = d_value = d_model // num_heads

    self.self_attn = Residual(
        MultiHeadedAttention(num_heads, d_model, d_query, d_value),
        d_model,
        drop_rate
    )
    self.feed_forward = Residual(
        PositionwiseFeedForward(d_model, d_ff),
        d_model,
        drop_rate
    )

  def forward(self, x, mask):
    x = self.self_attn(x, x, x, mask)
    return self.feed_forward(x)
    return x


class Encoder(nn.Module):

  def __init__(self, num_layers: int, d_model: int, num_heads: int, d_ff: int, vocab_size: int, drop_rate: float = 0.1):
    super(Encoder, self).__init__()
    self.embedding = Embeddings(d_model, vocab_size)
    self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, drop_rate) for _ in range(num_layers)])

  def forward(self, src, mask):
    x = self.embedding(src)
    seq_len = x.size(1)  #
    d_model = x.size(2)
    x += position_encoding(d_model, seq_len)

    for layer in self.encoder_layers:
      x = layer(x, mask)
    
    return x


class DecoderLayer(nn.Module):

  def __init__(self, d_model: int, num_heads: int, d_ff: int, drop_rate: float = 0.1):
    super().__init__()
    assert d_model % num_heads == 0
    d_query = d_value = d_model // num_heads

    self.self_attn = Residual(
        MultiHeadedAttention(num_heads, d_model, d_query, d_value),
        d_model,
        drop_rate
    )
    self.src_attn = Residual(
        MultiHeadedAttention(num_heads, d_model, d_query, d_value),
        d_model,
        drop_rate
    )
    self.feed_forward = Residual(
        PositionwiseFeedForward(d_model, d_ff),
        d_model,
        drop_rate
    )

  def forward(self, tgt, memory, src_mask, tgt_mask):
    tgt = self.self_attn(tgt, tgt, tgt, tgt_mask)
    tgt = self.src_attn(tgt, memory, memory, src_mask)
    return self.feed_forward(tgt)


class Decoder(nn.Module):

  def __init__(self, num_layers: int, d_model: int, num_heads: int, d_ff: int, vocab_size: int, drop_rate: float = 0.1):
    super().__init__()
    self.embedding = Embeddings(d_model, vocab_size)
    self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, drop_rate) for _ in range(num_layers)])
    self.linear = nn.Linear(d_model, vocab_size)

  def forward(self, tgt, memory, src_mask, tgt_mask):
    tgt = self.embedding(tgt)
    seq_len = tgt.size(1)  #
    d_model = tgt.size(2)
    tgt += position_encoding(d_model, seq_len)

    for layer in self.decoder_layers:
      tgt = layer(tgt, memory, src_mask, tgt_mask)
    return torch.softmax(self.linear(tgt), dim = -1)



In [None]:
class EncoderDecoder(nn.Module):

  def __init__(
      self, 
      src_vocab_size: int,
      tgt_vocab_size: int,
      num_encoder_layers: int = 6,
      num_decoder_layers: int = 6,
      d_model: int = 512,
      num_heads: int = 8,
      d_ff: int = 2048,
      drop_rate = 0.1
  ):
    super(EncoderDecoder, self).__init__()
    self.encoder = Encoder(
        num_encoder_layers,
        d_model,
        num_heads,
        d_ff,
        src_vocab_size,
        drop_rate
    )
    self.decoder = Decoder(
        num_decoder_layers,
        d_model,
        num_heads,
        d_ff,
        tgt_vocab_size,
        drop_rate
    )
    
  def forward(self, src, tgt, pad):
    src_mask = padding_mask(src, pad)
    tgt_mask = padding_mask(tgt, pad) & subsequent_mask(tgt.size(1))
    return self.decoder(tgt, self.encoder(src, src_mask), src_mask, tgt_mask)


In [None]:
src = torch.tensor([[1, 2, 3, 4, 0] for _ in range(4)])
tgt = torch.tensor([[4, 3, 2, 1, 0] for _ in range(4)])

out = EncoderDecoder(5, 5, 2, 2, 64, 8, 128, 0.1)(src, tgt, 0)
out