In [None]:
!pip install git+https://github.com/google-research/bigbird.git -q

[K     |████████████████████████████████| 1.2 MB 7.6 MB/s 
[K     |████████████████████████████████| 4.9 MB 33.0 MB/s 
[K     |████████████████████████████████| 1.4 MB 35.7 MB/s 
[K     |████████████████████████████████| 4.0 MB 31.5 MB/s 
[K     |████████████████████████████████| 48 kB 5.1 MB/s 
[K     |████████████████████████████████| 367 kB 43.1 MB/s 
[K     |████████████████████████████████| 5.8 MB 48.8 MB/s 
[K     |████████████████████████████████| 79 kB 7.7 MB/s 
[K     |████████████████████████████████| 1.1 MB 39.7 MB/s 
[K     |████████████████████████████████| 191 kB 48.1 MB/s 
[K     |████████████████████████████████| 981 kB 27.0 MB/s 
[K     |████████████████████████████████| 352 kB 46.8 MB/s 
[K     |████████████████████████████████| 366 kB 30.8 MB/s 
[K     |████████████████████████████████| 251 kB 49.1 MB/s 
[K     |████████████████████████████████| 191 kB 50.1 MB/s 
[K     |████████████████████████████████| 178 kB 52.0 MB/s 
[?25h  Building wheel for bi

In [None]:
from bigbird.core import flags
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.classifier import run_classifier
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
import sys

FLAGS = flags.FLAGS
if not hasattr(FLAGS, "f"): flags.DEFINE_string("f", "", "")
FLAGS(sys.argv)

tf.enable_v2_behavior()

In [None]:
FLAGS.data_dir = "tfds://imdb_reviews/plain_text"
FLAGS.attention_type = "block_sparse"
FLAGS.max_encoder_length = 512  # reduce for quicker demo on free colab
FLAGS.learning_rate = 1e-5
FLAGS.num_train_steps = 1000
FLAGS.attention_probs_dropout_prob = 0.0
FLAGS.hidden_dropout_prob = 0.0
FLAGS.use_gradient_checkpointing = True
FLAGS.vocab_model_file = "gpt2"
FLAGS.do_export = True

In [None]:
bert_config = flags.as_dictionary()

In [None]:
# bert_config['intermediate_size'] = 2048
# bert_config['hidden_size']
# bert_config['iterations_per_loop']= '500'
# bert_config['num_attention_heads']= 8
# bert_config['num_hidden_layers'] = 8

In [None]:
bert_config

{'attention_probs_dropout_prob': 0.0,
 'attention_type': 'block_sparse',
 'block_size': 16,
 'data_dir': 'tfds://imdb_reviews/plain_text',
 'do_eval': False,
 'do_export': True,
 'do_train': True,
 'eval_batch_size': 8,
 'gcp_project': None,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.0,
 'hidden_size': 768,
 'init_checkpoint': None,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'iterations_per_loop': '1000',
 'learning_rate': 1e-05,
 'master': None,
 'max_encoder_length': 512,
 'max_position_embeddings': 4096,
 'norm_type': 'postnorm',
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'num_labels': 2,
 'num_rand_blocks': 3,
 'num_tpu_cores': 8,
 'num_train_steps': 1000,
 'num_warmup_steps': 1000,
 'optimizer': 'AdamWeightDecay',
 'optimizer_beta1': 0.9,
 'optimizer_beta2': 0.999,
 'optimizer_epsilon': 1e-06,
 'output_dir': '/tmp/bigb',
 'rescale_embedding': False,
 'save_checkpoints_steps': 1000,
 'scope': 'bert',
 'substitute_newline': None,
 'tpu_job_name': None

## Define classification model

In [None]:
model = modeling.BertModel(bert_config)
headl = run_classifier.ClassifierLossLayer(
        bert_config["hidden_size"], bert_config["num_labels"],
        bert_config["hidden_dropout_prob"],
        utils.create_initializer(bert_config["initializer_range"]),
        name=bert_config["scope"]+"/classifier")

In [None]:
@tf.function(experimental_compile=True)
def fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output = model(features, training=True)
    loss, log_probs = headl(pooled_output, labels, True)
    # print("loss:",loss)
  grads = g.gradient(loss, model.trainable_weights+headl.trainable_weights)
  return loss, log_probs, grads

## Dataset pipeline

In [None]:
train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 32})

In [None]:
import time
start_time = time.time()

In [None]:
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, model.trainable_weights+headl.trainable_weights))
  train_loss(loss)
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if (i+1)% 50 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy())


 5%|▌         | 50/1000 [01:06<15:43,  1.01it/s]Loss = 0.69440758228302  Accuracy = 0.4975000023841858
 10%|█         | 100/1000 [01:56<14:54,  1.01it/s]Loss = 0.6949686408042908  Accuracy = 0.4987500011920929
 15%|█▌        | 150/1000 [02:45<14:01,  1.01it/s]Loss = 0.6937657594680786  Accuracy = 0.5091666579246521
 20%|██        | 200/1000 [03:35<13:12,  1.01it/s]Loss = 0.6920759081840515  Accuracy = 0.5221874713897705
 25%|██▌       | 250/1000 [04:24<12:28,  1.00it/s]Loss = 0.6906613111495972  Accuracy = 0.5267500281333923
 30%|███       | 300/1000 [05:14<11:35,  1.01it/s]Loss = 0.6866089105606079  Accuracy = 0.5407291650772095
 35%|███▌      | 350/1000 [06:03<10:48,  1.00it/s]Loss = 0.6828024387359619  Accuracy = 0.5516071319580078
 40%|████      | 400/1000 [06:53<09:57,  1.00it/s]Loss = 0.6743711829185486  Accuracy = 0.567578136920929
 45%|████▌     | 450/1000 [07:43<09:03,  1.01it/s]Loss = 0.6637136340141296  Accuracy = 0.614583301544109
 50%|█████     | 500/1000 [08:32<08:13,  1

In [None]:
print("--- %s seconds ---" % (time.time() - start_time))

--- 6935.456838846207 seconds ---


In [None]:
@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = model(features, training=False)
  loss, log_probs = headl(pooled_output, labels, False)
  return loss, log_probs

In [None]:
eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 32})

In [None]:
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for i, ex in enumerate(tqdm(eval_dataset, position=0)):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)

print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))

100%|██████████| 781/781 [03:11<00:00,  4.08it/s]Loss = 0.378199964761734  Accuracy = 0.90625


In [None]:
student_bert_config = flags.as_dictionary()

In [None]:
student_bert_config['intermediate_size'] = 512
# student_bert_config['hidden_size']= 256
# student_bert_config['iterations_per_loop']= '800'
student_bert_config['num_attention_heads']= 4 # 12 for teacher model
student_bert_config['num_hidden_layers'] = 3 # 12 for teacher model

In [None]:
student_bert_config

{'attention_probs_dropout_prob': 0.0,
 'attention_type': 'block_sparse',
 'block_size': 16,
 'data_dir': 'tfds://imdb_reviews/plain_text',
 'do_eval': False,
 'do_export': True,
 'do_train': True,
 'eval_batch_size': 8,
 'gcp_project': None,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.0,
 'hidden_size': 256,
 'init_checkpoint': None,
 'initializer_range': 0.02,
 'intermediate_size': 512,
 'iterations_per_loop': '1000',
 'learning_rate': 1e-05,
 'master': None,
 'max_encoder_length': 512,
 'max_position_embeddings': 4096,
 'norm_type': 'postnorm',
 'num_attention_heads': 4,
 'num_hidden_layers': 3,
 'num_labels': 2,
 'num_rand_blocks': 3,
 'num_tpu_cores': 8,
 'num_train_steps': 1000,
 'num_warmup_steps': 1000,
 'optimizer': 'AdamWeightDecay',
 'optimizer_beta1': 0.9,
 'optimizer_beta2': 0.999,
 'optimizer_epsilon': 1e-06,
 'output_dir': '/tmp/bigb',
 'rescale_embedding': False,
 'save_checkpoints_steps': 1000,
 'scope': 'bert',
 'substitute_newline': None,
 'tpu_job_name': None,
 

In [None]:
import numpy as np
alpha = 0.2

student_model = modeling.BertModel(student_bert_config)
student_headl = run_classifier.ClassifierLossLayer(
        student_bert_config["hidden_size"], student_bert_config["num_labels"],
        student_bert_config["hidden_dropout_prob"],
        utils.create_initializer(student_bert_config["initializer_range"]),
        name=student_bert_config["scope"]+"/classifier")
        
@tf.function(experimental_compile=True)
def student_fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output_student = student_model(features, training=True)
    student_loss, log_probs_student = student_headl(pooled_output_student, labels, True)
    # print("log_probs_student: ",log_probs_student)
    # one_hot_labels = tf.one_hot(labels, depth=student_bert_config["num_labels"],dtype=tf.float32)
    # student_loss = -tf.reduce_sum(one_hot_labels * log_probs_student, axis=-1)
    # student_loss = - one_hot_labels * tf.math.log(log_probs_student) - (1 - one_hot_labels) * tf.math.log(1 - log_probs_student)

    _, pooled_output_teacher = model(features, training=False)
    teacher_loss, log_probs_teacher = headl(pooled_output_teacher, labels, False)
    # print("log_probs_teacher: ",log_probs_teacher)
    
    # distil_loss = - log_probs_teacher * tf.math.log(log_probs_student) - (1 - log_probs_teacher) * tf.math.log(1 - log_probs_student)
    
    mse = tf.keras.losses.MeanSquaredError()
    distil_loss = mse(log_probs_teacher, log_probs_student)
    # print("distil_loss.numpy(): ",distil_loss.numpy())
    # print("student_loss: ",student_loss.shape)
    # print("distil_loss: ",distil_loss.shape)
    # distil_loss = (student_loss-teacher_loss)**2
    loss = alpha * student_loss + (1-alpha) * distil_loss
    # loss = student_loss
    # print('\nloss:',loss,'\nstudent_loss:',student_loss)
    # loss = tf.reduce_mean(loss, axis=-1)
    # tf.print(loss)
  grads = g.gradient(loss, student_model.trainable_weights+headl.trainable_weights)
  return loss, log_probs_student, grads

train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 32})

import time
start_time = time.time()
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = student_fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, student_model.trainable_weights+student_headl.trainable_weights))
  train_loss(loss)
  # print(train_loss(loss))
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if (i+1)% 50 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))
  # break

print("--- %s seconds ---" % (time.time() - start_time))
print("student_model params:",student_model.count_params())

@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = student_model(features, training=False)
  loss, log_probs = student_headl(pooled_output, labels, False)
  return loss, log_probs


eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 32})

eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for ex in tqdm(eval_dataset, position=0):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))



  0%|          | 0/1000 [00:00<?, ?it/s]
  5%|▌         | 50/1000 [02:53<43:52,  2.77s/it]Loss = 0.43946707248687744  Accuracy = 0.4987500011920929
 10%|█         | 100/1000 [05:12<41:36,  2.77s/it]Loss = 0.43933141231536865  Accuracy = 0.49781250953674316
 15%|█▌        | 150/1000 [07:30<39:21,  2.78s/it]Loss = 0.4384157359600067  Accuracy = 0.5087500214576721
 20%|██        | 200/1000 [09:49<37:00,  2.78s/it]Loss = 0.44259029626846313  Accuracy = 0.5223437547683716
 25%|██▌       | 250/1000 [12:07<34:19,  2.75s/it]Loss = 0.4414173364639282  Accuracy = 0.5237500071525574
 30%|███       | 300/1000 [14:24<32:06,  2.75s/it]Loss = 0.4424702525138855  Accuracy = 0.5390625
 35%|███▌      | 350/1000 [16:43<30:01,  2.77s/it]Loss = 0.4390639066696167  Accuracy = 0.5479464530944824
 40%|████      | 400/1000 [19:01<27:46,  2.78s/it]Loss = 0.4311911463737488  Accuracy = 0.5603125095367432
 45%|████▌     | 450/1000 [21:20<25:28,  2.78s/it]Loss = 0.4258018434047699  Accuracy = 0.6095139169692993
 

