In [None]:
import re
import html
from IPython.display import HTML, display
import numpy

def convert_clean_text(clean_text, k=1, tokens_left=30, tokens_right=5):
    """
    Wraps the top k scoring tokens in `<< >>` within the clean_text.
    Also, keeps tokens_left tokens before and tokens_right tokens after each wrapped token.
    If there are fewer than k non-zero tokens, wraps as many as there are.
    
    Parameters:
    - clean_text (str): The input string containing tokens and their scores, separated by " | ".
    - k (int): The number of top scoring tokens to wrap. Defaults to 5.
    - tokens_left (int): Number of tokens to keep before each top token. Defaults to 30.
    - tokens_right (int): Number of tokens to keep after each top token. Defaults to 5.
    
    Returns:
    - str: The modified text with top k tokens wrapped in `<< >>` and surrounding context.
    """
    # Split the clean text on the "|" separator
    token_score_pairs = clean_text.split(" | ")

    # Remove the first token if present
    if token_score_pairs:
        token_score_pairs = token_score_pairs[1:]

    # Initialize a list to hold tuples of (token, score)
    tokens_with_scores = []

    # Define regex to capture tokens with scores
    token_score_pattern = re.compile(r"^(.+?) \((\d+\.\d+)\)$")

    for token_score in token_score_pairs:
        match = token_score_pattern.match(token_score.strip())
        if match:
            token = match.group(1)
            score = float(match.group(2))
            tokens_with_scores.append((token, score))
        else:
            # Handle cases where score is zero or absent
            token = token_score.split(' (')[0].strip()
            tokens_with_scores.append((token, 0.0))

    # Sort tokens by score in descending order
    sorted_tokens = sorted(tokens_with_scores, key=lambda x: x[1], reverse=True)

    # Select top k tokens with non-zero scores
    top_k_tokens = [token for token, score in sorted_tokens if score > 0][:k]

    # Find all indices of top k tokens
    top_k_indices = [i for i, (token, score) in enumerate(tokens_with_scores) if token in top_k_tokens and score >0]

    # Define windows around each top token
    windows = []
    for idx in top_k_indices:
        start = max(0, idx - tokens_left)
        end = min(len(tokens_with_scores) - 1, idx + tokens_right)
        windows.append((start, end))

    # Merge overlapping windows
    merged_windows = []
    for window in sorted(windows, key=lambda x: x[0]):
        if not merged_windows:
            merged_windows.append(window)
        else:
            last_start, last_end = merged_windows[-1]
            current_start, current_end = window
            if current_start <= last_end + 1:
                # Overlapping or adjacent windows, merge them
                merged_windows[-1] = (last_start, max(last_end, current_end))
            else:
                merged_windows.append(window)

    # Collect all unique indices within the merged windows
    selected_indices = set()
    for start, end in merged_windows:
        selected_indices.update(range(start, end + 1))

    # Create the converted tokens list with wrapping
    converted_tokens = []
    for i, (token, score) in enumerate(tokens_with_scores):
        if i in selected_indices:
            if token in top_k_tokens and score > 0:
                token = f"<<{token}>>"
            converted_tokens.append(token)
        # Else, skip tokens outside the selected windows

    # Join the converted tokens into a single string
    converted_text = " ".join(converted_tokens)
    return converted_text

def highlight_scores_in_html(
    token_strs,
    scores,
    seq_idx,
    max_color="#ff8c00",
    zero_color="#ffffff",
    show_score=True,
):
    if len(token_strs) != len(scores):
        print("Length mismatch between tokens and scores")
        return "", ""
    scores_min = min(scores)
    scores_max = max(scores)
    scores_normalized = (np.array(scores) - scores_min) / (scores_max - scores_min)
    max_color_vec = np.array(
        [int(max_color[1:3], 16), int(max_color[3:5], 16), int(max_color[5:7], 16)]
    )
    zero_color_vec = np.array(
        [int(zero_color[1:3], 16), int(zero_color[3:5], 16), int(zero_color[5:7], 16)]
    )
    color_vecs = np.einsum("i, j -> ij", scores_normalized, max_color_vec) + np.einsum(
        "i, j -> ij", 1 - scores_normalized, zero_color_vec
    )
    color_strs = [f"#{int(x[0]):02x}{int(x[1]):02x}{int(x[2]):02x}" for x in color_vecs]
    if show_score:
        tokens_html = "".join(
            [
                f"""<span class='token' style='background-color: {color_strs[i]}'>{html.escape(token_str)}<span class='feature_val'> ({scores[i]:.2f})</span></span>"""
                for i, token_str in enumerate(token_strs)
            ]
        )
        clean_text = " | ".join(
            [f"{token_str} ({scores[i]:.2f})" for i, token_str in enumerate(token_strs)]
        )
    else:
        tokens_html = "".join(
            [
                f"""<span class='token' style='background-color: {color_strs[i]}'>{html.escape(token_str)}</span>"""
                for i, token_str in enumerate(token_strs)
            ]
        )
        clean_text = " | ".join(token_strs)
    head = """
    <style>
        span.token {
            font-family: monospace;
            border-style: solid;
            border-width: 1px;
            border-color: #dddddd;
        }
    </style>
    """
    return head + tokens_html, convert_clean_text(clean_text)

In [None]:
examples_html = []
examples_clean_text = []
j = 14
for i in range(k):
    #for j in range(100):
    try:
        example_html, clean_text = highlight_scores_in_html(top_k_tokens_str[i], top_k_scores_per_seq[i][j], top_k_seq_indices[i], show_score=True)
        examples_html.append(example_html)
        examples_clean_text.append(clean_text)
        print(f"Got one! i={i}, j={j}")
    except Exception as e:
        continue

In [None]:
for example in examples_html:
    display(HTML(example))