# Finetuning a pretrained BERT model on MRPC task

WIP

- [x] Test on Colab
- [ ] Add exercises
- [ ] Add references and explanations
- [ ] Include original code

In [None]:
!pip install tensorflow>=2.0.0 tensorflow_datasets transformers numpy

In [None]:
import tensorflow_datasets

In [None]:
import os
import tensorflow as tf
import tensorflow_datasets
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig, glue_convert_examples_to_features, BertForSequenceClassification, glue_processors

# script parameters
BATCH_SIZE = 32
EVAL_BATCH_SIZE = BATCH_SIZE * 2
USE_XLA = False
USE_AMP = False
EPOCHS = 1

TASK = "mrpc"

num_labels = len(glue_processors[TASK]().get_labels())

tf.config.optimizer.set_jit(USE_XLA)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})

# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
config = BertConfig.from_pretrained("bert-base-cased", num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')


In [None]:
# Load dataset via TensorFlow Datasets
data, info = tensorflow_datasets.load('glue/mrpc', with_info=True)
train_examples = info.splits['train'].num_examples

# MNLI expects either validation_matched or validation_mismatched
valid_examples = info.splits['validation'].num_examples

# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, 128, TASK)

# MNLI expects either validation_matched or validation_mismatched
valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, 128, TASK)
train_dataset = train_dataset.shuffle(128).batch(BATCH_SIZE).repeat(-1)
valid_dataset = valid_dataset.batch(EVAL_BATCH_SIZE)

In [4]:
model = TFBertForSequenceClassification.from_pretrained('bert-base-cased', config=config)

# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule 
opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
if USE_AMP:
    # loss scaling is currently required when using mixed precision
    opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')


if num_labels == 1:
    loss = tf.keras.losses.MeanSquaredError()
else:
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=opt, loss=loss, metrics=[metric])

INFO:absl:Load pre-computed datasetinfo (eg: splits) from bucket.
INFO:absl:Loading info from GCS for glue/mrpc/0.0.2
INFO:absl:Generating dataset glue (/home/collion/tensorflow_datasets/glue/mrpc/0.0.2)


