If you are interested in learning more about the applications of submodularity/diminishing returns, check out the blog post here https://jina.ai/news/submodular-optimization-for-diverse-query-generation-in-deepresearch

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from transformers import AutoModel, AutoProcessor, AutoTokenizer
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import heapq

# Initialize the model (this may take a moment)
print("Loading model...")
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v4", trust_remote_code=True, torch_dtype=torch.float16)
model.to("cuda" if torch.cuda.is_available() else "cpu")
preprocessor = AutoProcessor.from_pretrained('jinaai/jina-embeddings-v4', trust_remote_code=True)
print("Model loaded successfully!")

In [None]:
def compute_marginal_gain_diversity(new_idx, selected, embeddings, similarity_matrix):
    n = similarity_matrix.shape[0]

    if not selected:
        return np.sum(similarity_matrix[new_idx])

    # Vectorized computation of current coverage
    current_coverage = np.max(similarity_matrix[selected], axis=0)

    # Compute new coverage and marginal gain
    new_coverage = np.maximum(current_coverage, similarity_matrix[new_idx])
    return np.sum(new_coverage - current_coverage)


def lazy_greedy_token_selection(embeddings, k):
    n = len(embeddings)
    selected = []
    remaining = set(range(n))

    similarity_matrix = cosine_similarity(embeddings)

    pq = []
    for i in range(n):
        gain = compute_marginal_gain_diversity(i, [], embeddings, similarity_matrix)
        heapq.heappush(pq, (-gain, 0, i))

    for iteration in range(k):
        while pq:
            neg_gain, last_updated, best_idx = heapq.heappop(pq)

            if best_idx not in remaining:
                continue

            if last_updated == iteration:
                selected.append(best_idx)
                remaining.remove(best_idx)
                break

            current_gain = compute_marginal_gain_diversity(best_idx, selected, embeddings, similarity_matrix)
            heapq.heappush(pq, (-current_gain, iteration, best_idx))

    return selected

# Global variables to store processed data
current_embeddings = None
current_input_ids = None
current_text = ""
current_similarity_matrix = None  # Cache the similarity matrix too
current_token_strings = None  # Cache the token strings too
current_sentences = None  # Cache the sentences for sentence mode

def split_by_punctuation(text, punctuation_chars):
    """Split text by punctuation characters, keeping the punctuation with the sentences"""
    import re
    # Escape special regex characters for the lookbehind
    escaped_chars = re.escape(punctuation_chars)
    # Use positive lookbehind to split after punctuation while keeping it
    sentences = re.split(f'(?<=[{escaped_chars}])', text)
    # Filter out empty sentences and strip whitespace
    sentences = [s.strip() for s in sentences if s.strip()]
    return sentences

def process_text(text):
    """Process the input text and generate embeddings"""
    global current_embeddings, current_input_ids, current_text, current_similarity_matrix, current_token_strings, current_sentences

    if not text.strip():
        return None, None

    # Only recompute if text has actually changed
    if current_text == text.strip():
        return current_embeddings, current_input_ids

    current_text = text.strip()

    if sentence_mode_toggle.value:
        # Sentence mode
        sentences = split_by_punctuation(current_text, punctuation_input.value)
        current_sentences = sentences

        if not sentences:
            return None, None

        # Generate sentence-level embeddings
        mv_embed = model.encode_text(
            texts=sentences,
            task="text-matching",
            prompt_name="query",
            return_numpy=True
        )

        current_embeddings = mv_embed
        current_input_ids = None  # Not used in sentence mode
        current_token_strings = sentences  # Use sentences as "tokens" for display

    else:
        # Token mode (original behavior)
        # Generate multi-vector embeddings
        mv_embed = model.encode_text(
            texts=[current_text],
            task="text-matching",
            return_multivector=True,
            prompt_name="query"
        )[0][2:]  # Skip first 2 tokens

        # Get input IDs
        preprocessor_results = preprocessor.process_texts(
            texts=[current_text],
            prefix="query"
        )
        input_ids = preprocessor_results["input_ids"][0].tolist()[2:]  # Skip first 2 tokens

        current_embeddings = mv_embed.cpu().numpy()
        current_input_ids = input_ids
        current_sentences = None

        # Pre-compute token strings to avoid repeated tokenizer calls
        current_token_strings = preprocessor.tokenizer.convert_ids_to_tokens(current_input_ids)

    # Pre-compute similarity matrix since it's expensive and only depends on embeddings
    current_similarity_matrix = cosine_similarity(current_embeddings)

    return current_embeddings, current_input_ids

def update_display(k_value):
    """Update the display with selected tokens or sentences"""
    if current_embeddings is None:
        return

    # Ensure k_value doesn't exceed available items
    max_items = len(current_sentences) if sentence_mode_toggle.value else len(current_input_ids) if current_input_ids else 0
    k_value = min(k_value, max_items)

    if k_value <= 0:
        combined_text_output.value = ""
        with output_area:
            clear_output()
            print("❌ No items available for selection")
        return

    # Select diverse items using cached similarity matrix
    selected_indices = lazy_greedy_token_selection_cached(current_embeddings, k_value)
    selected_indices_sorted = sorted(selected_indices)

    if sentence_mode_toggle.value:
        # Sentence mode - add bounds checking
        valid_indices = [i for i in selected_indices_sorted if i < len(current_sentences)]
        selected_sentences = [current_sentences[i] for i in valid_indices]
        combined_text = ' '.join(selected_sentences)

        # Update debug output
        with output_area:
            clear_output()
            print(f"📊 Selected {len(valid_indices)} out of {len(current_sentences)} sentences")
            print(f"📝 Selected sentences:")
            for i, sentence in enumerate(selected_sentences):
                print(f"  {i+1}. {sentence}")
            print(f"📍 Sentence positions: {valid_indices}")
    else:
        # Token mode (original behavior) - add bounds checking
        if current_input_ids is None:
            return
        valid_indices = [i for i in selected_indices_sorted if i < len(current_input_ids)]
        selected_token_ids = [current_input_ids[i] for i in valid_indices]
        selected_strings = [current_token_strings[i] for i in valid_indices]
        combined_text = preprocessor.tokenizer.convert_tokens_to_string(selected_strings)

        # Update debug output
        with output_area:
            clear_output()
            print(f"📊 Selected {len(valid_indices)} out of {len(current_input_ids)} tokens")
            print(f"🔤 Selected tokens: {selected_strings}")
            print(f"📍 Token positions: {valid_indices}")

    # Update the combined text output box
    combined_text_output.value = combined_text

