# Fine-tuning a model with Keras

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_dataset(dataset):
    encoded = tokenizer(
        dataset["sentence1"],
        dataset["sentence2"],
        padding=True,
        truncation=True,
        return_tensors='np',
    )
    return encoded.data

tokenized_datasets = {
    split: tokenize_dataset(raw_datasets[split]) for split in raw_datasets.keys()
}

In [None]:
from transformers import TFAutoModelForSequenceClassification

model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

In [None]:
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)
model.fit(
    tokenized_datasets['train'],
    np.array(raw_datasets['train']['label']), 
    validation_data=(
        tokenized_datasets['validation'],
        np.array(raw_datasets['validation']['label']),
    ),
    batch_size=8
)

In [None]:
preds = model.predict(tokenized_datasets['validation'])['logits']

In [None]:
class_preds = np.argmax(preds, axis=1)
print(preds.shape, class_preds.shape)

(408, 2) (408,)

In [None]:
from datasets import load_metric

metric = load_metric("glue", "mrpc")
metric.compute(predictions=class_preds, references=raw_datasets['validation']['label']))

{'accuracy': 0.8578431372549019, 'f1': 0.8996539792387542}

In [None]:
class F1_metric(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super().__init__(name=name, **kwargs)
        # Initialize our metric by initializing the two metrics it's based on:
        # Precision and Recall
        self.precision = tf.keras.metrics.Precision()
        self.recall = tf.keras.metrics.Recall()

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Update our metric by updating the two metrics it's based on
        self.precision.update_state(y_true, y_pred, sample_weight)
        self.recall.update_state(y_true, y_pred, sample_weight)

    def result(self):
        # To get the F1 result, we compute the harmonic mean of the current
        # precision and recall
        return 2 / ((1 / self.precision.result()) + (1 / self.recall.result())) 

In [None]:
model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy', F1_metric()],
)
model.evaluate(
    tokenized_datasets['validation'],
    np.array(raw_datasets['validation']['label']),
)