In [None]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.classifier import run_classifier
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

In [None]:
FLAGS.data_dir = "tfds://imdb_reviews/plain_text"
FLAGS.attention_type = "original_full"
FLAGS.max_encoder_length = 512  # reduce for quicker demo on free colab
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 1000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.use_gradient_checkpointing = True
FLAGS.vocab_model_file = "gpt2"

## Define classification model

In [None]:
model = modeling.BertModel(bert_config)
headl = run_classifier.ClassifierLossLayer(
        bert_config["hidden_size"], bert_config["num_labels"],
        bert_config["hidden_dropout_prob"],
        utils.create_initializer(bert_config["initializer_range"]),
        name=bert_config["scope"]+"/classifier")

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output = model(features, training=True)
    loss, log_probs = headl(pooled_output, labels, True)
    # print("loss:",loss)
  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)
  return loss, log_probs, grads

## Dataset pipeline

In [None]:
train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 32})

[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-train.tfrecord...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-test.tfrecord...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-unsupervised.tfrecord...:   0%|          | 0/50000 [00:00<?, ? examples/s]

[1mDataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.[0m


## Train

In [None]:
import time
start_time = time.time()

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))
  train_loss(loss)
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if (i+1)% 50 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))


  5%|▌         | 50/1000 [07:11<1:58:40,  7.50s/it]Loss = 0.7017228603363037  Accuracy = 0.5093749761581421
 10%|█         | 100/1000 [13:26<1:52:16,  7.48s/it]Loss = 0.7004165649414062  Accuracy = 0.52009375214576721
 15%|█▌        | 150/1000 [19:41<1:46:11,  7.50s/it]Loss = 0.6983187794685364  Accuracy = 0.55077083110809326
 20%|██        | 200/1000 [25:56<1:39:51,  7.49s/it]Loss = 0.6972595453262329  Accuracy = 0.6067187547683716
 25%|██▌       | 250/1000 [32:11<1:33:50,  7.51s/it]Loss = 0.6956593990325928  Accuracy = 0.6143749713897705
 30%|███       | 300/1000 [38:26<1:27:33,  7.50s/it]Loss = 0.6941934823989868  Accuracy = 0.6620624995231628
 35%|███▌      | 350/1000 [44:46<1:22:39,  7.63s/it]Loss = 0.6885170340538025  Accuracy = 0.6932767832279205
 40%|████      | 400/1000 [51:08<1:16:29,  7.65s/it]Loss = 0.6801888942718506  Accuracy = 0.7154765623807907
 45%|████▌     | 450/1000 [57:29<1:09:50,  7.62s/it]Loss = 0.6720452308654785  Accuracy = 0.7456048613786697
 50%|█████     | 

In [None]:
print("--- %s seconds ---" % (time.time() - start_time))

--- 7944.823274374008 seconds ---


## Eval

In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = model(features, training=False)
  loss, log_probs = headl(pooled_output, labels, False)
  return loss, log_probs

In [None]:
eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 32})

In [None]:
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for i, ex in enumerate(tqdm(eval_dataset, position=0)):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
  # if (i+1)% 50 == 0:
  #   print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))

print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))

100%|██████████| 781/781 [03:11<00:00,  4.08it/s]Loss = 0.378199964761734  Accuracy = 0.933053507804871


In [None]:
# eval_input_fn = run_classifier.input_fn_builder(
#         data_dir=FLAGS.data_dir,
#         vocab_model_file=FLAGS.vocab_model_file,
#         max_encoder_length=FLAGS.max_encoder_length,
#         substitute_newline=FLAGS.substitute_newline,
#         is_training=False, )
# eval_dataset = eval_input_fn({'batch_size': 32})

In [None]:
student_bert_config = flags.as_dictionary()

In [None]:
student_bert_config['intermediate_size'] = 512
# student_bert_config['hidden_size']= 256
student_bert_config['iterations_per_loop']= '800'
student_bert_config['num_attention_heads']= 8
student_bert_config['num_hidden_layers'] = 3

In [None]:
student_bert_config

