# RefinedWebの処理MDRの実装をしたfineweb
* URL filter：

* repetition：
* 

# base

In [None]:
import contextlib
from abc import ABC, abstractmethod
from typing import Tuple

from datatrove.data import Document, DocumentsPipeline
from datatrove.pipeline.base import PipelineStep
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.typeshelper import StatHints


def get_filter_result(res):
    result, reason = res, None
    if isinstance(result, tuple):
        result, reason = res
    return result, reason


class BaseFilter(PipelineStep, ABC):
    """Base module for Filters. Filters remove documents.

    Args:
        exclusion_writer: optionally pass in a writer that will save the dropped documents
    """

    type = "🔻 - FILTER"

    def __init__(self, exclusion_writer: DiskWriter = None):
        super().__init__()
        self.exclusion_writer = exclusion_writer

    @abstractmethod
    def filter(self, doc: Document) -> bool | Tuple[bool, str]:
        """Filter modules main method.
        Returns true if a sample should be KEPT, false if it should be REMOVED.

        Args:
            doc: sample to filter

        Returns:
            bool - whether the doc should be kept
            or (False, str), to drop with a specific reason
        """
        raise NotImplementedError

    def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
        with self.exclusion_writer if self.exclusion_writer else contextlib.nullcontext() as writer:
            for doc in data:
                self.stat_update(StatHints.total)
                with self.track_time():
                    filter_result, reason = get_filter_result(self.filter(doc))
                    if filter_result:
                        self.stat_update(StatHints.forwarded)
                        self.update_doc_stats(doc)
                    else:
                        self.stat_update(StatHints.dropped)
                        if reason:
                            self.stat_update(f"dropped_{reason}")
                        if self.exclusion_writer:
                            if reason:
                                doc.metadata["filter_reason"] = reason
                            writer.write(doc, rank)
                        continue
                yield doc

# quality filter

In [None]:

from datatrove.pipeline.filters.base_filter import BaseFilter

class FWQualityFilter(BaseFilter):
    name = "fineweb quality filter"
    _requires_dependencies = ["nltk"]

    def __init__(
            self,
            exclusion_writer,
            line_punct_thr: float = 0.12,
            line_punct_exclude_zero=False,
            short_line_thr: float = 0.67,
            short_line_length: int = 30,
            char_duplicates_ratio: float = 0.01
        ):
        super().__init__(exclusion_writer)
        self.line_punct_thr = line_punct_thr
        self.line_punct_exclude_zero = line_punct_exclude_zero
        self.short_line_threshold = short_line_thr
        self.short_line_length = short_line_length
        self.char_duplicates_ratio = char_duplicates_ratio

    def filter(self, doc) -> bool | tuple[bool, str]:
        from datatrove.pipeline.filters.gopher_repetition_filter import find_duplicates

        def remove_empty_lines(lines: list[str]):
            return [l for l in lines if l.strip() != ""]

        stop_chars = (".", "'", '"', "!", "?")

        lines = doc.text.split("\n")
        ratio = sum(1 for line in lines if line.endswith(stop_chars)) / len(lines)
        if ratio <= self.line_punct_thr and not (ratio == 0 and self.line_punct_exclude_zero):
            return False, "line_punct_ratio"

        ratio = sum(1 for line in lines if len(line) <= self.short_line_length) / len(
            lines
        )
        if ratio >= self.short_line_threshold:
            return False, "short_line_ratio"

        ratio = find_duplicates(remove_empty_lines(lines))[1] / len(doc.text.replace("\n", ""))

        if ratio >= self.char_duplicates_ratio:
            return False, "char_dup_ratio"

        return True

In [None]:
import numpy as np

from datatrove.data import Document
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.text import PUNCTUATION_SET


STOP_WORDS = ["the", "be", "to", "of", "and", "that", "have", "with"]