In [None]:
import numpy as np

student_model = modeling.BertModel(student_bert_config)
student_headl = run_classifier.ClassifierLossLayer(
      student_bert_config["hidden_size"], student_bert_config["num_labels"],
      student_bert_config["hidden_dropout_prob"],
      utils.create_initializer(student_bert_config["initializer_range"]),
      name=student_bert_config["scope"]+"/classifier")

@tf.function(experimental_compile=True)
def student_fwd_bwd(features, labels):
  with tf.GradientTape() as g:
    _, pooled_output_student = student_model(features, training=True)
    loss, log_probs_student = student_headl(pooled_output_student, labels, True)      
  grads = g.gradient(loss, student_model.trainable_weights+student_headl.trainable_weights)
  return loss, log_probs_student, grads

train_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=True)
dataset = train_input_fn({'batch_size': 32})

import time
start_time = time.time()
opt = tf.keras.optimizers.Adam(FLAGS.learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

for i, ex in enumerate(tqdm(dataset.take(FLAGS.num_train_steps), position=0)):
  loss, log_probs, grads = student_fwd_bwd(ex[0], ex[1])
  opt.apply_gradients(zip(grads, student_model.trainable_weights+student_headl.trainable_weights))
  train_loss(loss)
  # print(train_loss(loss))
  train_accuracy(tf.one_hot(ex[1], 2), log_probs)
  if (i+1)% 50 == 0:
    print('Loss = {}  Accuracy = {}'.format(train_loss.result().numpy(), train_accuracy.result().numpy()))

print("--- %s seconds ---" % (time.time() - start_time))

@tf.function(experimental_compile=True)
def fwd_only(features, labels):
  _, pooled_output = student_model(features, training=False)
  loss, log_probs = student_headl(pooled_output, labels, False)
  return loss, log_probs


eval_input_fn = run_classifier.input_fn_builder(
        data_dir=FLAGS.data_dir,
        vocab_model_file=FLAGS.vocab_model_file,
        max_encoder_length=FLAGS.max_encoder_length,
        substitute_newline=FLAGS.substitute_newline,
        is_training=False)
eval_dataset = eval_input_fn({'batch_size': 32})

eval_loss = tf.keras.metrics.Mean(name='eval_loss')
eval_accuracy = tf.keras.metrics.CategoricalAccuracy(name='eval_accuracy')

for ex in tqdm(eval_dataset, position=0):
  loss, log_probs = fwd_only(ex[0], ex[1])
  eval_loss(loss)
  eval_accuracy(tf.one_hot(ex[1], 2), log_probs)
print('Loss = {}  Accuracy = {}'.format(eval_loss.result().numpy(), eval_accuracy.result().numpy()))


  5%|▌         | 50/1000 [01:06<15:43,  1.01it/s]Loss = 0.69440758228302  Accuracy = 0.4975000023841858
 10%|█         | 100/1000 [01:56<14:54,  1.01it/s]Loss = 0.6949686408042908  Accuracy = 0.4987500011920929
 15%|█▌        | 150/1000 [02:45<14:01,  1.01it/s]Loss = 0.6937657594680786  Accuracy = 0.5091666579246521
 20%|██        | 200/1000 [03:35<13:12,  1.01it/s]Loss = 0.6920759081840515  Accuracy = 0.5221874713897705
 25%|██▌       | 250/1000 [04:24<12:28,  1.00it/s]Loss = 0.6906613111495972  Accuracy = 0.5267500281333923
 30%|███       | 300/1000 [05:14<11:35,  1.01it/s]Loss = 0.6866089105606079  Accuracy = 0.5407291650772095
 35%|███▌      | 350/1000 [06:03<10:48,  1.00it/s]Loss = 0.6828024387359619  Accuracy = 0.5516071319580078
 40%|████      | 400/1000 [06:53<09:57,  1.00it/s]Loss = 0.6743711829185486  Accuracy = 0.567578136920929
 45%|████▌     | 450/1000 [07:43<09:03,  1.01it/s]Loss = 0.6637136340141296  Accuracy = 0.581458330154419
 50%|█████     | 500/1000 [08:32<08:13,  

In [None]:
print("student_model params:",student_model.count_params())
print("teacher_model params:",model.count_params())
(student_model.count_params()/model.count_params())*100

student_model params: 51873792
teacher_model params: 127468800
40.6952854345534


In [None]:
2804.71/60

46.74516666666667


* Test set accuracy: IMDB dataset 
  * Full attention - 0.93305
  * Sparse attention - 0.90625
  * Sparse attention with KD (student model) - 0.87947
  * Sparse attention without KD (student model) - 0.703053




* Number of parameters in student model : 51 million
* Number of parameters in teacher model : 127 million
* Size of student model is 40% of teacher model



* Time taken to train original full attention: 132.4 mins
* Time taken to train sparse attention: 115.5 mins
* Time taken to train sparse attention with KD (student model): 46.74 mins
* Time taken to train sparse attention without KD (student model): 16.77 mins


* Sequence length - 512
* Number of epochs - 1000
* Batch size - 32


In [None]:
# Teacher accuracy: 0.97687
# Student accuracy: 0.9456
# KT accuracy: 0.95370