In [1]:
from typing import Dict, List, Optional
import os
import json
import glob

In [4]:
torch.__version__

'1.5.1'

In [6]:
import torch

token_ids = torch.Tensor([12, 1, 0, 3, 3, 0, 2, 0])
token_ids == 0

tensor([False, False,  True, False, False,  True, False,  True])

In [143]:
SEP_TOKEN = '。'
UNK_TOKEN = '[UNK]'
PAD_TOKEN = '[PAD]'
SOS_TOKEN = '[SOS]'
EOS_TOKEN = '[EOS]'
SPECIAL_TOKENS = [UNK_TOKEN, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]

In [144]:
import MeCab

class Vocab():
    def __init__(self, vocab_path: str):
        """ constructor """
        self.word2idx = {word: idx for idx, word in enumerate(SPECIAL_TOKENS)}
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        self.build_vocab(vocab_path)

    def read_vocab_file(self, vocab_path: str) -> List[str]:
        vocab = list()
        with open(vocab_path) as vocab_file:
            for word in vocab_file:
                vocab.append(word.strip())
        return vocab

    def build_vocab(self, vocab_path: str) -> None:
        """ build vocabulary """
        vocab = self.read_vocab_file(vocab_path)
        for word in vocab:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word

    def size(self) -> int:
        return len(self.word2idx)

    def get_oovs(self, tokens: List[str]) -> List[str]:
        return list(filter(lambda token: token not in self.word2idx, tokens))

    def get_mask(self, tokens: List[str]) -> List[int]:
        return [1 if token != PAD_TOKEN else 0 for token in tokens]

    def encode(self, tokens: List[str]) -> List[int]:
        """ encode tokens """
        return [self.word2idx.get(token, self.word2idx[UNK_TOKEN]) for token in tokens]

    def decode(self, ids: List[int]) -> List[str]:
        """ decode indices """
        return [self.idx2word[idx] for idx in ids]

class Tokenizer():
    def __init__(self):
        """ constructor """
        self.mc = MeCab.Tagger('-Owakati')

    def __call__(self, text: str, sos_token: bool = None, eos_token: bool = None, padding: bool = False, max_length: Optional[int] = None, truncation: bool = False) -> List[str]:
        tokens = self.tokenize(text)
        if sos_token is not None:
            tokens = [sos_token] + tokens
        if eos_token is not None:
            tokens = tokens + [eos_token]
        if padding and max_length is not None:
            padding_length = max_length - len(tokens)
            tokens += [PAD_TOKEN] * padding_length
        if truncation and max_length is not None:
            # if the length of tokens is over max_length, the eos token is truncated
            tokens = tokens[:max_length]
        return tokens

    def tokenize(self, text: str) -> List[str]:
        """ simple tokenization """
        return self.mc.parse(text).strip().split()

In [145]:
tokenizer = Tokenizer()
src_vocab = Vocab('data/source_vocab.txt')
tgt_vocab = Vocab('data/target_vocab.txt')

In [146]:
SRC_MAX_LENGTH = 500
TGT_MAX_LENGTH = 50

In [147]:
import torch
from torch.utils.data import DataLoader

class BaseExample():
    def __init__(self, input_ids: List[int], mask: List[int]):
        self.input_ids = input_ids
        self.mask = mask

class EncExample(BaseExample):
    def __init__(self, input_ids: List[int], extended_ids: List[int], mask: List[int], n_oovs: int):
        super(EncExample, self).__init__(input_ids, mask)
        self.extended_ids = extended_ids
        self.n_oovs = n_oovs

class DecExample(BaseExample):
    def __init__(self, input_ids: List[int], target_ids: List[int], mask: List[int]):
        super(DecExample, self).__init__(input_ids, mask)
        self.target_ids = target_ids

class SummExample():
    def __init__(self, enc_example: EncExample, dec_example: DecExample):
        self.enc = enc_example
        self.dec = dec_example  

class BaseBatch():
    def __init__(self, input_ids: torch.Tensor, mask: torch.Tensor):
        self.input_ids = input_ids
        self.mask = mask 

class EncBatch(BaseBatch):
    def __init__(self, input_ids: torch.Tensor, extended_ids: torch.Tensor, mask: torch.Tensor, max_n_oovs: int):
        super(EncBatch, self).__init__(input_ids, mask)
        self.input_ids = input_ids
        self.extended_ids = extended_ids
        self.mask = mask 
        self.max_n_oovs = max_n_oovs

class DecBatch(BaseBatch):
    def __init__(self, input_ids: torch.Tensor, target_ids: torch.Tensor, mask: torch.Tensor):
        super(DecBatch, self).__init__(input_ids, mask)
        self.input_ids = input_ids
        self.target_ids = target_ids
        self.mask = mask 