{'attention_probs_dropout_prob': 0.0,
 'attention_type': 'block_sparse',
 'block_size': 16,
 'data_dir': 'tfds://imdb_reviews/plain_text',
 'do_eval': False,
 'do_export': False,
 'do_train': True,
 'eval_batch_size': 8,
 'gcp_project': None,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.0,
 'hidden_size': 768,
 'init_checkpoint': None,
 'initializer_range': 0.02,
 'intermediate_size': 512,
 'iterations_per_loop': '1000',
 'learning_rate': 1e-05,
 'master': None,
 'max_encoder_length': 512,
 'max_position_embeddings': 4096,
 'norm_type': 'postnorm',
 'num_attention_heads': 8,
 'num_hidden_layers': 3,
 'num_labels': 2,
 'num_rand_blocks': 3,
 'num_tpu_cores': 8,
 'num_train_steps': 1000,
 'num_warmup_steps': 1000,
 'optimizer': 'AdamWeightDecay',
 'optimizer_beta1': 0.9,
 'optimizer_beta2': 0.999,
 'optimizer_epsilon': 1e-06,
 'output_dir': '/tmp/bigb',
 'rescale_embedding': False,
 'save_checkpoints_steps': 1000,
 'scope': 'bert',
 'substitute_newline': None,
 'tpu_job_name': None,


In [None]:
student_model = modeling.BertModel(student_bert_config)
student_headl = run_classifier.ClassifierLossLayer(
        student_bert_config["hidden_size"], student_bert_config["num_labels"],
        student_bert_config["hidden_dropout_prob"],
        utils.create_initializer(student_bert_config["initializer_range"]),
        name=student_bert_config["scope"]+"/classifier")

In [None]:
import numpy as np
alphas = [0.2, 0.3, 0.4,0.5,0.6,0.7,0.8]
for alpha in alphas:  
  @tf.function(experimental_compile=True)
  def student_fwd_bwd(features, labels):
    with tf.GradientTape() as g:
      _, pooled_output_student = student_model(features, training=True)
      student_loss, log_probs_student = student_headl(pooled_output_student, labels, True)
      # print("log_probs_student: ",log_probs_student)
      # one_hot_labels = tf.one_hot(labels, depth=student_bert_config["num_labels"],dtype=tf.float32)
      # student_loss = -tf.reduce_sum(one_hot_labels * log_probs_student, axis=-1)
      # student_loss = - one_hot_labels * tf.math.log(log_probs_student) - (1 - one_hot_labels) * tf.math.log(1 - log_probs_student)

      _, pooled_output_teacher = model(features, training=False)
      teacher_loss, log_probs_teacher = headl(pooled_output_teacher, labels, False)
      # print("log_probs_teacher: ",log_probs_teacher)
      
      # distil_loss = - log_probs_teacher * tf.math.log(log_probs_student) - (1 - log_probs_teacher) * tf.math.log(1 - log_probs_student)
      
      # mse = tf.keras.losses.MeanSquaredError()
      # distil_loss = mse(log_probs_teacher, log_probs_student)
      # print("distil_loss.numpy(): ",distil_loss.numpy())
      # print("student_loss: ",student_loss.shape)
      # print("distil_loss: ",distil_loss.shape)
      distil_loss = abs(student_loss-teacher_loss)
      loss = alpha * student_loss + (1-alpha) * distil_loss
      # loss = student_loss
      print('\nloss:',loss,'\nstudent_loss:',student_loss)
      # loss = tf.reduce_mean(loss, axis=-1)
      # tf.print(loss)
    grads = g.gradient(loss, student_model.trainable_weights+headl.trainable_weights)
    return loss, log_probs_student, grads

  train_input_fn = run_classifier.input_fn_builder(
          data_dir=FLAGS.data_dir,
          vocab_model_file=FLAGS.vocab_model_file,
          max_encoder_length=FLAGS.max_encoder_length,
          substitute_newline=FLAGS.substitute_newline,
          is_training=True)
  dataset = train_input_fn({'batch_size': 32})

  import time
  start_time = time.time()
  opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  train_loss = tf.keras.metrics.Mean(name='train_loss')
  train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

  for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
    loss, log_probs, grads = student_fwd_bwd(ex[0], ex[1])
    opt.apply_gradients(zip(grads, student_model.trainable_weights+student_headl.trainable_weights))
    train_loss(loss)
    # print(train_loss(loss))
    train_accuracy(tf.one_hot(ex[1], 2), log_probs)
    if (i+1)% 50 == 0:
      print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))

  print("--- %s seconds ---" % (time.time() - start_time))

  @tf.function(experimental_compile=True)
  def fwd_only(features, labels):
    _, pooled_output = student_model(features, training=False)
    loss, log_probs = student_headl(pooled_output, labels, False)
    return loss, log_probs


  eval_input_fn = run_classifier.input_fn_builder(
          data_dir=FLAGS.data_dir,
          vocab_model_file=FLAGS.vocab_model_file,
          max_encoder_length=FLAGS.max_encoder_length,
          substitute_newline=FLAGS.substitute_newline,
          is_training=False)
  eval_dataset = eval_input_fn({'batch_size': 32})

  eval_loss = tf.keras.metrics.Mean(name='eval_loss')
  eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

  for ex in tqdm(eval_dataset, position=0):
    loss, log_probs = fwd_only(ex[0], ex[1])
    eval_loss(loss)
    eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
  print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))

  0%|          | 0/1000 [00:00<?, ?it/s]