class GopherQualityFilter(BaseFilter):
    name = "🥇 Gopher Quality"
    _requires_dependencies = ["nltk"]

    def __init__(
        self,
        min_doc_words: int | None = 50,
        max_doc_words: int | None = 100000,
        min_avg_word_length: int | None = 3,
        max_avg_word_length: int | None = 10,
        max_symbol_word_ratio: float | None = 0.1,
        max_bullet_lines_ratio: float | None = 0.9,
        max_ellipsis_lines_ratio: float | None = 0.3,
        max_non_alpha_words_ratio: float | None = 0.8,
        min_stop_words: int | None = 2,
        stop_words: list[str] | None = None,
        exclusion_writer: DiskWriter = None,
    ):
        """
        Filter to apply Gopher's quality heuristic rules.
        Reference: https://arxiv.org/pdf/2112.11446.pdf

        Args:
            min_doc_words:
            max_doc_words:
            min_avg_word_length:
            max_avg_word_length:
            max_symbol_word_ratio:
            max_bullet_lines_ratio:
            max_ellipsis_lines_ratio:
            max_non_alpha_words_ratio:
            min_stop_words:
            stop_words:
            exclusion_writer:
        """
        super().__init__(exclusion_writer)
        self.min_doc_words = min_doc_words
        self.max_doc_words = max_doc_words
        self.min_avg_word_length = min_avg_word_length
        self.max_avg_word_length = max_avg_word_length
        self.max_symbol_word_ratio = max_symbol_word_ratio
        self.max_bullet_lines_ratio = max_bullet_lines_ratio
        self.max_ellipsis_lines_ratio = max_ellipsis_lines_ratio
        self.max_non_alpha_words_ratio = max_non_alpha_words_ratio
        self.min_stop_words = min_stop_words
        self.stop_words = set(STOP_WORDS if stop_words is None else stop_words)

    def filter(self, doc: Document) -> bool | tuple[bool, str]:
        """

        Args:
            doc: Applies the heuristics rules to decide if a document should be REMOVED


        Returns: False if sample.text does not pass any of the the heuristic tests

        """
        from nltk.tokenize import word_tokenize

        text = doc.text
        words = word_tokenize(text)  # TODO we should use language id filter
        n_words = len(words)

        non_symbol_words = [w for w in words if any(ch not in PUNCTUATION_SET for ch in w)]
        n_non_symbol_words_words = len(non_symbol_words)

        # words < min_doc_words or words > max_doc_words
        if self.min_doc_words and n_non_symbol_words_words < self.min_doc_words:
            return False, "gopher_short_doc"
        if self.max_doc_words and n_non_symbol_words_words > self.max_doc_words:
            return False, "gopher_long_doc"

        # mean word length is outside the range of 3 to 10 characters
        avg_n_words = np.mean([len(w) for w in non_symbol_words])
        if self.min_avg_word_length and avg_n_words < self.min_avg_word_length:
            return False, "gopher_below_avg_threshold"
        if self.max_avg_word_length and avg_n_words > self.max_avg_word_length:
            return False, "gopher_above_avg_threshold"

        # symbol-to-word ratio greater than 0.1 for either the hash symbol or the ellipsis
        if self.max_symbol_word_ratio and text.count("#") / n_words > self.max_symbol_word_ratio:
            return False, "gopher_too_many_hashes"
        if self.max_symbol_word_ratio and (text.count("...") + text.count("…")) / n_words > self.max_symbol_word_ratio:
            return False, "gopher_too_many_ellipsis"

        # any document with more than 90 % of lines starting with a bullet point,
        # or more than 30 % ending with an ellipsis.
        lines = text.splitlines()
        if (
            self.max_bullet_lines_ratio
            and sum(s.lstrip().startswith("•") or s.lstrip().startswith("-") for s in lines) / len(lines)
            > self.max_bullet_lines_ratio
        ):
            return False, "gopher_too_many_bullets"
        if (
            self.max_ellipsis_lines_ratio
            and sum(s.rstrip().endswith("...") or s.rstrip().endswith("…") for s in lines) / len(lines)
            > self.max_ellipsis_lines_ratio
        ):
            return False, "gopher_too_many_end_ellipsis"

        # that 80 % of words in a document contain at least one alphabetic character
        if (
            self.max_non_alpha_words_ratio
            and sum([any((c.isalpha() for c in w)) for w in words]) / n_words < self.max_non_alpha_words_ratio
        ):
            return False, "gopher_below_alpha_threshold"

        # stop word filter
        if self.min_stop_words and sum(w in self.stop_words for w in words) < self.min_stop_words:
            return False, "gopher_enough_stop_words"

        return True

In [3]:
import numpy as np
#from datatrove.utils.text import PUNCTUATION_SET

PUNCTUATION = "!/—”:％１〈&(、━\\【#%「」，】；+^]~“《„';’{|∶´[=-`*．（–？！：$～«〉,><》)?）。…@_.\"}►»" + "".join(
    map(
        chr,
        (x for a, b in ((0, 9), (11, 13), (13, 32), (127, 160)) for x in range(a, b)),
    )
)
PUNCTUATION_SET = set(PUNCTUATION)



STOP_WORDS = ["the", "be", "to", "of", "and", "that", "have", "with"]


