# Decoder 구현

In [None]:
class Embedding(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.token_embedding = nn.Embedding(config.num_word,config.dim_token_emb)
    self.position_emb = nn.Embedding(config.max_position_embeddings,config.dim_token_emb)

  def forward(self,input_idx):
    wor_emb=self.token_embedding(input_idx)
    seq_len = input_idx.shape[0]
    position_idx = torch.arange(seq_len,dtype=torch.long)
    pos_emb=self.position_emb(position_idx)

    emb = wor_emb + pos_emb

    return emb

In [None]:
def scaled_dot_product_attention(query,key,value,mask=None):
  dim_k = key.shape[-1]
  scores = query.bmm(key.transpose(1,2))/math.sqrt(dim_k)
  if mask is not None:
    scores = scores.masked_fill(mask==0,float('-inf'))
  weights = F.softmax(scores,dim=-1)

  return weights.bmm(value)

In [None]:
class AttentionHead(nn.Module):
  def __init__(self,embed_dim,head_dim):
    super().__init__()
    self.q = nn.Linear(embed_dim,head_dim)
    self.k = nn.Linear(embed_dim,head_dim)
    self.v = nn.Linear(embed_dim,head_dim)

  def forward(self,hidden_state,mask=None):
    attention_output = scaled_dot_product_attention(
        self.q(hidden_state),self.k(hidden_state),self.v(hidden_state),mask=mask)

    return attention_output

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.embed_dim = config.hidden_size
    self.num_heads = config.num_heads
    self.head_dim = int(self.embed_dim/self.num_heads)
    self.heads = nn.ModuleList(
        [AttentionHead(self.embed_dim,self.head_dim) for _ in range(self.num_heads)]
    )
    self.linear = nn.Linear(self.embed_dim,self.embed_dim)

  def forward(self,hidden_state,mask=None):
    output = torch.cat([att(hidden_state,mask) for att in self.heads],dim=-1)
    output = self.linear(output)
    return output

In [None]:
class FeedForward(nn.Module):
  def __init__(self,config):
    super().__init__()

    self.linear_1 = nn.Linear(config.hidden_size,config.middle_size)
    self.linear_2 = nn.Linear(config.middel_size,config.hidden_size)
    self.act = nn.GELU()

  def forward(self,hidden_state):
    x = self.linear_1(hidden_state)
    x = self.linear_2(x)

    return self.act(x)

In [None]:
class TransformerDecoderLayer(nn.Module):
  def __init__(self,config):
    super().__init__()
    self.embedding = Embedding(config)
    self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
    self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
    self.masked_attn = MultiHeadAttention(config)
    self.attn = MultiHeadAttention(config)
    self.ff = FeedForward(config)

  def forward(self,input_idx,mask):
    hidden_state = self.Embedding(input_idx)
    masked_attn_output = self.masked_attn(hidden_state,mask)
    masked_attn_output = self.layer_norm1(masked_attn_output + hidden_state)
    attn_output = self.attn(masked_attn_output)
    attn_output = self.layer_norm2(attn_output + masked_attn_output)
    output = self.ff(attn_output)

    return output

In [None]:
def generate_mask(input_idx):
  seq_len = input_idx.shape[1]
  mask = torch.tril(seq_len)

  return mask