loss: Tensor("add:0", shape=(), dtype=float32) 
student_loss: Tensor("bert/classifier/Mean:0", shape=(), dtype=float32)


  5%|▌         | 50/1000 [02:53<43:52,  2.77s/it]

Loss = 0.43946707248687744  Accuracy = 0.4987500011920929


 10%|█         | 100/1000 [05:12<41:36,  2.77s/it]

Loss = 0.43933141231536865  Accuracy = 0.49781250953674316


 15%|█▌        | 150/1000 [07:30<39:21,  2.78s/it]

Loss = 0.4384157359600067  Accuracy = 0.5087500214576721


 20%|██        | 200/1000 [09:49<37:00,  2.78s/it]

Loss = 0.44259029626846313  Accuracy = 0.5223437547683716


 25%|██▌       | 250/1000 [12:07<34:19,  2.75s/it]

Loss = 0.4414173364639282  Accuracy = 0.5237500071525574


 30%|███       | 300/1000 [14:24<32:06,  2.75s/it]

Loss = 0.4424702525138855  Accuracy = 0.5390625


 35%|███▌      | 350/1000 [16:43<30:01,  2.77s/it]

Loss = 0.4390639066696167  Accuracy = 0.5479464530944824


 40%|████      | 400/1000 [19:01<27:46,  2.78s/it]

Loss = 0.4311911463737488  Accuracy = 0.5603125095367432


 45%|████▌     | 450/1000 [21:20<25:28,  2.78s/it]

Loss = 0.4258018434047699  Accuracy = 0.5695139169692993


 50%|█████     | 500/1000 [23:39<23:07,  2.77s/it]

Loss = 0.4167731702327728  Accuracy = 0.5831249952316284


 55%|█████▌    | 550/1000 [25:58<20:47,  2.77s/it]

Loss = 0.4075758457183838  Accuracy = 0.5969886183738708


 60%|██████    | 600/1000 [28:15<18:18,  2.75s/it]

Loss = 0.3963661789894104  Accuracy = 0.6101041436195374


 65%|██████▌   | 650/1000 [30:33<16:11,  2.78s/it]

Loss = 0.38260477781295776  Accuracy = 0.624471127986908


 70%|███████   | 700/1000 [32:52<13:52,  2.78s/it]

Loss = 0.3696557283401489  Accuracy = 0.6360714435577393


 75%|███████▌  | 750/1000 [35:11<11:33,  2.78s/it]

Loss = 0.3569420278072357  Accuracy = 0.6474999785423279


 80%|████████  | 800/1000 [37:29<09:15,  2.78s/it]

Loss = 0.3441997468471527  Accuracy = 0.6587890386581421


 85%|████████▌ | 850/1000 [39:48<06:55,  2.77s/it]

Loss = 0.3324805796146393  Accuracy = 0.669301450252533


 90%|█████████ | 900/1000 [42:07<04:37,  2.77s/it]

Loss = 0.32140111923217773  Accuracy = 0.6796180605888367


 95%|█████████▌| 950/1000 [44:25<02:18,  2.77s/it]

Loss = 0.31447988748550415  Accuracy = 0.6861842274665833


100%|██████████| 1000/1000 [46:44<00:00,  2.80s/it]


Loss = 0.30588796734809875  Accuracy = 0.6937812566757202
--- 2804.711443901062 seconds ---


100%|██████████| 781/781 [03:30<00:00,  3.71it/s]


Loss = 0.34927695989608765  Accuracy = 0.8494718074798584


  0%|          | 0/1000 [00:00<?, ?it/s]


loss: Tensor("add:0", shape=(), dtype=float32) 
student_loss: Tensor("bert/classifier/Mean:0", shape=(), dtype=float32)


  5%|▌         | 50/1000 [02:49<43:56,  2.77s/it]

Loss = 0.12745913863182068  Accuracy = 0.8631250262260437


 10%|█         | 100/1000 [05:08<41:39,  2.78s/it]

Loss = 0.1273157298564911  Accuracy = 0.8646875023841858


 15%|█▌        | 150/1000 [07:27<39:20,  2.78s/it]

Loss = 0.12683804333209991  Accuracy = 0.8658333420753479


 20%|██        | 200/1000 [09:46<36:59,  2.77s/it]

Loss = 0.12288951128721237  Accuracy = 0.8701562285423279


 25%|██▌       | 250/1000 [12:05<34:37,  2.77s/it]

Loss = 0.12334799766540527  Accuracy = 0.8693749904632568


 30%|███       | 300/1000 [14:24<32:23,  2.78s/it]

