In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
pip install sentencepiece

Note: you may need to restart the kernel to use updated packages.


In [3]:
import sentencepiece as spm

de_vocab_file = '../vocab/de.model'
en_vocab_file = '../vocab/en.model'

de_vocab = spm.SentencePieceProcessor()
en_vocab = spm.SentencePieceProcessor()

# de, en vocab 로드
de_vocab.load(de_vocab_file)
en_vocab.load(en_vocab_file)

True

In [4]:
import pandas as pd

train_df = pd.read_csv('../dataset/train.csv')

In [5]:
# data.py

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

# mt Dataset
class MtDataset(Dataset):
  def __init__(self, src_vocab, trg_vocab, df, src_name, trg_name):
    self.src_vocab  = src_vocab
    self.trg_vocab = trg_vocab
    self.src_train = []
    self.trg_train = []

    for idx, row in df.iterrows():
      src_line = row[src_name]
      trg_line = row[trg_name]
      if type(src_line) != str or type(trg_line) != str:
        continue
      # src 문장, trg 문장 각각 tokenize
      self.src_train.append(src_vocab.encode_as_ids(src_line))
      self.trg_train.append(trg_vocab.encode_as_ids(trg_line))

  def __len__(self):
    assert len(self.src_train) == len(self.trg_train)
    return len(self.src_train)

  def __getitem__(self, idx):
    return (torch.tensor(self.src_train[idx]), torch.tensor(self.trg_train[idx]))


# mt data collate_fn
# 배치 단위로 데이터 처리
def mt_collate_fn(inputs):
  enc_inputs, dec_inputs = list(zip(*inputs)) # to do

  # 입력 길이가 다르므로 입력 최대 길이에 맟춰 padding(0) 추가
  enc_inputs = torch.nn.utils.rnn.pad_sequence(enc_inputs, batch_first=True)
  dec_inputs = torch.nn.utils.rnn.pad_sequence(dec_inputs, batch_first=True)

  batch = [
      enc_inputs,
      dec_inputs
  ]

  return batch # DataLoader iterate 할 때 return됨


# DataLoader
def build_mt_data_loader(src_vocab, trg_vocab, df, src_name, trg_name, args, shuffle=True):
  # Dataset 생성
  dataset = MtDataset(src_vocab, trg_vocab, df, src_name, trg_name)
  if 1 < args['n_gpu'] and shuffle:
    sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=args['batch_size'], sampler=sampler, collate_fn=mt_collate_fn)
  else:
    sampler = None
    loader = DataLoader(dataset, batch_size=args['batch_size'], sampler=sampler, shuffle=shuffle, collate_fn=mt_collate_fn)

  return loader, sampler

In [6]:
tmp_config = {
    "n_gpu": 1, #tmp
    "n_layer": 6,
    "batch_size": 256,
    "n_enc_vocab": 8000, # tmp
    "n_dec_vocab": 8000, # tmp
    "n_enc_seq": 80, # tmp
    "n_dec_seq": 80, # tmp
    "d_model": 512,
    "d_ff": 2048,
    "h": 8,
    "d_h": 64,
    "dropout": 0.1,
    "layer_norm_epsilon": 1e-12,
    "i_pad": 0,
}

In [7]:
args = {
    'n_gpu': tmp_config['n_gpu'],
    'batch_size': tmp_config['batch_size'],
}

loader, sampler = build_mt_data_loader(en_vocab, de_vocab, train_df, 'en', 'de', args)

