In [None]:
import subprocess
import threading

from datasets import load_dataset

import os
import argparse
import re

import razdel
import nltk


In [None]:
from typing import Dict, Tuple, List
import numpy as np

import torch
import torch.nn.functional as F
from torch.nn.modules.linear import Linear
from torch.nn.modules.rnn import LSTMCell
from torch.nn.functional import relu

from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data.vocabulary import Vocabulary, DEFAULT_OOV_TOKEN
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.models.model import Model
from allennlp.modules.token_embedders import Embedding
from allennlp.modules import Attention
from allennlp.nn.beam_search import BeamSearch
from allennlp.nn import util

from collections import Counter
from statistics import mean

from true_rouge import Rouge
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.chrf_score import corpus_chrf
import torch

In [None]:
class Meteor:
    def __init__(self, meteor_jar, language):
        # Used to guarantee thread safety
        self.lock = threading.Lock()

        self.meteor_cmd = ['java', '-jar', '-Xmx2G', meteor_jar, '-', '-', '-stdio', '-l', language, '-norm']
        self.meteor_p = subprocess.Popen(self.meteor_cmd,
                                         stdin=subprocess.PIPE,
                                         stdout=subprocess.PIPE,
                                         stderr=subprocess.STDOUT,
                                         encoding='utf-8',
                                         bufsize=0)

    def compute_score(self, hyps, refs):
        scores = []
        self.lock.acquire()
        for hyp, ref in zip(hyps, refs):
            stat = self._stat(hyp, ref)
            # EVAL ||| stats
            eval_line = 'EVAL ||| {}'.format(" ".join(map(str, map(int, map(float, stat.split())))))
            self.meteor_p.stdin.write('{}\n'.format(eval_line))
            scores.append(float(self.meteor_p.stdout.readline().strip()))
        self.lock.release()

        return sum(scores) / len(scores)

    def _stat(self, hypothesis_str, reference_list):
        # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
        hypothesis_str = hypothesis_str.replace('|||', '').replace('  ', ' ')
        score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
        self.meteor_p.stdin.write('{}\n'.format(score_line))
        return self.meteor_p.stdout.readline().strip()

    def __del__(self):
        self.lock.acquire()
        self.meteor_p.stdin.close()
        self.meteor_p.kill()
        self.meteor_p.wait()
        self.lock.release()

In [None]:

def calc_duplicate_n_grams_rate(documents):
    all_ngrams_count = Counter()
    duplicate_ngrams_count = Counter()
    for doc in documents:
        words = doc.split(" ")
        for n in range(1, 5):
            ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
            unique_ngrams = set(ngrams)
            all_ngrams_count[n] += len(ngrams)
            duplicate_ngrams_count[n] += len(ngrams) - len(unique_ngrams)
    return {n: duplicate_ngrams_count[n]/all_ngrams_count[n] if all_ngrams_count[n] else 0.0
            for n in range(1, 5)}


def calc_bert_score(
    hyps,
    refs,
    lang="ru",
    bert_score_model=None,
    num_layers=None,
    idf=False,
    batch_size=32
):
    import bert_score
    all_preds, hash_code = bert_score.score(
        hyps,
        refs,
        lang=lang,
        model_type=bert_score_model,
        num_layers=num_layers,
        verbose=False,
        idf=idf,
        batch_size=batch_size,
        return_hash=True
    )
    avg_scores = [s.mean(dim=0) for s in all_preds]
    return {
        "p": avg_scores[0].cpu().item(),
        "r": avg_scores[1].cpu().item(),
        "f": avg_scores[2].cpu().item()
    }, hash_code


def calc_metrics(
    refs, hyps,
    language,
    metric="all",
    meteor_jar=None
):
    metrics = dict()
    metrics["count"] = len(hyps)
    metrics["ref_example"] = refs[-1]
    metrics["hyp_example"] = hyps[-1]
    many_refs = [[r] if r is not list else r for r in refs]
    if metric in ("bleu", "all"):
        t_hyps = [hyp.split(" ") for hyp in hyps]
        t_refs = [[r.split(" ") for r in rs] for rs in many_refs]
        metrics["bleu"] = corpus_bleu(t_refs, t_hyps)
    if metric in ("rouge", "all"):
        rouge = Rouge()
        scores = rouge.get_scores(hyps, refs, avg=True)
        metrics.update(scores)
    if metric in ("meteor", "all") and meteor_jar is not None and os.path.exists(meteor_jar):
        meteor = Meteor(meteor_jar, language=language)
        metrics["meteor"] = meteor.compute_score(hyps, many_refs)
    if metric in ("duplicate_ngrams", "all"):
        metrics["duplicate_ngrams"] = dict()
        metrics["duplicate_ngrams"].update(calc_duplicate_n_grams_rate(hyps))
    if metric in ("bert_score",) and torch.cuda.is_available():
        bert_scores, hash_code = calc_bert_score(hyps, refs)
        metrics["bert_score_{}".format(hash_code)] = bert_scores
    if metric in ("chrf", "all"):
        metrics["chrf"] = corpus_chrf(refs, hyps, beta=1.0)
    if metric in ("length", "all"):
        metrics["length"] = mean([len(h) for h in hyps])
    return metrics


