# BERT Masked Language Modeling (Pure JAX)

Ce notebook présente une implémentation pédagogique de BERT en JAX pur, sans dépendre de bibliothèques de haut niveau pour la modélisation (comme Flax ou Haiku). L'objectif est de comprendre en détail la construction du modèle, le chargement de poids pré-entraînés, et le processus d'entraînement par masquage (MLM).

In [1]:
import os
import time
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import optax
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from datasets import load_dataset
from transformers import BertForMaskedLM

try:
    print("Devices:", jax.devices())
except:
    print("No generic JAX devices found.")

# Section 1: Présentation et Préparation des Données

Nous allons utiliser le tokenizer de `bert-base-uncased` et préparer quelques exemples de texte. Pour ce notebook, nous simulons le chargement de données.

In [2]:
def prepare_data_demo():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # Exemple de textes
    texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Artificial intelligence is transforming the world.",
        "JAX is a high-performance numerical computing library.",
        "BERT uses a transformer architecture to understand language."
    ] * 100 # Dupliquer pour avoir un petit dataset

    print("Exemple de tokenisation :")
    encoded = tokenizer(texts[0])
    print(f"Texte: {texts[0]}")
    print(f"IDs: {encoded['input_ids']}")
    print(f"Tokens: {tokenizer.convert_ids_to_tokens(encoded['input_ids'])}")

    # Tokenisation en batch
    max_len = 128
    tokenized = tokenizer(texts, max_length=max_len, padding="max_length", truncation=True, return_tensors="np")

    return tokenized["input_ids"], tokenizer

input_ids, tokenizer = prepare_data_demo()

# Section 2: Construction du Modèle BERT en Pure JAX

Nous allons définir couche par couche les composants de BERT : Linear, LayerNorm, Embeddings, Attention, et enfin l'Encodeur complet.

In [3]:
### Modules de Base (Linear, Embedding, LayerNorm, Dropout, MLP)


def Linear(in_features, out_features):
    def model_init(rkey):
        k_w, k_b = jr.split(rkey)
        lim = jnp.sqrt(6.0 / (in_features + out_features))
        weight = jr.uniform(k_w, (in_features, out_features), minval=-lim, maxval=lim)
        bias = jnp.zeros((out_features,))
        return {"weight": weight, "bias": bias}

    def model_apply(params, x):
        return jnp.dot(x, params["weight"]) + params["bias"]

    return model_init, model_apply

def Embedding(num_embeddings, embedding_size):
    def model_init(rkey):
        weight = jr.normal(rkey, (num_embeddings, embedding_size)) * 0.02
        return {"weight": weight}

    def model_apply(params, x):
        return params["weight"][x]

    return model_init, model_apply

def LayerNormalization(dummy_input, axis=-1, epsilon=1e-5):
    shape_ = list(dummy_input.shape)
    if type(axis) == int:
        shape_[axis] = 1
    else:
        for i in axis:
            shape_[i] = 1
    shape_ = tuple(shape_)

    def model_init(rkey):
        gamma = jnp.ones(shape_)
        beta = jnp.zeros(shape_)
        return {"gamma": gamma, "beta": beta}

    def model_apply(params, x):
        mean = jnp.mean(x, axis=axis, keepdims=True)
        var = jnp.var(x, axis=axis, keepdims=True)
        x_hat = (x - mean) / jnp.sqrt(var + epsilon)
        return params["gamma"] * x_hat + params["beta"]

    return model_init, model_apply

def Dropout(rate=0.5):
    def model_init(key):
        return {}

    def model_apply(params, x, inference=False, rkey=None):
        if inference or rate == 0.0:
            return x
        if rkey is None:
            return x # Safety fallback
        keep_prob = 1.0 - rate
        mask = jax.random.bernoulli(rkey, keep_prob, x.shape)
        return mask * x / keep_prob

    return model_init, model_apply