class GopherQualityFilter():

    def __init__(
        self,
        min_doc_words: int | None = 50,
        max_doc_words: int | None = 100000,
        min_avg_word_length: int | None = 3,
        max_avg_word_length: int | None = 10,
        max_symbol_word_ratio: float | None = 0.1,
        max_bullet_lines_ratio: float | None = 0.9,
        max_ellipsis_lines_ratio: float | None = 0.3,
        max_non_alpha_words_ratio: float | None = 0.8,
        min_stop_words: int | None = 2,
        stop_words: list[str] | None = None,
        #exclusion_writer: DiskWriter = None,
    ):
        """
        Filter to apply Gopher's quality heuristic rules.
        Reference: https://arxiv.org/pdf/2112.11446.pdf

        Args:
            min_doc_words:
            max_doc_words:
            min_avg_word_length:
            max_avg_word_length:
            max_symbol_word_ratio:
            max_bullet_lines_ratio:
            max_ellipsis_lines_ratio:
            max_non_alpha_words_ratio:
            min_stop_words:
            stop_words:
            exclusion_writer:
        """
        self.min_doc_words = min_doc_words
        self.max_doc_words = max_doc_words
        self.min_avg_word_length = min_avg_word_length
        self.max_avg_word_length = max_avg_word_length
        self.max_symbol_word_ratio = max_symbol_word_ratio
        self.max_bullet_lines_ratio = max_bullet_lines_ratio
        self.max_ellipsis_lines_ratio = max_ellipsis_lines_ratio
        self.max_non_alpha_words_ratio = max_non_alpha_words_ratio
        self.min_stop_words = min_stop_words
        self.stop_words = set(STOP_WORDS if stop_words is None else stop_words)

    def filter(self, text: str) -> bool | tuple[bool, str]:
        """

        Args:
            doc: Applies the heuristics rules to decide if a document should be REMOVED


        Returns: False if sample.text does not pass any of the the heuristic tests

        """
        from nltk.tokenize import word_tokenize

        words = word_tokenize(text)  # TODO we should use language id filter
        n_words = len(words)
        non_symbol_words = [w for w in words if any(ch not in PUNCTUATION_SET for ch in w)]
        n_non_symbol_words_words = len(non_symbol_words)

        # words < min_doc_words or words > max_doc_words
        if self.min_doc_words and n_non_symbol_words_words < self.min_doc_words:
            return "",False, "gopher_short_doc"
        if self.max_doc_words and n_non_symbol_words_words > self.max_doc_words:
            return "",False, "gopher_long_doc"

        # mean word length is outside the range of 3 to 10 characters
        avg_n_words = np.mean([len(w) for w in non_symbol_words])
        if self.min_avg_word_length and avg_n_words < self.min_avg_word_length:
            return "",False, "gopher_below_avg_threshold"
        if self.max_avg_word_length and avg_n_words > self.max_avg_word_length:
            return "",False, "gopher_above_avg_threshold"

        # symbol-to-word ratio greater than 0.1 for either the hash symbol or the ellipsis
        if self.max_symbol_word_ratio and text.count("#") / n_words > self.max_symbol_word_ratio:
            return "",False, "gopher_too_many_hashes"
        if self.max_symbol_word_ratio and (text.count("...") + text.count("…")) / n_words > self.max_symbol_word_ratio:
            return "",False, "gopher_too_many_ellipsis"

        # any document with more than 90 % of lines starting with a bullet point,
        # or more than 30 % ending with an ellipsis.
        lines = text.splitlines()
        if (
            self.max_bullet_lines_ratio
            and sum(s.lstrip().startswith("•") or s.lstrip().startswith("-") for s in lines) / len(lines)
            > self.max_bullet_lines_ratio
        ):
            return "",False, "gopher_too_many_bullets"
        if (
            self.max_ellipsis_lines_ratio
            and sum(s.rstrip().endswith("...") or s.rstrip().endswith("…") for s in lines) / len(lines)
            > self.max_ellipsis_lines_ratio
        ):
            return "",False, "gopher_too_many_end_ellipsis"

        # that 80 % of words in a document contain at least one alphabetic character
        if (
            self.max_non_alpha_words_ratio
            and sum([any((c.isalpha() for c in w)) for w in words]) / n_words < self.max_non_alpha_words_ratio
        ):
            return "" ,False, "gopher_below_alpha_threshold"

        # stop word filter
        if self.min_stop_words and sum(w in self.stop_words for w in words) < self.min_stop_words:
            return "",False, "gopher_enough_stop_words"

        return True

In [None]:
#c4
import heapq
import re

from datatrove.data import Document
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter


CITATION_REGEX = re.compile(r"\[\d*]|\[edit]|\[citation needed]")
END_PUNCTUATION = (".", "?", "!", '"', "'")
ELLIPSIS = "..."
POLICY_SUBSTRINGS = [
    "terms of use",
    "privacy policy",
    "cookie policy",
    "uses cookies",
    "use of cookies",
    "use cookies",
]


