<a href="https://colab.research.google.com/github/google-research/fool-me-twice/blob/master/notebooks/nli_baselines.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2021 The Google AI Language Team Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [1]:
# Copyright 2021 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Set-Up

In [None]:
!pip install tf-models-official
!pip install tensorflow-text

In [3]:
import tensorflow as tf
import tensorflow_text as text

from official.nlp.modeling import networks
from official.nlp.modeling import models
from official.nlp.bert import configs

import json
import requests

Setting up the TPU (make sure you use a TPU runtime).  This needs to run before doing anything else to avoid issues.

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

In [5]:
#@title Load JSONL Data

LABELS = {'SUPPORTS': '0', 'REFUTES': '1'}

def json_to_dataset(url_path):
  def _load_json():
    data_tuples = []

    data = requests.get(url_path)

    for line in data.content.decode('utf-8').split('\n'):
      if not line:
        continue
      json_line = json.loads(line)
      # To work with CloudTPU, we cannot use a generator but have to create the
      # data using `from_tensor_slices`.  Because this requires a homogenous 
      # tuples, the label is a string at this point, and we only parse it to 
      # int later.
      data_tuples.append(
          (json_line['text'],  ' '.join(
              x['text'] for x in json_line['gold_evidence']),
              LABELS[json_line['label']]))

    return data_tuples

  data_tuples = _load_json()
  print(f'Loaded {len(data_tuples)} examples from {url_path}')

  dataset = tf.data.Dataset.from_tensor_slices(data_tuples).map(
      lambda x: {
          'hypothesis': x[0],
          'premise': x[1],
          'label': tf.strings.to_number(x[2], tf.int32),
      })
  
  # Hand-holding the dataset to know its own size.  This will make Keras'  
  # logging more informative, as it will know how many batches to expect per
  # epoch.
  dataset = dataset.apply(tf.data.experimental.assert_cardinality(len(data_tuples)))
  return dataset

fm2_train_dataset = json_to_dataset('https://raw.githubusercontent.com/google-research/fool-me-twice/main/dataset/train.jsonl')
fm2_dev_dataset = json_to_dataset('https://raw.githubusercontent.com/google-research/fool-me-twice/main/dataset/dev.jsonl')

fever_train_dataset = json_to_dataset('https://storage.googleapis.com/fool-me-twice-media/data/fever/train.jsonl')
fever_dev_dataset = json_to_dataset('https://storage.googleapis.com/fool-me-twice-media/data/fever/dev.jsonl')

Loaded 10419 examples from https://raw.githubusercontent.com/google-research/fool-me-twice/main/dataset/train.jsonl
Loaded 1169 examples from https://raw.githubusercontent.com/google-research/fool-me-twice/main/dataset/dev.jsonl
Loaded 109810 examples from https://storage.googleapis.com/fool-me-twice-media/data/fever/train.jsonl
Loaded 13332 examples from https://storage.googleapis.com/fool-me-twice-media/data/fever/dev.jsonl


We use tf-text to map over these "plain text" examples on the fly, turning them into the format expected by a BERT-based classifier.

We consider three "modes",

- *normal*, i.e. ```[CLS] claim [SEP] evidences [SEP]```
- *evidence_only*, i.e. ```[CLS] evidences [SEP]```
- *hypothesis_only*, i.e. ```[CLS] claim [SEP]```

*normal* corresponds to the proper setting, whereas *evidence_only*  and *hypothesis_only* are diagnostic settings to quantify the amount of artefacts a simple BERT-based classifier (and similar models) could exploit.

In [6]:
#@title Model Definition

NUM_CLASSES = 2  # SUPPORTS, REFUTES
MAX_SEQ_LENGTH = 512  #@param {type:"integer"}
BATCH_SIZE = 32  #@param {type:"integer"}
CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_model.ckpt'  #@param {type:"string"}
VOCAB_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/vocab.txt'  #@param {type:"string"}
CONFIG_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json'  #@param {type:"string"}