class SummBatch():
    def __init__(self, enc_batch: EncBatch, dec_batch: DecBatch):
        self.enc = enc_batch
        self.dec = dec_batch    

In [148]:
tokenizer = Tokenizer()

In [149]:
print(tokenizer('暑い夏の日にはアイスを食べたくなります', sos_token=SOS_TOKEN, truncation=True, max_length=20, padding=True))
print(tokenizer('暑い夏の日にはアイスを食べたくなります', eos_token=EOS_TOKEN, truncation=True, max_length=20, padding=True))

['[SOS]', '暑い', '夏', 'の', '日', 'に', 'は', 'アイス', 'を', '食べ', 'たく', 'なり', 'ます', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['暑い', '夏', 'の', '日', 'に', 'は', 'アイス', 'を', '食べ', 'たく', 'なり', 'ます', '[EOS]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']


In [150]:
class Trainer():
    def __init__(self, model, tokenizer, src_vocab, tgt_vocab, config):
        """ constructor """
        self.model = model
        self.tokenizer = tokenizer
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        # self.model.to(self.device)
        self.config = config

    def _prepare_enc_example(self, tokens: List[str], vocab: Vocab) -> EncExample:
        def _replace_oovs() -> List[int]:
            extend_ids = list()
            oov_words = list()
            vocab_size = vocab.size()
            unk_idx = vocab.word2idx[UNK_TOKEN]
            for token, input_id in zip(tokens, input_ids):
                if input_id == unk_idx:
                    if token not in oov_words:
                        oov_words.append(token)
                    n_oovs = oov_words.index(token)
                    extend_ids.append(vocab_size + n_oovs)
                else:
                    extend_ids.append(input_id)
            return extend_ids
        input_ids = vocab.encode(tokens)
        extended_ids = _replace_oovs()
        mask = vocab.get_mask(tokens)
        n_oovs = vocab.get_oovs(tokens)
        return EncExample(input_ids, extended_ids, mask, len(n_oovs))

    def _prepare_dec_example(self, input_tokens: List[str], target_tokens: List[str], vocab: Vocab, oov_words: List[str]) -> DecExample:
        def _replace_oovs() -> List[int]:
            target_ids_ = list()
            vocab_size = vocab.size()
            unk_idx = vocab.word2idx[UNK_TOKEN]
            for target_token, target_id in zip(target_tokens, target_ids):
                if target_id == unk_idx:
                    if target_token in oov_words:
                        target_ids_.append(vocab_size + oov_words.index(target_token))
                    else:
                        target_ids_.append(target_id)
                else:
                    target_ids_.append(target_id)
            return target_ids_
        input_ids = vocab.encode(input_tokens)
        target_ids = vocab.encode(target_tokens)
        target_ids = _replace_oovs()
        mask = vocab.get_mask(input_tokens)
        return DecExample(input_ids, target_ids, mask)        

    def convert_article_to_example(self, article_dict: Dict[str, str]) -> SummExample:
        """ convert article dicts to example """
        src_text, tgt_text = article_dict['body'], self.concat_summary(article_dict['summary'])
        enc_input_tokens = self.tokenizer(src_text, padding=True, max_length=SRC_MAX_LENGTH, truncation=True)
        oov_words = self.src_vocab.get_oovs(enc_input_tokens)
        dec_input_tokens = self.tokenizer(tgt_text, sos_token=SOS_TOKEN, padding=True, max_length=TGT_MAX_LENGTH, truncation=True)
        dec_target_tokens = self.tokenizer(tgt_text, eos_token=EOS_TOKEN, padding=True, max_length=TGT_MAX_LENGTH, truncation=True)
        # store data as an example instance
        enc_example = self._prepare_enc_example(enc_input_tokens, self.src_vocab)
        dec_example = self._prepare_dec_example(dec_input_tokens, dec_target_tokens, self.tgt_vocab, oov_words)
        example = SummExample(enc_example, dec_example)
        return example

    @staticmethod
    def load_jsons(data_path_pattern: str) -> List[Dict[str, str]]:
        """ load json files """
        json_paths = glob.glob(os.path.join(data_path_pattern))
        article_dicts = list()
        for json_path in json_paths:
            with open(json_path) as json_file:
                article_dicts.extend(json.load(json_file))
        return article_dicts

    @staticmethod
    def concat_summary(summaries: List[str]) -> str:
        """ concat summaries to a chunk """
        return SEP_TOKEN.join(summaries)

    def build_data_loader(self, examples: List[SummExample], shuffle: bool = False) -> DataLoader:
        """ build data loader from examples"""
        return DataLoader(examples,
                          batch_size=self.config.batch_size,
                          shuffle=shuffle,
                          collate_fn=self._collate_examples)

    def _collate_examples(self, examples: List[SummExample]) -> SummBatch:
        """ collate lists of samples into batch """
        enc_examples = list(map(lambda example : example.enc, examples))
        enc_batch = self._enc_examples_to_batch(enc_examples)
        dec_examples = list(map(lambda example : example.dec, examples))
        dec_batch = self._dec_examples_to_batch(dec_examples)
        return SummBatch(enc_batch, dec_batch)

    @staticmethod 
    def _enc_examples_to_batch(examples: List[EncExample]) -> EncBatch:
        """ convert encoder examples to tensor """
        enc_dict = dict()
        enc_dict['input_ids'] = torch.tensor([example.input_ids for example in examples], dtype=torch.long)
        enc_dict['extended_ids'] = torch.tensor([example.extended_ids for example in examples], dtype=torch.long)
        enc_dict['mask'] = torch.tensor([example.mask for example in examples], dtype=torch.long)
        # determine the max number of in-article OOVs in this batch
        enc_dict['max_n_oovs'] = max(list(map(lambda example: example.n_oovs, examples)))
        return EncBatch(**enc_dict)

    @staticmethod
    def _dec_examples_to_batch(examples: List[DecExample]) -> DecBatch:
        """ convert decoder examples to tensor """
        dec_dict = dict()
        dec_dict['input_ids'] = torch.tensor([example.input_ids for example in examples], dtype=torch.long)
        dec_dict['target_ids'] = torch.tensor([example.target_ids for example in examples], dtype=torch.long)
        dec_dict['mask'] = torch.tensor([example.mask for example in examples], dtype=torch.long)
        return DecBatch(**dec_dict)

    @staticmethod
    def to_device(device, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """ transfrer data to device"""
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        return inputs

class TrainerConfig():
    def __init__(self, batch_size):
        self.batch_size = batch_size

trainer_config = TrainerConfig(batch_size=32)

In [151]:
trainer = Trainer('model', tokenizer, src_vocab, tgt_vocab, trainer_config)

In [152]:
# article_dict = {'body': '暑い夏の日にはアイスとニョロが食べたくなります。そんな日には冷えたジュースも良いでしょう。', 'summary':['暑い夏にはアイスとニョロとジュースが欲しい']}
# example = trainer.convert_article_to_example(article_dict)
# print(example.enc.input_ids)
# print(example.enc.extended_ids)
# print(example.dec.input_ids)
# print(example.dec.target_ids)
# print(tgt_vocab.decode(example.dec.input_ids))
# print(tgt_vocab.decode(example.dec.target_ids))


## 学習（イテレーション）

---

In [153]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Config():
    def __init__(self):
        self.batch_size = 32
        self.vocab_size = 50004
        self.emb_dim = 128
        self.hidden_dim = 256
        self.pointer_gen = True
        self.is_coverage = True

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.emb_dim)

        self.lstm = nn.LSTM(config.emb_dim,
                            config.hidden_dim,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)

        self.W_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False)

    def forward(self, inputs):
        embedded = self.embedding(inputs)
        outputs, hidden = self.lstm(embedded)
        outputs = outputs.contiguous()
        feature = outputs.view(-1, 2 * config.hidden_dim)  # B * t_k x 2*hidden_dim
        feature = self.W_h(feature)

        return outputs, feature, hidden