def print_metrics(refs, hyps, language, metric="all", meteor_jar=None):
    metrics = calc_metrics(refs, hyps, language=language, metric=metric, meteor_jar=meteor_jar)

    print("-------------METRICS-------------")
    print("Count:\t", metrics["count"])
    print("Ref:\t", metrics["ref_example"])
    print("Hyp:\t", metrics["hyp_example"])

    if "bleu" in metrics:
        print("BLEU:     \t{:3.1f}".format(metrics["bleu"] * 100.0))
    if "chrf" in metrics:
        print("chrF:     \t{:3.1f}".format(metrics["chrf"] * 100.0))
    if "rouge-1" in metrics:
        print("ROUGE-1-F:\t{:3.1f}".format(metrics["rouge-1"]['f'] * 100.0))
        print("ROUGE-2-F:\t{:3.1f}".format(metrics["rouge-2"]['f'] * 100.0))
        print("ROUGE-L-F:\t{:3.1f}".format(metrics["rouge-l"]['f'] * 100.0))
    if "meteor" in metrics:
        print("METEOR:   \t{:3.1f}".format(metrics["meteor"] * 100.0))
    if "duplicate_ngrams" in metrics:
        print("Dup 1-grams:\t{:3.1f}".format(metrics["duplicate_ngrams"][1] * 100.0))
        print("Dup 2-grams:\t{:3.1f}".format(metrics["duplicate_ngrams"][2] * 100.0))
        print("Dup 3-grams:\t{:3.1f}".format(metrics["duplicate_ngrams"][3] * 100.0))
    if "length" in metrics:
        print("Avg length:\t{:3.1f}".format(metrics["length"]))
    for key, value in metrics.items():
        if "bert_score" not in key:
            continue
        print("{}:\t{:3.1f}".format(key, value["f"] * 100.0))

In [None]:
def punct_detokenize(text):
    text = text.strip()
    punctuation = ",.!?:;%"
    closing_punctuation = ")]}"
    opening_punctuation = "([}"
    for ch in punctuation + closing_punctuation:
        text = text.replace(" " + ch, ch)
    for ch in opening_punctuation:
        text = text.replace(ch + " ", ch)
    res = [r'"\s[^"]+\s"', r"'\s[^']+\s'"]
    for r in res:
        for f in re.findall(r, text, re.U):
            text = text.replace(f, f[0] + f[2:-2] + f[-1])
    text = text.replace("' s", "'s").replace(" 's", "'s")
    text = text.strip()
    return text


def postprocess(ref, hyp, language, is_multiple_ref=False, detokenize_after=False, tokenize_after=False, lower=False):
    if is_multiple_ref:
        reference_sents = ref.split(" s_s ")
        decoded_sents = hyp.split("s_s")
        hyp = [w.replace("<", "&lt;").replace(">", "&gt;").strip() for w in decoded_sents]
        ref = [w.replace("<", "&lt;").replace(">", "&gt;").strip() for w in reference_sents]
        hyp = " ".join(hyp)
        ref = " ".join(ref)
    ref = ref.strip()
    hyp = hyp.strip()
    if detokenize_after:
        hyp = punct_detokenize(hyp)
        ref = punct_detokenize(ref)
    if tokenize_after:
        hyp = hyp.replace("@@UNKNOWN@@", "<unk>")
        if language == "ru":
            hyp = " ".join([token.text for token in razdel.tokenize(hyp)])
            ref = " ".join([token.text for token in razdel.tokenize(ref)])
        else:
            hyp = " ".join([token for token in nltk.word_tokenize(hyp)])
            ref = " ".join([token for token in nltk.word_tokenize(ref)])
    if lower:
        hyp = hyp.lower()
        ref = ref.lower()
    return ref, hyp


def evaluate(predicted_path,
             gold_path,
             metric,
             language,
             max_count=None,
             is_multiple_ref=False,
             detokenize_after=False,
             tokenize_after=False,
             lower=False,
             meteor_jar=None):
    assert os.path.exists(gold_path)
    assert os.path.exists(predicted_path)
    if max_count is None:
        with open(gold_path) as gold:
            gold_num_lines = sum(1 for line in gold)
        with open(predicted_path) as pred:
            pred_num_lines = sum(1 for line in pred)
        msg = "Number of lines in files differ: {} vs {}".format(gold_num_lines, pred_num_lines)
        assert gold_num_lines == pred_num_lines, msg

    hyps = []
    refs = []
    with open(gold_path, "r") as gold, open(predicted_path, "r") as pred:
        for i, (ref, hyp) in enumerate(zip(gold, pred)):
            if max_count is not None and i >= max_count:
                break
            ref, hyp = postprocess(ref, hyp, language, is_multiple_ref, detokenize_after, tokenize_after, lower)
            if not hyp:
                print("Empty hyp for ref: ", ref)
                continue
            if not ref:
                continue
            refs.append(ref)
            hyps.append(hyp)
    print_metrics(refs, hyps, metric=metric, meteor_jar=meteor_jar, language=language)



