In [1]:
#%tensorflow_version 2.x
#%load_ext tensorboard
!pip3 -q install tensorflow==2.1.0 tensorflow-gpu==2.1.0 tensorflow-datasets==2.1.0 tensorflow-text==2.1.1 tensorflow-hub==0.7.0 nltk sklearn transformers tensorflow-addons
#!ulimit -n 1024

In [2]:
from typing import List, Tuple
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text
from nltk.tokenize import sent_tokenize
import nltk
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import BertTokenizer

nltk.download('punkt')


class BaseSummarizer(object):
    ROUND_DIGITS = 5

    def __text2sentences__(self, text: str) -> List[str]:
        raise NotImplementedError

    def __embeddings__(self, sentences: List[str]) -> tf.Tensor:
        raise NotImplementedError

    def __sim_mat__(self, vec: tf.Tensor) -> tf.Tensor:
        normalize = tf.math.l2_normalize(vec, 1)
        cosine = tf.linalg.matmul(normalize, normalize, transpose_b=True)
        rounded = tf.math.round(cosine * 10 ** BaseSummarizer.ROUND_DIGITS) / 10 ** BaseSummarizer.ROUND_DIGITS
        return rounded

    @staticmethod
    def __ranks__(sent_sim_mat: tf.Tensor) -> tf.Tensor:
        eig_val, eig_vec = tf.linalg.eigh(sent_sim_mat)
        best_vector_idx = tf.math.argmax(eig_val)
        return eig_vec[best_vector_idx]

    @staticmethod
    def __z_score__(vec: tf.Tensor) -> tf.Tensor:
        return (vec - tf.math.reduce_min(vec)) / (tf.math.reduce_max(vec) - tf.math.reduce_min(vec))

    def bleu(self, references: List[List[str]], texts: List[str]):
        score = 0.
        smoothie = SmoothingFunction().method1

        for refs, txt in zip(references, texts):
            hyp = self.the_most_important(txt, k=1)[0]
            score += sentence_bleu([ nltk.word_tokenize(s) for s in refs ], nltk.word_tokenize(hyp), smoothing_function=smoothie)

        score /= len(references)
        return score

    def scored_sentences(self, text: str) -> List[Tuple[str, float]]:
        sents = self.__text2sentences__(text)
        if not sents:
            return []
        sim_mat = self.__sim_mat__(self.__embeddings__(sents))
        ranks = BaseSummarizer.__z_score__(BaseSummarizer.__ranks__(sim_mat))
        return list(zip(sents, ranks.numpy()))

    def the_most_important(self, text, k=1):
        return [ p[0] for p in sorted(self.scored_sentences(text), key=lambda p: p[1], reverse=True)[:k] ]


class USETextRank(BaseSummarizer):
    __embed__ = hub.load("https://tfhub.dev/google/universal-sentence-encoder-multilingual/3")

    def __embeddings__(self, sentences: List[str]) -> tf.Tensor:
        return self.__embed__(sentences)

    def __text2sentences__(self, text: str) -> List[str]:
        return sent_tokenize(text)


class TFIDFTextRank(BaseSummarizer):
    __vectorizer__ = TfidfVectorizer()

    def __embeddings__(self, sentences: List[str]) -> tf.Tensor:
        return tf.constant(self.__vectorizer__.fit_transform(sentences).todense())

    def __text2sentences__(self, text: str) -> List[str]:
        return sent_tokenize(text)