Loss = 0.12109459191560745  Accuracy = 0.8727083206176758


 35%|███▌      | 350/1000 [16:42<30:03,  2.77s/it]

Loss = 0.12048765271902084  Accuracy = 0.8721428513526917


 40%|████      | 400/1000 [19:01<27:30,  2.75s/it]

Loss = 0.12084038555622101  Accuracy = 0.87109375


 45%|████▌     | 450/1000 [21:18<25:07,  2.74s/it]

Loss = 0.12041318416595459  Accuracy = 0.871874988079071


 50%|█████     | 500/1000 [23:35<22:50,  2.74s/it]

Loss = 0.12158437073230743  Accuracy = 0.8712499737739563


 55%|█████▌    | 550/1000 [25:53<20:34,  2.74s/it]

Loss = 0.11994599550962448  Accuracy = 0.8726704716682434


 60%|██████    | 600/1000 [28:10<18:18,  2.75s/it]

Loss = 0.12066374719142914  Accuracy = 0.8720312714576721


 65%|██████▌   | 650/1000 [30:27<15:59,  2.74s/it]

Loss = 0.11991095542907715  Accuracy = 0.8730769157409668


 70%|███████   | 700/1000 [32:44<13:43,  2.74s/it]

Loss = 0.12010673433542252  Accuracy = 0.8730803728103638


 75%|███████▌  | 750/1000 [35:01<11:26,  2.74s/it]

Loss = 0.11942759901285172  Accuracy = 0.8740833401679993


 80%|████████  | 800/1000 [37:19<09:08,  2.74s/it]

Loss = 0.1186467707157135  Accuracy = 0.8753125071525574


 85%|████████▌ | 850/1000 [39:36<06:51,  2.75s/it]

Loss = 0.11829531192779541  Accuracy = 0.8766176700592041


 85%|████████▌ | 852/1000 [39:41<06:46,  2.74s/it]

In [None]:
student_model = modeling.BertModel(student_bert_config)
student_headl = run_classifier.ClassifierLossLayer(
        student_bert_config["hidden_size"], student_bert_config["num_labels"],
        student_bert_config["hidden_dropout_prob"],
        utils.create_initializer(student_bert_config["initializer_range"]),
        name=student_bert_config["scope"]+"/classifier")

import numpy as np
alphas = [0.2, 0.3,0.4,0.5,0.6,0.7,0.8]
for alpha in alphas:  
  @tf.function(experimental_compile=True)
  def student_fwd_bwd(features, labels):
    with tf.GradientTape() as g:
      _, pooled_output_student = student_model(features, training=True)
      loss, log_probs_student = student_headl(pooled_output_student, labels, True)      
    grads = g.gradient(loss, student_model.trainable_weights+student_headl.trainable_weights)
    return loss, log_probs_student, grads

  train_input_fn = run_classifier.input_fn_builder(
          data_dir=FLAGS.data_dir,
          vocab_model_file=FLAGS.vocab_model_file,
          max_encoder_length=FLAGS.max_encoder_length,
          substitute_newline=FLAGS.substitute_newline,
          is_training=True)
  dataset = train_input_fn({'batch_size': 32})

  import time
  start_time = time.time()
  opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  train_loss = tf.keras.metrics.Mean(name='train_loss')
  train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

  for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
    loss, log_probs, grads = student_fwd_bwd(ex[0], ex[1])
    opt.apply_gradients(zip(grads, student_model.trainable_weights+student_headl.trainable_weights))
    train_loss(loss)
    # print(train_loss(loss))
    train_accuracy(tf.one_hot(ex[1], 2), log_probs)
    if (i+1)% 50 == 0:
      print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))

  print("--- %s seconds ---" % (time.time() - start_time))

  @tf.function(experimental_compile=True)
  def fwd_only(features, labels):
    _, pooled_output = student_model(features, training=False)
    loss, log_probs = student_headl(pooled_output, labels, False)
    return loss, log_probs


  eval_input_fn = run_classifier.input_fn_builder(
          data_dir=FLAGS.data_dir,
          vocab_model_file=FLAGS.vocab_model_file,
          max_encoder_length=FLAGS.max_encoder_length,
          substitute_newline=FLAGS.substitute_newline,
          is_training=False)
  eval_dataset = eval_input_fn({'batch_size': 32})

  eval_loss = tf.keras.metrics.Mean(name='eval_loss')
  eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

  for ex in tqdm(eval_dataset, position=0):
    loss, log_probs = fwd_only(ex[0], ex[1])
    eval_loss(loss)
    eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
  print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))

In [None]:
print("student_model params:",student_model.count_params())
print("teacher_model params:",model.count_params())
(model.count_params()/student_model.count_params())*100