class C4QualityFilter(BaseFilter):
    """Applies heuristic rules from C4 https://jmlr.org/papers/volume21/20-074/20-074.pdf

    - We only retained lines that ended in a terminal punctuation mark (! . " ?)
    - We discarded any page with fewer than 5 sentences and only retained lines that contained at least 3 words
    - [NOT IMPLEMENTED] We removed any page that contained any word on the “List of Dirty, Naughty, Obscene or Otherwise Bad Words”
    - We removed any line with the word Javascript.
    - We removed any page where the phrase “lorem ipsum” appeared
    - We removed any pages that contained a curly bracket
    Additional filters not mentioned on the list from the paper but on the code:
    - Remove lines with one word over 1000 chars
    - Remove lines with cookies and terms of use keywords

    Reference implementation: https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/text/c4_utils.py#L197
    Args:
        exclusion_writer: optionally pass in a writer that will save the dropped documents
        tokenizer_language: load a diff language specific punkt tokenizer from nltk
        split_paragraph: by default (as in the paper) split on "\n".
            Set to "False" to apply the filters to each sentence instead of to each line
        remove_citations: remove wikipedia style citations from the text
        filter_no_terminal_punct: remove lines without terminal punctuation marks
        min_num_sentences: remove documents that do not have at least this number of sentences (after line filtering).
            set to -1 to disable
        min_words_per_line: drop lines without this min number of words
        max_word_length: drop lines where at least one word has more than this number of characters
        filter_lorem_ipsum: drop documents that contain "lorem ipsum"
        filter_javascript: drop lines mentioning "javascript"
        filter_curly_bracket: drop documents containing {
        filter_policy: drop lines containing any of the phrases in POLICY_SUBSTRINGS
    """

    name = "⛰ C4 Quality"
    _requires_dependencies = ["nltk"]

    def __init__(
        self,
        exclusion_writer: DiskWriter = None,
        tokenizer_language: str = "english",
        split_paragraph: bool = True,  # default as used on c4. Set to "False" to split with sent_tokenize
        remove_citations: bool = True,
        filter_no_terminal_punct: bool = True,
        min_num_sentences: int = 5,  # set to -1 to disable
        min_words_per_line: int = 3,  # set to -1 to disable
        max_word_length: int = 1000,  # set to -1 to disable
        filter_lorem_ipsum: bool = True,
        filter_javascript: bool = True,
        filter_curly_bracket: bool = True,
        filter_policy: bool = True,
    ):
        super().__init__(exclusion_writer)
        self.tokenizer_language = tokenizer_language
        self.split_paragraph = split_paragraph
        self.remove_citations = remove_citations
        self.filter_no_terminal_punct = filter_no_terminal_punct
        self.min_num_sentences = min_num_sentences
        self.min_words_per_line = min_words_per_line
        self.max_word_length = max_word_length
        self.filter_lorem_ipsum = filter_lorem_ipsum
        self.filter_javascript = filter_javascript
        self.filter_curly_bracket = filter_curly_bracket
        self.filter_policy = filter_policy

    def filter(self, doc: Document) -> bool | tuple[bool, str]:
        from nltk.tokenize import sent_tokenize

        lines = (
            doc.text.splitlines()
            if self.split_paragraph
            else sent_tokenize(doc.text, language=self.tokenizer_language)
        )

        num_sentences = 0
        kept_lines = []

        for line in lines:
            line = line.strip()
            words = line.split()
            self.stat_update("line-total")
            # check line has too long word
            if self.max_word_length != -1 and any(len(word) > self.max_word_length for word in words):
                self.stat_update("line-filter-too_long_word")
                continue
            # remove citation
            if self.remove_citations:
                line = CITATION_REGEX.sub("", line)
            # end punctuation
            if self.filter_no_terminal_punct and (not line.endswith(END_PUNCTUATION) or line.endswith(ELLIPSIS)):
                self.stat_update("line-filter-no_terminal_punc")
                continue
            # min words per line
            if len(words) < self.min_words_per_line:
                self.stat_update("line-filter-too_few_words")
                continue
            line_l = line.lower()
            # lorem ipsum
            if self.filter_lorem_ipsum and "lorem ipsum" in line_l:
                return False, "lorem_ipsum"  # drop entire doc
            # javascript
            if self.filter_javascript and "javascript" in line_l:
                self.stat_update("line-filter-javascript")
                continue
            # bracket
            if self.filter_curly_bracket and "{" in line:
                return False, "curly_bracket"  # drop entire doc
            # policy
            if self.filter_policy and any(p in line_l for p in POLICY_SUBSTRINGS):
                self.stat_update("line-filter-policy")
                continue
            num_sentences += len(sent_tokenize(line, language=self.tokenizer_language)) if self.split_paragraph else 1
            kept_lines.append(line)
            self.stat_update("line-kept")
        if num_sentences < self.min_num_sentences:
            return False, "too_few_sentences"

        doc.text = ("\n" if self.split_paragraph else " ").join(kept_lines).strip()
        return True