In [None]:
def calc_method_score(records, predict_func, nrows=None, meteor_jar="meteor-1.5/meteor-1.5.jar"):
    references = []
    predictions = []
    for i, record in enumerate(records):
        if nrows is not None and i >= nrows:
            break
        references.append(record["summary"])
        predictions.append(predict_func(record["text"], record["summary"]))

    for i, (ref, hyp) in enumerate(zip(references, predictions)):
        references[i], predictions[i] = postprocess(ref, hyp, language="ru", tokenize_after=True, lower=True)
    print_metrics(references, predictions, language="ru", meteor_jar=meteor_jar)


def calc_bert_score(records, predict_func, nrows=None):
    references = []
    predictions = []
    for i, record in enumerate(records):
        if nrows is not None and i >= nrows:
            break
        references.append(record["summary"])
        predictions.append(predict_func(record["text"], record["summary"]))

    for i, (ref, hyp) in enumerate(zip(references, predictions)):
        references[i], predictions[i] = postprocess(ref, hyp, language="ru", tokenize_after=False, lower=False)
    print_metrics(references, predictions, language="ru", meteor_jar=None, metric="bert_score")

In [None]:
from summa.summarizer import summarize


def predict_text_rank(text, summary, summary_part=0.1):
    return summarize(text, ratio=summary_part, language='russian').replace("\n", " ")