In [8]:
# Sinusoidal position representations
def get_sinusoidal(n_seq, d_model):
  '''
  Args:
      n_seq: sequence 길이 (=한 문장 내 토큰 개수)
      d_model: (=512)
  '''
  def cal_angle(i_seq, i_dmodel):
    return i_seq / np.power(10000, 2 * (i_dmodel // 2) / d_model)

  def get_pos_enc(i_seq):
    return [cal_angle(i_seq, i_dmodel) for i_dmodel in range(d_model)]

  pos_enc_table = np.array([get_pos_enc(i_seq) for i_seq in range(n_seq)])
  pos_enc_table[:, 0::2] = np.sin(pos_enc_table[:, 0::2]) # even idx
  pos_enc_table[:, 1::2] = np.cos(pos_enc_table[:, 1::2]) # odd idx

  return pos_enc_table

In [9]:
class FFN(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.conv1 = nn.Conv1d(in_channels=self.config["d_model"], out_channels=self.config["d_ff"], kernel_size=1)
    self.conv2 = nn.Conv1d(in_channels=self.config["d_ff"], out_channels=self.config["d_model"], kernel_size=1)
    self.active = F.relu
    self.dropout = nn.Dropout(self.config["dropout"])

  # inputs: (batch_size, n_seq, d_model)
  def forward(self, inputs):
    # (batch_size, n_seq, d_model) -> (batch_size, d_model, n_seq) -> (batch_size, d_ff, n_seq)
    output = self.active(self.conv1(inputs.transpose(1,2)))
    # (batch_size, d_ff, n_seq) -> (batch_size, d_model, n_seq) -> (batch_size, n_seq, d_model)
    output = self.conv2(output).transpose(1,2)
    output = self.dropout(output)
    # output: (batch_size, n_seq, d_model)
    return output

In [10]:
# attention pad mask
def get_attn_pad_mask(query, key, i_pad):
  '''
  Args:
      query: query(Q) (batch_size, 문장 내 토큰 개수)
      key: key(K) (batch_size, 문장 내 토큰 개수)
      * 전처리 했으므로 배치 내 토큰 개수 동일
      i_pad: padding 인덱스 (=0)
  '''
  batch_size, len_q = query.size()
  batch_size, len_k = key.size()
  # (batch_size, len_q, len_k)
  mask = key.data.eq(i_pad).unsqueeze(1).expand(batch_size, len_q, len_k)
  return mask


# attention decoder mask
def get_attn_decoder_mask(seq):
  '''
  Args:
      seq: (batch_size, 문장 내 토큰 개수)
  '''
  mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0), seq.size(1), seq.size(1))
  # (batch_size, len_seq, len_seq)
  mask = mask.triu(diagonal=1)
  return mask

In [11]:
class ScaledDotProductAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.dropout = nn.Dropout(self.config["dropout"])
    self.scale = 1 / (self.config["d_h"] ** 0.5)

  def forward(self, Q, K, V, attn_mask):
    '''
    Args:
        Q: (batch_size, h, len_q, d_h)
        K: (batch_size, h, len_k, d_h)
        V: (batch_size, h, len_v, d_h)
        attn_mask: (batch_size, h, len_q, len_k)
    '''
    # (batch_size, h, len_q, len_k)
    affinities = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
    affinities.masked_fill_(attn_mask, -1e9)
    # (batch_size, h, len_q, len_k)
    attn_weights = nn.Softmax(dim=-1)(affinities)
    attn_weights = self.dropout(attn_weights)
    # (batch_size, h, len_q, d_h)
    output = torch.matmul(attn_weights, V)
    # (batch_size, h, len_q, d_h), (batch_size, h, len_q, len_k)
    return output, attn_weights

In [12]:
class MultiHeadAttention(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.W_Q = nn.Linear(self.config['d_model'], self.config['h'] * self.config['d_h'])
    self.W_K = nn.Linear(self.config['d_model'], self.config['h'] * self.config['d_h'])
    self.W_V = nn.Linear(self.config['d_model'], self.config['h'] * self.config['d_h'])
    self.scaled_dot_attn = ScaledDotProductAttention(self.config)
    self.linear = nn.Linear(self.config['h'] * self.config['d_h'], self.config['d_model'])
    self.dropout = nn.Dropout(self.config['dropout'])

  def forward(self, Q, K, V, attn_mask):
    '''
    Args:
        Q: (batch_size, len_q, d_model)
        K: (batch_size, len_q, d_model)
        V: (batch_size, len_q, d_model)
        attn_mask: (batch_size, len_q, len_k)
    '''
    # linearly project the queries, keys and values
    # (batch_size, len_q, d_model) * (d_model, h * d_h) = (batch_size, len_q, h * d_h)
    # -> (batch_size, len_q, h, d_h)
    # -> (batch_size, h, len_q, d_h)
    pjted_Q = self.W_Q(Q).view(self.config['batch_size'], -1, self.config['h'], self.config['d_h']).transpose(1,2)
    pjted_K = self.W_K(K).view(self.config['batch_size'], -1, self.config['h'], self.config['d_h']).transpose(1,2)
    pjted_V = self.W_V(V).view(self.config['batch_size'], -1, self.config['h'], self.config['d_h']).transpose(1,2)
    # (batch_size, len_q, len_k) -> (batch_size, h, len_q, len_k)
    attn_mask = attn_mask.unsqueeze(1).repeat(1, self.config['h'], 1, 1)
    # scaled dot product attention
    # (batch_size, h, len_q, d_h), (batch_size, h, len_q, len_k)
    context, attn_weights = self.scaled_dot_attn(pjted_Q, pjted_K, pjted_V, attn_mask)
    # concat
    # (batch_size, h, len_q, d_h) -> (batch_size, len_q, h * d_h)
    context= context.transpose(1, 2).contiguous().view(self.config['batch_size'], -1, self.config['h'] * self.config['d_h'])
    # linear
    # (batch_size, len_q, h * d_h) * (h * d_h, d_model)
    # -> (batch_size, len_q, d_model)
    output = self.linear(context)
    output = self.dropout(output)
    # (batch_size, len_q, d_model), (batch_size, h, len_q, len_k)
    return output, attn_weights

In [13]:
# encoder layer
class EncoderLayer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.self_attn = MultiHeadAttention(self.config)
    self.layer_norm1 = nn.LayerNorm(self.config["d_model"], eps = self.config["layer_norm_epsilon"])
    self.ffn = FFN(self.config)
    self.layer_norm2 = nn.LayerNorm(self.config["d_model"], eps = self.config["layer_norm_epsilon"])

  '''
  Args:
      inputs: (batch_size, len_seq, d_model)
      attn_mask: (batch_size, len_q, len_k)
  '''
  def forward(self, inputs, attn_mask):
    # (batch_size, len_q, d_model), (batch_size, h, len_q, len_k)
    attn_output, attn_weights = self.self_attn(inputs, inputs, inputs, attn_mask)
    # (batch_size, len_q, d_model)
    attn_output = self.layer_norm1(inputs + attn_output)
    # (batch_size, len_q, d_model)
    ffn_output = self.ffn(attn_output)
    # (batch_size, len_q, d_model)
    ffn_output = self.layer_norm2(ffn_output + attn_output)
    # (batch_size, len_q, d_model), (batch_size, h, len_q, len_k)
    return ffn_output, attn_weights

In [14]:
# encoder
class Encoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.enc_emb = nn.Embedding(self.config["n_enc_vocab"], self.config["d_model"])
    pos_enc_table = torch.FloatTensor(get_sinusoidal(self.config["n_enc_seq"] + 1, self.config["d_model"]))
    self.pos_emb = nn.Embedding.from_pretrained(pos_enc_table, freeze=True)

    self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config["n_layer"])])

  '''
  Args
      inputs: (batch_size, len_seq)
  '''
  def forward(self, inputs):
    # (batch_size, len_enc_seq)
    positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1
    pos_mask = inputs.eq(self.config["i_pad"])
    positions.masked_fill_(pos_mask, 0)

    # (batch_size, len_enc_seq, d_model)
    output = self.enc_emb(inputs) + self.pos_emb(positions)

    # (batch_size, len_enc_seq, len_enc_seq)
    attn_mask = get_attn_pad_mask(inputs, inputs, self.config["i_pad"])

    attn_weights_history = list([])
    for layer in self.layers:
      # (batch_size, len_enc_seq, d_model), (batch_size, h, len_enc_seq, len_enc_seq)
      output, attn_weights = layer(output, attn_mask)
      attn_weights_history.append(attn_weights)

    # (batch_size, len_enc_seq, d_model), [(batch_size, h, len_enc_seq, len_enc_seq)]
    return output, attn_weights_history