def combine_segments(segments, start_of_sequence_id, end_of_segment_id):
  """Combine one or more input segments for a model's input sequence.

  `combine_segments` combines the tokens of one or more input segments to a
  single sequence of token values and generates matching segment ids.
  `combine_segments` can follow a `Trimmer`, who limit segment lengths and
  emit `RaggedTensor` outputs, and can be followed up by `ModelInputPacker`.

  See `Detailed Experimental Setup` in `BERT: Pre-training of Deep Bidirectional
  Transformers for Language Understanding`
  (https://arxiv.org/pdf/1810.04805.pdf) for more examples of combined
  segments.


  `combine_segments` first flattens and combines a list of one or more
  segments
  (`RaggedTensor`s of n dimensions) together along the 1st axis, then packages
  any special tokens  into a final n dimensional `RaggedTensor`.

  And finally `combine_segments` generates another `RaggedTensor` (with the
  same rank as the final combined `RaggedTensor`) that contains a distinct int
  id for each segment.

  Example usage:

  ```
  segment_a = [[1, 2],
               [3, 4,],
               [5, 6, 7, 8, 9]]

  segment_b = [[10, 20,],
               [30, 40, 50, 60,],
               [70, 80]]
  expected_combined, expected_ids = combine_segments([segment_a, segment_b])

  # segment_a and segment_b have been combined w/ special tokens describing
  # the beginning of a sequence and end of a sequence inserted.
  expected_combined=[
   [101, 1, 2, 102, 10, 20, 102],
   [101, 3, 4, 102, 30, 40, 50, 60, 102],
   [101, 5, 6, 7, 8, 9, 102, 70, 80, 102],
  ]

  # ids describing which items belong to which segment.
  expected_ids=[
   [0, 0, 0, 0, 1, 1, 1],
   [0, 0, 0, 0, 1, 1, 1, 1, 1],
   [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]
  ```

  Args:
    segments: A list of `RaggedTensor`s with the tokens of the input segments.
      All elements must have the same dtype (int32 or int64), same rank, and
      same dimension 0 (namely batch size). Slice `segments[i][j, ...]`
      contains the tokens of the i-th input segment to the j-th example in the
      batch.
    start_of_sequence_id: a python int or scalar Tensor containing the id used
      to denote the start of a sequence (e.g. `[CLS]` token in BERT
      terminology).
    end_of_segment_id: a python int or scalar Tensor containing the id used to
      denote end of a segment (e.g. the `[SEP]` token in BERT terminology).

  Returns:
    a tuple of (combined_segments, segment_ids), where:

    combined_segments: A `RaggedTensor` with segments combined and special
      tokens inserted.
    segment_ids:  A `RaggedTensor` w/ the same shape as `combined_segments`
      and containing int ids for each item detailing the segment that they
      correspond to.
  """
  start_of_sequence_id = tf.convert_to_tensor(
      start_of_sequence_id, dtype=tf.int64)
  end_of_segment_id = tf.convert_to_tensor(
      end_of_segment_id, dtype=tf.int64)

  # Create special tokens ([CLS] and [SEP]) that will be combined with the
  # segments
  if len(segments) <= 0:
    raise ValueError("`segments` must be a nonempty list.")
  segment_dtype = segments[0].dtype
  if segment_dtype not in (tf.int32, tf.int64):
    raise ValueError("`segments` must have elements with dtype of int32 or " +
                     "int64")
  start_sequence_id = tf.cast(start_of_sequence_id, segment_dtype)
  end_segment_id = tf.cast(end_of_segment_id, segment_dtype)
  start_seq_tokens = tf.tile([start_sequence_id], [segments[0].nrows()])
  end_segment_tokens = tf.tile([end_segment_id], [segments[0].nrows()])
  for i in range(segments[0].ragged_rank):
    start_seq_tokens = tf.expand_dims(start_seq_tokens, 1)
    end_segment_tokens = tf.expand_dims(end_segment_tokens, 1)
  special_token_segment_template = tf.ones_like(start_seq_tokens)

  # Combine all segments w/ special tokens
  segments_to_combine = [start_seq_tokens]
  for seg in segments:
    segments_to_combine.append(seg)
    segments_to_combine.append(end_segment_tokens)
  segments_combined = tf.concat(segments_to_combine, 1)

  # Create the segment ids, making sure to account for special tokens.
  segment_ids_to_combine = []
  segment_ids_to_combine.append(special_token_segment_template * 0)
  for i, item in enumerate(segments):
    # Add segment id
    segment_id = tf.ones_like(item) * i
    segment_ids_to_combine.append(segment_id)

    # Add for SEP
    special_token_segment_id = special_token_segment_template * i
    segment_ids_to_combine.append(special_token_segment_id)

  segment_ids = tf.concat(segment_ids_to_combine, 1)
  return segments_combined, segment_ids

# Sets up the BERT tokenizer using tf-text.
vocab_table = tf.lookup.StaticVocabularyTable(
        tf.lookup.TextFileInitializer(
            filename=VOCAB_PATH,
            key_dtype=tf.string,
            key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
            value_dtype=tf.int64,
            value_index=tf.lookup.TextFileIndex.LINE_NUMBER
        ), 
        num_oov_buckets=1)
cls_id, sep_id = vocab_table.lookup(tf.convert_to_tensor(['[CLS]', '[SEP]']))
bert_tokenizer = text.BertTokenizer(vocab_lookup_table=vocab_table, 
                                    token_out_type=tf.int64, 
                                    preserve_unused_token=True, 
                                    lower_case=True)