def MLP(layer_sizes, activation_name="relu"):
    def linear_init(key, in_dim, out_dim):
        k1, k2 = jr.split(key)
        lim = jnp.sqrt(6.0 / (in_dim + out_dim))
        return (jr.uniform(k1, (in_dim, out_dim), minval=-lim, maxval=lim), jnp.zeros((out_dim,)))

    def model_init(rkey):
        keys = jr.split(rkey, len(layer_sizes))
        params = []
        for i in range(len(layer_sizes) - 1):
            params.append(linear_init(keys[i], layer_sizes[i], layer_sizes[i+1]))
        return params

    def activation(x):
        if activation_name == "relu":
            return jnp.maximum(0, x)
        elif activation_name == "gelu" or activation_name == "gelu_approximate":
            return jax.nn.gelu(x, approximate=True)
        else:
            return jax.nn.relu(x)

    def model_apply(params, inp):
        activations = inp
        for w, b in params[:-1]:
            activations = jnp.dot(activations, w) + b
            activations = activation(activations)
        final_w, final_b = params[-1]
        return jnp.dot(activations, final_w) + final_b

    return model_init, model_apply

In [4]:
### Multi-Head Attention

def MultiHeadAttention(hidden_size, num_heads, dropout_rate=0.1):
    assert hidden_size % num_heads == 0
    head_dim = hidden_size // num_heads

    q_proj_init, q_proj_apply = Linear(hidden_size, hidden_size)
    k_proj_init, k_proj_apply = Linear(hidden_size, hidden_size)
    v_proj_init, v_proj_apply = Linear(hidden_size, hidden_size)
    out_proj_init, out_proj_apply = Linear(hidden_size, hidden_size)

    def model_init(rkey):
        k_q, k_k, k_v, k_o = jr.split(rkey, 4)
        return {
            "query": q_proj_init(k_q), "key": k_proj_init(k_k),
            "value": v_proj_init(k_v), "output": out_proj_init(k_o)
        }

    def model_apply(params, query, key_, value, mask=None, inference=False, rkey=None):
        seq_len = query.shape[0]
        Q = q_proj_apply(params["query"], query)
        K = k_proj_apply(params["key"], key_)
        V = v_proj_apply(params["value"], value)

        def split_heads(x):
            return jnp.transpose(x.reshape(seq_len, num_heads, head_dim), (1, 0, 2))

        Q, K, V = split_heads(Q), split_heads(K), split_heads(V)

        scores = jnp.matmul(Q, jnp.transpose(K, (0, 2, 1))) / jnp.sqrt(head_dim)
        if mask is not None: scores = scores + mask
        weights = jax.nn.softmax(scores, axis=-1)

        if not inference and rkey is not None:
             weights = jax.random.bernoulli(rkey, 1.0 - dropout_rate, weights.shape) * weights / (1.0 - dropout_rate)

        context = jnp.transpose(jnp.matmul(weights, V), (1, 0, 2)).reshape(seq_len, hidden_size)
        return out_proj_apply(params["output"], context)

    return model_init, model_apply

In [5]:
### BERT Blocks (Encoder Layer)

def EmbedderBlock(vocab_size, max_length, type_vocab_size, embedding_size, hidden_size, dropout_rate):
    word_emb_init, word_emb_apply = Embedding(vocab_size, embedding_size)
    pos_emb_init, pos_emb_apply = Embedding(max_length, embedding_size)
    type_emb_init, type_emb_apply = Embedding(type_vocab_size, embedding_size)
    ln_init, ln_apply = LayerNormalization(jnp.ones((1, hidden_size)), axis=-1)
    drop_init, drop_apply = Dropout(dropout_rate)

    def model_init(rkey):
        ks = jr.split(rkey, 5)
        return {"word": word_emb_init(ks[0]), "pos": pos_emb_init(ks[1]), "type": type_emb_init(ks[2]), "ln": ln_init(ks[3]), "drop": drop_init(ks[4])}

    def model_apply(params, token_ids, position_ids, segment_ids, inference=False, rkey=None):
        embeddings = word_emb_apply(params["word"], token_ids) + pos_emb_apply(params["pos"], position_ids) + type_emb_apply(params["type"], segment_ids)
        return drop_apply(params["drop"], ln_apply(params["ln"], embeddings), inference=inference, rkey=rkey)
    return model_init, model_apply

