In [81]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [82]:
import json

class Config(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setattr__

    @classmethod
    def load(cls, file):
        with open(file, 'r') as f:
            config = json.loads(f.read())
            return Config(config)

In [83]:
vocab_file = "kowiki.model"
vocab = spm.SentencePieceProcessor()
vocab.load(vocab_file)

True

In [84]:
config = Config({
    "n_dec_vocab": len(vocab),
    "n_dec_seq": 256,
    "n_layer": 6,
    "d_hidn": 256,
    "i_pad": 0,
    "d_ff": 1024,
    "n_head": 4,
    "d_head": 64,
    "dropout": 0.1,
    "layer_norm_epsilon": 1e-12
})
print(config)

{'n_dec_vocab': 8007, 'n_dec_seq': 256, 'n_layer': 6, 'd_hidn': 256, 'i_pad': 0, 'd_ff': 1024, 'n_head': 4, 'd_head': 64, 'dropout': 0.1, 'layer_norm_epsilon': 1e-12}


In [85]:
def get_sinusoid_encoding_table(n_seq, d_hidn):
    def cal_angle(position, i_hidn):
        return position / np.power(10000, 2*(i_hidn // 2) / d_hidn)
    
    def get_posi_angle_vec(positon):
        return [cal_angle(positon, i_hidn) for i_hidn in range(d_hidn)]

    sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
    return sinusoid_table

In [86]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.scale = 1 / (config.d_head ** 0.5)
        self.drop_out = nn.Dropout(config.dropout)
    
    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
        scores.masked_fill_(attn_mask, -1e9)
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.drop_out(attn_prob)
        context = torch.matmul(attn_prob, V)
        return context, attn_prob

In [87]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(config.dropout)

        self.W_Q = nn.Linear(self.config.d_hidn, self.config.n_head*self.config.d_head)
        self.W_K = nn.Linear(self.config.d_hidn, self.config.n_head*self.config.d_head)
        self.W_V = nn.Linear(self.config.d_hidn, self.config.n_head*self.config.d_head)
        self.scaled_dot_attn = ScaledDotProductAttention(config)
        self.linear = nn.Linear(self.config.n_head*self.config.d_head, self.config.d_hidn)
    
    def forward(self, Q, K, V, attn_mask):
        batch_size = Q.size(0)

        q_s = self.W_Q(Q).view(batch_size, -1, self.config.n_head, self.config.d_head).transpose(1,2)
        k_s = self.W_K(K).view(batch_size, -1, self.config.n_head, self.config.d_head).transpose(1,2)
        v_s = self.W_V(V).view(batch_size, -1, self.config.n_head, self.config.d_head).transpose(1,2)
        
        attn_mask = attn_mask.unsqueeze(1).repeat(1,self.config.n_head,1,1)

        context, attn_prob = self.scaled_dot_attn(q_s,k_s,v_s, attn_mask)

        context = context.transpose(1,2).contiguous().view(batch_size, -1, self.config.n_head*self.config.d_head)
        output = self.linear(context)
        output = self.dropout(output)
        return output, attn_prob
    


In [88]:
def get_attn_decoder_mask(seq):
    subsequent_mask = torch.ones_like(seq).unsqueeze(-1).expand(seq.size(0),seq.size(1),seq.size(1))
    subsequent_mask = subsequent_mask.triu(diagonal=1)
    return subsequent_mask

In [89]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.conv1 = nn.Conv1d(in_channels=self.config.d_hidn, out_channels= self.config.d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=self.config.d_ff, out_channels= self.config.d_hidn, kernel_size=1)
        self.active = F.gelu
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, inputs):
        output = self.conv1(inputs.transpose(1,2))
        output = self.active(output)
        output = self.conv2(output).transpose(1,2)
        output = self.dropout(output)
        return output

In [90]:
def get_attn_pad_mask(seq_q, seq_k, i_pad):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(i_pad)
    pad_attn_mask = pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
    return pad_attn_mask

In [91]:
class DecoderLayer(nn.Module):
      def __init__(self, config):
            super.__init__()
            self.config = config

            self.self_attn = MultiHeadAttention(self.config)
            self.layer_norm1 = nn.LayerNorm(self.config.d_hidn, self.config.layer_norm_epsilon)
            self.pos_ffn = PoswiseFeedForwardNet(self.config)
            self.layer_norm2 = nn.LayerNorm(self.config.n_hidn, self.config.layer_norm_epsilon)

      def forward(self, dec_inputs, dec_attn_mask):
            dec_attn_outputs, dec_attn_prob = self.self_attn(dec_inputs, dec_inputs, dec_inputs, dec_attn_mask)
            dec_attn_outputs = self.layer_norm1(dec_inputs, dec_attn_outputs)

            ffn_outputs = self.pos_ffn(dec_attn_outputs)
            ffn_outputs = self.layer_norm2(ffn_outputs)

            return ffn_outputs, dec_attn_prob
            

In [92]:
class Decoder(nn.Module):
      def __init__(self, config):
            super.__init__()
            self.config = config

            self.dec_embs = nn.Embedding(self.config.n_dec_vocab, self.config.d_hidn)
            sinusoid_table = torch.FloatTensor(get_sinusoid_encoding_table(self.config.n_dec_seq +1, self.config.d_hidn))
            self.pos_embs = nn.Embedding.from_pretrained(sinusoid_table,freeze=True)

            self.layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(self.config.n_layer)])
      
      def forward(self, dec_inputs):
            positions = torch.arange(dec_inputs.size(1),device = dec_inputs.device, dtype = dec_inputs.dtype).expand(dec_inputs.size(0),dec_inputs.size(1)).contiguous() + 1
            pos_mask = positions.eq(self.config.i_pad)
            positions.masked_fill_(pos_mask, 0)

            dec_outputs = self.dec_embs(dec_inputs) + self.pos_embs(positions) 

            dec_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.config.i_pad)
            dec_attn_decoder_mask = get_attn_decoder_mask(dec_inputs)

            dec_self_attn_mask = torch.gt((dec_attn_pad_mask + dec_attn_decoder_mask), 0)

            self_attn_probs = []

            for layer in self.layers:
                  dec_outputs, self_attn_prob = layer(dec_outputs, dec_self_attn_mask)
                  self_attn_probs.append(self_attn_prob)
            return dec_outputs, self_attn_probs

