In [1]:
import os

os.environ["JUPYTER_CONFIG_DIR"] = "/home/morg/students/gottesman3/.jupyter"
os.environ["JUPYTER_DATA_DIR"] = "/home/morg/students/gottesman3/.local/share/jupyter"
os.environ["JUPYTER_RUNTIME_DIR"] = "/home/morg/students/gottesman3/.local/share/jupyter/runtime"

# Then import widgets etc.
import ipywidgets as widgets
from IPython.display import display


# Refined LMBD 
## For retrieving metadata about QIDs such as title and url.

In [2]:
import logging
import os
import warnings
from typing import List, Any, Dict, Mapping, Iterable
from typing import TypeVar

import lmdb
import ujson as json
from tqdm import tqdm


K = TypeVar('K')
V = TypeVar('V')


def batch_items(iterable: Iterable[Any], n: int = 1):
    """
    Batches an iterables by yielding lists of length n. Final batch length may be less than n.
    :param iterable: any iterable
    :param n: batch size (final batch may be truncated)
    """
    current_batch = []
    for item in iterable:
        current_batch.append(item)
        if len(current_batch) == n:
            yield current_batch
            current_batch = []
    if current_batch:
        yield current_batch


class LmdbImmutableDict(Mapping[K, V]):
    def __iter__(self):
        NotImplementedError()

    def __getitem__(self, key: K) -> V:
        if not key:
            raise KeyError(key)
        with self.env.begin() as txn:
            value = txn.get(self.encode(key))
        if value is None:
            raise KeyError(key)
        return self.decode(value)

    def encode(self, key: K) -> bytes:
        try:
            return key.encode("utf-8")
        except UnicodeEncodeError as err:
            warnings.warn(f'Unable to encode key {key}, err: {err}')

    def decode(self, value: bytes) -> Dict[Any, Any]:
        return json.loads(value.decode("utf-8"))

    def __len__(self) -> int:
        with self.env.begin() as txn:
            return txn.stat()["entries"]

    def __init__(self, path: str, write_mode: bool = False):
        if write_mode:
            self.path = f'{path}.incomplete'
            self.env = lmdb.open(self.path, max_dbs=1, readonly=False, create=True, writemap=True,
                                 subdir=False, map_size=1099511627776 * 2,
                                 meminit=False, map_async=True, mode=0o755,
                                 lock=False)
        else:
            self.path = path
            self.env = lmdb.open(self.path, max_dbs=1, readonly=True, create=True, writemap=False,
                                 subdir=False, map_size=1099511627776 * 2,
                                 meminit=False, map_async=True, mode=0o755,
                                 lock=False)

    def __contains__(self, key: K) -> bool:
        if not key:
            return False
        with self.env.begin() as txn:
            encoded_key = self.encode(key)
            value = txn.get(encoded_key) if encoded_key else None
        return value is not None

    def close(self) -> None:
        self.env.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def get(self, key: K, default_value=None) -> V:
        with self.env.begin() as txn:
            value = txn.get(self.encode(key))
        if value is None:
            return default_value
        return self.decode(value)

    def put(self, key: K, value: V):
        if key is not None and value is not None:
            with self.env.begin(write=True) as txn:
                txn.put(key=key.encode(), value=json.dumps(value).encode())

    def put_batch(self, keys: List[K], values: List[V]):
        with self.env.begin(write=True) as txn:
            for key, value in zip(keys, values):
                try:
                    txn.put(key=key.encode(), value=json.dumps(value).encode())
                except lmdb.Error as err:
                    logging.debug(f'skipping {key}, error: {err}')

    def write_to_compacted_file(self):
        """
        Writes memmap-based data structure to disk in compacted format
        and deletes original over-allocated file.
        Only call this method when this object is finished with.
        """
        self.env.copy(path=self.path.replace('.incomplete', ''), compact=True)
        os.remove(self.path)

    @classmethod
    def from_dict(cls, input_dict: Dict[str, Any], output_file_path: str) -> 'LmdbImmutableDict[Any, Any]':
        if os.path.exists(output_file_path):
            print(f'Skipping conversion as {output_file_path} already exists.')
        else:
            output_lmdb_dict = LmdbImmutableDict(output_file_path,
                                                 write_mode=True)
            for batch in tqdm(list(batch_items(input_dict.items(), n=250000)),
                              desc=f'Writing {output_file_path}'):
                keys, values = zip(*batch)
                output_lmdb_dict.put_batch(keys=keys, values=values)
            output_lmdb_dict.write_to_compacted_file()
        return cls(path=output_file_path, write_mode=False)