class BERTFTextRank(BaseSummarizer):
      #__tokenizer__ = BertTokenizer.from_pretrained('bert-base-uncased')
      #__embed__ = hub.Module("https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1", trainable=False)

    def __sim_mat__(self, sentances: List[str]) -> tf.Tensor:
        pairs = []
        for s1 in sentances:
            for s2 in sentances:
                pairs.append((s1, s2))

        input_ids, segment_ids, input_mask = tokenizer.batch_encode_plus(pairs, max_length=256, return_attention_mask = True)

        return None


    def scored_sentences(self, text: str) -> List[Tuple[str, float]]:
        sents = self.__text2sentences__(text)
        if not sents:
            return []

        bert_inputs = dict(
          input_ids=input_ids,
          input_mask=input_mask,
          segment_ids=segment_ids)
        sim_mat = self.__sim_mat__(self.__embeddings__(sents))
        ranks = BaseSummarizer.__z_score__(BaseSummarizer.__ranks__(sim_mat))
        return list(zip(sents, ranks.numpy()))

    def __text2sentences__(self, text: str) -> List[str]:
        return sent_tokenize(text)


summarizerUSE = USETextRank()
summarizerTFIDF = TFIDFTextRank()

[nltk_data] Downloading package punkt to /home/vad/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
INFO:absl:Using /tmp/tfhub_modules to cache modules.


In [3]:
import tensorflow_datasets as tfds
ds_test = tfds.load(name="cnn_dailymail", split='validation')

INFO:absl:No config specified, defaulting to first: cnn_dailymail/plain_text
INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset cnn_dailymail (/home/vad/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0)
INFO:absl:Constructing tf.data.Dataset for split validation, from /home/vad/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0


In [4]:
from typing import List, Tuple
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text
from transformers import BertTokenizer

MAX_SEQ_LENGTH = 256
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
bert_subnet = hub.KerasLayer("https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1", 
                      signature="tokens", output_key="pooled_output", trainable=True)

def embedding4pair(s1: List[str], s2: List[str]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    input_ids, segment_ids, input_mask = tokenizer.batch_encode_plus(pairs, 
                                  max_length=MAX_SEQ_LENGTH, return_attention_mask = True)
    return input_ids, segment_ids, input_mask


def create_ruler() -> tf.keras.Model:
    i_id = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="input_ids", dtype=tf.int32)
    i_mask = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="input_masks", dtype=tf.int32)
    i_segment = tf.keras.layers.Input(shape=(MAX_SEQ_LENGTH,), name="segment_ids", dtype=tf.int32)

    bert_inputs = {"input_ids": i_id, "input_mask": i_mask, "segment_ids": i_segment}
  
    embedding = bert_subnet(bert_inputs)
    dense = tf.keras.layers.Dense(256, input_shape=(768,), activation='relu')(embedding)
    d = tf.keras.layers.Dense(1, input_shape=(256,))(dense)

    return tf.keras.models.Model(inputs=bert_inputs, outputs=d)


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [6]:
import numpy as np

nli_validation = tfds.load(name="multi_nli", split='validation_matched')
nli_train = tfds.load(name="multi_nli", split='train')

def process_dataset(ds):
    premises = []
    hypothesis = []
    input_ids = []
    input_mask = []
    segment_ids = []
    labels = []
    for x in ds:
        p = x['premise'].numpy().decode('utf8')
        l = x['label'].numpy()
        h = x['hypothesis'].numpy().decode('utf8')
        r = tokenizer.encode_plus(
          pad_to_max_length='right',
          text=p,
          text_pair=h,
          max_length=MAX_SEQ_LENGTH)
  
        input_ids.append(r['input_ids'])
        input_mask.append(r['attention_mask'])
        segment_ids.append(r['token_type_ids'])
        if l == 1:
            labels.append(1)
        else:
            labels.append(0)

        premises.append(p)
        hypothesis.append(h)

    input_ids = np.array(input_ids, dtype=np.int32)
    input_mask = np.array(input_mask, dtype=np.int32)
    segment_ids = np.array(segment_ids, dtype=np.int32)

    labels = np.array(labels, dtype=np.float16)

    dataset = tf.data.Dataset.from_tensor_slices(((input_ids, input_mask, segment_ids), labels))

    return dataset

ds_nli_train = process_dataset(nli_train)
ds_nli_valid = process_dataset(nli_validation)