[1mDownloading and preparing dataset glue (1.43 MiB) to /home/collion/tensorflow_datasets/glue/mrpc/0.0.2...[0m


INFO:absl:Downloading https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt into /home/collion/tensorflow_datasets/downloads/dl.fbaip.com_sente_sente_msr_parap_test0PdekMcyqYR-w4Rx_d7OTryq0J3RlYRn4rAMajy9Mak.txt.tmp.1324833565e54201bb8d0f7b2e60886a...
INFO:absl:Downloading https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt into /home/collion/tensorflow_datasets/downloads/dl.fbaip.com_sente_sente_msr_parap_trainfGxPZuQWGBti4Tbd1YNOwQr-OqxPejJ7gcp0Al6mlSk.txt.tmp.50490adb56d445d1a014a6f06a2d68a9...
INFO:absl:Downloading https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc into /home/collion/tensorflow_datasets/downloads/fire.goog.com_v0_b_mtl-sent-repr.apps.com_o_2FjSIMlCiqs1QSmIykr4IRPnEHjPuGwAz5i40v8K9U0Z8.tsvalt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc.tmp.f9f9d579ac874bc8b5f440b7b6b33736...
INFO:absl:Generating spli








Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


INFO:absl:Generating split validation
INFO:absl:Writing TFRecords






INFO:absl:Generating split test
INFO:absl:Writing TFRecords






INFO:absl:Skipping computing stats for mode ComputeStatsMode.AUTO.
INFO:absl:Constructing tf.data.Dataset for split None, from /home/collion/tensorflow_datasets/glue/mrpc/0.0.2


[1mDataset glue downloaded and prepared to /home/collion/tensorflow_datasets/glue/mrpc/0.0.2. Subsequent calls will reuse this data.[0m


In [None]:
# Train and evaluate using tf.keras.Model.fit()
train_steps = train_examples//BATCH_SIZE
valid_steps = valid_examples//EVAL_BATCH_SIZE

history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=train_steps,
                    validation_data=valid_dataset, validation_steps=valid_steps)

# Save TF2 model
os.makedirs('./save/', exist_ok=True)
model.save_pretrained('./save/')

Train for 114 steps, validate for 6 steps


Compatibility between TF and Pytorch (from HuggingFace Transformers, useful?)

In [3]:
if TASK == "mrpc":
    # Load the TensorFlow model in PyTorch for inspection
    # This is to demo the interoperability between the two frameworks, you don't have to 
    # do this in real life (you can run the inference on the TF model).
    pytorch_model = BertForSequenceClassification.from_pretrained('./save/', from_tf=True)

    # Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task
    sentence_0 = 'This research was consistent with his findings.'
    sentence_1 = 'His findings were compatible with this research.'
    sentence_2 = 'His findings were not compatible with this research.'
    inputs_1 = tokenizer.encode_plus(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')
    inputs_2 = tokenizer.encode_plus(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')

    del inputs_1["special_tokens_mask"]
    del inputs_2["special_tokens_mask"]

    pred_1 = pytorch_model(**inputs_1)[0].argmax().item()
    pred_2 = pytorch_model(**inputs_2)[0].argmax().item()
    print('sentence_1 is', 'a paraphrase' if pred_1 else 'not a paraphrase', 'of sentence_0')
    print('sentence_2 is', 'a paraphrase' if pred_2 else 'not a paraphrase', 'of sentence_0')

2


100%|██████████| 313/313 [00:00<00:00, 138264.05B/s]
100%|██████████| 213450/213450 [00:00<00:00, 814692.17B/s]
100%|██████████| 526681800/526681800 [01:29<00:00, 5882059.80B/s] 


DatasetNotFoundError: Dataset glue_data not found. Available datasets:
	- abstract_reasoning
	- aeslc
	- aflw2k3d
	- amazon_us_reviews
	- bair_robot_pushing_small
	- big_patent
	- bigearthnet
	- billsum
	- binarized_mnist
	- binary_alpha_digits
	- c4
	- caltech101
	- caltech_birds2010
	- caltech_birds2011
	- cars196
	- cassava
	- cats_vs_dogs
	- celeb_a
	- celeb_a_hq
	- chexpert
	- cifar10
	- cifar100
	- cifar10_1
	- cifar10_corrupted
	- clevr
	- cmaterdb
	- cnn_dailymail
	- coco
	- coco2014
	- coil100
	- colorectal_histology
	- colorectal_histology_large
	- curated_breast_imaging_ddsm
	- cycle_gan
	- deep_weeds
	- definite_pronoun_resolution
	- diabetic_retinopathy_detection
	- downsampled_imagenet
	- dsprites
	- dtd
	- dummy_dataset_shared_generator
	- dummy_mnist
	- emnist
	- eurosat
	- fashion_mnist
	- flores
	- food101
	- gap
	- gigaword
	- glue
	- groove
	- higgs
	- horses_or_humans
	- image_label_folder
	- imagenet2012
	- imagenet2012_corrupted
	- imagenet_resized
	- imdb_reviews
	- iris
	- kitti
	- kmnist
	- lfw
	- lm1b
	- lsun
	- malaria
	- mnist
	- mnist_corrupted
	- moving_mnist
	- multi_news
	- multi_nli
	- multi_nli_mismatch
	- newsroom
	- nsynth
	- omniglot
	- open_images_v4
	- oxford_flowers102
	- oxford_iiit_pet
	- para_crawl
	- patch_camelyon
	- pet_finder
	- places365_small
	- quickdraw_bitmap
	- reddit_tifu
	- resisc45
	- rock_paper_scissors
	- rock_you
	- scene_parse150
	- scientific_papers
	- shapes3d
	- smallnorb
	- snli
	- so2sat
	- squad
	- stanford_dogs
	- stanford_online_products
	- starcraft_video
	- sun397
	- super_glue
	- svhn_cropped
	- ted_hrlr_translate
	- ted_multi_translate
	- tf_flowers
	- the300w_lp
	- titanic
	- trivia_qa
	- uc_merced
	- ucf101
	- visual_domain_decathlon
	- voc
	- wider_face
	- wikihow
	- wikipedia
	- wmt14_translate
	- wmt15_translate
	- wmt16_translate
	- wmt17_translate
	- wmt18_translate
	- wmt19_translate
	- wmt_t2t_translate
	- wmt_translate
	- xnli
	- xsum
Check that:
    - the dataset name is spelled correctly
    - dataset class defines all base class abstract methods
    - dataset class is not in development, i.e. if IN_DEVELOPMENT=True
    - the module defining the dataset class is imported
