# BiLSTM-CRF with keras-crf

This notebook builds a Bidirectional LSTM + CRF model for sequence labeling using the standalone `keras_crf` package.
It trains on a small synthetic dataset and evaluates decoding accuracy.

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from keras_crf import CRF, text as kcrf
print(tf.__version__, keras.__version__)

2025-08-22 15:24:37.532918: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-22 15:24:37.547404: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-22 15:24:37.679181: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-22 15:24:37.821484: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755833077.936308  454677 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755833077.96

2.18.1 3.10.0


## Create a synthetic tagging dataset
We create sequences of token IDs. Tags depend on simple rules over tokens so the model has a learnable pattern.

In [5]:
def make_dataset(num_samples=1000, seq_len=20, vocab_size=100, num_tags=3, seed=42):
    rng = np.random.default_rng(seed)
    X = rng.integers(low=1, high=vocab_size, size=(num_samples, seq_len), dtype=np.int32)
    # Tag rule example:
    # tag 2 if token%10 in {7,8,9}, tag 1 if token%10 in {3,4,5,6}, else tag 0
    Y = np.zeros((num_samples, seq_len), dtype=np.int32)
    mod = X % 10
    Y[mod >= 7] = 2
    Y[(mod >= 3) & (mod <= 6)] = 1
    # Add some noise
    flip_idx = rng.random((num_samples, seq_len)) < 0.05
    Y[flip_idx] = rng.integers(0, num_tags, size=np.count_nonzero(flip_idx))
    return X, Y

num_tags = 3
vocab_size = 200
seq_len = 30
X_train, Y_train = make_dataset(2000, seq_len, vocab_size, num_tags, seed=1)
X_val, Y_val     = make_dataset(400, seq_len, vocab_size, num_tags, seed=2)
X_test, Y_test   = make_dataset(400, seq_len, vocab_size, num_tags, seed=3)
X_train.shape, Y_train.shape

((2000, 30), (2000, 30))

## Build BiLSTM-CRF
We use an Embedding followed by a bidirectional LSTM (returning sequences). The CRF layer consumes the LSTM features and decodes/learns transitions.

In [8]:
embedding_dim = 64
lstm_units = 64

class BiLstmCrfModel(keras.Model):
    def __init__(self, vocab_size, num_tags, embedding_dim=64, lstm_units=64):
        super().__init__()
        self.embedding = layers.Embedding(input_dim=vocab_size+1, output_dim=embedding_dim, mask_zero=True)
        self.bilstm = layers.Bidirectional(layers.LSTM(lstm_units, return_sequences=True))
        # CRF will apply an internal Dense to project to num_tags (use_kernel=True by default)
        self.crf = CRF(units=num_tags)

    def call(self, inputs, training=False):
        x = self.embedding(inputs)
        x = self.bilstm(x)
        decoded, potentials, seq_len, kernel = self.crf(x, mask=self.embedding.compute_mask(inputs))
        return decoded, potentials, seq_len, kernel

class ModelWithCRFLoss(keras.Model):
    def __init__(self, core):
        super().__init__()
        self.core = core

    def call(self, inputs, training=False):
        return self.core(inputs, training=training)

    def train_step(self, data):
        x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
        with tf.GradientTape() as tape:
            decoded, potentials, seq_len, kernel = self(x, training=True)
            ll, _ = kcrf.crf_log_likelihood(potentials, y, seq_len, kernel)
            loss = -tf.reduce_mean(ll)
            if sample_weight is not None:
                sw = tf.cast(sample_weight, loss.dtype)
                # broadcast to batch
                if sw.shape.rank == 0:
                    sw = tf.fill(tf.shape(ll), sw)
                loss = tf.reduce_mean(sw * (-ll))
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return {"loss": loss}

    def test_step(self, data):
        x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
        decoded, potentials, seq_len, kernel = self(x, training=False)
        ll, _ = kcrf.crf_log_likelihood(potentials, y, seq_len, kernel)
        loss = -tf.reduce_mean(ll)
        if sample_weight is not None:
            sw = tf.cast(sample_weight, loss.dtype)
            if sw.shape.rank == 0:
                sw = tf.fill(tf.shape(ll), sw)
            loss = tf.reduce_mean(sw * (-ll))
        return {"loss": loss}

core = BiLstmCrfModel(vocab_size=vocab_size, num_tags=num_tags, embedding_dim=embedding_dim, lstm_units=lstm_units)
model = ModelWithCRFLoss(core)
model.compile(optimizer=keras.optimizers.Adam(1e-3))
model

<ModelWithCRFLoss name=model_with_crf_loss_2, built=False>

## Train

In [9]:
history = model.fit(X_train, Y_train, validation_data=(X_val, Y_val), epochs=5, batch_size=64)
history.history

Epoch 1/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 51ms/step - loss: 30.5184 - val_loss: 0.0000e+00
Epoch 2/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 38ms/step - loss: 11.5061 - val_loss: 0.0000e+00
Epoch 3/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 39ms/step - loss: 5.1077 - val_loss: 0.0000e+00
Epoch 4/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 36ms/step - loss: 4.9165 - val_loss: 0.0000e+00
Epoch 5/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 36ms/step - loss: 4.8703 - val_loss: 0.0000e+00


{'loss': [0.0, 0.0, 0.0, 0.0, 0.0], 'val_loss': [0.0, 0.0, 0.0, 0.0, 0.0]}

## Evaluate decoding accuracy
We get decoded tags from the model (first output) and compare to ground-truth.

In [10]:
decoded_test, potentials_test, seq_len_test, kernel_test = model.predict(X_test, batch_size=64, verbose=0)
acc = np.mean((decoded_test == Y_test).astype(np.float32))
print(f'Decode accuracy: {acc:.4f}')

Decode accuracy: 0.9670


## Inspect a sample

In [11]:
i = 0
print('Tokens:', X_test[i])
print('True  :', Y_test[i])
print('Pred  :', decoded_test[i])

Tokens: [162  18  36  48  37 160 173 116   8  19  67  87 124  96  53  32 138 147
   7  23  90  78 177 103  84  86 133 117  35 147]
True  : [0 2 1 2 2 0 1 1 2 0 2 2 1 1 1 0 2 2 2 1 0 2 2 1 1 1 1 2 1 2]
Pred  : [0 2 1 2 2 0 1 1 2 2 2 2 1 1 1 0 2 2 2 1 0 2 2 1 1 1 1 2 1 2]