In [None]:
@Model.register("pgn")
class PointerGeneratorNetwork(Model):
    """
    Based on https://arxiv.org/pdf/1704.04368.pdf
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_decoding_steps: int,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 projection_dim: int = None,
                 use_coverage: bool = False,
                 coverage_shift: float = 0.,
                 coverage_loss_weight: float = None,
                 embed_attn_to_output: bool = False) -> None:
        super(PointerGeneratorNetwork, self).__init__(vocab)

        self._target_namespace = target_namespace
        self._start_index = self.vocab.get_token_index(START_SYMBOL, target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, target_namespace)
        self._unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN, target_namespace)
        self._vocab_size = self.vocab.get_vocab_size(target_namespace)
        assert self._vocab_size > 2, \
            "Target vocabulary is empty. Make sure 'target_namespace' option of the model is correct."

        # Encoder
        self._source_embedder = source_embedder
        self._encoder = encoder
        self._encoder_output_dim = self._encoder.get_output_dim()

        # Decoder
        self._target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._target_embedder = Embedding(self._target_embedding_dim, self._num_classes)

        self._decoder_input_dim = self._encoder_output_dim + self._target_embedding_dim
        self._decoder_output_dim = self._encoder_output_dim
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        self._projection_dim = projection_dim or self._source_embedder.get_output_dim()
        hidden_projection_dim = self._decoder_output_dim if not embed_attn_to_output else self._decoder_output_dim * 2
        self._hidden_projection_layer = Linear(hidden_projection_dim, self._projection_dim)
        self._output_projection_layer = Linear(self._projection_dim, self._num_classes)

        self._p_gen_layer = Linear(self._decoder_output_dim * 3 + self._decoder_input_dim, 1)
        self._attention = attention
        self._use_coverage = use_coverage
        self._coverage_loss_weight = coverage_loss_weight
        self._eps = 1e-31
        self._embed_attn_to_output = embed_attn_to_output
        self._coverage_shift = coverage_shift

        # Metrics
        self._p_gen_sum = 0.0
        self._p_gen_iterations = 0
        self._coverage_loss_sum = 0.0
        self._coverage_iterations = 0

        # Decoding
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size or 1)

    def forward(self,
                source_tokens: Dict[str, Dict[str, torch.LongTensor]],
                source_token_ids: torch.Tensor,
                source_to_target: torch.LongTensor,
                target_tokens: Dict[str, Dict[str, torch.LongTensor]] = None,
                target_token_ids: torch.Tensor = None,
                metadata=None) -> Dict[str, torch.Tensor]:
        state = self._encode(source_tokens)
        target_tokens_tensor = target_tokens["tokens"]["tokens"].long() if target_tokens else None
        extra_zeros, modified_source_tokens, modified_target_tokens = self._prepare(
            source_to_target, source_token_ids, target_tokens_tensor, target_token_ids)

        state["tokens"] = modified_source_tokens
        state["extra_zeros"] = extra_zeros

        output_dict = {}
        if target_tokens:
            state["target_tokens"] = modified_target_tokens
            state = self._init_decoder_state(state)
            output_dict = self._forward_loop(state, target_tokens)
        output_dict["metadata"] = metadata
        output_dict["source_to_target"] = source_to_target

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

        return output_dict

    def _prepare(self,
                 source_tokens: torch.LongTensor,
                 source_token_ids: torch.Tensor,
                 target_tokens: torch.LongTensor = None,
                 target_token_ids: torch.Tensor = None):
        batch_size = source_tokens.size(0)
        source_max_length = source_tokens.size(1)

        tokens = source_tokens
        token_ids = source_token_ids.long()

        # Concat target tokens if exist
        if target_tokens is not None:
            tokens = torch.cat((tokens, target_tokens), 1)
            token_ids = torch.cat((token_ids, target_token_ids.long()), 1)

        is_unk = torch.eq(tokens, self._unk_index).long()
        # Create tensor with ids of unknown tokens only.
        # Those ids are batch-local.
        unk_only = token_ids * is_unk

        # Recalculate batch-local ids to range [1, count_of_unique_unk_tokens].
        # All known tokens have zero id.
        unk_token_nums = token_ids.new_zeros((batch_size, token_ids.size(1)))
        for i in range(batch_size):
            unique = torch.unique(unk_only[i, :], return_inverse=True, sorted=True)[1]
            unk_token_nums[i, :] = unique

        # Replace DEFAULT_OOV_TOKEN id with new batch-local ids starting from vocab_size
        # For example, if vocabulary size is 50000, the first unique unknown token will have 50000 index,
        # the second will have 50001 index and so on.
        tokens = tokens - tokens * is_unk + (self._vocab_size - 1) * is_unk + unk_token_nums

        modified_target_tokens = None
        modified_source_tokens = tokens
        if target_tokens is not None:
            # Remove target unknown tokens that do not exist in source tokens
            max_source_num = torch.max(tokens[:, :source_max_length], dim=1)[0]
            vocab_size = max_source_num.new_full((1,), self._vocab_size-1)
            max_source_num = torch.max(max_source_num, other=vocab_size).unsqueeze(1).expand((-1, tokens.size(1)))
            unk_target_tokens_mask = torch.gt(tokens, max_source_num).long()
            tokens = tokens - tokens * unk_target_tokens_mask + self._unk_index * unk_target_tokens_mask
            modified_target_tokens = tokens[:, source_max_length:]
            modified_source_tokens = tokens[:, :source_max_length]

        # Count unique unknown source tokens to create enough zeros for final distribution
        source_unk_count = torch.max(unk_token_nums[:, :source_max_length])
        extra_zeros = tokens.new_zeros((batch_size, source_unk_count), dtype=torch.float32)
        return extra_zeros, modified_source_tokens, modified_target_tokens

    def _encode(self, source_tokens: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder.forward(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder.forward(embedded_input, source_mask)

        return {
                "source_mask": source_mask,
                "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
                state["encoder_outputs"],
                state["source_mask"],
                self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output

        encoder_outputs = state["encoder_outputs"]
        state["decoder_context"] = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim)
        if self._embed_attn_to_output:
            state["attn_context"] = encoder_outputs.new_zeros(encoder_outputs.size(0), encoder_outputs.size(2))
        if self._use_coverage:
            state["coverage"] = encoder_outputs.new_zeros(batch_size, encoder_outputs.size(1))
        return state

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]
        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]
        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]
        # shape: (group_size, decoder_output_dim)
        attn_context = state.get("attn_context", None)

        is_unk = (last_predictions >= self._vocab_size).long()
        last_predictions_fixed = last_predictions - last_predictions * is_unk + self._unk_index * is_unk
        embedded_input = self._target_embedder(last_predictions_fixed)

        coverage = state.get("coverage", None)

        def get_attention_context(decoder_hidden_inner):
            if coverage is None:
                attention_scores = self._attention(decoder_hidden_inner, encoder_outputs, source_mask)
            else:
                attention_scores = self._attention(decoder_hidden_inner, encoder_outputs, source_mask, coverage)
            attention_context = util.weighted_sum(encoder_outputs, attention_scores)
            return attention_scores, attention_context

        if not self._embed_attn_to_output:
            attn_scores, attn_context = get_attention_context(decoder_hidden)
            decoder_input = torch.cat((attn_context, embedded_input), -1)
            decoder_hidden, decoder_context = self._decoder_cell(decoder_input, (decoder_hidden, decoder_context))
            projection = self._hidden_projection_layer(decoder_hidden)
        else:
            decoder_input = torch.cat((attn_context, embedded_input), -1)
            decoder_hidden, decoder_context = self._decoder_cell(decoder_input, (decoder_hidden, decoder_context))
            attn_scores, attn_context = get_attention_context(decoder_hidden)
            projection = self._hidden_projection_layer(torch.cat((attn_context, decoder_hidden), -1))

        output_projections = self._output_projection_layer(projection)
        if self._use_coverage:
            state["coverage"] = coverage + attn_scores
        state["decoder_input"] = decoder_input
        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        state["attn_scores"] = attn_scores
        state["attn_context"] = attn_context

        return output_projections, state

    def _get_final_dist(self, state: Dict[str, torch.Tensor], output_projections):
        attn_dist = state["attn_scores"]
        tokens = state["tokens"]
        extra_zeros = state["extra_zeros"]
        attn_context = state["attn_context"]
        decoder_input = state["decoder_input"]
        decoder_hidden = state["decoder_hidden"]
        decoder_context = state["decoder_context"]

        decoder_state = torch.cat((decoder_hidden, decoder_context), 1)
        p_gen = self._p_gen_layer(torch.cat((attn_context, decoder_state, decoder_input), 1))
        p_gen = torch.sigmoid(p_gen)
        self._p_gen_sum += torch.mean(p_gen).item()
        self._p_gen_iterations += 1

        vocab_dist = F.softmax(output_projections, dim=-1)

        vocab_dist = vocab_dist * p_gen
        attn_dist = attn_dist * (1.0 - p_gen)
        if extra_zeros.size(1) != 0:
            vocab_dist = torch.cat((vocab_dist, extra_zeros), 1)
        final_dist = vocab_dist.scatter_add(1, tokens, attn_dist)
        normalization_factor = final_dist.sum(1, keepdim=True)
        final_dist = final_dist / normalization_factor

        return final_dist

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      target_tokens: Dict[str, Dict[str, torch.LongTensor]] = None) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        batch_size = source_mask.size(0)

        num_decoding_steps = self._max_decoding_steps
        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]["tokens"]
            _, target_sequence_length = targets.size()
            num_decoding_steps = target_sequence_length - 1

        if self._use_coverage:
            coverage_loss = source_mask.new_zeros(1, dtype=torch.float32)

        last_predictions = state["tokens"].new_full((batch_size,), fill_value=self._start_index)
        step_proba: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                input_choices = last_predictions
            elif not target_tokens:
                input_choices = last_predictions
            else:
                input_choices = targets[:, timestep]

            if self._use_coverage:
                old_coverage = state["coverage"]

            output_projections, state = self._prepare_output_projections(input_choices, state)
            final_dist = self._get_final_dist(state, output_projections)
            step_proba.append(final_dist)
            last_predictions = torch.max(final_dist, 1)[1]
            step_predictions.append(last_predictions.unsqueeze(1))

            if self._use_coverage:
                step_coverage_loss = torch.sum(torch.min(state["attn_scores"], old_coverage), 1)
                coverage_loss = coverage_loss + step_coverage_loss

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            num_classes = step_proba[0].size(1)
            proba = step_proba[0].new_zeros((batch_size, num_classes, len(step_proba)))
            for i, p in enumerate(step_proba):
                proba[:, :, i] = p

            loss = self._get_loss(proba, state["target_tokens"], self._eps)
            if self._use_coverage:
                coverage_loss = torch.mean(coverage_loss / num_decoding_steps)
                self._coverage_loss_sum += coverage_loss.item()
                self._coverage_iterations += 1
                modified_coverage_loss = relu(coverage_loss - self._coverage_shift) + self._coverage_shift - 1.0
                loss = loss + self._coverage_loss_weight * modified_coverage_loss
            output_dict["loss"] = loss

        return output_dict

    @staticmethod
    def _get_loss(proba: torch.LongTensor,
                  targets: torch.LongTensor,
                  eps: float) -> torch.Tensor:
        targets = targets[:, 1:]
        proba = torch.log(proba + eps)
        loss = torch.nn.NLLLoss(ignore_index=0)(proba, targets)
        return loss

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["tokens"].size()[0]
        start_predictions = state["tokens"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(last_predictions, state)
        final_dist = self._get_final_dist(state, output_projections)
        log_probabilities = torch.log(final_dist + self._eps)
        return log_probabilities, state

    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, np.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        all_meta = output_dict["metadata"]
        all_source_to_target = output_dict["source_to_target"]
        for (indices, metadata), source_to_target in zip(zip(predicted_indices, all_meta), all_source_to_target):
            all_predicted_tokens.append(self._decode_sample(indices, metadata, source_to_target))
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _decode_sample(self, indices, metadata, source_to_target):
        all_predicted_tokens = []
        if len(indices.shape) == 1:
            indices = [indices]
        for sample_indices in indices:
            sample_indices = list(sample_indices)
            # Collect indices till the first end_symbol
            if self._end_index in sample_indices:
                sample_indices = sample_indices[:sample_indices.index(self._end_index)]
            # Get all unknown tokens from source
            original_source_tokens = metadata["source_tokens"]
            unk_tokens = list()
            for i, token_vocab_index in enumerate(source_to_target):
                if token_vocab_index != self._unk_index:
                    continue
                token = original_source_tokens[i]
                if token in unk_tokens:
                    continue
                unk_tokens.append(token)
            predicted_tokens = []
            for token_vocab_index in sample_indices:
                if token_vocab_index < self._vocab_size:
                    token = self.vocab.get_token_from_index(token_vocab_index, namespace=self._target_namespace)
                else:
                    unk_number = token_vocab_index - self._vocab_size
                    token = unk_tokens[unk_number]
                predicted_tokens.append(token)
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if not self._use_coverage:
            return {}
        avg_coverage_loss = 0.0
        if self._coverage_iterations != 0:
            avg_coverage_loss = self._coverage_loss_sum / self._coverage_iterations
        avg_p_gen = self._p_gen_sum / self._p_gen_iterations if self._p_gen_iterations != 0 else 0.0
        metrics = {"coverage_loss": avg_coverage_loss, "p_gen": avg_p_gen}
        if reset:
            self._p_gen_sum = 0.0
            self._p_gen_iterations = 0
            self._coverage_loss_sum = 0.0
            self._coverage_iterations = 0
        return metrics

In [None]:
@Model.register("pgn")
class PointerGeneratorNetwork(Model):
    """
    Based on https://arxiv.org/pdf/1704.04368.pdf
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_decoding_steps: int,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 projection_dim: int = None,
                 use_coverage: bool = False,
                 coverage_shift: float = 0.,
                 coverage_loss_weight: float = None,
                 embed_attn_to_output: bool = False) -> None:
        super(PointerGeneratorNetwork, self).__init__(vocab)

        self._target_namespace = target_namespace
        self._start_index = self.vocab.get_token_index(START_SYMBOL, target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, target_namespace)
        self._unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN, target_namespace)
        self._vocab_size = self.vocab.get_vocab_size(target_namespace)
        assert self._vocab_size > 2, \
            "Target vocabulary is empty. Make sure 'target_namespace' option of the model is correct."

        # Encoder
        self._source_embedder = source_embedder
        self._encoder = encoder
        self._encoder_output_dim = self._encoder.get_output_dim()

        # Decoder
        self._target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._target_embedder = Embedding(self._target_embedding_dim, self._num_classes)

        self._decoder_input_dim = self._encoder_output_dim + self._target_embedding_dim
        self._decoder_output_dim = self._encoder_output_dim
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        self._projection_dim = projection_dim or self._source_embedder.get_output_dim()
        hidden_projection_dim = self._decoder_output_dim if not embed_attn_to_output else self._decoder_output_dim * 2
        self._hidden_projection_layer = Linear(hidden_projection_dim, self._projection_dim)
        self._output_projection_layer = Linear(self._projection_dim, self._num_classes)

        self._p_gen_layer = Linear(self._decoder_output_dim * 3 + self._decoder_input_dim, 1)
        self._attention = attention
        self._use_coverage = use_coverage
        self._coverage_loss_weight = coverage_loss_weight
        self._eps = 1e-31
        self._embed_attn_to_output = embed_attn_to_output
        self._coverage_shift = coverage_shift

        # Metrics
        self._p_gen_sum = 0.0
        self._p_gen_iterations = 0
        self._coverage_loss_sum = 0.0
        self._coverage_iterations = 0

        # Decoding
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size or 1)

    def forward(self,
                source_tokens: Dict[str, Dict[str, torch.LongTensor]],
                source_token_ids: torch.Tensor,
                source_to_target: torch.LongTensor,
                target_tokens: Dict[str, Dict[str, torch.LongTensor]] = None,
                target_token_ids: torch.Tensor = None,
                metadata=None) -> Dict[str, torch.Tensor]:
        state = self._encode(source_tokens)
        target_tokens_tensor = target_tokens["tokens"]["tokens"].long() if target_tokens else None
        extra_zeros, modified_source_tokens, modified_target_tokens = self._prepare(
            source_to_target, source_token_ids, target_tokens_tensor, target_token_ids)

        state["tokens"] = modified_source_tokens
        state["extra_zeros"] = extra_zeros

        output_dict = {}
        if target_tokens:
            state["target_tokens"] = modified_target_tokens
            state = self._init_decoder_state(state)
            output_dict = self._forward_loop(state, target_tokens)
        output_dict["metadata"] = metadata
        output_dict["source_to_target"] = source_to_target

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

        return output_dict

    def _prepare(self,
                 source_tokens: torch.LongTensor,
                 source_token_ids: torch.Tensor,
                 target_tokens: torch.LongTensor = None,
                 target_token_ids: torch.Tensor = None):
        batch_size = source_tokens.size(0)
        source_max_length = source_tokens.size(1)

        tokens = source_tokens
        token_ids = source_token_ids.long()

        # Concat target tokens if exist
        if target_tokens is not None:
            tokens = torch.cat((tokens, target_tokens), 1)
            token_ids = torch.cat((token_ids, target_token_ids.long()), 1)

        is_unk = torch.eq(tokens, self._unk_index).long()
        # Create tensor with ids of unknown tokens only.
        # Those ids are batch-local.
        unk_only = token_ids * is_unk

        # Recalculate batch-local ids to range [1, count_of_unique_unk_tokens].
        # All known tokens have zero id.
        unk_token_nums = token_ids.new_zeros((batch_size, token_ids.size(1)))
        for i in range(batch_size):
            unique = torch.unique(unk_only[i, :], return_inverse=True, sorted=True)[1]
            unk_token_nums[i, :] = unique

        # Replace DEFAULT_OOV_TOKEN id with new batch-local ids starting from vocab_size
        # For example, if vocabulary size is 50000, the first unique unknown token will have 50000 index,
        # the second will have 50001 index and so on.
        tokens = tokens - tokens * is_unk + (self._vocab_size - 1) * is_unk + unk_token_nums

        modified_target_tokens = None
        modified_source_tokens = tokens
        if target_tokens is not None:
            # Remove target unknown tokens that do not exist in source tokens
            max_source_num = torch.max(tokens[:, :source_max_length], dim=1)[0]
            vocab_size = max_source_num.new_full((1,), self._vocab_size-1)
            max_source_num = torch.max(max_source_num, other=vocab_size).unsqueeze(1).expand((-1, tokens.size(1)))
            unk_target_tokens_mask = torch.gt(tokens, max_source_num).long()
            tokens = tokens - tokens * unk_target_tokens_mask + self._unk_index * unk_target_tokens_mask
            modified_target_tokens = tokens[:, source_max_length:]
            modified_source_tokens = tokens[:, :source_max_length]

        # Count unique unknown source tokens to create enough zeros for final distribution
        source_unk_count = torch.max(unk_token_nums[:, :source_max_length])
        extra_zeros = tokens.new_zeros((batch_size, source_unk_count), dtype=torch.float32)
        return extra_zeros, modified_source_tokens, modified_target_tokens

    def _encode(self, source_tokens: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder.forward(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder.forward(embedded_input, source_mask)

        return {
                "source_mask": source_mask,
                "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
                state["encoder_outputs"],
                state["source_mask"],
                self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output

        encoder_outputs = state["encoder_outputs"]
        state["decoder_context"] = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim)
        if self._embed_attn_to_output:
            state["attn_context"] = encoder_outputs.new_zeros(encoder_outputs.size(0), encoder_outputs.size(2))
        if self._use_coverage:
            state["coverage"] = encoder_outputs.new_zeros(batch_size, encoder_outputs.size(1))
        return state

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]
        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]
        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]
        # shape: (group_size, decoder_output_dim)
        attn_context = state.get("attn_context", None)

        is_unk = (last_predictions >= self._vocab_size).long()
        last_predictions_fixed = last_predictions - last_predictions * is_unk + self._unk_index * is_unk
        embedded_input = self._target_embedder(last_predictions_fixed)

        coverage = state.get("coverage", None)

        def get_attention_context(decoder_hidden_inner):
            if coverage is None:
                attention_scores = self._attention(decoder_hidden_inner, encoder_outputs, source_mask)
            else:
                attention_scores = self._attention(decoder_hidden_inner, encoder_outputs, source_mask, coverage)
            attention_context = util.weighted_sum(encoder_outputs, attention_scores)
            return attention_scores, attention_context

        if not self._embed_attn_to_output:
            attn_scores, attn_context = get_attention_context(decoder_hidden)
            decoder_input = torch.cat((attn_context, embedded_input), -1)
            decoder_hidden, decoder_context = self._decoder_cell(decoder_input, (decoder_hidden, decoder_context))
            projection = self._hidden_projection_layer(decoder_hidden)
        else:
            decoder_input = torch.cat((attn_context, embedded_input), -1)
            decoder_hidden, decoder_context = self._decoder_cell(decoder_input, (decoder_hidden, decoder_context))
            attn_scores, attn_context = get_attention_context(decoder_hidden)
            projection = self._hidden_projection_layer(torch.cat((attn_context, decoder_hidden), -1))

        output_projections = self._output_projection_layer(projection)
        if self._use_coverage:
            state["coverage"] = coverage + attn_scores
        state["decoder_input"] = decoder_input
        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        state["attn_scores"] = attn_scores
        state["attn_context"] = attn_context

        return output_projections, state

    def _get_final_dist(self, state: Dict[str, torch.Tensor], output_projections):
        attn_dist = state["attn_scores"]
        tokens = state["tokens"]
        extra_zeros = state["extra_zeros"]
        attn_context = state["attn_context"]
        decoder_input = state["decoder_input"]
        decoder_hidden = state["decoder_hidden"]
        decoder_context = state["decoder_context"]

        decoder_state = torch.cat((decoder_hidden, decoder_context), 1)
        p_gen = self._p_gen_layer(torch.cat((attn_context, decoder_state, decoder_input), 1))
        p_gen = torch.sigmoid(p_gen)
        self._p_gen_sum += torch.mean(p_gen).item()
        self._p_gen_iterations += 1

        vocab_dist = F.softmax(output_projections, dim=-1)

        vocab_dist = vocab_dist * p_gen
        attn_dist = attn_dist * (1.0 - p_gen)
        if extra_zeros.size(1) != 0:
            vocab_dist = torch.cat((vocab_dist, extra_zeros), 1)
        final_dist = vocab_dist.scatter_add(1, tokens, attn_dist)
        normalization_factor = final_dist.sum(1, keepdim=True)
        final_dist = final_dist / normalization_factor

        return final_dist

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      target_tokens: Dict[str, Dict[str, torch.LongTensor]] = None) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        batch_size = source_mask.size(0)

        num_decoding_steps = self._max_decoding_steps
        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]["tokens"]
            _, target_sequence_length = targets.size()
            num_decoding_steps = target_sequence_length - 1

        if self._use_coverage:
            coverage_loss = source_mask.new_zeros(1, dtype=torch.float32)

        last_predictions = state["tokens"].new_full((batch_size,), fill_value=self._start_index)
        step_proba: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                input_choices = last_predictions
            elif not target_tokens:
                input_choices = last_predictions
            else:
                input_choices = targets[:, timestep]

            if self._use_coverage:
                old_coverage = state["coverage"]

            output_projections, state = self._prepare_output_projections(input_choices, state)
            final_dist = self._get_final_dist(state, output_projections)
            step_proba.append(final_dist)
            last_predictions = torch.max(final_dist, 1)[1]
            step_predictions.append(last_predictions.unsqueeze(1))

            if self._use_coverage:
                step_coverage_loss = torch.sum(torch.min(state["attn_scores"], old_coverage), 1)
                coverage_loss = coverage_loss + step_coverage_loss

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            num_classes = step_proba[0].size(1)
            proba = step_proba[0].new_zeros((batch_size, num_classes, len(step_proba)))
            for i, p in enumerate(step_proba):
                proba[:, :, i] = p

            loss = self._get_loss(proba, state["target_tokens"], self._eps)
            if self._use_coverage:
                coverage_loss = torch.mean(coverage_loss / num_decoding_steps)
                self._coverage_loss_sum += coverage_loss.item()
                self._coverage_iterations += 1
                modified_coverage_loss = relu(coverage_loss - self._coverage_shift) + self._coverage_shift - 1.0
                loss = loss + self._coverage_loss_weight * modified_coverage_loss
            output_dict["loss"] = loss

        return output_dict

    @staticmethod
    def _get_loss(proba: torch.LongTensor,
                  targets: torch.LongTensor,
                  eps: float) -> torch.Tensor:
        targets = targets[:, 1:]
        proba = torch.log(proba + eps)
        loss = torch.nn.NLLLoss(ignore_index=0)(proba, targets)
        return loss

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["tokens"].size()[0]
        start_predictions = state["tokens"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(last_predictions, state)
        final_dist = self._get_final_dist(state, output_projections)
        log_probabilities = torch.log(final_dist + self._eps)
        return log_probabilities, state

    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, np.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        all_meta = output_dict["metadata"]
        all_source_to_target = output_dict["source_to_target"]
        for (indices, metadata), source_to_target in zip(zip(predicted_indices, all_meta), all_source_to_target):
            all_predicted_tokens.append(self._decode_sample(indices, metadata, source_to_target))
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _decode_sample(self, indices, metadata, source_to_target):
        all_predicted_tokens = []
        if len(indices.shape) == 1:
            indices = [indices]
        for sample_indices in indices:
            sample_indices = list(sample_indices)
            # Collect indices till the first end_symbol
            if self._end_index in sample_indices:
                sample_indices = sample_indices[:sample_indices.index(self._end_index)]
            # Get all unknown tokens from source
            original_source_tokens = metadata["source_tokens"]
            unk_tokens = list()
            for i, token_vocab_index in enumerate(source_to_target):
                if token_vocab_index != self._unk_index:
                    continue
                token = original_source_tokens[i]
                if token in unk_tokens:
                    continue
                unk_tokens.append(token)
            predicted_tokens = []
            for token_vocab_index in sample_indices:
                if token_vocab_index < self._vocab_size:
                    token = self.vocab.get_token_from_index(token_vocab_index, namespace=self._target_namespace)
                else:
                    unk_number = token_vocab_index - self._vocab_size
                    token = unk_tokens[unk_number]
                predicted_tokens.append(token)
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if not self._use_coverage:
            return {}
        avg_coverage_loss = 0.0
        if self._coverage_iterations != 0:
            avg_coverage_loss = self._coverage_loss_sum / self._coverage_iterations
        avg_p_gen = self._p_gen_sum / self._p_gen_iterations if self._p_gen_iterations != 0 else 0.0
        metrics = {"coverage_loss": avg_coverage_loss, "p_gen": avg_p_gen}
        if reset:
            self._p_gen_sum = 0.0
            self._p_gen_iterations = 0
            self._coverage_loss_sum = 0.0
            self._coverage_iterations = 0
        return metrics

In [None]:
from summa.summarizer import summarize

In [None]:
@Model.register("pgn_textrank")
class PointerGeneratorNetworkTextRank(PointerGeneratorNetwork):
    """
    Based on https://link.springer.com/chapter/10.1007/978-3-319-99495-6_39
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_decoding_steps: int,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 projection_dim: int = None,
                 use_coverage: bool = False,
                 coverage_shift: float = 0.,
                 coverage_loss_weight: float = None,
                 embed_attn_to_output: bool = False) -> None:
        super(PointerGeneratorNetwork, self).__init__(
            vocab, 
            source_embedder,
            encoder,
            attention,
            max_decoding_steps,
            beam_size,
            target_namespace,
            target_embedding_dim,
            scheduled_sampling_ratio,
            projection_dim,
            use_coverage,
            coverage_shift,
            coverage_loss_weight,
            embed_attn_to_output
        )
        self.vocab = self.predict_text_rank(self.vocab)
        def predict_text_rank(self, text, summary, summary_part=0.1):
            return summarize(text, ratio=summary_part, language='russian').replace("\n", " ")

In [None]:
from allennlp.data.tokenizers.token_class import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer


@Tokenizer.register("razdel")
class RazdelTokenizer(Tokenizer):
    def __init__(self, lowercase: bool = False):
        self._lowercase = lowercase

    def tokenize(self, text: str) -> List[Token]:
        return [Token(token.text.lower()) if self._lowercase else Token(token.text) for token in razdel.tokenize(text)]

    def batch_tokenize(self, texts: List[str]) -> List[List[Token]]:
        return [self.tokenize(text) for text in texts]