# The three different settings.
def _normal(inputs):
  # <string>[batch_size, hyp_length]
  hypothesis = inputs['hypothesis']
  # <string>[batch_size, prem_length]
  premise = inputs['premise']

  # <int>[batch_size, hyp_length, (subwords)]
  tokenized_hypothesis = bert_tokenizer.tokenize(hypothesis)
  # <int>[batch_size, prem_length, (subwords)]
  tokenized_premise = bert_tokenizer.tokenize(premise)

  # Get rid of the subword dimensions.
  # <int>[batch_size, hyp_length]
  flat_tokenized_hypothesis = tokenized_hypothesis.merge_dims(1, 2)
  flat_tokenized_premise = tokenized_premise.merge_dims(1, 2)

  return (flat_tokenized_premise, flat_tokenized_hypothesis)

def _evidence_only(inputs):
  return (_normal(inputs)[0],)

def _hypothesis_only(inputs):
  return (_normal(inputs)[1],)

def input_pipeline(dataset, mode, shuffle):
  """Maps the `plain` examples in `dataset` to classifier inputs.

  Args:
    dataset:  A tf.data.Dataset yielding (unbatched) examples of the form
      {'hypothesis': string, 'premise': string, 'label': int}
    mode:  One of `normal`, `evidence_only`, or `hypothesis_only`, see above.
    shuffle:  Whether or not the dataset gets shuffled.

  Returns:
    A batched and shuffled dataset yielding (classifier input, label) tuples.
  """
  segment_fn = {
      'normal': _normal,
      'evidence_only': _evidence_only,
      'hypothesis_only': _hypothesis_only,
    }[mode]

  def to_example(inputs):
    segments = segment_fn(inputs)

    # BERT input encoding.
    # input_ids: <int>[batch_size, hyp_length + prem_length + 3]
    # segment_ids: <int>[batch_size, hyp_length + prem_length + 3]
    input_ids, segment_ids = combine_segments(
        segments=segments, 
        start_of_sequence_id=cls_id, 
        end_of_segment_id=sep_id)
    
    # [batch_size, max_seq_length]
    padded_input_ids = input_ids.to_tensor(shape=(None, MAX_SEQ_LENGTH))
    padded_segment_ids = segment_ids.to_tensor(shape=(None, MAX_SEQ_LENGTH))
    input_mask = tf.cast(padded_input_ids != 0, tf.int64)
    
    

    return ({
        'input_word_ids': padded_input_ids, 
        'input_type_ids': padded_segment_ids,
        'input_mask': input_mask,
    }, tf.ensure_shape(inputs['label'], (None, )))

  dataset = dataset.batch(BATCH_SIZE).map(to_example)
  if shuffle:
    dataset = dataset.shuffle(1000)
 
  return dataset


cfg = configs.BertConfig.from_json_file(json_file=CONFIG_PATH)

with strategy.scope():
  bert_encoder = networks.TransformerEncoder(
        vocab_size=cfg.vocab_size,
        hidden_size=cfg.hidden_size,
        num_layers=cfg.num_hidden_layers,
        num_attention_heads=cfg.num_attention_heads,
        max_sequence_length=MAX_SEQ_LENGTH,
        type_vocab_size=cfg.type_vocab_size,
        intermediate_size=cfg.intermediate_size,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=cfg.initializer_range))
  bert_checkpoint = tf.train.Checkpoint(model=bert_encoder)
  classifier = models.BertClassifier(
      network=bert_encoder,
      num_classes=NUM_CLASSES)
  
  optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)  
  
  classifier.compile(optimizer=optimizer, 
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                    metrics=tf.keras.metrics.SparseCategoricalAccuracy())     

Run experiments you want.  The "mode" can be `normal` , `evidence_only`, and `hypothesis_only`, see above.

In [7]:
def run_experiment(training_data: tf.data.Dataset,
                   dev_data: tf.data.Dataset,
                   mode: str,
                   num_epochs: int):
  # Reset the BERT model.
  with strategy.scope():
    bert_checkpoint.restore(CHECKPOINT_PATH).assert_consumed().run_restore_ops()

  train_dataset = input_pipeline(dataset=training_data, mode=mode, shuffle=True)
  dev_dataset = input_pipeline(dataset=dev_data, mode=mode, shuffle=False)
  classifier.fit(x=train_dataset, validation_data=dev_dataset, epochs=num_epochs)

def predict(data: tf.data.Dataset):
  predictions = classifier.predict(input_pipeline(dataset=data, 
                                                  mode='normal', 
                                                  shuffle=False))
  results = []
  for (example, prediction) in zip(data, predictions):
    results.append({
        'hypothesis': example['hypothesis'].numpy(),
        'premise': example['premise'].numpy(),
        'gold_label': example['label'].numpy(),
        'predictions': prediction,
    })
  return results

# Model Training

Train on FM2, eval on FM2.

In [None]:
run_experiment(training_data=fm2_train_dataset,
               dev_data=fm2_dev_dataset,
               mode='normal',
               num_epochs=10)

Train on FEVER, eval on FEVER.

In [None]:
run_experiment(training_data=fever_train_dataset,
               dev_data=fever_dev_dataset,
               mode='normal',
               num_epochs=10)