# Keras Core CRF on Torch or JAX

This notebook demonstrates running the CRF layer with Keras 3 universal ops on non-TensorFlow backends.

Setup:
- Choose a backend before importing Keras: set the environment variable `KERAS_BACKEND=torch` or `KERAS_BACKEND=jax`.
- Install backend packages if needed (CPU by default):
  - Torch: `pip install .[torch]` (or follow https://pytorch.org/get-started/locally/ for CUDA)
  - JAX: `pip install .[jax]` (or follow https://jax.readthedocs.io/en/latest/installation.html for CUDA)

Note: If you change `KERAS_BACKEND` within this notebook, restart the kernel before importing Keras.

In [None]:
# Optional: set backend in-code (prefer setting the environment before starting the kernel)
import os
# os.environ["KERAS_BACKEND"] = "torch"  # or "jax"
print("KERAS_BACKEND=", os.environ.get("KERAS_BACKEND", "<not set>"))

In [None]:
import numpy as np
import keras
from keras import layers, ops as K
from keras_crf import CRF

# Build a simple model: Embedding -> BiLSTM -> CRF
vocab_size = 100
num_tags = 5
max_len = 10

inputs = keras.Input(shape=(max_len,), dtype="int32")
emb = layers.Embedding(vocab_size + 1, 16, mask_zero=True)(inputs)
seq = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(emb)
crf = CRF(units=num_tags)
decoded, potentials, seq_len, kernel = crf(seq)
model = keras.Model(inputs, [decoded, potentials, seq_len, kernel])
model

In [None]:
# Dummy data
X = np.random.randint(1, vocab_size, (4, max_len)).astype("int32")
Y = np.random.randint(0, num_tags, (4, max_len)).astype("int32")

decoded_out, potentials_out, seq_len_out, kernel_out = model(X)
print("decoded shape:", K.shape(decoded_out))
print("potentials shape:", K.shape(potentials_out))
print("seq_len shape:", K.shape(seq_len_out))

# Compute backend-agnostic CRF log-likelihood
ll = crf.log_likelihood(potentials_out, Y, seq_len_out)
print("mean negative log-likelihood:", float(K.mean(-ll)))