def lazy_greedy_token_selection_cached(embeddings, k):
    """Optimized version that uses pre-computed similarity matrix"""
    n = len(embeddings)
    selected = []
    remaining = set(range(n))

    # Use cached similarity matrix
    similarity_matrix = current_similarity_matrix

    pq = []
    for i in range(n):
        gain = compute_marginal_gain_diversity(i, [], embeddings, similarity_matrix)
        heapq.heappush(pq, (-gain, 0, i))

    for iteration in range(k):
        while pq:
            neg_gain, last_updated, best_idx = heapq.heappop(pq)

            if best_idx not in remaining:
                continue

            if last_updated == iteration:
                selected.append(best_idx)
                remaining.remove(best_idx)
                break

            current_gain = compute_marginal_gain_diversity(best_idx, selected, embeddings, similarity_matrix)
            heapq.heappush(pq, (-current_gain, iteration, best_idx))

    return selected

def on_text_change(change):
    """Handle text input changes"""
    text = change['new']
    if text.strip():
        try:
            process_text(text)
            # Update slider max value
            max_items = len(current_sentences) if sentence_mode_toggle.value else len(current_input_ids) if current_input_ids else 0
            if max_items > 0:
                k_slider.max = max_items
                k_slider.value = min(k_slider.value, max_items)
                update_display(k_slider.value)
        except Exception as e:
            with output_area:
                clear_output()
                print(f"❌ Error processing text: {str(e)}")

def on_slider_change(change):
    """Handle slider value changes - now much faster!"""
    if current_embeddings is not None:
        update_display(change['new'])

def on_mode_change(change):
    """Handle mode toggle changes"""
    global current_text
    if current_text:
        try:
            # Force reprocessing by clearing current_text
            temp_text = current_text
            current_text = ""
            process_text(temp_text)

            # Update slider max value and constrain current value
            max_items = len(current_sentences) if sentence_mode_toggle.value else len(current_input_ids) if current_input_ids else 0
            if max_items > 0:
                k_slider.max = max_items
                k_slider.value = min(k_slider.value, max_items)
                # Force update display with the new constrained value
                update_display(k_slider.value)
            else:
                with output_area:
                    clear_output()
                    print("❌ No items found after mode switch")
        except Exception as e:
            with output_area:
                clear_output()
                print(f"❌ Error switching mode: {str(e)}")
                import traceback
                traceback.print_exc()

def on_punctuation_change(change):
    """Handle punctuation input changes"""
    if sentence_mode_toggle.value and current_text:
        on_mode_change(None)  # Reprocess text with new punctuation

# Create UI components
text_input = widgets.Textarea(
    value="""Founded in 2020, Jina AI is a leading search AI company. Our Search Foundation platform combines Embeddings, Rerankers, and Small Language Models to help businesses build reliable and high-quality GenAI and multimodal search applications.""",
    placeholder='Enter your text here...',
    description='Input Text:',
    layout=widgets.Layout(width='100%', height='120px')
)

# New UI components for sentence mode
sentence_mode_toggle = widgets.Checkbox(
    value=False,
    description='Sentence Selection Mode',
    style={'description_width': 'initial'}
)

punctuation_input = widgets.Text(
    value=',.!?，。！？',
    placeholder='Punctuation characters for splitting...',
    description='Punctuation:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='300px')
)

k_slider = widgets.IntSlider(
    value=30,
    min=1,
    max=100,
    step=1,
    description='Items (k):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='100%')
)

# Add a separate text area for the combined text output
combined_text_output = widgets.Textarea(
    value='',
    placeholder='Selected text will appear here...',
    description='Selected:',
    layout=widgets.Layout(width='100%', height='120px'),
    disabled=False  # Allow users to copy/edit the text
)

output_area = widgets.Output()

# Connect event handlers
text_input.observe(on_text_change, names='value')
k_slider.observe(on_slider_change, names='value')
sentence_mode_toggle.observe(on_mode_change, names='value')
punctuation_input.observe(on_punctuation_change, names='value')

# Create the UI layout
ui = widgets.VBox([
    widgets.HTML("<h2>Submodular Optimization for Token/Sentence Selection</h2>"),
    widgets.HTML("<p>Enter text below and use the controls to select tokens or sentences:</p>"),
    text_input,
    widgets.HBox([sentence_mode_toggle, punctuation_input]),
    k_slider,
    combined_text_output,
    widgets.HTML("<h3>📋 Debug Info:</h3>"),
    output_area
])

# Process initial text
if text_input.value.strip():
    process_text(text_input.value)
    max_items = len(current_sentences) if sentence_mode_toggle.value else len(current_input_ids) if current_input_ids else 0
    if max_items > 0:
        k_slider.max = max_items
        k_slider.value = min(k_slider.value, max_items)
        update_display(k_slider.value)

# Display the UI
display(ui)