In [15]:
class DecoderLayer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.self_attn = MultiHeadAttention(self.config)
    self.layer_norm1 = nn.LayerNorm(self.config["d_model"], eps = self.config["layer_norm_epsilon"])
    self.dec_enc_attn = MultiHeadAttention(self.config)
    self.layer_norm2 = nn.LayerNorm(self.config["d_model"], eps = self.config["layer_norm_epsilon"])
    self.ffn = FFN(self.config)
    self.layer_norm3 = nn.LayerNorm(self.config["d_model"], eps = self.config["layer_norm_epsilon"])

  '''
  Args:
      dec_inputs: (batch_size, len_seq, d_model)
      enc_outputs: (batch_size, len_enc_seq, d_model)
      self_attn_mask: (batch_size, len_dec_seq, len_dec_seq)
      dec_enc_attn_mask: (batch_size, len_dec_seq, len_enc_seq)
  '''
  def forward(self, dec_inputs, enc_outputs, self_attn_mask, dec_enc_attn_mask):
    # (batch_size, len_dec_seq, d_model), (batch_size, h, len_dec_seq, len_dec_seq)
    self_attn_output, self_attn_weights = self.self_attn(dec_inputs, dec_inputs, dec_inputs, self_attn_mask)
    self_attn_output = self.layer_norm1(dec_inputs + self_attn_output)
    # (batch_size, len_dec_seq, d_model), (batch_size, h, len_dec_seq, len_ebc_seq)
    dec_enc_attn_output, dec_enc_attn_weights = self.dec_enc_attn(self_attn_output, enc_outputs, enc_outputs, dec_enc_attn_mask)
    dec_enc_attn_output = self.layer_norm2(self_attn_output + dec_enc_attn_output)
    # (batch_size, len_dec_seq, d_model)
    ffn_output = self.ffn(dec_enc_attn_output)
    ffn_output = self.layer_norm3(dec_enc_attn_output + ffn_output)
    # (batch_size, len_dec_seq, d_model), (batch_size, h, len_dec_seq, len_dec_seq), (batch_size, h, len_dec_seq, len_ebc_seq)
    return ffn_output, self_attn_weights, dec_enc_attn_weights