def FeedForwardBlock(hidden_size, intermediate_size, dropout_rate):
    mlp_init, mlp_apply = MLP([hidden_size, intermediate_size, hidden_size], activation_name="gelu_approximate")
    ln_init, ln_apply = LayerNormalization(jnp.ones((1, hidden_size)), axis=-1)
    drop_init, drop_apply = Dropout(dropout_rate)

    def model_init(rkey):
        k_m, k_d, k_l = jr.split(rkey, 3)
        return {"mlp": mlp_init(k_m), "drop": drop_init(k_d), "ln": ln_init(k_l)}

    def model_apply(params, x, inference=False, rkey=None):
        output = mlp_apply(params["mlp"], x)
        output = drop_apply(params["drop"], output, inference=inference, rkey=rkey)
        output = output + x
        output = ln_apply(params["ln"], output)
        return output
    return model_init, model_apply

def AttentionBlock(hidden_size, num_heads, dropout_rate, attention_dropout_rate):
    att_init, att_apply = MultiHeadAttention(hidden_size, num_heads, attention_dropout_rate)
    ln_init, ln_apply = LayerNormalization(jnp.ones((1, hidden_size)), axis=-1)
    drop_init, drop_apply = Dropout(dropout_rate)

    def model_init(rkey):
        k_a, k_d, k_l = jr.split(rkey, 3)
        return {"att": att_init(k_a), "drop": drop_init(k_d), "ln": ln_init(k_l)}

    def model_apply(params, x, mask=None, inference=False, rkey=None):
        k1, k2 = (None, None) if rkey is None else jr.split(rkey)
        att_out = att_apply(params["att"], x, x, x, mask=mask, inference=inference, rkey=k1)
        att_out = drop_apply(params["drop"], att_out, inference=inference, rkey=k2)
        att_out = att_out + x
        att_out = ln_apply(params["ln"], att_out)
        return att_out
    return model_init, model_apply

def TransformerLayer(hidden_size, intermediate_size, num_heads, dropout_rate, attention_dropout_rate):
    att_init, att_apply = AttentionBlock(hidden_size, num_heads, dropout_rate, attention_dropout_rate)
    ff_init, ff_apply = FeedForwardBlock(hidden_size, intermediate_size, dropout_rate)

    def model_init(rkey):
        k_a, k_f = jr.split(rkey)
        return {"att": att_init(k_a), "ff": ff_init(k_f)}

    def model_apply(params, x, mask=None, inference=False, rkey=None):
        k1, k2 = (None, None) if rkey is None else jr.split(rkey)
        x = att_apply(params["att"], x, mask=mask, inference=inference, rkey=k1)
        x = ff_apply(params["ff"], x, inference=inference, rkey=k2)
        return x
    return model_init, model_apply

In [6]:
### BERT Encoder & MLM Model

def BertEncoder(config):
    emb_init, emb_apply = EmbedderBlock(config["vocab_size"], config["max_position_embeddings"], config["type_vocab_size"], config["hidden_size"], config["hidden_size"], config["hidden_dropout_prob"])
    layer_init, layer_apply = TransformerLayer(config["hidden_size"], config["intermediate_size"], config["num_attention_heads"], config["hidden_dropout_prob"], config["attention_probs_dropout_prob"])

    def model_init(rkey):
        k_emb, k_layers = jr.split(rkey)
        layers_params = [layer_init(k) for k in jr.split(k_layers, config["num_hidden_layers"])]
        return {"embedder": emb_init(k_emb), "layers": layers_params}

    def model_apply(params, token_ids, position_ids, segment_ids, inference=False, rkey=None):
        k_emb, k_layers = (None, None) if rkey is None else jr.split(rkey)
        x = emb_apply(params["embedder"], token_ids, position_ids, segment_ids, inference=inference, rkey=k_emb)

        pad_mask = (token_ids != 0).astype(jnp.float32)
        attention_mask = (1.0 - pad_mask[None, :]) * -1e9

        layer_keys = [None]*len(params["layers"]) if k_layers is None else jr.split(k_layers, len(params["layers"]))
        for i, layer_p in enumerate(params["layers"]):
            x = layer_apply(layer_p, x, mask=attention_mask, inference=inference, rkey=layer_keys[i])
        return x

    return model_init, model_apply