class C4ParagraphFilter(BaseFilter):
    """Applies paragraph filtering from mC4

    https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/text/c4_utils.py#L551
    """

    name = "⛰ C4 Paragraph"

    def __init__(self, exclusion_writer: DiskWriter = None):
        super().__init__(exclusion_writer)

        self.min_paragraphs = 3
        self.min_paragraph_len = 200
        self.line_delimiter = "\n"

    def paragraph_filter(self, page):
        """Returns False iff a page has too few or too short paragraphs."""
        lines = page.split(self.line_delimiter)
        # Filter out docs that don't have at least three "paragraphs"
        # (lines >= `min_paragraph_len` chars).
        if (
            len(lines) < self.min_paragraphs
            or min(heapq.nlargest(3, [len(line) for line in lines])) < self.min_paragraph_len
        ):
            return False
        return True

    def filter(self, doc: Document) -> bool | tuple[bool, str]:
        if not self.paragraph_filter(doc.text):
            return False, f"< {self.min_paragraphs} paragraphs"
        return True

# repetition 

In [None]:
# repetition 
import re
from collections import Counter

from datatrove.data import Document
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter


"""
Table A1 from https://arxiv.org/pdf/2112.11446.pd
    duplicate line fraction                 0.30
    duplicate paragraph fraction            0.30
    duplicate line character fraction       0.20
    duplicate paragraph character fraction  0.20

    top 2-gram character fraction           0.20
    top 3-gram character fraction           0.18
    top 4-gram character fraction           0.16

    duplicate 5-gram character fraction     0.15
    duplicate 6-gram character fraction     0.14
    duplicate 7-gram character fraction     0.13
    duplicate 8-gram character fraction     0.12
    duplicate 9-gram character fraction     0.11
    duplicate 10-gram character fraction    0.10
"""


def get_n_grams(words: list[str], n: int) -> list[str]:
    return [" ".join(words[i : i + n]) for i in range(len(words) - n + 1)]


def find_duplicates(x: list[str]) -> tuple[int, int]:
    unique_x = set()
    duplicate_chars = 0
    duplicate_elements = 0
    for element in x:
        if element in unique_x:
            duplicate_chars += len(element)
            duplicate_elements += 1

        else:
            unique_x.add(element)
    return duplicate_elements, duplicate_chars

def find_top_duplicate(x: list[str]) -> int:
    counter = Counter()
    for element in x:
        counter[element] += 1
    top_n_gram = counter.most_common(1)[0]
    return len(top_n_gram[0]) * top_n_gram[1]


def find_all_duplicate(words: list[str], n: int) -> int:
    n_words = len(words)
    unique = set()
    repeated_chars, idx = 0, 0
    while idx < n_words - n + 1:
        n_gram = "".join(words[idx : idx + n])
        if n_gram in unique:
            repeated_chars += len(n_gram)
            idx += n
        else:
            unique.add(n_gram)
            idx += 1
    assert repeated_chars <= len("".join(words))
    return repeated_chars


class GopherRepetitionFilter(BaseFilter):
    name = "👯 Gopher Repetition"
    _requires_dependencies = ["nltk"]

    def __init__(
        self,
        dup_line_frac: float | None = 0.3,
        dup_para_frac: float | None = 0.3,
        dup_line_char_frac: float | None = 0.2,
        dup_para_char_frac: float | None = 0.2,
        top_n_grams: tuple[tuple[int, float]] = ((2, 0.2), (3, 0.18), (4, 0.16)),
        dup_n_grams: tuple[tuple[int, float]] = ((5, 0.15), (6, 0.14), (7, 0.13), (8, 0.12), (9, 0.11), (10, 0.10)),
        exclusion_writer: DiskWriter = None,
    ):
        """

        Args:
            dup_line_frac:
            dup_para_frac:
            dup_line_char_frac:
            dup_para_char_frac:
            top_n_grams:
            dup_n_grams:
            exclusion_writer:
        """
        super().__init__(exclusion_writer)

        self.dup_line_frac = dup_line_frac
        self.dup_para_frac = dup_para_frac
        self.dup_line_char_frac = dup_line_char_frac
        self.dup_para_char_frac = dup_para_char_frac
        self.top_n_grams = top_n_grams
        self.dup_n_grams = dup_n_grams
        self.paragraph_exp = re.compile(r"\n{2,}")

    def filter(self, doc: Document) -> bool | tuple[bool, str]:
        from nltk.tokenize import word_tokenize

        text = doc.text

        paragraphs = self.paragraph_exp.split(text.strip())
        paragraphs_duplicates, char_duplicates = find_duplicates(paragraphs)
        if self.dup_para_frac and paragraphs_duplicates / len(paragraphs) > self.dup_para_frac:
            return False, "dup_para_frac"
        if self.dup_para_char_frac and char_duplicates / len(text) > self.dup_para_char_frac:
            return False, "dup_para_char_frac"

        lines = text.splitlines()
        line_duplicates, char_duplicates = find_duplicates(lines)
        if self.dup_line_frac and line_duplicates / len(lines) > self.dup_line_frac:
            return False, "dup_line_frac"
        if self.dup_line_char_frac and char_duplicates / len(text) > self.dup_line_char_frac:
            return False, "dup_line_char_frac"

        words = word_tokenize(text, language="english")  # TODO we should use language id filter

        for n, n_frac in self.top_n_grams:
            n_grams = get_n_grams(words, n)
            if not n_grams:
                continue
            top_char_length = find_top_duplicate(n_grams)
            if top_char_length / len(text) > n_frac:
                return False, f"top_{n}_gram"

        for n, n_frac in self.dup_n_grams:
            n_duplicates_char = find_all_duplicate(words, n)
            if n_duplicates_char / len(text) > n_frac:
                return False, f"duplicated_{n}_n_grams"

        return True

