In [None]:
from typing import Any, Iterable
import numpy as np
from numpy.typing import NDArray
import pandas as pd
from datasets import load_dataset
import sys
from pathlib import Path
from tqdm import tqdm
from features import FeatureExtractorPipeline, ExtCtx

sys.path.append(str(Path.cwd().parent))
from book_segmenting import TextSegmenter
from utils import DATA_DIR

feature_extractor = FeatureExtractorPipeline()

SEGMENT_CHARS_MIN = 150
SEGMENT_CHARS_MAX = 500
segmenter = TextSegmenter(chunk_size=(SEGMENT_CHARS_MIN, SEGMENT_CHARS_MAX))


class Dataset:
    MIN_TEXT_LENGTH = 60
    MAX_TEXT_LENGTH = 500

    def __init__(
        self,
        name: str,
        src: Iterable[Any],
        take: int,
        skip: int = 0,
        text_getter=None,
        deduplicate=False,
        segment=False,
        check_length=True,
    ):
        self.name = name
        self.src = iter(src)
        self.take = take
        self.skip = skip
        self.texts: list[str] | None = None
        self.contexts: list[ExtCtx] | None = None
        self.features: list[NDArray[np.float32]] | None = None
        self.text_getter = text_getter
        self.deduplicate = deduplicate
        self.segment = segment
        self.check_length = check_length

    def process(
        self, deduplicate: bool | None = None, segment: bool | None = None
    ) -> list[ExtCtx]:
        if deduplicate is None:
            deduplicate = self.deduplicate
        if segment is None:
            segment = self.segment

        self.texts = []
        self.contexts = []
        self.features = []
        if deduplicate:
            seen = set()
        taken = 0
        to_skip = self.skip

        with tqdm(total=self.take, desc="Processing texts", unit="text") as pbar:
            while taken < self.take:
                try:
                    text = next(self.src)
                except StopIteration:
                    break
                if to_skip > 0:
                    to_skip -= 1
                    continue

                if self.text_getter is not None:
                    text = self.text_getter(text)
                if not text or (
                    self.check_length and len(text.strip()) < Dataset.MIN_TEXT_LENGTH
                ):
                    continue
                if deduplicate:
                    if text in seen:
                        continue
                    seen.add(text)

                text = FeatureExtractorPipeline.preprocess(text)
                if self.check_length and len(text.strip()) < Dataset.MIN_TEXT_LENGTH:
                    continue

                if segment:
                    segments = [
                        seg
                        for seg in segmenter.segment_text(text)
                        if seg
                        and (
                            not self.check_length
                            or (seg_len := len(seg.strip())) >= Dataset.MIN_TEXT_LENGTH
                            and seg_len <= Dataset.MAX_TEXT_LENGTH
                        )
                    ]
                    if len(segments) == 0:
                        continue
                    example = segments[len(segments) // 2]
                    ctx = feature_extractor.get_ctx(example)
                    self.texts.append(example)
                    self.contexts.append(ctx)
                    self.features.append(
                        feature_extractor.extract(example, preprocess=False, ctx=ctx)
                    )
                else:
                    if self.check_length and len(text) > Dataset.MAX_TEXT_LENGTH:
                        continue
                    ctx = feature_extractor.get_ctx(text)
                    self.texts.append(text)
                    self.contexts.append(ctx)
                    self.features.append(
                        feature_extractor.extract(text, preprocess=False, ctx=ctx)
                    )
                taken += 1
                pbar.update(1)

        return self.contexts

    def __iter__(self):
        if self.contexts is None:
            raise ValueError("Dataset not processed yet. Call process() first.")
        return iter(self.contexts)

    def __len__(self):
        if self.contexts is None:
            raise ValueError("Dataset not processed yet. Call process() first.")
        return len(self.contexts)

    def save_as_parquet(self, labels: list[Any] | None = None):
        if self.texts is None or self.features is None:
            raise ValueError("Dataset not processed yet. Call process() first.")
        df = pd.DataFrame(
            {"text": self.texts, "features": [feat.tolist() for feat in self.features]}
        )
        if labels is not None:
            if len(labels) != len(self.texts):
                raise ValueError("Labels length does not match texts length.")
            df["label"] = labels
        df.to_parquet(
            DATA_DIR / "datasets" / "large" / f"{self.name}.parquet", index=False
        )

  from .autonotebook import tqdm as notebook_tqdm
2025-10-12 20:38:43.112136: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760294323.192802    4402 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760294323.216713    4402 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-12 20:38:43.391637: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


using device cpu


In [None]:
ds_high_flickr = load_dataset(
    "CaptionEmporium/flickr-megalith-10m-internvl2-multi-caption",
    split="train",
    streaming=True,
)
ds_flickr30k = load_dataset("embedding-data/flickr30k_captions_quintets", split="train")
ds_coco = load_dataset("sentence-transformers/coco-captions", split="train")
ds_sbu = load_dataset("vicenteor/sbu_captions", split="train", trust_remote_code=True)
with open(DATA_DIR / "datasets" / "large" / "movie_summaries.txt") as f:
    ds_movie_summaries = [line.strip() for line in f.readlines()]
ds_book_summaries = load_dataset("textminr/cmu-book-summaries", split="train")
with open(DATA_DIR / "datasets" / "large" / "book_dialogs.txt") as f:
    ds_book_dialogs = [line.strip() for line in f.read().split("\n\n")]
ds_wiki = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train")
ds_news = load_dataset("EdinburghNLP/xsum", split="validation")
ds_hotels = load_dataset("argilla/tripadvisor-hotel-reviews", split="train")
ds_yelp = load_dataset("Yelp/yelp_review_full", split="test")
ds_arxiv = load_dataset(
    "armanc/scientific_papers",
    "arxiv",
    split="validation",
    trust_remote_code=True,
    streaming=True,
)
AMAZON_CATEGORIES = [
    "Cell_Phones_and_Accessories",
    "Beauty_and_Personal_Care",
    "Electronics",
    "Grocery_and_Gourmet_Food",
    "CDs_and_Vinyl",
    "Musical_Instruments",
    "Magazine_Subscriptions",
    "Industrial_and_Scientific",
    "Software",
]
ds_amazon_reviews = []
N_TOTAL = 15000
for category in AMAZON_CATEGORIES:
    ds = iter(
        load_dataset(
            "McAuley-Lab/Amazon-Reviews-2023",
            f"raw_review_{category}",
            split="full",
            trust_remote_code=True,
            streaming=True,
        )
    )
    for i in range(N_TOTAL // len(AMAZON_CATEGORIES)):
        ds_amazon_reviews.append(next(ds))
while len(ds_amazon_reviews) < N_TOTAL:
    ds_amazon_reviews.append(next(ds))

In [None]:
datasets = [
    Dataset(
        "artif_5",
        ds_high_flickr,
        skip=2 * 1500,
        take=15000,
        text_getter=lambda x: x["caption_internlm2"],
    ),
    Dataset(
        "artif_4",
        ds_high_flickr,
        skip=2 * 1500,
        take=15000,
        text_getter=lambda x: x["caption_internlm2_short"],
    ),
    Dataset(
        "flickr30k",
        ds_flickr30k,
        skip=2 * 1500,
        take=15000,
        text_getter=lambda x: x["set"][0],
    ),
    Dataset(
        "coco", ds_coco, skip=2 * 500, take=5000, text_getter=lambda x: x["caption1"]
    ),
    Dataset("sbu", ds_sbu, skip=2 * 500, take=5000, text_getter=lambda x: x["caption"]),
    Dataset(
        "movie_summaries", ds_movie_summaries, skip=2 * 500, take=5000, segment=True
    ),
    Dataset(
        "book_summaries",
        ds_book_summaries,
        skip=2 * 500,
        take=5000,
        text_getter=lambda x: x["summary"],
        segment=True,
    ),
    Dataset("book_dialogs", ds_book_dialogs, skip=2 * 500, take=5000),
    Dataset(
        "wiki",
        ds_wiki,
        skip=2 * 1000,
        take=10000,
        text_getter=lambda x: x["text"].replace(" @-@ ", "-").replace(" @,@ ", ","),
        segment=True,
    ),
    Dataset(
        "news",
        ds_news,
        skip=2 * 500,
        take=5000,
        text_getter=lambda x: x["document"],
        segment=True,
    ),
    Dataset(
        "hotels", ds_hotels, skip=2 * 200, take=2000, text_getter=lambda x: x["text"]
    ),
    Dataset("yelp", ds_yelp, skip=2 * 300, take=3000, text_getter=lambda x: x["text"]),
    Dataset(
        "arxiv",
        ds_arxiv,
        skip=2 * 500,
        take=5000,
        text_getter=lambda x: x["abstract"],
        segment=True,
    ),
    Dataset(
        "amazon_reviews",
        ds_amazon_reviews,
        skip=2 * 500,
        take=5000,
        text_getter=lambda x: x["text"],
    ),
]

for i, dataset in enumerate(datasets):
    print(f"--- Done: {i}/{len(datasets)} ---")
    dataset.process(deduplicate=True)
print("DONE")

--- Done: 0/14 ---


  return length_counts / total_ngrams
Processing texts: 100%|██████████| 15000/15000 [25:33<00:00,  9.78text/s]


--- Done: 1/14 ---


Processing texts: 100%|██████████| 15000/15000 [14:14<00:00, 17.56text/s]  


--- Done: 2/14 ---


Processing texts:  90%|█████████ | 13504/15000 [07:24<00:49, 30.41text/s] 


--- Done: 3/14 ---


Processing texts: 100%|██████████| 5000/5000 [03:46<00:00, 22.06text/s]


--- Done: 4/14 ---


Processing texts: 100%|██████████| 5000/5000 [03:36<00:00, 23.05text/s]


--- Done: 5/14 ---


Processing texts: 100%|██████████| 5000/5000 [06:00<00:00, 13.86text/s] 


--- Done: 6/14 ---


Processing texts: 100%|██████████| 5000/5000 [06:08<00:00, 13.57text/s]


--- Done: 7/14 ---


Processing texts: 100%|██████████| 5000/5000 [08:31<00:00,  9.78text/s]  


--- Done: 8/14 ---


Processing texts: 100%|██████████| 10000/10000 [15:25<00:00, 10.81text/s] 


--- Done: 9/14 ---


Processing texts: 100%|██████████| 5000/5000 [05:51<00:00, 14.23text/s]


--- Done: 10/14 ---


Processing texts: 100%|██████████| 2000/2000 [02:41<00:00, 12.39text/s]


--- Done: 11/14 ---


Processing texts: 100%|██████████| 3000/3000 [04:26<00:00, 11.25text/s]


--- Done: 12/14 ---


Processing texts: 100%|██████████| 5000/5000 [07:49<00:00, 10.65text/s]  


--- Done: 13/14 ---


Processing texts: 100%|██████████| 5000/5000 [06:10<00:00, 13.50text/s]

DONE





In [None]:
# Fill the missing examples in ds_flickr30k with second caption from the set
to_fill = 15000 - len(datasets[2].features)
if to_fill > 0:
    ds_flickr30k_2 = Dataset(
        "flickr30k",
        ds_flickr30k,
        skip=2 * 1500,
        take=to_fill,
        text_getter=lambda x: x["set"][1],
    )
    ds_flickr30k_2.process(deduplicate=True)
    datasets[2].texts.extend(ds_flickr30k_2.texts)
    datasets[2].contexts.extend(ds_flickr30k_2.contexts)
    datasets[2].features.extend(ds_flickr30k_2.features)

print(sum(len(ds.features) for ds in datasets))

100000


In [6]:
datasets[0].save_as_parquet(labels=np.ones(len(datasets[0].features), dtype=int) * 5)
datasets[1].save_as_parquet(labels=np.ones(len(datasets[1].features), dtype=int) * 4)

In [None]:
import pickle

for path in (DATA_DIR / "models").glob("ordinal_model_dataset_*.pkl"):
    with open(path, "rb") as f:
        model = pickle.load(f)
    # Use the model for classification
    dataset_idx = int(path.stem.split("_")[-1])
    X = np.array(datasets[dataset_idx].features)
    preds = model.predict(X)
    # Save
    print("saving", dataset_idx, datasets[dataset_idx].name, len(preds))
    datasets[dataset_idx].save_as_parquet(labels=preds)

saving 2 flickr30k 15000
saving 13 amazon_reviews 5000
saving 10 hotels 2000
saving 8 wiki 10000
saving 11 yelp 3000
saving 4 sbu 5000
saving 12 arxiv 5000


In [None]:
from utils import DATA_DIR

# Combine the individual parquet files into a single one
parquet_files = list((DATA_DIR / "datasets" / "large").glob("*.parquet"))
dfs = []
for p in parquet_files:
    name = p.stem
    df = pd.read_parquet(p)
    df["dataset"] = name
    dfs.append(df)

df_combined = pd.concat(dfs, ignore_index=True)
df_combined.to_parquet(
    DATA_DIR / "datasets" / "large" / "combined.parquet", index=False
)

print("Label distribution:")
print(df_combined["label"].value_counts().sort_index())

Label distribution:
label
0    26442
1    15519
2     4509
3    20808
4    17722
5    15000
Name: count, dtype: int64


In [4]:
# DANGER
# Delete the individual parquet files except combined.parquet
for p in parquet_files:
    if p.stem != "combined":
        p.unlink()