def BertForMaskedLM_JAX(config):
    encoder_init, encoder_apply = BertEncoder(config)
    transform_init, transform_apply = Linear(config["hidden_size"], config["hidden_size"])
    ln_init, ln_apply = LayerNormalization(jnp.ones((1, config["hidden_size"])), axis=-1)
    decoder_init, decoder_apply = Linear(config["hidden_size"], config["vocab_size"])

    def model_init(rkey):
        k_enc, k_trans, k_ln, k_dec = jr.split(rkey, 4)
        return {"encoder": encoder_init(k_enc), "mlm": {"transform": transform_init(k_trans), "ln": ln_init(k_ln), "decoder": decoder_init(k_dec)}}

    def model_apply(params, token_ids, position_ids, segment_ids, inference=False, rkey=None):
        x = encoder_apply(params["encoder"], token_ids, position_ids, segment_ids, inference=inference, rkey=rkey)
        hidden = ln_apply(params["mlm"]["ln"], jax.nn.gelu(transform_apply(params["mlm"]["transform"], x)))
        return decoder_apply(params["mlm"]["decoder"], hidden)

    return model_init, model_apply

In [7]:
# Nous chargeons les poids que nous avons convertis précédemment (bert_base_pure_params.pkl)
BERT_CONFIG = {
    "vocab_size": 30522, "hidden_size": 768, "num_hidden_layers": 12, "num_attention_heads": 12,
    "hidden_act": "gelu", "intermediate_size": 3072, "hidden_dropout_prob": 0.1,
    "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 512, "type_vocab_size": 2,
}

In [10]:
CHECKPOINTS_DIR = "TP_bert"
os.makedirs(CHECKPOINTS_DIR,exist_ok=True)
PRETRAIN_PARAMS_PATH = os.path.join(CHECKPOINTS_DIR, "bert_base_pure_params.pkl")