In [3]:
qcode_to_wiki = LmdbImmutableDict(path="/home/morg/dataset/refined/organised_data_dir/wikidata_data/qcode_to_wiki.lmdb", write_mode=False)

# Sample docs

In [4]:
import json
from tqdm import tqdm

gpu_id = 6
output_path = f"/home/morg/dataset/maverick/maverick_{gpu_id}.json"

long_documents = []
short_documents = []
max_chars = 50000
start_index = 5114886

def stream_ndjson(file_path, start):
    with open(file_path, 'r') as f:
        for line in f:
            line = json.loads(line)
            i = line.get("line_index")
            if i >= start:
                yield (i, line)

for line_index, doc in tqdm(stream_ndjson(output_path, start_index)):
    if len(doc["text"]) >= max_chars and len(long_documents) < 5:
        long_documents.append(doc)
    elif len(doc["text"]) < max_chars and len(short_documents) < 5:
        short_documents.append(doc)
    if len(short_documents) == 5 and len(long_documents) == 5:
        break

0it [00:00, ?it/s]

7696it [00:03, 2016.58it/s]


In [33]:
doc = long_documents[0]
hyperlinks = doc["hyperlinks_clean"]
coref = doc["coref"]
entity_linking = doc["entity_linking"]
title = doc["title"]

# Merge metadata into clusters and score dominant entities

In [39]:
from collections import defaultdict
import re

def is_pronoun(text):
    return re.fullmatch(r"\b(?:he|she|it|they|we|i|you|him|her|us|them|me|my|your|his|their|our|its|mine|yours|hers|ours|theirs)\b", text.lower()) is not None

def enrich_coref_clusters(coref, entity_linking, hyperlinks):
    from collections import defaultdict

    enriched_clusters = defaultdict(list)

    for cluster_id, cluster in enumerate(coref["clusters_char_offsets"]):
        for i, span in enumerate(cluster):
            span_start, span_end = span
            span_text = coref["clusters_char_text"][cluster_id][i]
            span_length = span_end - span_start

            span_entry = {
                "start": span_start,
                "end": span_end,
                "coref_text": span_text,
                "entities": [],
                "links": [],
            }

            for mention in entity_linking:
                mention_start = mention["start"]
                mention_end = mention_start + mention["ln"]

                if not (mention_end <= span_start or mention_start >= span_end):
                    entity_id = mention.get("predicted_entity", {}).get("wikidata_entity_id")
                    if entity_id:
                        overlap_start = max(span_start, mention_start)
                        overlap_end = min(span_end, mention_end)
                        overlap_length = overlap_end - overlap_start
                        coverage_ratio = overlap_length / span_length if span_length > 0 else 0.0

                        span_entry["entities"].append({
                            "id": entity_id,
                            "el_text": mention["text"],
                            "start": mention_start,
                            "end": mention_end,
                            "coverage_ratio": coverage_ratio,
                            "exact": mention_end == span_end and mention_start == span_start
                        })

            for link in hyperlinks:
                link_start = link["start"]
                link_end = link["end"]

                if not (link_end <= span_start or link_start >= span_end):
                    entity_id = link.get("qcode")
                    if entity_id:
                        overlap_start = max(span_start, link_start)
                        overlap_end = min(span_end, link_end)
                        overlap_length = overlap_end - overlap_start
                        coverage_ratio = overlap_length / span_length if span_length > 0 else 0.0

                        span_entry["links"].append({
                            "id": entity_id,
                            "surface_form": link["surface_form"],
                            "start": link_start,
                            "end": link_end,
                            "coverage_ratio": coverage_ratio,
                            "exact": link_start == span_start and link_end == span_end
                        })

            enriched_clusters[cluster_id].append(span_entry)

    return enriched_clusters