# fasttext

In [None]:
from collections import defaultdict
from typing import Tuple

import numpy as np

from datatrove.data import Document
from datatrove.io import cached_asset_path_or_download
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.text import SPLIT_TEXT_DOCUMENTS, split_into_parts


class FastTextClassifierFilter(BaseFilter):
    """
    Only keeps documents that have
    - AT LEAST ONE of the labels in `keep_labels` with a score above the configured threshold, or
    - NONE of the labels in `remove_labels` with a score above the configured threshold.

    You can only supply one of these, to avoid conflicts. Use multiple filters if you need to. If you supply
    neither, the block will simply annotate each document with the labels (set `save_labels_in_metadata=True`)

    Example:
        for `keep_labels=[("math", 0.9)]` will only keep samples with a score on __label__math of at least 0.9
        for `remove_labels=[("math", 0.9)]` will remove samples with a score on __label__math of at least 0.9

    Info to train your own classifier: https://fasttext.cc/docs/en/supervised-tutorial.html

    Args:
        model_url: url to download the model from or local path
        keep_labels: tuple of (label name without "__label__", min score) (or list of such tuples)
        remove_labels: tuple of (label name without "__label__", min score) (or list of such tuples)
        save_labels_in_metadata: whether to save all the label scores in the document metadata
        newline_replacement: str to replace \n with before predicting scores
        filter_mode: predict and filter on DOCUMENT, PARAGRAPH or SENTENCE level
        exclusion_writer:
    """

    name = "🤖 fastText"
    _requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"]

    def __init__(
        self,
        model_url: str,
        keep_labels: Tuple[str, float] | list[Tuple[str, float]] | None = None,
        remove_labels: Tuple[str, float] | list[Tuple[str, float]] | None = None,
        save_labels_in_metadata: bool = True,
        exclusion_writer: DiskWriter | None = None,
        newline_replacement="",
        filter_mode: str = SPLIT_TEXT_DOCUMENTS,
    ):
        super().__init__(exclusion_writer)
        self.model_url = model_url
        self.keep_labels = keep_labels
        self.remove_labels = remove_labels
        self.filter_mode = filter_mode
        if keep_labels and remove_labels:
            raise ValueError("You can only supply one of `keep_labels` or `remove_labels`.")
        self.newline_replacement = newline_replacement
        if keep_labels and isinstance(keep_labels[0], str):
            self.keep_labels = [keep_labels]
        if remove_labels and isinstance(remove_labels[0], str):
            self.remove_labels = [remove_labels]
        self.save_labels_in_metadata = save_labels_in_metadata
        self._model = None

    @property
    def model(self):
        if not self._model:
            from fasttext.FastText import _FastText

            model_file = cached_asset_path_or_download(
                self.model_url, namespace="filters", subfolder="fasttext", desc="fast-text model"
            )
            self._model = _FastText(model_file)
            # check label values
            available_labels = [x.removeprefix("__label__") for x in self._model.labels]
            for label, _ in self.keep_labels or [] + self.remove_labels or []:
                if label not in available_labels:
                    raise ValueError(
                        f"Label '{label}' passed as keep_labels or remove_labels is not available in this "
                        f"FastText model. Available labels: {available_labels}"
                    )
        return self._model

    def filter(self, doc: Document) -> bool:
        def check_label_scores(unit_scores):
            if self.keep_labels:
                return any(
                    unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.keep_labels
                )
            else:
                return not self.remove_labels or not any(
                    unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.remove_labels
                )

        units = split_into_parts(doc.text, mode=self.filter_mode)
        kept_spans = []
        label_scores = defaultdict(list)
        for unit in units:
            labels, scores = self.model.predict(unit.strip().replace("\n", self.newline_replacement), k=-1)
            if self.save_labels_in_metadata:
                for label, score in zip(labels, scores):
                    label_scores[label].append(score)
            if check_label_scores(dict(zip(labels, scores))):
                kept_spans.append(unit)
                self.stat_update("kept_span")
            else:
                self.stat_update("removed_span")
        doc.text = "".join(kept_spans)
        if self.save_labels_in_metadata:
            doc.metadata.update({label: np.mean(scores).item() for label, scores in label_scores.items()})
        return not not doc.text.strip()

