# Import Required Libraries
Import torch and the necessary classes from the transformers library.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModel

# Load Tokenizer and Model
Load the AutoTokenizer and AutoModel for 'dicta-il/dictabert-large-char-menaked'. Set the model to evaluation mode.

In [2]:
model_name = "dicta-il/dictabert-large-char-menaked"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.eval()

BertForDiacritization(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(1024, 1024, padding_idx=0)
      (position_embeddings): Embedding(2048, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), e

In [3]:
def predict_with_probs(sentences, tokenizer, model, mark_matres_lectionis=None, top_k=5):
    # Use model.predict for decoded output
    decoded = model.predict(sentences, tokenizer, mark_matres_lectionis=mark_matres_lectionis)
    
    # Manual forward pass
    inputs = tokenizer(sentences, padding='longest', truncation=True,
                       return_tensors='pt', return_offsets_mapping=True)
    offset_mapping = inputs.pop('offset_mapping')
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.forward(**inputs, return_dict=True)
    logits = outputs.logits  # MenakedLogitsOutput
    nikud_logits = logits.nikud_logits  # [batch, seq_len, num_nikud]
    shin_logits = logits.shin_logits    # [batch, seq_len, num_shin]
    
    results = []
    for sent_idx, (sentence, offsets) in enumerate(zip(sentences, offset_mapping)):
        # For each character token, collect top-k probabilities
        sent_data = {'decoded': decoded[sent_idx], 'chars': []}
        probs = torch.softmax(nikud_logits[sent_idx], dim=-1)  # probabilities
        
        for i, (start, end) in enumerate(offsets):
            if end - start != 1:
                continue
            char = sentence[start:end]
            dist = probs[i]
            top_p, top_ids = torch.topk(dist, top_k)
            sent_data['chars'].append({
                'char': char,
                'predictions': {
                    model.config.nikud_classes[label_id.item()]: float(p.item())
                    for p, label_id in zip(top_p, top_ids)
                }
            })
        results.append(sent_data)
    return results


In [10]:
res = predict_with_probs(['מסכת עָרְלָה היא המסכת העשירית בסדר זרעים'], tokenizer, model, top_k=3)
import pprint; pprint.pprint(res[0])

{'chars': [{'char': 'מ',
            'predictions': {'ְ': 0.0015557006699964404,
                            'ַ': 0.9960278272628784,
                            'ָ': 0.0013132020831108093}},
           {'char': 'ס',
            'predictions': {'ֶ': 0.000505781383253634,
                            'ֵּ': 0.001638155896216631,
                            'ֶּ': 0.9971392154693604}},
           {'char': 'כ',
            'predictions': {'ֶ': 0.999430239200592,
                            'ֶּ': 0.00017696806753519922,
                            'ַּ': 0.0001835094444686547}},
           {'char': 'ת',
            'predictions': {'': 0.9999827146530151,
                            'ּ': 2.610164528960013e-06,
                            'ְּ': 4.128542968828697e-06}},
           {'char': ' ',
            'predictions': {'': 0.4220789074897766,
                            'ֵ': 0.13712063431739807,
                            'ָ': 0.24417512118816376}},
           {'char': 'ע',
            'predi

In [5]:
from math import log2

def nikud_uncertainty(text, model, tokenizer, 
                      top_k=5, 
                      entropy_threshold=1.0, 
                      margin_threshold=0.2, 
                      maxprob_threshold=0.7):
    """
    Identify ambiguous characters in the text according to nikud predictions.
    """
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", return_offsets_mapping=True, truncation=True)
    offsets = inputs.pop("offset_mapping")[0]

    with torch.no_grad():
        outputs = model(**inputs)   # MenakedOutput
        probs = torch.softmax(outputs.logits.nikud_logits[0], dim=-1)  # shape [seq_len, num_nikud_classes]

    id2label = model.config.nikud_classes  # list of all nikud symbols

    ambiguous = []
    for i, (start, end) in enumerate(offsets):
        if end - start != 1:  # skip special tokens / padding
            continue
        char = text[start:end]
        dist = probs[i]

        # Sort probabilities
        sorted_probs, sorted_ids = torch.sort(dist, descending=True)
        p1, p2 = sorted_probs[0].item(), sorted_probs[1].item()

        # --- criteria ---
        entropy = -sum(p.item() * log2(p.item()) for p in dist if p.item() > 0)
        margin = p1 - p2
        max_prob = p1

        # decide ambiguity
        is_ambig = (entropy > entropy_threshold) or (margin < margin_threshold) or (max_prob < maxprob_threshold)

        if is_ambig:
            ambiguous.append({
                "char": char,
                "position": (start, end),
                "entropy": entropy,
                "margin": margin,
                "max_prob": max_prob,
                "top_candidates": [
                    (id2label[sorted_ids[j].item()], sorted_probs[j].item())
                    for j in range(min(top_k, len(sorted_ids)))
                ]
            })
    return ambiguous

In [6]:
# Example usage
ambig = nikud_uncertainty("שלום עולם", model, tokenizer)
for a in ambig:
    print(a)
    print(a["position"][0])

{'char': ' ', 'position': (tensor(4), tensor(5)), 'entropy': 2.520866388614735, 'margin': 0.10133625566959381, 'max_prob': 0.3462553322315216, 'top_candidates': [('ָ', 0.3462553322315216), ('ֵ', 0.2449190765619278), ('', 0.2309339940547943), ('ִ', 0.0470949187874794), ('<MAT_LECT>', 0.0413050502538681)]}
tensor(4)


In [7]:
%pip install matplotlib seaborn

Note: you may need to restart the kernel to use updated packages.


In [8]:
from IPython.display import HTML, display
import matplotlib
import math

def colorize_text_by_certainty(text, model, tokenizer,
                               certainty_metric="max_prob",
                               combine="max",   # "max" or "avg"
                               scale="linear",
                               low_conf=0.6, high_conf=0.95):
    """
    Display text with characters color-coded by combined certainty
    from both nikud_logits and shin_logits.
    
    - certainty_metric: "max_prob" or "entropy"
    - combine: how to combine nikud & shin uncertainty ("max" or "avg")
    - scale: "linear", "sqrt", or "log"
    """
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", return_offsets_mapping=True, truncation=True)
    offsets = inputs.pop("offset_mapping")[0]

    with torch.no_grad():
        outputs = model(**inputs)
        nikud_probs = torch.softmax(outputs.logits.nikud_logits[0], dim=-1)
        shin_probs  = torch.softmax(outputs.logits.shin_logits[0], dim=-1)

    html_chars = []

    for i, (start, end) in enumerate(offsets):
        if end - start != 1:
            continue
        char = text[start:end]
        dist = nikud_probs[i]

        # ---- nikud certainty ----
        if certainty_metric == "max_prob":
            nikud_conf = dist.max().item()
        elif certainty_metric == "entropy":
            entropy = -sum(p.item() * math.log2(p.item()) for p in dist if p.item() > 0)
            nikud_conf = 1 - entropy / math.log2(len(dist))
        else:
            raise ValueError("Unknown certainty metric")

        # ---- shin certainty (only for ש) ----
        if char == "ש":
            s_probs = shin_probs[i]
            if certainty_metric == "max_prob":
                shin_conf = s_probs.max().item()
            elif certainty_metric == "entropy":
                entropy = -sum(p.item() * math.log2(p.item()) for p in s_probs if p.item() > 0)
                shin_conf = 1 - entropy / math.log2(len(s_probs))
        else:
            shin_conf = None

        # ---- combine uncertainties ----
        if shin_conf is not None:
            if combine == "max":
                conf = min(nikud_conf, shin_conf)  # lower = less certain
            elif combine == "avg":
                conf = (nikud_conf + shin_conf) / 2
        else:
            conf = nikud_conf

        # Normalize into [0,1]
        norm = (conf - low_conf) / (high_conf - low_conf)
        norm = min(max(norm, 0.0), 1.0)

        # Scaling
        if scale == "sqrt":
            norm = norm**0.5
        elif scale == "log":
            norm = (math.log1p(norm * 9) / math.log1p(9)) if norm > 0 else 0

        # Bright colormap: yellow → orange → red
        cmap = matplotlib.cm.get_cmap("YlOrRd")
        rgba = cmap(1 - norm)  # invert so red = uncertain
        color = matplotlib.colors.rgb2hex(rgba)

        html_chars.append(f"<span style='color:{color}'>{char}</span>")

    display(HTML("".join(html_chars)))


# Example usage:
colorize_text_by_certainty("דה הוושי, שהיה כימאי פיזיקלי הונגרי ממוצא יהודי,", model, tokenizer,
                           certainty_metric="max_prob",
                           scale="sqrt", low_conf=0.6, high_conf=0.95)
colorize_text_by_certainty("שמחה האח חבר טוב של כדרלעומר", model, tokenizer,
                           certainty_metric="max_prob",
                           scale="sqrt", low_conf=0.6, high_conf=0.95)


  cmap = matplotlib.cm.get_cmap("YlOrRd")
