# A token classification model

__Objective:__ build a custom classification token model based on the XLM-R body with a custom classification head.

__Source:__ [here](https://www.oreilly.com/library/view/natural-language-processing/9781098136789/)

In [None]:
import tensorflow as tf
from transformers import AutoTokenizer

## Load tokenizer

In [None]:
distilbert_model_name = 'distilbert-base-uncased'

distilbert_tokenizer = AutoTokenizer.from_pretrained(distilbert_model_name)

Test the tokenizer.

In [None]:
test_text = [
    "Splinter taught them to be ninja teens (He's a radical rat!)",
    "Leonardo leads, Donatello does machines (That's a fact, Jack!)",
    "Raphael is cool but crude (Gimme a break!)",
    "Michaelangelo is a party dude (Party!)",
    "Teenage Mutant Ninja Turtles"
]

test_tokens = distilbert_tokenizer(
    test_text,
    padding=True
)

test_tokens

## Build model as a subclass of a pretrained model

In [None]:
from transformers import DistilBertConfig
from transformers.models.distilbert.modeling_tf_distilbert import TFDistilBertModel, TFDistilBertPreTrainedModel
from tensorflow.keras.layers import Dropout, Dense

In [None]:
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
    """
    """
    def __init__(self, config):
        """
        """
        super().__init__(config)

        self.num_labels = config.num_labels

        # Load model body.
        self.distilbert = TFDistilBertModel(config)

        # Initialize token classification head.
        self.dropout = Dropout(0.1)
        self.classifier = Dense(units=config.num_labels, activation='softmax')

    def call(
        self,
        input_ids,
        **kwargs
    ):
        """
        """
        # Use RoBERTa to get the hiddens states.
        outputs = self.distilbert(
            input_ids=input_ids,
            **kwargs
        )

        # Apply classifier.
        sequence_output = self.dropout(outputs[0])
        logits = self.classifier(sequence_output)

        return logits

In [None]:
token_classification_model = TFDistilBertForTokenClassification.from_pretrained(distilbert_model_name)

Test generating predictions. Output shape: `(batch_shape, seq_len, num_labels)`, where `num_labels` is read from the pretrained model's config object (in case of the `distilbert-base-uncased` checkpoint, there are 2 classes).

In [None]:
test_input_ids = tf.constant(test_tokens['input_ids'])

pred = token_classification_model(input_ids=test_input_ids)

pred

Check: with softmax activation, summing over the last dimension should give 1 (normalized output probabilities).

In [None]:
tf.reduce_sum(
    pred,
    axis=-1
)

## As a pure Tensorflow model

As an experiment, let's redo the same, this time with less of Huggingface Transformers' machinery.

In [None]:
from tensorflow.keras.layers import Layer
from tensorflow.keras import Input, Model
from transformers import TFAutoModel

In [None]:
class TFDistilBertForTokenClassificationLayer(Layer):
    """
    """
    def __init__(self, config, model_ckpt='distilbert-base-uncased'):
        """
        """
        super().__init__()

        # Model head (pretrained).
        self.distilbert = TFAutoModel.from_pretrained(model_ckpt)

        # Model body.
        self.dropout = Dropout(config['body_dropout_rate'])
        self.classification = Dense(units=config['num_classes'], activation='softmax')

    def call(self, input_ids):
        """
        """
        x = self.distilbert(input_ids=input_ids)

        x = self.dropout(x[0])
        x = self.classification(x)

        return x

In [None]:
config = {
    'body_dropout_rate': 0.1,
    'num_classes': 10
}

In [None]:
token_classification_layer = TFDistilBertForTokenClassificationLayer(config=config)

In [None]:
token_classification_layer(test_input_ids)

Build a Keras `Model` object and train it on fake data.

In [None]:
inputs = Input(shape=test_input_ids.shape[1:], dtype=tf.int32)
outputs = token_classification_layer(inputs)

token_classification_model_2 = Model(
    inputs=inputs,
    outputs=outputs
)

In [None]:
token_classification_model_2.compile(
    optimizer='adam',
    loss='mse'
)

In [None]:
token_classification_model_2.fit(
    x=test_input_ids,
    y=tf.random.uniform(shape=token_classification_layer(test_input_ids).shape),
    epochs=1
)