# lambda filter

In [None]:
from typing import Callable
from datatrove.data import Document
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter

class LambdaFilter(BaseFilter):
    name = "👤 Lambda"

    def __init__(self, filter_function: Callable[[Document], bool], exclusion_writer: DiskWriter = None):
        """
        filters documents triggering the given filter_function with respect to a specific metadata key.

        Args:
            filter_function:
            exclusion_writer:
        """
        super().__init__(exclusion_writer)
        self.filter_function = filter_function

    def filter(self, doc: Document) -> bool:
        """Args:
            doc: document

        Returns:
            is_filter
        """
        return self.filter_function(doc)

# regex

In [None]:
import re

from datatrove.data import Document
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter


class RegexFilter(BaseFilter):
    name = "🕵 Regex"

    def __init__(self, regex_exp: str, exclusion_writer: DiskWriter = None):
        """
        filters if regex finds at least one match

        Args:
            regex_exp: regex expression
            exclusion_writer:
        """
        super().__init__(exclusion_writer)
        self.regex = re.compile(regex_exp)

    def filter(self, doc: Document) -> bool:
        """Args:
            doc: document

        Returns:
            is_filter
        """
        return not self.regex.search(doc.text)

# URL_filter

In [None]:
import os
import re
import tarfile
from typing import Iterable

from huggingface_hub import cached_assets_path
from loguru import logger

from datatrove.data import Document
from datatrove.io import safely_create_file
from datatrove.utils._import_utils import ASSETS_PATH

from ..writers.disk_base import DiskWriter
from .base_filter import BaseFilter


normalizer = re.compile(r"[^a-zA-Z0-9]+")


def normalize(text, replace=""):
    return normalizer.sub(replace, text).lower()


def parse_list(line, do_normalize=True):
    return {normalize(x) if do_normalize else x.strip() for x in line if x[0] != "#"}


def get_list(abs_path: str, file_name: str, extra: set = None, do_normalize: bool = True):
    with open(os.path.join(abs_path, file_name)) as f:
        return parse_list(f, do_normalize).union(set(parse_list(extra, do_normalize)) if extra else set())


class URLFilter(BaseFilter):
    """
    Performs filtering based on samples urls.
    Samples are removed if:
    - their domain is present on `block_listed_domains`
    - if their subdomain is present on `block_listed_domains`
    - if the full url is present on `block_listed_url`
    - if any word from `banned_words` is in the url
    - if there are at least `soft_word_threshold` words from `soft_banned_words` in the url
    - if any word from `banned_subwords` is a substring of the url
    """

    name = "😈 Url-filter"
    _requires_dependencies = ["tldextract", "fasteners"]

    def __init__(
        self,
        soft_word_threshold: int = 2,
        extra_domains: Iterable = None,
        extra_urls: Iterable = None,
        banned_words: Iterable = None,
        banned_subwords: Iterable = None,
        soft_banned_words: Iterable = None,
        exclusion_writer: DiskWriter = None,
    ):
        from tldextract import TLDExtract

        super().__init__(exclusion_writer)
        self.soft_word_threshold = soft_word_threshold
        self.block_listed_domains = extra_domains
        self.block_listed_url = extra_urls
        self.banned_words = banned_words
        self.banned_subwords = banned_subwords
        self.soft_banned_words = soft_banned_words
        self._downloaded = False
        self.tldextractor = TLDExtract()

    def download_data(self):
        if self._downloaded:
            return
        download_dir = cached_assets_path(library_name="datatrove", namespace="filters", subfolder="url_filter")
        file_to_lock = os.path.join(download_dir, "url_filterblacklists.tar.gz")

        def do_extract():
            logger.info("💥 Extracting url filter blacklists...")
            with tarfile.open(os.path.join(ASSETS_PATH, "url_filterblacklists.tar.gz"), "r:gz") as tar:
                tar.extractall(download_dir)
            logger.info("💥 Extracted url filter blacklists.")

        safely_create_file(file_to_lock, do_extract)

        self.block_listed_domains = get_list(
            download_dir, "adult/domains", self.block_listed_domains, do_normalize=False
        )
        self.block_listed_url = get_list(download_dir, "adult/urls", self.block_listed_url, do_normalize=False)
        self.banned_words = get_list(ASSETS_PATH, "banned_words.txt", self.banned_words)
        self.banned_subwords = get_list(ASSETS_PATH, "banned_subwords.txt", self.banned_subwords)
        self.soft_banned_words = get_list(ASSETS_PATH, "soft_banned_words.txt", self.soft_banned_words)
        self._downloaded = True

    def filter(self, document: Document) -> bool | tuple[bool, str]:
        self.download_data()
        url = document.metadata.get("url")

        assert url, "Document does not have url in its metadata"
        url_info = self.tldextractor(url)

        if url_info.registered_domain in self.block_listed_domains:
            return False, "domain"

        if url_info.fqdn in self.block_listed_domains:
            return False, "subdomain"

        if url in self.block_listed_url:
            return False, "url"

        url_words = set(normalizer.split(url))
        if any(word in url_words for word in self.banned_words):
            return False, "hard_blacklisted"

        nb_soft_words = sum([word in url_words for word in self.soft_banned_words])
        if nb_soft_words >= self.soft_word_threshold:
            return False, "soft_blacklisted"

        normalized_space = normalize(url)
        if any(word in normalized_space for word in self.banned_subwords):
            return False, "blacklisted_subword"

        return True

