# Fine tune pre-trained BERT

In this notebooks:

1. Load train dataet
2. Load pre-trained BERT model
3. Define model architecture
4. Compile model
5. Train model
6. Save model

## 1. Load train dataset

In [33]:
import tensorflow as tf

train_ds = tf.data.Dataset.load('train')
validation_ds = tf.data.Dataset.load('test')

In [2]:
train_ds.take(1)

<TakeDataset element_spec=({'input_ids': TensorSpec(shape=(16, 512), dtype=tf.int64, name=None), 'attention_mask': TensorSpec(shape=(16, 512), dtype=tf.int64, name=None)}, TensorSpec(shape=(16, 5), dtype=tf.float64, name=None))>

## 2. Load pre-trained BERT model

In [3]:
from transformers import TFAutoModel

In [19]:
pretrained_bert = TFAutoModel.from_pretrained('bert-base-cased')

Some layers from the model checkpoint at bert-base-cased were not used when initializing TFBertModel: ['nsp___cls', 'mlm___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at bert-base-cased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [20]:
pretrained_bert.summary()

Model: "tf_bert_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 bert (TFBertMainLayer)      multiple                  108310272 
                                                                 
Total params: 108,310,272
Trainable params: 108,310,272
Non-trainable params: 0
_________________________________________________________________


## 3. Define model architecture

In [6]:
import tensorflow as tf

In [22]:
# Input Layer (2 inputs)
# Names have to match the name defined during the Dataset creation
input_ids = tf.keras.layers.Input(shape=(512,), name='input_ids', dtype='int32')
mask = tf.keras.layers.Input(shape=(512,), name='attention_mask', dtype='int32')

# Transformer layer
embeddings = pretrained_bert.bert(input_ids, attention_mask=mask).pooler_output

# Classification head
X = tf.keras.layers.Dense(1024, activation='relu')(embeddings)
y = tf.keras.layers.Dense(5, activation='softmax', name='outputs')(X)

# Create the model
model = tf.keras.Model(inputs=[input_ids, mask], outputs=y)

In [24]:
# Don't retrain BERT layer
model.layers[2].trainable = False

In [25]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 512)]        0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 512)]        0           []                               
                                                                                                  
 bert (TFBertMainLayer)         TFBaseModelOutputWi  108310272   ['input_ids[0][0]',              
                                thPoolingAndCrossAt               'attention_mask[0][0]']         
                                tentions(last_hidde                                               
                                n_state=(None, 512,                                         

## 4. Compile the model

In [30]:
# Parameters recommended on BERT documentation
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5, decay=1e-6)

# Loss function
loss = tf.keras.losses.CategoricalCrossentropy()

# Metrics
metrics = [tf.keras.metrics.CategoricalAccuracy('accuracy')]

In [31]:
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

## 5. Train the model

In [None]:
training_results = model.fit(
    train_ds,
    validation_data=validation_ds,
    epochs=3
)