In [1]:
from typing import Any, Iterable
from collections import Counter
import numpy as np
import pandas as pd
from datasets import load_dataset
import json
from numpy.typing import NDArray
import sys
from pathlib import Path
from tqdm import tqdm
import ZODB
import ZODB.FileStorage
import transaction
from features import FeatureExtractorPipeline, ExtCtx, SentenceToken

sys.path.append(str(Path.cwd().parent))
from book_segmenting import TextSegmenter
from utils import DATA_DIR

feature_extractor = FeatureExtractorPipeline()

SEGMENT_CHARS_MIN = 150
SEGMENT_CHARS_MAX = 500
segmenter = TextSegmenter(chunk_size=(SEGMENT_CHARS_MIN, SEGMENT_CHARS_MAX))


class Dataset:
    MIN_TEXT_LENGTH = 60
    MAX_TEXT_LENGTH = 500

    def __init__(
        self,
        src: Iterable[Any],
        take: int,
        skip: int = 0,
        text_getter=None,
        deduplicate=False,
        segment=False,
        check_length=True,
    ):
        self.src = iter(src)
        self.take = take
        self.skip = skip
        self.contexts: list[ExtCtx] | None = None
        self.features: list[NDArray[np.float32]] | None = None
        self.text_getter = text_getter
        self.deduplicate = deduplicate
        self.segment = segment
        self.check_length = check_length

    def process(
        self, deduplicate: bool | None = None, segment: bool | None = None
    ) -> list[ExtCtx]:
        if deduplicate is None:
            deduplicate = self.deduplicate
        if segment is None:
            segment = self.segment

        self.contexts = []
        self.features = []
        if deduplicate:
            seen = set()
        taken = 0
        to_skip = self.skip

        with tqdm(total=self.take, desc="Processing texts", unit="text") as pbar:
            while taken < self.take:
                try:
                    text = next(self.src)
                except StopIteration:
                    break
                if to_skip > 0:
                    to_skip -= 1
                    continue

                if self.text_getter is not None:
                    text = self.text_getter(text)
                if not text or (
                    self.check_length and len(text.strip()) < Dataset.MIN_TEXT_LENGTH
                ):
                    continue
                if deduplicate:
                    if text in seen:
                        continue
                    seen.add(text)

                text = FeatureExtractorPipeline.preprocess(text)
                if self.check_length and len(text.strip()) < Dataset.MIN_TEXT_LENGTH:
                    continue

                if segment:
                    segments = [
                        seg
                        for seg in segmenter.segment_text(text)
                        if seg
                        and (
                            not self.check_length
                            or (seg_len := len(seg.strip())) >= Dataset.MIN_TEXT_LENGTH
                            and seg_len <= Dataset.MAX_TEXT_LENGTH
                        )
                    ]
                    if len(segments) == 0:
                        continue
                    example = segments[len(segments) // 2]
                    ctx = feature_extractor.get_ctx(example)
                    self.contexts.append(ctx)
                    self.features.append(
                        feature_extractor.extract(example, preprocess=False, ctx=ctx)
                    )
                else:
                    if self.check_length and len(text) > Dataset.MAX_TEXT_LENGTH:
                        continue
                    ctx = feature_extractor.get_ctx(text)
                    self.contexts.append(ctx)
                    self.features.append(
                        feature_extractor.extract(text, preprocess=False, ctx=ctx)
                    )
                taken += 1
                pbar.update(1)

        return self.contexts

    def __iter__(self):
        if self.contexts is None:
            raise ValueError("Dataset not processed yet. Call process() first.")
        return iter(self.contexts)

  from .autonotebook import tqdm as notebook_tqdm
2025-10-12 13:30:43.899483: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760268643.934909 1215969 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760268643.944707 1215969 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-12 13:30:43.979588: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


using device cpu


In [2]:
ds_high_flickr = load_dataset(
    "CaptionEmporium/flickr-megalith-10m-internvl2-multi-caption",
    split="train",
    streaming=True,
)

In [3]:
ds_flickr30k = load_dataset("embedding-data/flickr30k_captions_quintets", split="train")

In [4]:
ds_coco = load_dataset("sentence-transformers/coco-captions", split="train")

In [5]:
ds_sbu = load_dataset("vicenteor/sbu_captions", split="train", trust_remote_code=True)

In [None]:
with open(DATA_DIR / "datasets" / "large" / "movie_summaries.txt") as f:
    ds_movie_summaries = [line.strip() for line in f.readlines()]

In [7]:
ds_book_summaries = load_dataset("textminr/cmu-book-summaries", split="train")

In [None]:
with open(DATA_DIR / "datasets" / "large" / "book_dialogs.txt") as f:
    ds_book_dialogs = [line.strip() for line in f.read().split("\n\n")]

In [9]:
ds_wiki = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")

In [10]:
ds_news = load_dataset("EdinburghNLP/xsum", split="validation")

In [11]:
ds_hotels = load_dataset("argilla/tripadvisor-hotel-reviews", split="train")

In [12]:
ds_yelp = load_dataset("Yelp/yelp_review_full", split="test")

In [13]:
ds_arxiv = load_dataset(
    "armanc/scientific_papers",
    "arxiv",
    split="validation",
    trust_remote_code=True,
    streaming=True,
)

In [14]:
AMAZON_CATEGORIES = [
    "Cell_Phones_and_Accessories",
    "Beauty_and_Personal_Care",
    "Electronics",
    "Grocery_and_Gourmet_Food",
    "CDs_and_Vinyl",
    "Musical_Instruments",
    "Magazine_Subscriptions",
    "Industrial_and_Scientific",
    "Software",
]
ds_amazon_reviews = []
N_TOTAL = 2000
for category in AMAZON_CATEGORIES:
    ds = iter(
        load_dataset(
            "McAuley-Lab/Amazon-Reviews-2023",
            f"raw_review_{category}",
            split="full",
            trust_remote_code=True,
            streaming=True,
        )
    )
    for i in range(N_TOTAL // len(AMAZON_CATEGORIES)):
        ds_amazon_reviews.append(next(ds))
while len(ds_amazon_reviews) < N_TOTAL:
    ds_amazon_reviews.append(next(ds))

In [None]:
with open(DATA_DIR / "datasets" / "large" / "batch_10k.json") as f:
    books_10k_dataset = json.load(f)

In [15]:
datasets = [
    Dataset(ds_high_flickr, take=1500, text_getter=lambda x: x["caption_internlm2"]),
    Dataset(
        ds_high_flickr, take=1500, text_getter=lambda x: x["caption_internlm2_short"]
    ),
    Dataset(ds_flickr30k, take=1500, text_getter=lambda x: x["set"][0]),
    Dataset(ds_coco, take=500, text_getter=lambda x: x["caption1"]),
    Dataset(ds_sbu, take=500, text_getter=lambda x: x["caption"]),
    Dataset(ds_movie_summaries, take=500, segment=True),
    Dataset(
        ds_book_summaries, take=500, text_getter=lambda x: x["summary"], segment=True
    ),
    Dataset(ds_book_dialogs, take=500),
    Dataset(
        ds_wiki,
        take=1000,
        text_getter=lambda x: x["text"].replace(" @-@ ", "-").replace(" @,@ ", ","),
        segment=True,
    ),
    Dataset(ds_news, take=500, text_getter=lambda x: x["document"], segment=True),
    Dataset(ds_hotels, take=200, text_getter=lambda x: x["text"]),
    Dataset(ds_yelp, take=300, text_getter=lambda x: x["text"]),
    Dataset(ds_arxiv, take=500, text_getter=lambda x: x["abstract"], segment=True),
    Dataset(ds_amazon_reviews, take=500, text_getter=lambda x: x["text"]),
    # Dataset(books_10k_dataset, take=10000, text_getter=lambda x: x["text"], check_length=False),
]

for i, dataset in enumerate(datasets):
    print(f"--- Done: {i}/{len(datasets)} ---")
    dataset.process(deduplicate=True)
print("DONE")

--- Done: 0/14 ---


Processing texts: 100%|██████████| 1500/1500 [04:06<00:00,  6.08text/s]


--- Done: 1/14 ---


Processing texts:   0%|          | 0/1500 [00:00<?, ?text/s]'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 9f8e6406-b986-4152-94d3-d821a5a34bdc)')' thrown while requesting GET https://huggingface.co/datasets/CaptionEmporium/flickr-megalith-10m-internvl2-multi-caption/resolve/75b33ce72533023bf907f8a0bf160099883f1bae/train/train_0000.parquet
Retrying in 1s [Retry 1/5].
Processing texts: 100%|██████████| 1500/1500 [02:21<00:00, 10.56text/s]


--- Done: 2/14 ---


Processing texts: 100%|██████████| 1500/1500 [00:56<00:00, 26.68text/s]


--- Done: 3/14 ---


Processing texts: 100%|██████████| 500/500 [00:17<00:00, 29.10text/s]


--- Done: 4/14 ---


Processing texts: 100%|██████████| 500/500 [00:19<00:00, 25.07text/s]


--- Done: 5/14 ---


Processing texts: 100%|██████████| 500/500 [00:32<00:00, 15.26text/s]


--- Done: 6/14 ---


Processing texts: 100%|██████████| 500/500 [00:34<00:00, 14.51text/s]


--- Done: 7/14 ---


Processing texts: 100%|██████████| 500/500 [00:46<00:00, 10.68text/s]


--- Done: 8/14 ---


Processing texts: 100%|██████████| 1000/1000 [01:22<00:00, 12.10text/s]


--- Done: 9/14 ---


Processing texts: 100%|██████████| 500/500 [00:34<00:00, 14.38text/s]


--- Done: 10/14 ---


Processing texts: 100%|██████████| 200/200 [00:15<00:00, 12.71text/s]


--- Done: 11/14 ---


Processing texts: 100%|██████████| 300/300 [00:27<00:00, 10.91text/s]


--- Done: 12/14 ---


Processing texts: 100%|██████████| 500/500 [00:57<00:00,  8.67text/s]


--- Done: 13/14 ---


Processing texts: 100%|██████████| 500/500 [00:30<00:00, 16.14text/s]

DONE





In [16]:
storage = ZODB.FileStorage.FileStorage("../data/db/mydata.fs")
db = ZODB.DB(storage)
connection = db.open()
root = connection.root

# root.datasets = BTrees.OOBTree.BTree()

In [None]:
# ZODB Load
datasets_features: list[NDArray[np.float32]] = root.datasets["ref_features"]

In [17]:
# ZODB Save
def save_datasets(datasets: list[Dataset], name: str):
    root.datasets[f"{name}_contexts"] = [ds.contexts for ds in datasets]
    root.datasets[f"{name}_features"] = [ds.features for ds in datasets]
    transaction.commit()


save_datasets(datasets, name="ref")

object you're saving is large. (1622460330 bytes.)

Perhaps you're storing media which should be stored in blobs.

Perhaps you're using a non-scalable data structure, such as a
PersistentMapping or PersistentList.

Perhaps you're storing data in objects that aren't persistent at
all. In cases like that, the data is stored in the record of the
containing persistent object.

In any case, storing records this big is probably a bad idea.

large_record_size option of the ZODB.DB constructor (or the
large-record-size option in a configuration file) to specify a larger
size.



In [None]:
def extract_char_ngrams(ctx: ExtCtx, ctr: Counter, document_ctr: Counter) -> Counter:
    """Required ctx: 'text'

    Time complexity: O(mn), where m is number of characters and n is n-gram size
    """
    MAX_N = 5

    document_set = set()
    text = ctx.text.casefold()
    for n in range(2, MAX_N + 1):
        for i in range(len(text) - n + 1):
            ngram = text[i : i + n]
            ctr[ngram] += 1
            document_set.add(ngram)
    document_ctr.update(document_set)


def extract_pos_ngrams(ctx: ExtCtx, ctr: Counter, document_ctr: Counter) -> Counter:
    """Required ctx: 'tokens'

    Time complexity: O(mn), where m is number of tokens and n is n-gram size

    Uses coarse UD tags:
    ADJ: adjective
    ADP: adposition
    ADV: adverb
    AUX: auxiliary
    CCONJ: coordinating conjunction
    DET: determiner
    INTJ: interjection
    NOUN: noun
    NUM: numeral
    PART: particle
    PRON: pronoun
    PROPN: proper noun
    PUNCT: punctuation
    SCONJ: subordinating conjunction
    SYM: symbol
    VERB: verb
    X: other
    """
    MAX_N = 4

    document_set = set()
    tokens = ctx.tokens
    for n in range(2, MAX_N + 1):
        for i in range(len(tokens) - n + 1):
            ngram = tuple(token.pos for token in tokens[i : i + n])
            ctr[ngram] += 1
            document_set.add(ngram)
    document_ctr.update(document_set)


def extract_dependency_tree_structure(
    ctx: ExtCtx, depth_ctr: Counter, branching_factor_ctr: Counter, width_ctr: Counter
):
    """Required ctx: 'sents'

    Complexity: O(n), where n is number of tokens in all sentences
    """

    def get_tree_depth(token: SentenceToken):
        """Compute the longest path from root to any leaf node."""
        if not list(token.children):  # leaf node
            return 0
        return 1 + max(get_tree_depth(child) for child in token.children)

    def get_branching_factors(token: SentenceToken, factors=None):
        """Get branching factor for each non-leaf node."""
        if factors is None:
            factors = []

        children = list(token.children)
        if children:
            factors.append(len(children))
            for child in children:
                get_branching_factors(child, factors)

        return factors

    def count_leaf_nodes(token: SentenceToken):
        """Count the number of leaf nodes in the tree."""
        children = list(token.children)
        if not children:
            return 1
        return sum(count_leaf_nodes(child) for child in children)

    for sent in ctx.sents:
        root = sent.root

        depth = get_tree_depth(root)
        depth_ctr[depth] += 1

        factors = get_branching_factors(root)
        branching_factor_ctr.update(factors)

        width = count_leaf_nodes(root)
        width_ctr[width] += 1


def extract_dependency_tree_relations(
    ctx: ExtCtx,
    node_ngrams: Counter,
    relation_ngrams: Counter,
    complete_ngrams: Counter,
    node_doc_freq: Counter,
    relation_doc_freq: Counter,
    complete_doc_freq: Counter,
):
    """Required ctx: 'sents'

    Returns three counters:
    1. Node n-grams (2-4-grams) - ascending path of node labels (POS tags)
    2. Relation n-grams (1-4-grams) - ascending path of edge labels (dependency relations)
    3. Complete n-grams (2-4-grams) - path with both node and edge labels
    """
    node_ngrams_set = set()
    relation_ngrams_set = set()
    complete_ngrams_set = set()

    def get_ascending_paths(
        token: SentenceToken,
        current_path_nodes=None,
        current_path_rels=None,
        visited=None,
    ):
        """Get all ascending paths starting from this token"""
        if current_path_nodes is None:
            current_path_nodes = []
        if current_path_rels is None:
            current_path_rels = []
        if visited is None:
            visited = set()

        if token.text in visited:  # Avoid cycles
            return

        visited.add(token.text)
        current_path_nodes.append(token.pos)

        # Process current path for node n-grams (2-4)
        path_len = len(current_path_nodes)
        for n in range(2, min(5, path_len + 1)):
            if path_len >= n:
                ngram = tuple(current_path_nodes[-n:])
                node_ngrams[ngram] += 1
                node_ngrams_set.add(ngram)

        # Process relation n-grams (1-4)
        if current_path_rels:
            for n in range(1, min(5, len(current_path_rels) + 1)):
                if len(current_path_rels) >= n:
                    ngram = tuple(current_path_rels[-n:])
                    relation_ngrams[ngram] += 1
                    relation_ngrams_set.add(ngram)

        # Process complete n-grams (2-4) - alternating node-rel-node
        if len(current_path_nodes) >= 2 and len(current_path_rels) >= 1:
            for n in range(2, min(5, len(current_path_nodes) + 1)):
                if len(current_path_nodes) >= n and len(current_path_rels) >= n - 1:
                    complete_path = []
                    for i in range(n):
                        complete_path.append(current_path_nodes[-(n - i)])
                        if i < n - 1 and len(current_path_rels) > (n - 2 - i):
                            complete_path.append(current_path_rels[-(n - 1 - i)])
                    complete_ngrams[tuple(complete_path)] += 1
                    complete_ngrams_set.add(tuple(complete_path))

        for child in token.children:
            new_path_rels = current_path_rels + [child.dep]
            get_ascending_paths(
                child, current_path_nodes[:], new_path_rels[:], visited.copy()
            )

        visited.remove(token.text)

    for sent in ctx.sents:
        get_ascending_paths(sent.root)

    node_doc_freq.update(node_ngrams_set)
    relation_doc_freq.update(relation_ngrams_set)
    complete_doc_freq.update(complete_ngrams_set)

    return node_ngrams, relation_ngrams, complete_ngrams


def extract_noun_phrase_lengths(ctx: ExtCtx, np_length_ctr: Counter):
    """Required ctx: 'noun_chunks'

    Time complexity: O(m), where m is number of noun phrases
    """
    for chunk in ctx.noun_chunks:
        np_length_ctr[chunk.length] += 1

In [None]:
char_counter = Counter()
char_doc_freq = Counter()
pos_counter = Counter()
pos_doc_freq = Counter()
dep_tree_depth_counter = Counter()
dep_tree_branching_factor_counter = Counter()
dep_tree_width_counter = Counter()
node_ngrams = Counter()  # 2-4-grams of POS tags
node_doc_freq = Counter()
relation_ngrams = Counter()  # 1-4-grams of dependency relations
relation_doc_freq = Counter()
complete_ngrams = Counter()  # 2-4-grams of alternating POS-rel-POS patterns
complete_doc_freq = Counter()
np_length_ctr = Counter()
concr_matches = 0
concr_effective_word_count = 0

for dataset in datasets:
    for ctx in dataset:
        extract_char_ngrams(ctx, char_counter, char_doc_freq)
        extract_pos_ngrams(ctx, pos_counter, pos_doc_freq)
        extract_dependency_tree_structure(
            ctx,
            dep_tree_depth_counter,
            dep_tree_branching_factor_counter,
            dep_tree_width_counter,
        )
        extract_dependency_tree_relations(
            ctx,
            node_ngrams,
            relation_ngrams,
            complete_ngrams,
            node_doc_freq,
            relation_doc_freq,
            complete_doc_freq,
        )
        extract_noun_phrase_lengths(ctx, np_length_ctr)
        _, _, match_count, effective_word_count = (
            feature_extractor.extract_word_concreteness(ctx)
        )
        concr_matches += match_count
        concr_effective_word_count += effective_word_count

In [None]:
total_example_count = sum(ds.take for ds in datasets)

# Get top 1000 ngrams closest to 50% document frequency in each dataset
df1 = pd.DataFrame(char_doc_freq.most_common(), columns=["ngram", "doc_freq"])
df1["doc_freq_ratio"] = df1["doc_freq"] / total_example_count
df1["doc_freq_diff"] = np.abs(df1["doc_freq_ratio"] - 0.5)
df1 = df1.sort_values("doc_freq_diff").reset_index(drop=True)
df1 = df1.head(1000)
# df1.to_csv("char_ngrams_features.csv", index=False)

# Print length of pos ngrams with >= 2% document frequency
df_pos1 = pd.DataFrame(pos_doc_freq.most_common(), columns=["ngram", "doc_freq"])
df_pos1["doc_freq_ratio"] = df_pos1["doc_freq"] / total_example_count
df_pos1 = df_pos1[df_pos1["doc_freq_ratio"] >= 0.02]
print(df_pos1["ngram"].str.len().value_counts().sort_index())
# Get all pos ngrams with >= 2% document frequency
# df_pos1.to_csv("pos_ngrams_features.csv", index=False)

# Print frequencies of depths, branching factors, and widths sorted in descending order by size
print("Depth frequencies (descending):")
for depth, count in sorted(
    dep_tree_depth_counter.items(), key=lambda x: x[0], reverse=True
):
    print(
        f"Depth {depth}: {count} ({count / sum(dep_tree_depth_counter.values()):.2%})"
    )
"""
Depth 25: 1 (0.00%)
Depth 23: 1 (0.00%)
Depth 22: 1 (0.00%)
Depth 21: 1 (0.00%)
Depth 19: 1 (0.00%)
Depth 18: 7 (0.01%)
Depth 17: 12 (0.02%)
Depth 16: 22 (0.04%)
Depth 15: 35 (0.06%)
Depth 14: 61 (0.11%)
Depth 13: 141 (0.26%)
Depth 12: 198 (0.37%)
Depth 11: 388 (0.72%)
Depth 10: 777 (1.44%)
Depth 9: 1435 (2.66%)
Depth 8: 2496 (4.63%)
Depth 7: 4220 (7.82%)
Depth 6: 6437 (11.93%)
Depth 5: 8967 (16.62%)
Depth 4: 9694 (17.97%)
Depth 3: 8546 (15.84%)
Depth 2: 6464 (11.98%)
Depth 1: 3907 (7.24%)
Depth 0: 128 (0.24%)

-> 18 depth levels (0-17+)
"""
print("Branching factor frequencies (descending):")
for factor, count in sorted(
    dep_tree_branching_factor_counter.items(), key=lambda x: x[0], reverse=True
):
    print(
        f"Branching factor {factor}: {count} ({count / sum(dep_tree_branching_factor_counter.values()):.2%})"
    )
"""
Branching factor 38: 1 (0.00%)
Branching factor 26: 1 (0.00%)
Branching factor 24: 2 (0.00%)
Branching factor 20: 1 (0.00%)
Branching factor 19: 2 (0.00%)
Branching factor 18: 3 (0.00%)
Branching factor 17: 5 (0.00%)
Branching factor 16: 12 (0.00%)
Branching factor 15: 22 (0.01%)
Branching factor 14: 54 (0.01%)
Branching factor 13: 115 (0.03%)
Branching factor 12: 219 (0.05%)
Branching factor 11: 545 (0.13%)
Branching factor 10: 1134 (0.27%)
Branching factor 9: 2218 (0.53%)
Branching factor 8: 4308 (1.04%)
Branching factor 7: 7768 (1.87%)
Branching factor 6: 13773 (3.31%)
Branching factor 5: 21412 (5.15%)
Branching factor 4: 30381 (7.30%)
Branching factor 3: 50672 (12.18%)
Branching factor 2: 82596 (19.85%)
Branching factor 1: 200918 (48.28%)

-> 18 branching factor levels (1-18+)
"""
print("Width frequencies (descending):")
for width, count in sorted(
    dep_tree_width_counter.items(), key=lambda x: x[0], reverse=True
):
    print(
        f"Width {width}: {count} ({count / sum(dep_tree_width_counter.values()):.2%})"
    )
"""
Width 68: 2 (0.00%)
Width 66: 2 (0.00%)
Width 65: 1 (0.00%)
Width 64: 2 (0.00%)
Width 63: 3 (0.01%)
Width 61: 2 (0.00%)
Width 60: 1 (0.00%)
Width 59: 3 (0.01%)
Width 58: 3 (0.01%)
Width 57: 2 (0.00%)
Width 56: 7 (0.01%)
Width 55: 8 (0.01%)
Width 54: 6 (0.01%)
Width 53: 7 (0.01%)
Width 52: 6 (0.01%)
Width 51: 5 (0.01%)
Width 50: 14 (0.03%)
Width 49: 19 (0.04%)
Width 48: 8 (0.01%)
Width 47: 17 (0.03%)
Width 46: 18 (0.03%)
Width 45: 13 (0.02%)
Width 44: 35 (0.06%)
Width 43: 29 (0.05%)
Width 42: 37 (0.07%)
Width 41: 44 (0.08%)
Width 40: 55 (0.10%)
Width 39: 65 (0.12%)
Width 38: 67 (0.12%)
Width 37: 63 (0.12%)
Width 36: 79 (0.15%)
Width 35: 76 (0.14%)
Width 34: 110 (0.20%)
Width 33: 123 (0.23%)
Width 32: 133 (0.25%)
Width 31: 155 (0.29%)
Width 30: 157 (0.29%)
Width 29: 221 (0.41%)
Width 28: 233 (0.43%)
Width 27: 288 (0.53%)
Width 26: 306 (0.57%)
Width 25: 372 (0.69%)
Width 24: 405 (0.75%)
Width 23: 549 (1.02%)
Width 22: 562 (1.04%)
Width 21: 658 (1.22%)
Width 20: 855 (1.59%)
Width 19: 956 (1.77%)
Width 18: 1201 (2.23%)
Width 17: 1408 (2.61%)
Width 16: 1575 (2.92%)
Width 15: 1901 (3.52%)
Width 14: 2233 (4.14%)
Width 13: 2439 (4.52%)
Width 12: 2748 (5.09%)
Width 11: 3123 (5.79%)
Width 10: 3440 (6.38%)
Width 9: 3806 (7.06%)
Width 8: 4227 (7.84%)
Width 7: 4260 (7.90%)
Width 6: 4158 (7.71%)
Width 5: 3661 (6.79%)
Width 4: 3107 (5.76%)
Width 3: 2088 (3.87%)
Width 2: 1181 (2.19%)
Width 1: 602 (1.12%)

-> 56 width levels (1-56+)
"""

# Get length of node, relation, complete ngrams with >= 2% document frequency
df_node1 = pd.DataFrame(node_doc_freq.most_common(), columns=["ngram", "doc_freq"])
df_node1["doc_freq_ratio"] = df_node1["doc_freq"] / total_example_count
df_node1 = df_node1[df_node1["doc_freq_ratio"] >= 0.02]
print(df_node1["ngram"].str.len().value_counts().sort_index())
# df_node1.to_csv("dep_tree_node_ngrams_features.csv", index=False)

df_relation1 = pd.DataFrame(
    relation_doc_freq.most_common(), columns=["ngram", "doc_freq"]
)
df_relation1["doc_freq_ratio"] = df_relation1["doc_freq"] / total_example_count
df_relation1 = df_relation1[df_relation1["doc_freq_ratio"] >= 0.02]
print(df_relation1["ngram"].str.len().value_counts().sort_index())
# df_relation1.to_csv("dep_tree_relation_ngrams_features.csv", index=False)

df_complete1 = pd.DataFrame(
    complete_doc_freq.most_common(), columns=["ngram", "doc_freq"]
)
df_complete1["doc_freq_ratio"] = df_complete1["doc_freq"] / total_example_count
df_complete1 = df_complete1[df_complete1["doc_freq_ratio"] >= 0.02]
print(df_complete1["ngram"].str.len().value_counts().sort_index())
# df_complete1.to_csv("dep_tree_complete_ngrams_features.csv", index=False)

df_noun_phrase_lengths = pd.DataFrame(
    np_length_ctr.most_common(), columns=["length", "count"]
)
print(df_noun_phrase_lengths.sort_values("length"))

print(
    f"Concreteness matches: {concr_matches}, effective word count: {concr_effective_word_count}, ratio: {concr_matches / concr_effective_word_count:.2%}"
)

In [None]:
# Load saved csv's as a list of ngrams and print the counts
char_features = pd.read_csv(DATA_DIR / "features" / "char_ngrams_features.csv")[
    "ngram"
].tolist()
print(f"Loaded {len(char_features)} char ngrams")
pos_features = pd.read_csv(DATA_DIR / "features" / "pos_ngrams_features.csv")[
    "ngram"
].tolist()
print(f"Loaded {len(pos_features)} pos ngrams")
node_features = pd.read_csv(
    DATA_DIR / "features" / "dep_tree_node_ngrams_features.csv"
)["ngram"].tolist()
print(f"Loaded {len(node_features)} node ngrams")
relation_features = pd.read_csv(
    DATA_DIR / "features" / "dep_tree_relation_ngrams_features.csv"
)["ngram"].tolist()
print(f"Loaded {len(relation_features)} relation ngrams")
complete_features = pd.read_csv(
    DATA_DIR / "features" / "dep_tree_complete_ngrams_features.csv"
)["ngram"].tolist()
print(f"Loaded {len(complete_features)} complete ngrams")

# Big dataset heuristic

Get training sets

In [None]:
import gradio as gr
import random
import json

dataset_to_scores = [
    [5],
    [4],
    [4, 3, 2],
    [3, 2],
    [4, 3, 2, 1],
    [3, 2, 1, 0],
    [3, 2, 1, 0],
    [2, 1, 0],
    [2, 1, 0],
    [2, 1, 0],
    [2, 1, 0],
    [2, 1, 0],
    [1, 0],
    [1, 0],
    [],
]


def create_labeling_interface(
    output_filename: str,
    samples_per_dataset: int = 50,
    seed: int = 42,
    exclude_samples: dict[int, set[int]] = None,
):
    """Create a Gradio interface for labeling dataset examples.

    Args:
        output_filename: Name of the file to save progress and results (without extension)
        samples_per_dataset: Number of samples to label per dataset
        seed: Random seed for reproducibility
    """
    random.seed(seed)

    # Filter datasets that have more than one possible score
    datasets_to_label = [
        (i, ds, scores)
        for i, (ds, scores) in enumerate(zip(datasets, dataset_to_scores))
        if len(scores) > 1
    ]

    # Get random samples
    samples_to_label = []
    for dataset_idx, dataset, scores in datasets_to_label:
        if dataset.processed and len(dataset.processed) >= samples_per_dataset:
            population = set(range(len(dataset.processed))) - (
                exclude_samples.get(dataset_idx, set()) if exclude_samples else set()
            )
            sampled_indices = random.sample(
                sorted(population), min(samples_per_dataset, len(population))
            )
            for idx in sampled_indices:
                samples_to_label.append(
                    {
                        "dataset_idx": dataset_idx,
                        "example_idx": idx,
                        "text": dataset.processed[idx].text,
                        "possible_scores": scores,
                        "score": None,
                    }
                )

    # Load progress if exists
    progress_file = DATA_DIR / "datasets" / "large" / f"{output_filename}_progress.json"
    if progress_file.exists():
        with open(progress_file, "r") as f:
            progress_data = json.load(f)
            samples_to_label = progress_data["samples"]
            # Update possible scores in case they changed
            for sample in samples_to_label:
                dataset_idx = sample["dataset_idx"]
                sample["possible_scores"] = dataset_to_scores[dataset_idx]
            current_idx = progress_data.get("current_idx", 0)
    else:
        current_idx = 0

    # Count labeled examples
    labeled_count = sum(1 for s in samples_to_label if s["score"] is not None)

    def save_progress():
        with open(progress_file, "w") as f:
            json.dump(
                {"samples": samples_to_label, "current_idx": current_idx}, f, indent=2
            )

    def get_next_unlabeled():
        nonlocal current_idx
        for i in range(current_idx, len(samples_to_label)):
            if samples_to_label[i]["score"] is None:
                current_idx = i
                return i
        return None

    def label_example(score):
        nonlocal current_idx, labeled_count
        if current_idx < len(samples_to_label):
            samples_to_label[current_idx]["score"] = score
            labeled_count += 1
            save_progress()

            # Move to next unlabeled
            next_idx = get_next_unlabeled()
            if next_idx is not None:
                current_idx = next_idx
                sample = samples_to_label[current_idx]
                progress_text = f"Example {current_idx + 1} / {len(samples_to_label)} (Labeled: {labeled_count})"
                return (
                    sample["text"],
                    gr.update(choices=sample["possible_scores"]),
                    progress_text,
                )
            else:
                # All done - save to CSV
                df_results = pd.DataFrame(
                    [
                        {
                            "dataset_idx": s["dataset_idx"],
                            "example_idx": s["example_idx"],
                            "text": s["text"],
                            "possible_scores": s["possible_scores"],
                            "score": s["score"],
                        }
                        for s in samples_to_label
                        if s["score"] is not None
                    ]
                )
                df_results.to_csv(
                    DATA_DIR / "datasets" / "large" / f"{output_filename}.csv",
                    index=False,
                )
                return (
                    "All examples labeled! Results saved.",
                    gr.update(choices=[]),
                    f"Complete: {labeled_count} / {len(samples_to_label)}",
                )

        return "No more examples", gr.update(choices=[]), "Complete"

    # Initialize interface
    with gr.Blocks() as demo:
        gr.Markdown(f"# Dataset Labeling Interface - {output_filename}")

        progress = gr.Textbox(
            label="Progress",
            value=f"Example {current_idx + 1} / {len(samples_to_label)} (Labeled: {labeled_count})",
            interactive=False,
        )

        initial_sample = (
            samples_to_label[current_idx]
            if current_idx < len(samples_to_label)
            else None
        )

        text_display = gr.Textbox(
            label="Text to Label",
            value=initial_sample["text"] if initial_sample else "",
            lines=10,
            interactive=False,
        )

        score_radio = gr.Radio(
            choices=initial_sample["possible_scores"] if initial_sample else [],
            label="Select Score",
        )

        submit_btn = gr.Button("Submit and Next")

        submit_btn.click(
            fn=label_example,
            inputs=[score_radio],
            outputs=[text_display, score_radio, progress],
        )

    return demo

In [None]:
demo = create_labeling_interface("heuristic_train_set")
demo.launch()

Get validation set

In [None]:
train_set = pd.read_csv(DATA_DIR / "datasets" / "large" / "heuristic_train_set.csv")
exclude_samples = {}
for _, row in train_set.iterrows():
    ds_idx = int(row["dataset_idx"])
    ex_idx = int(row["example_idx"])
    if ds_idx not in exclude_samples:
        exclude_samples[ds_idx] = set()
    exclude_samples[ds_idx].add(ex_idx)

demo = create_labeling_interface(
    "heuristic_validation_set", samples_per_dataset=15, exclude_samples=exclude_samples
)
demo.launch()

### Training

In [None]:
# Tried:
# from mord import LogisticAT, LogisticIT, OrdinalRidge
# from sklearn.svm import LinearSVR
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.naive_bayes import GaussianNB
import pickle

train_set = pd.read_csv(
    DATA_DIR / "datasets" / "large" / "heuristic_train_set_combined.csv"
)
cv_scores = []

# Train ordinal regression model for each dataset
for dataset_idx, group in train_set.groupby("dataset_idx"):
    X = np.array([feature_extractor.extract(text) for text in group["text"]])
    y = group["score"]

    model = GaussianNB()
    model.fit(X, y, sample_weight=compute_sample_weight("balanced", y))
    with open(
        DATA_DIR / "models" / f"ordinal_model_dataset_{dataset_idx}.pkl", "wb"
    ) as f:
        pickle.dump(model, f)

    # scores = cross_val_score(GaussianNB(), X, y, cv=min(5, len(X)), scoring='accuracy', params={'sample_weight': compute_sample_weight('balanced', y)})
    # mean_score = scores.mean()
    # std_score = scores.std()
    # cv_scores.append(mean_score)
    # print(f"Dataset {dataset_idx}: CV Accuracy = {mean_score:.2%} (+/- {std_score:.2%})")

print(f"\nOverall CV Accuracy: {np.mean(cv_scores):.2%}")


Overall CV Accuracy: nan%


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [None]:
large_train_set = root.datasets["large.v1"]

In [None]:
from sklearn.naive_bayes import GaussianNB
import warnings
import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_sample_weight

# Suppress warnings from sklearn about features with zero variance
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")

train_set = pd.read_csv(DATA_DIR / "datasets" / "large" / "heuristic_train_set.csv")
validation_set = pd.read_csv(
    DATA_DIR / "datasets" / "large" / "heuristic_validation_set.csv"
)
X_val = np.array([feature_extractor.extract(text) for text in validation_set["text"]])
y_val = validation_set["score"].values
cv_scores = []

# Unlabeled set (U) - all data not in the training or validation sets
examples_indices = set(train_set["example_idx"].tolist())
X_unlabeled = np.array(
    [
        ex
        for ei, ex in enumerate(datasets[dataset_idx].features)
        if ei not in examples_indices
    ]
)
X_unlabeled = np.concatenate([X_unlabeled, large_train_set[dataset_idx]])
print(f"Unlabeled samples: {len(X_unlabeled)}")

for dataset_idx, group in train_set.groupby("dataset_idx"):
    # Labeled set (L)
    X_labeled = np.array([feature_extractor.extract(text) for text in group["text"]])
    y_labeled = group["score"].values
    print(f"Labeled samples: {len(X_labeled)}")

    # --- 2. EM Algorithm Implementation ---
    MAX_ITER = 10
    model = None
    all_classes = np.unique(y_labeled)
    prev_predictions = None
    base_weights = compute_sample_weight("balanced", y_labeled)

    for i in range(MAX_ITER):
        print(f"--- EM Iteration {i + 1}/{MAX_ITER} ---")

        # --- E-Step: Estimate labels for unlabeled data ---
        if i == 0:
            # Initial model trained only on labeled data
            model = GaussianNB()
            model.fit(X_labeled, y_labeled, sample_weight=base_weights)

        # Predict probabilities for the unlabeled set
        y_unlabeled_probs = model.predict_proba(X_unlabeled)
        y_unlabeled_pred = model.classes_[np.argmax(y_unlabeled_probs, axis=1)]

        # Monitor confidence/certainty of predictions
        avg_confidence = np.mean(np.max(y_unlabeled_probs, axis=1))
        print(f"Average confidence on unlabeled data: {avg_confidence:.4f}")

        # --- M-Step: Retrain model with labeled and "soft"-labeled data ---
        # Combine original labeled data with newly predicted labels for unlabeled data
        X_combined = np.vstack((X_labeled, X_unlabeled))
        y_combined = np.hstack((y_labeled, y_unlabeled_pred))

        # Create sample weights. Labeled data has weight 1.
        # Unlabeled data has weight equal to the probability of its predicted class.
        unlabeled_weights = np.max(y_unlabeled_probs, axis=1)
        sample_weights_combined = np.hstack((base_weights, unlabeled_weights))

        # Retrain the model on the combined dataset with sample weights
        model = GaussianNB()
        model.fit(X_combined, y_combined, sample_weight=sample_weights_combined)

        # Evaluate on validation set
        y_val_pred = model.predict(X_val)
        val_accuracy = np.mean(y_val_pred == y_val)
        print(f"Validation Accuracy after iteration {i + 1}: {val_accuracy:.2%}")

        # Convergence based on prediction stability
        if i > 0 and np.array_equal(y_unlabeled_pred, prev_predictions):
            print("Convergence: predictions stabilized.")
            break
        prev_predictions = y_unlabeled_pred.copy()

    cv_scores.append(val_accuracy)
    print(
        f"\nEM training complete for dataset {dataset_idx}. Final CV Accuracy: {val_accuracy:.2%}\n"
    )

print(f"\nOverall CV Accuracy across datasets: {np.mean(cv_scores):.2%}")

Unlabeled samples: 16109
Labeled samples: 50
--- EM Iteration 1/10 ---
Average confidence on unlabeled data: 1.0000
Validation Accuracy after iteration 1: 8.89%
--- EM Iteration 2/10 ---
Average confidence on unlabeled data: 0.9998
Validation Accuracy after iteration 2: 8.89%
--- EM Iteration 3/10 ---
Average confidence on unlabeled data: 0.9999
Validation Accuracy after iteration 3: 8.89%
--- EM Iteration 4/10 ---
Average confidence on unlabeled data: 1.0000
Validation Accuracy after iteration 4: 8.89%
--- EM Iteration 5/10 ---
Average confidence on unlabeled data: 1.0000
Validation Accuracy after iteration 5: 8.89%
--- EM Iteration 6/10 ---
Average confidence on unlabeled data: 1.0000
Validation Accuracy after iteration 6: 9.44%
--- EM Iteration 7/10 ---
Average confidence on unlabeled data: 1.0000
Validation Accuracy after iteration 7: 9.44%
--- EM Iteration 8/10 ---
Average confidence on unlabeled data: 1.0000
Validation Accuracy after iteration 8: 9.44%
--- EM Iteration 9/10 ---
A