INFO:absl:No config specified, defaulting to first: multi_nli/plain_text
INFO:absl:Load pre-computed datasetinfo (eg: splits) from bucket.
INFO:absl:Loading info from GCS for multi_nli/plain_text/1.0.0
INFO:absl:Generating dataset multi_nli (/home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0)


[1mDownloading and preparing dataset multi_nli/plain_text/1.0.0 (download: 216.34 MiB, generated: Unknown size, total: 216.34 MiB) to /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

INFO:absl:Downloading http://storage.googleapis.com/tfds-data/downloads/multi_nli/multinli_1.0.zip into /home/vad/tensorflow_datasets/downloads/tfds-data_downloads_multi_nli_multinli_1.0HMUsk5OVZAJ-rEiUmzbrHqIXnZ_lNC_BY3bkXFsAYtY.zip.tmp.7fd8713105144addbf05102921940b1a...
INFO:absl:Generating split train










HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0.incomplete0EB9KW/multi_nli-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=392702.0), HTML(value='')))

INFO:absl:Done writing /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0.incomplete0EB9KW/multi_nli-train.tfrecord. Shard lengths: [392702]
INFO:absl:Generating split validation_matched


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0.incomplete0EB9KW/multi_nli-validation_matched.tfrecord


HBox(children=(FloatProgress(value=0.0, max=9815.0), HTML(value='')))

INFO:absl:Done writing /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0.incomplete0EB9KW/multi_nli-validation_matched.tfrecord. Shard lengths: [9815]
INFO:absl:Generating split validation_mismatched


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0.incomplete0EB9KW/multi_nli-validation_mismatched.tfrecord


HBox(children=(FloatProgress(value=0.0, max=9832.0), HTML(value='')))

INFO:absl:Done writing /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0.incomplete0EB9KW/multi_nli-validation_mismatched.tfrecord. Shard lengths: [9832]
INFO:absl:Skipping computing stats for mode ComputeStatsMode.AUTO.
INFO:absl:Constructing tf.data.Dataset for split validation_matched, from /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0
INFO:absl:No config specified, defaulting to first: multi_nli/plain_text
INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset multi_nli (/home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0)
INFO:absl:Constructing tf.data.Dataset for split train, from /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0


[1mDataset multi_nli downloaded and prepared to /home/vad/tensorflow_datasets/multi_nli/plain_text/1.0.0. Subsequent calls will reuse this data.[0m


In [None]:
import tensorflow_addons as tfa
import os

model = create_ruler()

es_cb = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

model.compile(
    optimizer=tfa.optimizers.LAMB(),
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[tf.keras.losses.MeanSquaredError()]
)
cp_cb = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                            save_weights_only=True,
                                            verbose=1)

model.fit(
    ds_nli_train.shuffle(32868).batch(32868), 
    validation_data=ds_nli_valid.batch(32868), 
    callbacks=[es_cb, cp_cb], 
    epochs=15)

Train for 12 steps, validate for 1 steps
Epoch 1/15


In [7]:
next(iter(ds_nli_valid))

((<tf.Tensor: shape=(256,), dtype=int32, numpy=
  array([  101, 24625,   117, 10817,   117, 16938,   112,   188, 13028,
         21852, 13028,   112, 10323, 10590, 13507, 10741, 10142, 23457,
           136,   102, 10357, 12482, 10108, 16624, 39784, 10662, 11223,
         12153, 10106, 10105, 56538, 10155, 10226, 11795, 10108, 21997,
           119,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,
             0, 

In [0]:
import pandas as pd
from tqdm import tqdm

txts = []
references = []
for example in tqdm(ds_valid, total=len(list(ds_valid))):
  references.append(example['highlights'].numpy().decode("utf-8").split('\n'))
  txts.append(example['article'].numpy().decode("utf-8"))

In [0]:
print('use', summarizerUSE.bleu(references, txts))

In [0]:
print('tfidf', summarizerTFIDF.bleu(references, txts))