# unigram
クラスのドキュメント文字列で、このフィルタが英語の単語頻度データに基づいて単語のログ確率の平均を計算し、その平均が閾値よりも高いかどうかをチェックすることを説明しています。

In [None]:
import csv
import os
import urllib.request

import numpy as np
from huggingface_hub import cached_assets_path
from loguru import logger

from datatrove.data import Document
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter


UNIGRAM_DOWNLOAD = "https://ai2-s2-research-public.s3-us-west-2.amazonaws.com/lucas/google-1T-unigram/unigram_freq.csv"


class UnigramLogProbFilter(BaseFilter):
    """
    Computes average unigram log probability based on word frequencies from
    https://www.kaggle.com/datasets/rtatman/english-word-frequency

    Idea taken from https://huggingface.co/datasets/allenai/peS2o
    """

    name = "🧑‍🍳 Unigram log-prob filter"
    _requires_dependencies = ["nltk"]

    def __init__(
        self,
        logprobs_threshold: float = -10,
        exclusion_writer: DiskWriter = None,
    ):
        """

        Args:
            logprobs_threshold: the minimum average unigram logprobs needed to keep a document
            exclusion_writer:
        """
        super().__init__(exclusion_writer)
        self.logprobs_threshold = logprobs_threshold
        self.unigram_frequencies = self.get_frequencies()

    def get_frequencies(self):
        download_dir = cached_assets_path(
            library_name="datatrove", namespace="filters", subfolder="unigram_logprob_filter"
        )
        unigram_freq_file = os.path.join(download_dir, "unigram_freq.csv")
        if not os.path.isfile(unigram_freq_file):
            logger.info("⬇️ Downloading unigram-frequencies ...")
            urllib.request.urlretrieve(UNIGRAM_DOWNLOAD, unigram_freq_file)

        words = []
        counts = []
        with open(unigram_freq_file, encoding="utf-8", newline="") as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                words.append(row["word"])
                counts.append(int(row["count"]))
        total_count = sum(counts)
        return {word: count / total_count for word, count in zip(words, counts)}

    def get_logprob(self, doc):
        from nltk.tokenize import word_tokenize

        words = word_tokenize(doc.text)
        freqs = [self.unigram_frequencies.get(word.lower(), 1e-9) for word in words]

        if len(freqs) == 0:
            return 0
        return sum([np.log(f) for f in freqs]) / len(freqs)

    def filter(self, doc: Document) -> bool:
        """
            Checks if the average unigram probability is above the threshold. This assumes the text is in english.
        Args:
            doc:

        Returns:

        """
        return self.get_logprob(doc) > self.logprobs_threshold

# language filter

In [None]:
from datatrove.data import Document
from datatrove.io import cached_asset_path_or_download
from datatrove.pipeline.filters.base_filter import BaseFilter
from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.utils.typeshelper import Languages


LANGUAGE_ID_MODEL_URL = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin"


class LanguageFilter(BaseFilter):
    name = "🌍 Language ID"
    _requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"]

    def __init__(
        self,
        languages: tuple = (Languages.english,),
        language_threshold: float = 0.65,
        exclusion_writer: DiskWriter = None,
    ):
        """
        filters if the predicted language is not among given language or if the language score is below language
        language_threshold

        Args:
            languages: list of languages to keep
            language_threshold: language_threshold minimum score to accept a document
            exclusion_writer:
        """
        super().__init__(exclusion_writer)
        self.language_threshold = language_threshold
        self.languages = languages
        self._model = None

    @property
    def model(self):
        if not self._model:
            from fasttext.FastText import _FastText

            model_file = cached_asset_path_or_download(
                LANGUAGE_ID_MODEL_URL,
                namespace="filters",
                subfolder="language_filter",
                desc="fast-text language identifier model",
            )
            self._model = _FastText(model_file)
        return self._model

    def filter(self, doc: Document) -> bool:
        """Args:
            doc: document

        Returns:
            is_filter
        """

        language, score = self.model.predict(doc.text.replace("\n", ""))
        # language label is given in the form __label__<language_id>
        language = language[0].split("__")[2]
        doc.metadata["language"] = language
        doc.metadata["language_score"] = score[0]
        return score > self.language_threshold and language in self.languages