In [17]:
class Decoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.dec_emb = nn.Embedding(self.config["n_dec_vocab"], self.config["d_model"])
    pos_enc_table = torch.FloatTensor(get_sinusoidal(self.config["n_dec_seq"] + 1, self.config["d_model"]))
    self.pos_emb = nn.Embedding.from_pretrained(pos_enc_table, freeze=True)

    self.layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(self.config["n_layer"])])

  '''
  Args:
      dec_inputs: (batch_size, len_dec_seq, d_model)
      enc_inputs: (batch_size, len_enc_seq, d_model)
      enc_outputs: (batch_size, len_enc_seq, d_model)
  '''
  def forward(self, dec_inputs, enc_inputs, enc_outputs):
    # (batch_size, len_enc_seq)
    positions = torch.arange(dec_inputs.size(1), device=dec_inputs.device, dtype=dec_inputs.dtype).expand(dec_inputs.size(0), dec_inputs.size(1)).contiguous() + 1
    pos_mask = dec_inputs.eq(self.config["i_pad"])
    positions.masked_fill_(pos_mask, 0)

    # (batch_size, n_dec_seq, d_model)
    dec_output = self.dec_emb(dec_inputs) + self.pos_emb(positions)

    # (batch_size, len_dec_seq, len_dec_seq)
    attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.config["i_pad"])
    # (batch_size, len_dec_seq, len_dec_seq)
    attn_decoder_mask = get_attn_decoder_mask(dec_inputs)
    # (batch_size, len_dec_seq, len_dec_seq)
    self_attn_mask = torch.gt((attn_pad_mask + attn_decoder_mask), 0)
    # (batch_size, len_dec_seq, len_enc_seq)
    dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.config["i_pad"])

    self_attn_weights_history, dec_enc_attn_weights_history = list([]), list([])
    for layer in self.layers:
      # (batch_size, len_dec_seq, d_model), (batch_size, h, len_dec_seq, len_dec_seq), (batch_size, h, len_dec_seq, len_ebc_seq)
      output, self_attn_weights, dec_enc_attn_weights = layer(dec_output, enc_outputs, self_attn_mask, dec_enc_attn_mask)
      self_attn_weights_history.append(self_attn_weights)
      dec_enc_attn_weights_history.append(dec_enc_attn_weights)
    # (batch_size, len_dec_seq, d_model), [(batch_size, h, len_dec_seq, len_dec_seq)], [(batch_size, h, len_dec_seq, len_ebc_seq)]
    return output, self_attn_weights_history, dec_enc_attn_weights_history