In [12]:
def load_and_convert():
    print("Loading BERT-Base MaskedLM from Hugging Face (PyTorch)...")
    pt_model = BertForMaskedLM.from_pretrained("bert-base-uncased")

    print("Initializing JAX model...")
    model_init, model_apply = BertForMaskedLM_JAX(BERT_CONFIG) # Note: Script imports as BertForMaskedLM, aliased?
    # Wait, the import in line 9 is `from bert_pure_jax.prediction.bert_mlm_pure import BertForMaskedLM as BertForMaskedLM_JAX`
    # But I imported `BertForMaskedLM` from transformers in line 7!
    # I need to handle the name collision.

    with jax.default_device(jax.devices("cpu")[0]):
        rkey = jr.PRNGKey(0)
    jax_params = model_init(rkey)

    # Helper to get PT tensor as numpy
    state_dict = pt_model.state_dict()
    def get_pt(name):
        return state_dict[name].numpy()

    print("Converting weights...")

    # Prefix mapping:
    # JAX "encoder" -> PyTorch "bert"

    # --- Embeddings ---
    jax_params["encoder"]["embedder"]["word"]["weight"] = jnp.array(get_pt("bert.embeddings.word_embeddings.weight"))
    jax_params["encoder"]["embedder"]["pos"]["weight"] = jnp.array(get_pt("bert.embeddings.position_embeddings.weight"))
    jax_params["encoder"]["embedder"]["type"]["weight"] = jnp.array(get_pt("bert.embeddings.token_type_embeddings.weight"))

    jax_params["encoder"]["embedder"]["ln"]["gamma"] = jnp.array(get_pt("bert.embeddings.LayerNorm.weight"))
    jax_params["encoder"]["embedder"]["ln"]["beta"] = jnp.array(get_pt("bert.embeddings.LayerNorm.bias"))

    # --- Layers ---
    for i in range(12):
        prefix = f"bert.encoder.layer.{i}"
        layer_params = jax_params["encoder"]["layers"][i]

        # Attention
        layer_params["att"]["att"]["query"]["weight"] = jnp.array(get_pt(f"{prefix}.attention.self.query.weight").T)
        layer_params["att"]["att"]["query"]["bias"] = jnp.array(get_pt(f"{prefix}.attention.self.query.bias"))

        layer_params["att"]["att"]["key"]["weight"] = jnp.array(get_pt(f"{prefix}.attention.self.key.weight").T)
        layer_params["att"]["att"]["key"]["bias"] = jnp.array(get_pt(f"{prefix}.attention.self.key.bias"))

        layer_params["att"]["att"]["value"]["weight"] = jnp.array(get_pt(f"{prefix}.attention.self.value.weight").T)
        layer_params["att"]["att"]["value"]["bias"] = jnp.array(get_pt(f"{prefix}.attention.self.value.bias"))

        layer_params["att"]["att"]["output"]["weight"] = jnp.array(get_pt(f"{prefix}.attention.output.dense.weight").T)
        layer_params["att"]["att"]["output"]["bias"] = jnp.array(get_pt(f"{prefix}.attention.output.dense.bias"))

        layer_params["att"]["ln"]["gamma"] = jnp.array(get_pt(f"{prefix}.attention.output.LayerNorm.weight"))
        layer_params["att"]["ln"]["beta"] = jnp.array(get_pt(f"{prefix}.attention.output.LayerNorm.bias"))

        # MLP
        layer_params["ff"]["mlp"][0] = (
            jnp.array(get_pt(f"{prefix}.intermediate.dense.weight").T),
            jnp.array(get_pt(f"{prefix}.intermediate.dense.bias"))
        )

        layer_params["ff"]["mlp"][1] = (
            jnp.array(get_pt(f"{prefix}.output.dense.weight").T),
            jnp.array(get_pt(f"{prefix}.output.dense.bias"))
        )

        layer_params["ff"]["ln"]["gamma"] = jnp.array(get_pt(f"{prefix}.output.LayerNorm.weight"))
        layer_params["ff"]["ln"]["beta"] = jnp.array(get_pt(f"{prefix}.output.LayerNorm.bias"))

    # --- Pooler ---
    # Check if pooler exists in BertForMaskedLM
    if "bert.pooler.dense.weight" in state_dict:
        jax_params["encoder"]["pooler"]["weight"] = jnp.array(get_pt("bert.pooler.dense.weight").T)
        jax_params["encoder"]["pooler"]["bias"] = jnp.array(get_pt("bert.pooler.dense.bias"))
    else:
        print("Pooler weights not found in BertForMaskedLM, keeping random init (or load BertModel separately if critical).")

    # --- MLM Head ---
    jax_params["mlm"]["transform"]["weight"] = jnp.array(get_pt("cls.predictions.transform.dense.weight").T)
    jax_params["mlm"]["transform"]["bias"] = jnp.array(get_pt("cls.predictions.transform.dense.bias"))

    jax_params["mlm"]["ln"]["gamma"] = jnp.array(get_pt("cls.predictions.transform.LayerNorm.weight"))
    jax_params["mlm"]["ln"]["beta"] = jnp.array(get_pt("cls.predictions.transform.LayerNorm.bias"))

    try:
        decoder_weight = get_pt("cls.predictions.decoder.weight").T
    except:
        print("Decoder weight not found directly, using embeddings (tied weights).")
        decoder_weight = get_pt("bert.embeddings.word_embeddings.weight").T # Prefix changed

    jax_params["mlm"]["decoder"]["weight"] = jnp.array(decoder_weight)
    jax_params["mlm"]["decoder"]["bias"] = jnp.array(get_pt("cls.predictions.bias"))

    print("Conversion complete.")
    save_path = os.path.join(CHECKPOINTS_DIR, "bert_base_pure_params.pkl")
    print(f"Saving to {save_path}...")
    with open(save_path, "wb") as f:
        pickle.dump(jax_params, f)
    print("Done.")
load_and_convert()

In [13]:
### Chargement des Poids pré-entraînés
def load_weights(path):
    print(f"Chargement depuis {path}...")
    with open(path, "rb") as f:
        params = pickle.load(f)
    print("Poids chargés.")
    return params


params = load_weights(PRETRAIN_PARAMS_PATH)

# Section 3: Entraînement (Fine-tuning MLM)

Nous définissons la fonction de masquage : on remplace 15% des tokens par `[MASK]` (ID 103). Le modèle doit prédire le token original.

In [15]:
def mask_tokens(inputs, key, mask_prob=0.15):
    mask_key, replace_key = jax.random.split(key)
    probability_matrix = jax.random.uniform(mask_key, inputs.shape)
    # Ne pas masquer les tokens spéciaux (0=PAD, 101=CLS, 102=SEP)
    special_tokens_mask = (inputs == 0) | (inputs == 101) | (inputs == 102)
    probability_matrix = jnp.where(special_tokens_mask, 0.0, probability_matrix)

    masked_indices = probability_matrix < mask_prob
    labels = jnp.where(masked_indices, inputs, -100) # -100 = ignorer dans la loss

    # Remplacer par [MASK] (103)
    inputs_masked = jnp.where(masked_indices, 103, inputs)
    return inputs_masked, labels