In [93]:
class GPT(nn.Module):
      def __init__(self, config):
            super().__init__()
            self.config = config

            self.decoder = Decoder(self.config)

      def forward(self, dec_inputs):
            dec_outputs, dec_self_attn_probs = self.decoder(dec_inputs)

            return dec_outputs, dec_self_attn_probs

      def save(self, epoch, loss, path):
            torch.save({
                  "epoch" : epoch,
                  "loss" : loss,
                  "state_dict" : self.state_dict()
            }, path)
      
      def load(self, path):
            save = torch.load(path)
            self.load_state_dict(save["state_dict"])

            return save["epoch"], save["loss"]


In [94]:
class GPTPretrain(nn.Module):
      def __init__(self, config):
            super().__init__()
            self.config = config

            self.gpt = GPT(self.config)

            self.projection_lm = nn.Linear(self.config.d_hidn, self.config.n_dec_vocab, bias = False)
            self.projection_lm.weight = self.gpt.decoder.dec_embs.weight

      def forward(self, dec_inputs):
            dec_outputs, dec_self_attn_probs = self.gpt(dec_inputs)

            logits_lm = self.projection_lm(dec_outputs)

            return logits_lm[:, :-1, :].contiguous(), dec_self_attn_probs

In [95]:
def create_pretrain_instances(doc, n_seq):
      max_seq = n_seq - 2
      tgt_seq = max_seq

      instances = []
      current_chunk = []
      current_length = 0
      for i in range(len(doc)):
            current_chunk.append(doc[i])
            current_length += len(doc[i])
            if i == len(doc) - 1 or current_length >= tgt_seq:
                  if 0 < len(current_chunk):
                        tokens = []
                        for chunk in current_chunk: tokens.extend(chunk)
                        tokens = tokens[:tgt_seq]
                        if 1 < len(tokens):
                              instance = {
                                    "tokens": ["[BOS]"] + tokens + ["[EOS]"],
                              }
                              instances.append(instance)
                  current_chunk = []
                  current_length = 0
      return instances

