
<a href="https://colab.research.google.com/github/google-research/bigbird/blob/master/bigbird/classifier/imdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The BigBird Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2020 The BigBird Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

## Set Up

In [None]:
!pip install git+https://github.com/google-research/bigbird.git -q

In [None]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.classifier import run_classifier
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

## Set options

In [None]:
FLAGS.data_dir = "tfds://imdb_reviews/plain_text"
FLAGS.attention_type = "block_sparse"
FLAGS.max_encoder_length = 3072  # 4096 on 16GB GPUs like V100, on free colab only lower memory GPU like T4 is available
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 10000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.vocab_model_file = "gpt2"

In [None]:
bert_config = flags.as_dictionary()

## Define classification model

In [None]:
model = modeling.BertModel(bert_config)
headl = run_classifier.ClassifierLossLayer(
        bert_config["num_labels"], bert_config["hidden_dropout_prob"],
        utils.create_initializer(bert_config["initializer_range"]),
        name=bert_config["scope"]+"/classifier")

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output = model(features, training=True)
    loss, log_probs = headl(pooled_output, labels, True)
  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)
  return loss, log_probs, grads

## Dataset pipeline

In [None]:
train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 2})




In [None]:
# inspect at a few examples
for ex in dataset.take(3):
  print(ex)

(<tf.Tensor: shape=(2, 4096), dtype=int32, numpy=
array([[   65,   733,   474, ...,     0,     0,     0],
       [   65,   415, 26500, ...,     0,     0,     0]], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>)
(<tf.Tensor: shape=(2, 4096), dtype=int32, numpy=
array([[   65,   484, 20677, ...,     0,     0,     0],
       [   65,   871,  3908, ...,     0,     0,     0]], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 1], dtype=int32)>)
(<tf.Tensor: shape=(2, 4096), dtype=int32, numpy=
array([[  65,  415, 6506, ...,    0,    0,    0],
       [  65,  418, 1150, ...,    0,    0,    0]], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 0], dtype=int32)>)


## Check outputs

In [None]:
loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
print('Loss: ', loss.numpy())


Loss:  0.6977416





## (Optionally) Load pretrained model

In [None]:
ckpt_path = 'gs://bigbird-transformer/pretrain/bigbr_base/model.ckpt-0'
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(ckpt_path)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])

100%|██████████| 199/199 [00:34<00:00,  4.94it/s]


## Train

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))
  train_loss(loss)
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if i% 1000 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))

Loss = 0.7094929218292236  Accuracy = 0.5

  0%|          | 0/10000 [00:06<1:32:59,  1.79it/s]


Loss = 0.4131925702095032  Accuracy = 0.8123108148574829

 10%|█         | 1000/10000 [08:26<1:16:08,  1.97it/s]


Loss = 0.32566359639167786  Accuracy = 0.8608739376068115

 20%|██        | 2000/10000 [16:52<1:08:17,  1.95it/s]


Loss = 0.28784531354904175  Accuracy = 0.882480800151825

 30%|███       | 3000/10000 [25:18<58:58,  1.98it/s]


Loss = 0.2657429575920105  Accuracy = 0.8936356902122498

 40%|████      | 4000/10000 [33:44<50:41,  1.97it/s]


Loss = 0.24971100687980652  Accuracy = 0.9020236134529114

 50%|█████     | 5000/10000 [42:10<42:03,  1.98it/s]


Loss = 0.23958759009838104  Accuracy = 0.9069437384605408

 60%|██████    | 6000/10000 [50:36<33:43,  1.98it/s]


Loss = 0.2304597944021225  Accuracy = 0.9108854532241821

 70%|███████   | 7000/10000 [59:02<25:20,  1.97it/s]


Loss = 0.2243848443031311  Accuracy = 0.9135903120040894

 80%|████████  | 8000/10000 [1:07:30<17:23,  1.92it/s]


Loss = 0.21911397576332092  Accuracy = 0.9155822396278381

 90%|█████████ | 9000/10000 [1:16:05<08:34,  1.94it/s]


Loss = 0.21378542482852936  Accuracy = 0.9180262088775635

100%|██████████| 10000/10000 [1:24:39<00:00,  1.94it/s]







## Eval

In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = model(features, training=False)
  loss, log_probs = headl(pooled_output, labels, False)
  return loss, log_probs

In [None]:
eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 2})

In [None]:
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for ex in tqdm(eval_dataset, position=0):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))


Loss = 0.16173037886619568  Accuracy = 0.9459513425827026100