In [1]:
import numpy as np

## 1. Page break indices

In [24]:
#numpy.ndarray.cumsum â€” NumPy v2.2 Manual
#

page_texts = ["Hello World.", "This is so me.", "Thank you.", "For this.", "Honestly."]
full_text = "".join(page_texts)

# line that converts page text
page_text_lens = [len(page_t) for page_t in page_texts]
page_text_lens = np.concatenate([np.array([0]), np.cumsum(page_text_lens[:-1])]).tolist()

page_text_lens

[0, 12, 26, 36, 45]

In [29]:
for i in range(len(page_text_lens)-1):
    start_idx, end_idx = page_text_lens[i], page_text_lens[i+1]
    print(full_text[start_idx : end_idx])
    print('- - - - - - -')

Hello World.
- - - - - - -
This is so me.
- - - - - - -
Thank you.
- - - - - - -
For this.
- - - - - - -


## 2. Flatten list of logits, predict, re-strucutre to list of list
Nested qualities

In [9]:
from itertools import chain, islice
from typing import List, Sequence, Any

def predict_grouped(classifier, document_page_texts: Sequence[Sequence[str]]) -> List[List[Any]]:
    # record original lengths (empties ok)
    lengths = [len(pages) for pages in document_page_texts]

    # flatten once
    flat_inputs = list(chain.from_iterable(document_page_texts))  # preserves order

    # single inference call (your model can batch/chunk internally)
    flat_preds = classifier.predict(flat_inputs)  # must return len(flat_inputs) items

    if len(flat_preds) != len(flat_inputs):
        raise ValueError(f"Prediction length mismatch: got {len(flat_preds)} for {len(flat_inputs)} inputs")

    # regroup by original lengths (iter + islice avoids copying large slices repeatedly)
    it = iter(flat_preds)
    regrouped = [list(islice(it, n)) for n in lengths]

    return regrouped


In [None]:
# record original lengths (empties ok)
lengths = [len(page_texts) for page_texts in document_page_texts]
# flatten input to SciBERT/Specter
flat_document_texts= list(chain.from_iterable(document_page_texts))
flat_qualities = self.classifier.predict(flat_document_texts)
# regroup into qualities:list[list[int]]
it = iter(flat_qualities)
qualities = [list(islice(it, n)) for n in lengths]


In [7]:
def classifier_predict(s:list[str]) -> int:
    """Faux-prediction mode"""
    int_list = []
    for s_0 in s:
        try:
            int_list.append(int(s_0[0]))
        except:
            int_list.append(-1)
    return int_list

#document_page_texts
document_page_texts = [["15abc", "2ruhr"], ["3gfgf", "4gfs", "5dgfd"], ["6xdg", "7vd", "8kkk", "9jt"], ["0sfg"]]

# record original lengths (empties ok)
lengths = [len(page_list) for page_list in document_page_texts]
# flatten once
flat_inputs = list(chain.from_iterable(document_page_texts))
flat_preds = classifier_predict(flat_inputs)
# regroup
it = iter(flat_preds)
regrouped = [list(islice(it, n)) for n in lengths]

In [8]:
regrouped

[[1, 2], [3, 4, 5], [6, 7, 8, 9], [0]]