def longest_common_substring(s1, s2):
    s1 = s1.lower()
    s2 = s2.lower()
    m = len(s1)
    n = len(s2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    longest = 0

    for i in range(m):
        for j in range(n):
            if s1[i] == s2[j]:
                dp[i + 1][j + 1] = dp[i][j] + 1
                longest = max(longest, dp[i + 1][j + 1])

    return longest

def lcs_overlap_ratio(unlinked_text, anchor_text):
    lcs_len = longest_common_substring(unlinked_text, anchor_text)
    return lcs_len / len(anchor_text) if unlinked_text else 0.0

def best_entity_or_link_match(span_text, anchor_spans):
    span_text = span_text.lower()
    best = {
        "match_type": None,
        "anchor": None,
        "item": None,
        "similarity": 0.0,
    }

    for anchor in anchor_spans:
        for link in anchor.get("links", []):
            sim = lcs_overlap_ratio(span_text, link["surface_form"])
            if sim > best["similarity"] or (sim == best["similarity"] and best["match_type"] != "link"):
                best = {
                    "match_type": "link",
                    "anchor": anchor,
                    "item": link,
                    "similarity": sim,
                }

        for entity in anchor.get("entities", []):
            sim = lcs_overlap_ratio(span_text, entity["el_text"])
            if sim > best["similarity"] or (sim == best["similarity"] and best["match_type"] is None):
                best = {
                    "match_type": "entity",
                    "anchor": anchor,
                    "item": entity,
                    "similarity": sim,
                }

    return best if best["similarity"] > 0.3 else None

def score_entities_by_subject_likelihood(enriched_clusters):
    cluster_entity_scores = {}

    for cluster_id, spans in enriched_clusters.items():
        entity_counts = defaultdict(float)
        total_contribution = 0.0
        seen_mentions = set()

        anchor_spans = []
        unlinked_spans = []

        for span in spans:
            if is_pronoun(span["coref_text"]):
                continue

            has_entities = bool(span.get("entities"))
            has_links = bool(span.get("links"))

            if has_entities or has_links:
                anchor_spans.append(span)

                for entity in span.get("entities", []):
                    key = (entity["id"], entity["start"], entity["end"])
                    if key in seen_mentions:
                        continue
                    seen_mentions.add(key)

                    weight = 0.85 * entity["coverage_ratio"]
                    entity_counts[entity["id"]] += weight
                    total_contribution += weight

                for link in span.get("links", []):
                    key = (link["id"], link["start"], link["end"])
                    if key in seen_mentions:
                        continue
                    seen_mentions.add(key)

                    weight = 1.0 * link["coverage_ratio"]
                    entity_counts[link["id"]] += weight
                    total_contribution += weight
            else:
                unlinked_spans.append(span)

        # Infer contribution for unlinked spans via char-level overlap
        for span in unlinked_spans:
            result = best_entity_or_link_match(span["coref_text"], anchor_spans)

            if result:
                sim = result["similarity"]
                if result["match_type"] == "entity":
                    entity = result["item"]
                    weight = sim * 0.85 * entity.get("coverage_ratio", 0.5)
                    entity_counts[entity["id"]] += weight
                    total_contribution += weight

                elif result["match_type"] == "link":
                    link = result["item"]
                    weight = sim * 1.0 * link.get("coverage_ratio", 0.5)
                    entity_counts[link["id"]] += weight
                    total_contribution += weight

        # print(total_contribution)
        total_contribution = max(1.0, total_contribution)

        normalized_scores = {
            entity_id: count / total_contribution
            for entity_id, count in entity_counts.items()
        } if total_contribution > 0 else {}

        cluster_entity_scores[cluster_id] = normalized_scores

    return cluster_entity_scores


# Display clusters + entities in text

In [40]:
from collections import defaultdict, Counter

def analyze_coref_clusters_interactive_with_merge(doc):
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output

    text = doc["text"]
    coref = doc["coref"]
    entity_linking = doc["entity_linking"]
    hyperlinks = doc["hyperlinks_clean"]

    enriched_clusters = enrich_coref_clusters(coref, entity_linking, hyperlinks)

    # Keep only clusters that contain at least one span with an entity or link
    filtered_enriched_clusters = {
        cid: spans for cid, spans in enriched_clusters.items()
        if any(span.get("entities") or span.get("links") for span in spans)
    }
    enriched_clusters = filtered_enriched_clusters

    cluster_ids = sorted(enriched_clusters.keys())
    entity_scores = score_entities_by_subject_likelihood(enriched_clusters)

    def cluster_similarity_score(cluster_a_id, cluster_b_id, enriched_clusters):
        spans_a = enriched_clusters[cluster_a_id]
        spans_b = enriched_clusters[cluster_b_id]

        text_to_entities_a = defaultdict(Counter)
        text_to_entities_b = defaultdict(Counter)

        for span in spans_a:
            if not is_pronoun(span["coref_text"]):
                text_key = span["coref_text"].strip().lower()
                qids = [ent["id"] for ent in span["entities"]] + [link["id"] for link in span["links"]]
                if qids:
                    text_to_entities_a[text_key].update(qids)
                else:
                    text_to_entities_a[text_key]["<NONE>"] += 1

        for span in spans_b:
            if not is_pronoun(span["coref_text"]):
                text_key = span["coref_text"].strip().lower()
                qids = [ent["id"] for ent in span["entities"]] + [link["id"] for link in span["links"]]
                if qids:
                    text_to_entities_b[text_key].update(qids)
                else:
                    text_to_entities_b[text_key]["<NONE>"] += 1

        shared_text = set(text_to_entities_b.keys()).intersection(set(text_to_entities_a.keys()))

        total_text_freq = Counter()
        for text, counter in text_to_entities_a.items():
            total_text_freq[text] += sum(counter.values())
        for text, counter in text_to_entities_b.items():
            total_text_freq[text] += sum(counter.values())

        overlap_score = 0
        for text in shared_text:
            entities_a = text_to_entities_a[text]
            entities_b = text_to_entities_b[text]

            qids_a = set(entities_a.keys())
            qids_b = set(entities_b.keys())

            qids_a.discard("<NONE>")
            qids_b.discard("<NONE>")

            freq_a = sum(entities_a.values())
            freq_b = sum(entities_b.values())

            overlap_score += (freq_a + freq_b)

        total = sum(total_text_freq.values())
        return overlap_score / total if total > 0 else 0

    def render_cluster_html(text, cluster_id, spans, scores, merged_info):
        sorted_entities = sorted(scores.items(), key=lambda x: -x[1])

        html_chunks = []

        merged_str = ", ".join(
            f"{cid} ({score:.2f})" if cid != cluster_id else f"{cid} (base)"
            for cid, score in merged_info
        )
        html_chunks.append(f"<h3>Cluster {cluster_id} (merged with: {merged_str})</h3>")
        html_chunks.append("<ul>")
        for entity_id, score in sorted_entities:
            name = qcode_to_wiki.get(entity_id, "Unknown")
            html_chunks.append(f"<li><b>{entity_id}</b> ({name}): {score:.3f}</li>")
        html_chunks.append("</ul>")

        annotated_text = list(text)
        tags_to_insert = []

        for span in spans:
            for entity in span.get("entities", []):
                tags_to_insert.append((entity["start"], '<span style="background-color: #aaf;">'))
                tags_to_insert.append((entity["end"], '</span>'))

            for link in span.get("links", []):
                tags_to_insert.append((link["start"], '<span style="background-color: #afa;">'))
                tags_to_insert.append((link["end"], '</span>'))

            tags_to_insert.append((span["start"], '<span style="background-color: #aaf;">'))
            tags_to_insert.append((span["end"], '</span>'))

        tags_to_insert.sort(reverse=True)
        for pos, tag in tags_to_insert:
            annotated_text.insert(pos, tag)

        annotated_string = "".join(annotated_text)
        html_chunks.append(f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0;'>{annotated_string}</div>")

        return "\n".join(html_chunks)

    def update_view(cluster_id, threshold):
        base_cluster = enriched_clusters[cluster_id]
        merged_spans = list(base_cluster)
        merged_info = [(cluster_id, 1.0)]

        debug_cluster_similarity = []
        for other_id in enriched_clusters:
            if other_id == cluster_id:
                continue
            score = cluster_similarity_score(cluster_id, other_id, enriched_clusters)
            if score > 0:
                debug_cluster_similarity.append((other_id, score))
            if score >= threshold:
                merged_spans.extend(enriched_clusters[other_id])
                merged_info.append((other_id, score))

        debug_cluster_similarity = sorted(debug_cluster_similarity, key=lambda x: -x[1])
        debug_cluster_similarity_str = "\n".join(
            f"Overlap similarity with Cluster {cid}: ({score:.2f})"
            for cid, score in debug_cluster_similarity
        )
        print(debug_cluster_similarity_str)

        seen = set()
        unique_spans = []
        for span in merged_spans:
            key = (span["start"], span["end"])
            if key not in seen:
                seen.add(key)
                unique_spans.append(span)
        unique_spans.sort(key=lambda s: s["start"])

        temp_cluster = {0: unique_spans}
        scores = score_entities_by_subject_likelihood(temp_cluster)
        return render_cluster_html(text, cluster_id, unique_spans, scores[0], merged_info)

    # Widgets
    dropdown = widgets.Dropdown(
        options=cluster_ids,
        description='Cluster:',
        style={'description_width': 'initial'}
    )

    threshold_slider = widgets.FloatSlider(
        value=0.6,
        min=0.0,
        max=1.0,
        step=0.05,
        description='Merge Threshold:',
        continuous_update=False
    )

    search_box = widgets.Text(
        value='',
        placeholder='Type to search text or QID...',
        description='Search:',
        style={'description_width': 'initial'},
        continuous_update=False
    )

    qid_score_threshold_slider = widgets.FloatSlider(
        value=0.0,
        min=0.0,
        max=1.0,
        step=0.01,
        description='QID Score Threshold:',
        style={'description_width': 'initial'},
        continuous_update=False
    )

    qid_score_filter_direction = widgets.Dropdown(
        options=[('≥ (Above)', 'above'), ('≤ (Below)', 'below')],
        value='above',
        description='Filter Direction:',
        style={'description_width': 'initial'}
    )

    output = widgets.Output()

    def find_relevant_clusters(query):
        query = query.strip().lower()
        if not query:
            return cluster_ids

        # QID search (e.g., Q42)
        if query.upper().startswith("Q") and query[1:].isdigit():
            qid = query.upper()
            direction = qid_score_filter_direction.value
            threshold_val = qid_score_threshold_slider.value

            scored_clusters = [
                (cid, scores[qid])
                for cid, scores in entity_scores.items()
                if qid in scores
            ]

            if direction == 'above':
                scored_clusters = [item for item in scored_clusters if item[1] >= threshold_val]
            else:
                scored_clusters = [item for item in scored_clusters if item[1] <= threshold_val]

            if scored_clusters:
                return [cid for cid, _ in sorted(scored_clusters, key=lambda x: -x[1])]

        scored_clusters = []
        fallback_clusters = []

        for cid, spans in enriched_clusters.items():
            if not any(query in span["coref_text"].lower() for span in spans):
                continue

            scores = entity_scores.get(cid, {})
            match_score = None
            for qid, score in scores.items():
                qid_name = qcode_to_wiki.get(qid, "").lower()
                if qid_name == query:
                    match_score = score
                    break

            if match_score is not None:
                scored_clusters.append((cid, match_score))
            else:
                fallback_clusters.append(cid)

        sorted_scored = [cid for cid, _ in sorted(scored_clusters, key=lambda x: (-x[1], x[0]))]
        fallback_clusters.sort()
        return sorted_scored + fallback_clusters

    def update_output_with_search(*args):
        with output:
            query = search_box.value.strip()
            relevant_ids = find_relevant_clusters(query)
            threshold = threshold_slider.value

            if not query:
                dropdown.options = cluster_ids
                dropdown.value = cluster_ids[0]
            elif relevant_ids:
                dropdown.options = relevant_ids
                dropdown.value = relevant_ids[0]
            else:
                dropdown.options = cluster_ids
                clear_output()
                display(HTML(f"<p>No clusters found for search '<b>{query}</b>'.</p>"))
                return

            clear_output()
            html = update_view(dropdown.value, threshold)
            display(HTML(html))

    def update_output_with_dropdown(change):
        with output:
            clear_output()
            cluster_id = dropdown.value
            threshold = threshold_slider.value
            html = update_view(cluster_id, threshold)
            display(HTML(html))

    # Bind widget listeners
    dropdown.observe(update_output_with_dropdown, names='value')
    threshold_slider.observe(update_output_with_search, names='value')
    search_box.observe(update_output_with_search, names='value')
    qid_score_threshold_slider.observe(update_output_with_search, names='value')
    qid_score_filter_direction.observe(update_output_with_search, names='value')

    # Initial display
    display(widgets.VBox([
        search_box,
        widgets.HBox([dropdown, threshold_slider]),
        widgets.HBox([qid_score_threshold_slider, qid_score_filter_direction]),
        output
    ]))
    update_output_with_search()

    return enriched_clusters, entity_scores


In [41]:
enriched_clusters, entity_scores = analyze_coref_clusters_interactive_with_merge(doc)

VBox(children=(Text(value='', continuous_update=False, description='Search:', placeholder='Type to search text…

In [32]:
score_entities_by_subject_likelihood({227: enriched_clusters[227]})

0.8468518518518519


{227: {'Q27980167': 0.595,
  'Q3894681': 0.07345679012345678,
  'Q1196645': 0.0839506172839506,
  'Q3065717': 0.09444444444444444}}

In [55]:
from collections import defaultdict, Counter

def analyze_coref_clusters_interactive_with_merge(doc):
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output

    text = doc["text"]
    coref = doc["coref"]
    entity_linking = doc["entity_linking"]
    hyperlinks = doc["hyperlinks_clean"]

    enriched_clusters = enrich_coref_clusters(coref, entity_linking, hyperlinks)

    # Keep only clusters that contain at least one span with an entity or link
    enriched_clusters = {
        cid: spans for cid, spans in enriched_clusters.items()
        if any(span.get("entities") or span.get("links") for span in spans)
    }

    cluster_ids = sorted(enriched_clusters.keys())
    entity_scores = score_entities_by_subject_likelihood(enriched_clusters)

    def cluster_similarity_score(cluster_a_id, cluster_b_id, enriched_clusters):
        spans_a = enriched_clusters[cluster_a_id]
        spans_b = enriched_clusters[cluster_b_id]

        text_to_entities_a = defaultdict(Counter)
        text_to_entities_b = defaultdict(Counter)

        for span in spans_a:
            if not is_pronoun(span["coref_text"]):
                text_key = span["coref_text"].strip().lower()
                qids = [ent["id"] for ent in span["entities"]] + [link["id"] for link in span["links"]]
                if qids:
                    text_to_entities_a[text_key].update(qids)
                else:
                    text_to_entities_a[text_key]["<NONE>"] += 1

        for span in spans_b:
            if not is_pronoun(span["coref_text"]):
                text_key = span["coref_text"].strip().lower()
                qids = [ent["id"] for ent in span["entities"]] + [link["id"] for link in span["links"]]
                if qids:
                    text_to_entities_b[text_key].update(qids)
                else:
                    text_to_entities_b[text_key]["<NONE>"] += 1

        shared_text = set(text_to_entities_b.keys()).intersection(set(text_to_entities_a.keys()))

        total_text_freq = Counter()
        for text, counter in text_to_entities_a.items():
            total_text_freq[text] += sum(counter.values())
        for text, counter in text_to_entities_b.items():
            total_text_freq[text] += sum(counter.values())

        overlap_score = 0
        for text in shared_text:
            entities_a = text_to_entities_a[text]
            entities_b = text_to_entities_b[text]

            qids_a = set(entities_a.keys())
            qids_b = set(entities_b.keys())

            qids_a.discard("<NONE>")
            qids_b.discard("<NONE>")

            freq_a = sum(entities_a.values())
            freq_b = sum(entities_b.values())

            overlap_score += (freq_a + freq_b)

        total = sum(total_text_freq.values())
        return overlap_score / total if total > 0 else 0

    def render_cluster_html(text, cluster_id, spans, scores, merged_info):
        sorted_entities = sorted(scores.items(), key=lambda x: -x[1])

        html_chunks = []

        merged_str = ", ".join(
            f"{cid} ({score:.2f})" if cid != cluster_id else f"{cid} (base)"
            for cid, score in merged_info
        )
        html_chunks.append(f"<h3>Cluster {cluster_id} (merged with: {merged_str})</h3>")
        html_chunks.append("<ul>")
        for entity_id, score in sorted_entities:
            name = qcode_to_wiki.get(entity_id, "Unknown")
            html_chunks.append(f"<li><b>{entity_id}</b> ({name}): {score:.3f}</li>")
        html_chunks.append("</ul>")

        annotated_text = list(text)
        tags_to_insert = []

        for span in spans:
            for entity in span.get("entities", []):
                tags_to_insert.append((entity["start"], '<span style="background-color: #aaf;">'))
                tags_to_insert.append((entity["end"], '</span>'))

            for link in span.get("links", []):
                tags_to_insert.append((link["start"], '<span style="background-color: #afa;">'))
                tags_to_insert.append((link["end"], '</span>'))

            tags_to_insert.append((span["start"], '<span style="background-color: #aaf;">'))
            tags_to_insert.append((span["end"], '</span>'))

        tags_to_insert.sort(reverse=True)
        for pos, tag in tags_to_insert:
            annotated_text.insert(pos, tag)

        annotated_string = "".join(annotated_text)
        html_chunks.append(f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0;'>{annotated_string}</div>")

        return "\n".join(html_chunks)

    def update_view(cluster_id, threshold):
        base_cluster = enriched_clusters[cluster_id]
        merged_spans = list(base_cluster)
        merged_info = [(cluster_id, 1.0)]

        debug_cluster_similarity = []
        for other_id in enriched_clusters:
            if other_id == cluster_id:
                continue
            score = cluster_similarity_score(cluster_id, other_id, enriched_clusters)
            if score > 0:
                debug_cluster_similarity.append((other_id, score))
            if score >= threshold:
                merged_spans.extend(enriched_clusters[other_id])
                merged_info.append((other_id, score))

        debug_cluster_similarity = sorted(debug_cluster_similarity, key=lambda x: -x[1])

        seen = set()
        unique_spans = []
        for span in merged_spans:
            key = (span["start"], span["end"])
            if key not in seen:
                seen.add(key)
                unique_spans.append(span)
        unique_spans.sort(key=lambda s: s["start"])

        temp_cluster = {0: unique_spans}
        scores = score_entities_by_subject_likelihood(temp_cluster)
        return render_cluster_html(text, cluster_id, unique_spans, scores[0], merged_info)

    # Widgets
    dropdown = widgets.Dropdown(
        options=cluster_ids,
        description='Cluster:',
        style={'description_width': 'initial'}
    )

    threshold_slider = widgets.FloatSlider(
        value=0.6,
        min=0.0,
        max=1.0,
        step=0.05,
        description='Merge Threshold:',
        continuous_update=False
    )

    search_box = widgets.Text(
        value='',
        placeholder='Type to search text or QID...',
        description='Search:',
        style={'description_width': 'initial'},
        continuous_update=False
    )

    qid_score_threshold_slider = widgets.FloatSlider(
        value=0.0,
        min=0.0,
        max=1.0,
        step=0.01,
        description='QID Score Threshold:',
        style={'description_width': 'initial'},
        continuous_update=False
    )

    qid_score_filter_direction = widgets.Dropdown(
        options=[('≥ (Above)', 'above'), ('≤ (Below)', 'below')],
        value='above',
        description='Filter Direction:',
        style={'description_width': 'initial'}
    )

    view_all_qid_mentions_toggle = widgets.Checkbox(
        value=False,
        description='Show all QID mentions',
        style={'description_width': 'initial'}
    )

    output = widgets.Output()

    def find_relevant_clusters(query):
        query = query.strip().lower()
        if not query:
            return cluster_ids

        if query.upper().startswith("Q") and query[1:].isdigit():
            qid = query.upper()
            direction = qid_score_filter_direction.value
            threshold_val = qid_score_threshold_slider.value

            scored_clusters = [
                (cid, scores[qid])
                for cid, scores in entity_scores.items()
                if qid in scores
            ]

            if direction == 'above':
                scored_clusters = [item for item in scored_clusters if item[1] >= threshold_val]
            else:
                scored_clusters = [item for item in scored_clusters if item[1] <= threshold_val]

            if scored_clusters:
                return [cid for cid, _ in sorted(scored_clusters, key=lambda x: -x[1])]

        scored_clusters = []
        fallback_clusters = []

        for cid, spans in enriched_clusters.items():
            if not any(query in span["coref_text"].lower() for span in spans):
                continue

            scores = entity_scores.get(cid, {})
            match_score = None
            for qid, score in scores.items():
                qid_name = qcode_to_wiki.get(qid, "").lower()
                if qid_name == query:
                    match_score = score
                    break

            if match_score is not None:
                scored_clusters.append((cid, match_score))
            else:
                fallback_clusters.append(cid)

        sorted_scored = [cid for cid, _ in sorted(scored_clusters, key=lambda x: (-x[1], x[0]))]
        fallback_clusters.sort()
        return sorted_scored + fallback_clusters

    def update_output_with_search(*args):
        with output:
            query = search_box.value.strip()
            threshold = threshold_slider.value
            view_all = view_all_qid_mentions_toggle.value

            if query.upper().startswith("Q") and query[1:].isdigit() and view_all:
                qid = query.upper()
                direction = qid_score_filter_direction.value
                threshold_val = qid_score_threshold_slider.value

                scored_clusters = [
                    (cid, scores[qid])
                    for cid, scores in entity_scores.items()
                    if qid in scores
                ]

                if direction == 'above':
                    scored_clusters = [item for item in scored_clusters if item[1] >= threshold_val]
                else:
                    scored_clusters = [item for item in scored_clusters if item[1] <= threshold_val]

                if not scored_clusters:
                    dropdown.options = cluster_ids
                    clear_output()
                    display(HTML(f"<p>No clusters found for QID '<b>{qid}</b>'.</p>"))
                    return

                # Aggregate spans
                spans = []
                merged_info = []
                for cid, score in scored_clusters:
                    spans.extend(enriched_clusters[cid])
                    merged_info.append((cid, score))

                # Deduplicate
                seen = set()
                unique_spans = []
                for span in spans:
                    key = (span["start"], span["end"])
                    if key not in seen:
                        seen.add(key)
                        unique_spans.append(span)
                unique_spans.sort(key=lambda s: s["start"])

                temp_cluster = {0: unique_spans}
                scores = score_entities_by_subject_likelihood(temp_cluster)
                html = render_cluster_html(text, qid, unique_spans, scores[0], merged_info)

                clear_output()
                display(HTML(html))
                return

            relevant_ids = find_relevant_clusters(query)
            if not query:
                dropdown.options = cluster_ids
                dropdown.value = cluster_ids[0]
            elif relevant_ids:
                dropdown.options = relevant_ids
                dropdown.value = relevant_ids[0]
            else:
                dropdown.options = cluster_ids
                clear_output()
                display(HTML(f"<p>No clusters found for search '<b>{query}</b>'.</p>"))
                return

            clear_output()
            html = update_view(dropdown.value, threshold)
            display(HTML(html))

    def update_output_with_dropdown(change):
        with output:
            cluster_id = dropdown.value
            threshold = threshold_slider.value
            html = update_view(cluster_id, threshold)
            clear_output()
            display(HTML(html))

    # Bind widget listeners
    dropdown.observe(update_output_with_dropdown, names='value')
    threshold_slider.observe(update_output_with_search, names='value')
    search_box.observe(update_output_with_search, names='value')
    qid_score_threshold_slider.observe(update_output_with_search, names='value')
    qid_score_filter_direction.observe(update_output_with_search, names='value')
    view_all_qid_mentions_toggle.observe(update_output_with_search, names='value')

    # Initial display
    display(widgets.VBox([
        search_box,
        widgets.HBox([dropdown, threshold_slider]),
        widgets.HBox([qid_score_threshold_slider, qid_score_filter_direction]),
        view_all_qid_mentions_toggle,
        output
    ]))
    update_output_with_search()

    return enriched_clusters, entity_scores


In [56]:
enriched_clusters, entity_scores = analyze_coref_clusters_interactive_with_merge(doc)

VBox(children=(Text(value='', continuous_update=False, description='Search:', placeholder='Type to search text…