
<a href="https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/summarization/eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The BigBird Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2020 The BigBird Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [None]:
!pip install git+https://github.com/google-research/bigbird.git -q

In [None]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow_text as tft
from tqdm import tqdm

tf.enable_v2_behavior()

## Load Saved Model

In [None]:
path = 'gs://bigbird-transformer/summarization/pubmed/roberta/saved_model'
imported_model = tf.saved_model.load(path, tags='serve')
summerize = imported_model.signatures['serving_default']

## Setup Data

In [None]:
dataset = tfds.load('scientific_papers/pubmed', split='test', shuffle_files=False, as_supervised=True)

In [None]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

## Print predictions

In [None]:
predicted_summary = summerize(ex[0])['pred_sent'][0]

In [None]:
print('Article:\n {}\n\n Predicted summary:\n {}\n\n Ground truth summary:\n {}\n\n'.format(
    ex[0].numpy(),
    predicted_summary.numpy(),
    ex[1].numpy()))

Article:
 b'\n hepatitis c virus ( hcv ) infection is reported to have a prevalence of approximately 3% worldwide .\nmajority of patients with chronic hcv have a mild , asymptomatic elevation in serum transaminase levels with no significant clinical symptoms . around 25% of patients with chronic hcv have persistently normal alanine aminotransferase ( pnalt ) .\ndefinition of normal alanine aminotransferase ( alt ) has changed over time and reference range for normal alt differs based on different laboratory cutoffs .\nprati et al .   in 2002 suggested new cutoffs with 30  u / l ( international unit ) for men and 19  u / l for women compared to 40 \nu / l and 30  u / l for men and women , respectively .\na 2009 american association for the study of liver disease ( aasld ) practice guideline suggested an alt value of 40  u / l on 2 - 3 different occasions separated by at least a month over a period of 6 months .\nothers have used 3 different alt levels equal to or below upper limit of no

## Evaluate Rouge Scores

In [None]:
from rouge_score import rouge_scorer
from rouge_score import scoring

In [None]:
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeLsum"], use_stemmer=True)
aggregator = scoring.BootstrapAggregator()

In [None]:
for ex in tqdm(dataset.take(100), position=0):
  predicted_summary = summerize(ex[0])['pred_sent'][0]
  score = scorer.score(ex[1].numpy().decode('utf-8'), predicted_summary.numpy().decode('utf-8'))
  aggregator.add_scores(score)

In [None]:
aggregator.aggregate()