class ReduceState(nn.Module):
    def __init__(self):
        super(ReduceState, self).__init__()

        self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
        self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim)

    def forward(self, hidden):
        h, c = hidden  # h, c dim = 2 x b x hidden_dim
        h_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
        hidden_reduced_h = F.relu(self.reduce_h(h_in))
        c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)
        hidden_reduced_c = F.relu(self.reduce_c(c_in))

        return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)
                )  # h, c dim = 1 x b x hidden_dim


In [154]:
config = Config()
enc = Encoder()
reduce_state = ReduceState()

In [155]:
# article_dicts = trainer.load_jsons('data/train1.json')
# examples = list(map(trainer.convert_article_to_example, article_dicts))

# data_loader = trainer.build_data_loader(examples)

# for batch in data_loader:
#     enc_outputs, enc_feature, enc_hidden = enc(batch.enc.input_ids)
#     s_t_1 = reduce_state(enc_hidden)
#     assert False

In [156]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        # attention
        if config.is_coverage:
            self.W_c = nn.Linear(1, config.hidden_dim * 2, bias=False)
        self.decode_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2)
        self.v = nn.Linear(config.hidden_dim * 2, 1, bias=False)

    def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage):
        b, t_k, n = list(encoder_outputs.size())

        dec_fea = self.decode_proj(s_t_hat)  # B x 2*hidden_dim
        dec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k,
                                                       n).contiguous()  # B x t_k x 2*hidden_dim
        dec_fea_expanded = dec_fea_expanded.view(-1, n)  # B * t_k x 2*hidden_dim

        att_features = encoder_feature + dec_fea_expanded  # B * t_k x 2*hidden_dim
        if config.is_coverage:
            coverage_input = coverage.view(-1, 1)  # B * t_k x 1
            coverage_feature = self.W_c(coverage_input)  # B * t_k x 2*hidden_dim
            att_features = att_features + coverage_feature

        e = F.tanh(att_features)  # B * t_k x 2*hidden_dim
        scores = self.v(e)  # B * t_k x 1
        scores = scores.view(-1, t_k)  # B x t_k

        attn_dist_ = F.softmax(scores, dim=1) * enc_padding_mask  # B x t_k
        normalization_factor = attn_dist_.sum(1, keepdim=True)
        attn_dist = attn_dist_ / normalization_factor

        attn_dist = attn_dist.unsqueeze(1)  # B x 1 x t_k
        c_t = torch.bmm(attn_dist, encoder_outputs)  # B x 1 x n
        c_t = c_t.view(-1, config.hidden_dim * 2)  # B x 2*hidden_dim

        attn_dist = attn_dist.view(-1, t_k)  # B x t_k

        if config.is_coverage:
            coverage = coverage.view(-1, t_k)
            coverage = coverage + attn_dist

        return c_t, attn_dist, coverage


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.attention_network = Attention()
        # decoder
        self.embedding = nn.Embedding(config.vocab_size, config.emb_dim)

        self.x_context = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim)

        self.lstm = nn.LSTM(config.emb_dim,
                            config.hidden_dim,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=False)

        if config.pointer_gen:
            self.p_gen_linear = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1)

        #p_vocab
        self.out1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim)
        self.out2 = nn.Linear(config.hidden_dim, config.vocab_size)

    def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab, coverage, step):

        if not self.training and step == 0:
            h_decoder, c_decoder = s_t_1
            s_t_hat = torch.cat(
                (h_decoder.view(-1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)),
                1)  # B x 2*hidden_dim
            c_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs,
                                                           encoder_feature, enc_padding_mask,
                                                           coverage)
            coverage = coverage_next

        y_t_1_embd = self.embedding(y_t_1)
        x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))
        lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1)

        h_decoder, c_decoder = s_t
        s_t_hat = torch.cat(
            (h_decoder.view(-1, config.hidden_dim), c_decoder.view(-1, config.hidden_dim)),
            1)  # B x 2*hidden_dim
        c_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs,
                                                               encoder_feature, enc_padding_mask,
                                                               coverage)

        if self.training or step > 0:
            coverage = coverage_next

        p_gen = None
        if config.pointer_gen:
            p_gen_input = torch.cat((c_t, s_t_hat, x), 1)  # B x (2*2*hidden_dim + emb_dim)
            p_gen = self.p_gen_linear(p_gen_input)
            p_gen = F.sigmoid(p_gen)

        output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1)  # B x hidden_dim * 3
        output = self.out1(output)  # B x hidden_dim

        #output = F.relu(output)

        output = self.out2(output)  # B x vocab_size
        vocab_dist = F.softmax(output, dim=1)

        if config.pointer_gen:
            vocab_dist_ = p_gen * vocab_dist
            attn_dist_ = (1 - p_gen) * attn_dist

            if extra_zeros is not None:
                vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)

            final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
        else:
            final_dist = vocab_dist

        return final_dist, s_t, c_t, attn_dist, p_gen, coverage