In [96]:
def make_pretrain_data(vocab, in_file, out_file, n_seq):
      line_cnt = 0
      with open(in_file, "r") as in_f:
            for line in in_f:
                  line_cnt += 1
      
      docs = []

      with open(in_file, "r") as f:
            doc = []
            with tqdm(total= line_cnt, desc= f"Loading") as pbar:
                  for i, line in enumerate(f):
                        line = line.strip()
                        if line == "":
                              if 0 < len(doc):
                                    docs.append(doc)
                                    doc = []
                        else:
                              pieces = vocab.EncodeAsPieces(line)
                              if 0 < len(pieces):
                                    doc.append(pieces)
                        pbar.update(1)
            
            if doc:
                  docs.append(doc)
      
      with open(out_file, "w") as out_f:
            with tqdm(total=len(docs), desc = f"Making") as pbar:
                  for i, doc in enumerate(docs):
                        instances = create_pretrain_instances(doc, n_seq)
                        for instance in instances:
                              out_f.write(json.dumps(instance))
                              out_f.write("\n")
                        pbar.update(1)

In [97]:
class PretrainDataSet(torch.utils.data.Dataset):
      def __init__(self, vocab, infile):
            self.vocab = vocab
            self.sentences = []

            line_cnt = 0
            with open(infile, "r") as f:
                  for line in f:
                        line_cnt += 1

            with open(infile, "r") as f:
                  for i, line in enumerate(tqdm(f, total=line_cnt, desc="Make Pretrain Dataset", unit=" lines")):
                      instance = json.loads(line)
                  self.sentences.append([vocab.piece_to_id(p) for p in instance["tokens"]])
    
      def __len__(self):
            return len(self.sentences)
    
      def __getitem__(self, item):
            return (torch.tensor(self.sentences[item]), torch.tensor(item))

In [98]:
class PretrainDataSet(torch.utils.data.Dataset):
    def __init__(self, vocab, infile):
        self.vocab = vocab
        self.sentences = []

        line_cnt = 0
        with open(infile, "r") as f:
            for line in f:
                line_cnt += 1

        with open(infile, "r") as f:
            for i, line in enumerate(tqdm(f, total=line_cnt, desc="Make Pretrain Dataset", unit=" lines")):
                instance = json.loads(line)
                self.sentences.append([vocab.piece_to_id(p) for p in instance["tokens"]])
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, item):
        return (torch.tensor(self.sentences[item]), torch.tensor(item))

In [99]:
def pretrin_collate_fn(inputs):
    dec_inputs, item = list(zip(*inputs))

    dec_inputs = torch.nn.utils.rnn.pad_sequence(dec_inputs, batch_first=True, padding_value=0)

    batch = [
        dec_inputs,
        torch.stack(item, dim=0),
    ]
    return batch

In [100]:
def train_epoch(config, epoch, model, criterion_lm, optimizer, train_loader):
    losses = []
    model.train()

    with tqdm(total = len(train_epoch), desc = f"Train({epoch})") as pbar:
        for i, value in enumerate(train_loader):
            dec_inputs, _ = map(lambda v : v.to(config.device), value)
            labels_lm = dec_inputs[:, 1:].contiguous()
            
            optimizer.zero_grad()
            outputs = model(dec_inputs)
            logits_lm = outputs[0]

            loss_lm = criterion_lm(logits_lm.view(-1, logits_lm.size(2)))
            loss = loss_lm

            loss_val = loss_lm.item()
            losses.append(loss_val)

            loss.backward()
            optimizer.step()

            pbar.update(1)
            pbar.set_postfix_str(f"Loss: {loss_val:.3f}({np.mean(losses):.3f})")
    return np.mean(losses)