# Initialisation Optimizer & Model
model_init, model_apply = BertForMaskedLM_JAX(BERT_CONFIG)
optimizer = optax.adamw(learning_rate=2e-5)
opt_state = optimizer.init(params)

@jax.jit
def train_step(params, opt_state, batch_inputs, rkey):
    # Masquage dynamique
    keys = jr.split(rkey, batch_inputs.shape[0])

    def step_fn(p, inputs, k):
        masked_input, label = mask_tokens(inputs, k)
        # Forward
        seq_len = masked_input.shape[0]
        pos_ids = jnp.arange(seq_len)
        seg_ids = jnp.zeros_like(masked_input)
        logits = model_apply(p, masked_input, pos_ids, seg_ids, inference=False, rkey=k)

        # Loss
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, label)
        mask = (label != -100).astype(jnp.float32)
        return (loss * mask).sum() / (mask.sum() + 1e-9)

    # Batching via VMAP
    # On vmap mask_tokens aussi ou on le fait avant?
    # Pour simplifier, on le fait ici inside vmap ou outside? Outside est mieux pour debug.
    # Faisons simple : mask outside.
    pass

    # Correction pour JAX vmap flow
    step_key, val_key = jr.split(rkey)
    # Vmap mask
    masked_inputs, labels = jax.vmap(mask_tokens)(batch_inputs, keys)

    def loss_fn(p, x, y, k_batch):
        def single_loss(p, inp, lbl, k):
            seq_len = inp.shape[0]
            logits = model_apply(p, inp, jnp.arange(seq_len), jnp.zeros_like(inp), inference=False, rkey=k)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, lbl)
            mask = (lbl != -100).astype(jnp.float32)
            return (loss * mask).sum() / (mask.sum() + 1e-9)
        return jax.vmap(single_loss, in_axes=(None, 0, 0, 0))(p, x, y, k_batch).mean()

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(params, masked_inputs, labels, keys)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

print("Fonction d'entraînement compilée.")

In [16]:
# Boucle d'entraînement (Demo)
print("Démarrage de l'entraînement (Demo - 10 steps)...")
rkey = jr.PRNGKey(42)
batch_size = 4

for i in range(10):
    rkey, step_key = jr.split(rkey)
    # Prend un batch fictif (ici toujours le même pour demo)
    batch = input_ids[:batch_size]
    params, opt_state, loss = train_step(params, opt_state, batch, step_key)
    print(f"Step {i+1}, Loss: {loss:.4f}")

# Section 4: Évaluation

On teste le modèle : on masque un seul mot manuellement et on regarde si BERT le retrouve.

In [17]:
def predict_masked_sentence(text, model_params):
    # 1. Tokenisation
    encoded = tokenizer(text, return_tensors="np")
    input_ids_ = jnp.array(encoded["input_ids"])
    seq_len = input_ids_.shape[1]

    # 2. Trouver le token [MASK]
    mask_token_id = tokenizer.mask_token_id
    mask_pos = np.where(input_ids_ == mask_token_id)[1]
    if len(mask_pos) == 0:
        print("Pas de token [MASK] trouvé.")
        return
    mask_idx = mask_pos[0]

    # 3. Inférence
    logits = model_apply(model_params, input_ids_[0], jnp.arange(seq_len), jnp.zeros(seq_len, dtype=int), inference=True, rkey=None)

    # 4. Décodage
    mask_logits = logits[mask_idx]
    # Ignorer les tokens spéciaux si besoin, ou juste argmax
    top_5 = np.argsort(mask_logits)[-5:][::-1]

    print(f"Phrase : {text}")
    print("Prédictions :")
    for token_id in top_5:
        print(f"- {tokenizer.decode([token_id])} ({float(mask_logits[token_id]):.2f})")

# Test
predict_masked_sentence("The capital of France is [MASK].", params)
predict_masked_sentence("The doctor said the patient needs [MASK].", params)