In [157]:
dec = Decoder()

In [158]:
article_dicts = trainer.load_jsons('data/train1.json')
examples = list(map(trainer.convert_article_to_example, article_dicts))

data_loader = trainer.build_data_loader(examples)

for batch in data_loader:
    batch_size = batch.enc.input_ids.size(0)
    enc_outputs, enc_feature, enc_hidden = enc(batch.enc.input_ids)
    s_t_1 = reduce_state(enc_hidden)
    c_t_1 = torch.zeros((batch_size, 2 * config.hidden_dim))
    extra_zeros = torch.zeros((batch_size, batch.enc.max_n_oovs))
    coverage = torch.zeros(batch.enc.input_ids.size())
    for di in range(TGT_MAX_LENGTH):
        y_t_1 = batch.dec.input_ids[:, di]  # Teacher forcing
        final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = dec(
                y_t_1, s_t_1, enc_outputs, enc_feature, batch.enc.mask, c_t_1,
                extra_zeros, batch.enc.extended_ids, coverage, di)
        target = batch.dec.target_ids[:, di]
        gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
        step_loss = -torch.log(gold_probs + 1e-12)

        step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
        step_loss += step_coverage_loss
        coverage = next_coverage

        step_mask = batch.dec.mask[:, di]
        step_loss = step_loss * step_mask

    assert False

In [None]:
true = torch.tensor([1, 0, 2])
out = torch.softmax(torch.rand(3, 2), -1)
torch.gather(out, 1, true.unsqueeze(1))
# torch.gather

RuntimeError: index 2 is out of bounds for dimension 1 with size 2