In [18]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.encoder = Encoder(self.config)
        self.decoder = Decoder(self.config)
    
    def forward(self, enc_inputs, dec_inputs):
        # (batch_size, len_enc_seq, d_model), [(batch_size, h, len_enc_seq, len_enc_seq)]
        enc_outputs, enc_self_attn_weights_history = self.encoder(enc_inputs)
        # (batch_size, len_dec_seq, d_model), [(batch_size, h, len_dec_seq, len_dec_seq)], [(batch_size, h, len_dec_seq, len_ebc_seq)]
        dec_outputs, dec_self_attn_weights_history, dec_enc_attn_weights_history = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        # (batch_size, len_dec_seq, d_model), [(batch_size, h, len_enc_seq, len_enc_seq)],
        # [(batch_size, h, len_dec_seq, len_dec_seq)], [(batch_size, h, len_dec_seq, len_ebc_seq)]
        return dec_outputs, enc_self_attn_weights_history, dec_self_attn_weights_history, dec_enc_attn_weights_history
        

In [19]:
def returnExampleBatch():
  cnt = 0
  for [enc, dec] in loader:
   if (cnt < 1):
      return enc, dec
   else:
      break

q, k = returnExampleBatch()

In [21]:
tmp_config["n_enc_seq"] = q.size(1)
tmp_config["n_dec_seq"] = k.size(1)

In [22]:
transformer = Transformer(tmp_config)

In [23]:
dec_outputs, enc_self_attn_weights_history, dec_self_attn_weights_history, dec_enc_attn_weights_history = transformer(q, k)

In [24]:
dec_outputs

tensor([[[-0.6684, -0.1634, -0.4582,  ...,  0.0436, -0.3301,  0.2300],
         [-0.3527, -0.8143,  0.5640,  ..., -1.5524, -2.7957,  1.1874],
         [ 0.4643, -1.4126, -1.0146,  ...,  2.4775, -0.8689,  1.4743],
         ...,
         [ 1.1099,  0.8237, -1.3496,  ..., -0.4410, -0.5183,  1.2187],
         [ 1.1340,  0.8649, -1.3344,  ..., -0.7411, -0.5506,  1.1182],
         [ 1.1284,  0.8886, -1.3049,  ..., -0.3644, -0.5333,  1.1483]],

        [[-0.6268, -0.5419, -0.8031,  ...,  0.1300, -0.1310,  0.5507],
         [-0.5160, -0.5924, -2.3623,  ...,  0.4685, -1.3419,  1.6644],
         [-1.1154, -0.5856, -0.5961,  ...,  2.3031,  1.1585,  1.1080],
         ...,
         [ 1.1860,  0.5093, -1.0463,  ..., -0.0348, -0.4950,  1.4005],
         [ 1.1327,  0.5200, -1.2506,  ..., -0.0713, -0.5912,  1.1135],
         [ 1.1049,  0.6109, -0.8801,  ..., -0.4539, -0.6391,  1.1499]],

        [[ 0.3794,  0.9222, -1.7307,  ..., -0.7747, -0.1287,  0.4227],
         [-0.4887, -0.1143, -2.4089,  ...,  0

In [25]:
dec_outputs.size()

torch.Size([256, 136, 512])

In [27]:
enc_self_attn_weights_history[0].size()

torch.Size([256, 8, 111, 111])

In [28]:
dec_self_attn_weights_history[0].size()

torch.Size([256, 8, 136, 136])

In [29]:
dec_enc_attn_weights_history[0].size()

torch.Size([256, 8, 136, 111])