# Fine tune pre-trained BERT

In this notebooks:

1. Load train dataet
2. Load pre-trained BERT model
3. Define model architecture
4. Compile model
5. Train model
6. Save model

## 1. Load train dataset

In [14]:
import tensorflow as tf

# In case of using Tensorflow 2, replace tf.data.experimental.load
train_ds = tf.data.Dataset.load('train')
validation_ds = tf.data.Dataset.load('test')

In [15]:
train_ds.take(1)

<TakeDataset element_spec=({'attention_mask': TensorSpec(shape=(16, 512), dtype=tf.int64, name=None), 'input_ids': TensorSpec(shape=(16, 512), dtype=tf.int64, name=None)}, TensorSpec(shape=(16, 5), dtype=tf.float64, name=None))>

## 2. Load pre-trained BERT model

In [69]:
from transformers import TFAutoModel

In [73]:
pretrained_bert = TFAutoModel.from_pretrained('distilbert-base-cased')

Some layers from the model checkpoint at distilbert-base-cased were not used when initializing TFDistilBertModel: ['vocab_layer_norm', 'vocab_projector', 'activation_13', 'vocab_transform']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-cased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


## 3. Define model architecture

In [74]:
import tensorflow as tf

In [88]:
# Input Layer (2 inputs)
# Names have to match the name defined during the Dataset creation
input_ids = tf.keras.layers.Input(shape=(512,), name='input_ids', dtype='int32')
mask = tf.keras.layers.Input(shape=(512,), name='attention_mask', dtype='int32')

# Transformer layer
embeddings = pretrained_bert.distilbert(input_ids, attention_mask=mask)[0][:,0,:]

# Classification head
X = tf.keras.layers.Dense(1024, activation='relu')(embeddings)
y = tf.keras.layers.Dense(5, activation='softmax', name='outputs')(X)

# Create the model
model = tf.keras.Model(inputs=[input_ids, mask], outputs=y)

In [89]:
# Don't retrain BERT layer
model.layers[2].trainable = False

In [90]:
model.summary()

Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 512)]        0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 512)]        0           []                               
                                                                                                  
 distilbert (TFDistilBertMainLa  TFBaseModelOutput(l  65190912   ['input_ids[0][0]',              
 yer)                           ast_hidden_state=(N               'attention_mask[0][0]']         
                                one, 512, 768),                                                   
                                 hidden_states=None                                         

## 4. Compile the model

In [91]:
# Parameters recommended on BERT documentation
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5, decay=1e-6)

# Loss function
loss = tf.keras.losses.CategoricalCrossentropy()

# Metrics
metrics = [tf.keras.metrics.CategoricalAccuracy('accuracy')]

In [92]:
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

## 5. Train the model

In [None]:
training_results = model.fit(
    train_ds,
    validation_data=validation_ds,
    epochs=3
)

In [None]:
model.save('sentimental_bert')