# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| export
import logging
import multiprocessing
import os
import re
import time

import numpy as np

from datasets import Dataset, load_dataset, Features, Sequence, Value
from datasketch import LeanMinHash, MinHash, MinHashLSH
from rich.logging import RichHandler
from tqdm.auto import tqdm

In [None]:
#| export
multiprocessing.set_start_method("fork", force=True)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(RichHandler(rich_tracebacks=True))
logger.propagate = False

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
MINHASH_SEED = 42
NON_ALPHA = re.compile("[^A-Za-z_0-9]")

In [None]:
#| export
def hash_content(idx: int, content: str, *, num_perm: int):
    """
    Hash the content of a record using MinHash. This function should be
    used with multiprocessing and it scales well with the number of cores.
    Parameters
    ----------
    idx : int
        The index of the record.
    content : str
        The content to embed.
    num_perm : int
        The number of permutations to use in the MinHash object.
    seed : int
        The seed to use in the MinHash object.
    Returns
    -------
    Dict[str, Any]
        The MinHash signature and the index of the record.
    Examples
    --------
    >>> result = hash_content(0, "Hello world!", num_perm=128)
    >>> result["__id__"]
    0
    >>> result["__signature__"].shape
    (128,)
    >>> result["__signature__"].dtype
    dtype('uint64')
    """
    m = MinHash(num_perm=num_perm, seed=MINHASH_SEED)
    m.update_batch([token.encode("utf-8") for token in {t for t in NON_ALPHA.split(content) if t}])
    return {"__signature__": m.hashvalues, "__id__": idx}

def query_content(idx: int, signature: np.ndarray, *, index: MinHashLSH):
    """
    Query the MinHashLSH index for the record. This function can be used with multiprocessing
    as long as the index is shared across processes.
    Parameters
    ----------
    index : MinHashLSH
        The MinHashLSH index. It is shared across all processes when using multiprocessing with fork without copy.
    record : Dict[str, Any]
        The record to query.
    Returns
    -------
    Dict[str, Any]
        The query result.
    """
    return {
        "__neighbors__": [
            str(dup_idx)
            for dup_idx in index.query(
                LeanMinHash(seed=MINHASH_SEED, hashvalues=signature),
            )
        ],
        "__id__": idx,
    }

def jaccard_similarity(s1: str, s2: str) -> float:
    """
    Calculate the jaccard similarity between two code snippets.
    Parameters
    ----------
    s1 : str
        The first code snippet.
    s2 : str
        The second code snippet.
    Returns
    -------
    float
        The jaccard similarity between the two code snippets.
    Examples
    --------
    >>> jaccard_similarity("a = 1", "a = 2")
    0.3333333333333333
    >>> jaccard_similarity("a = 1", "a = 1")
    1.0
    """
    tokens1 = set([t for t in NON_ALPHA.split(s1) if t.strip()])
    tokens2 = set([t for t in NON_ALPHA.split(s2) if t.strip()])
    return len(tokens1 & tokens2) / max(1, len(tokens1 | tokens2))

In [None]:
#| export
class BenchmarkCleaner:
    def __init__(self, benchmarks, threshold = 0.5, num_perm = 128):
        self.benchmarks = benchmarks
        self.threshold = threshold
        self.num_perm = num_perm
    
    def clean(self, ds, column):
        start_time = time.time()
        DATA_SIZE = len(ds)
        ds = ds.map(
            lambda _, idx: {"__id__": idx},
            with_indices=True,
            num_proc=os.cpu_count(),
            desc="Adding index...",
        )
        hashed_ds = ds.map(
            function=hash_content,
            fn_kwargs={"num_perm": self.num_perm},
            input_columns=["__id__", column],
            remove_columns=[column],
            num_proc=os.cpu_count(),
            desc=f"Fingerprinting...",
        )
        dup_ids = set()
        for bm in self.benchmarks:
            globals()[bm["name"]] = MinHashLSH(
                threshold=self.threshold,
                num_perm=self.num_perm,
            )
            benchmark_ds = load_dataset(bm["name"], split="+".join(bm["splits"]))
            columns_to_remove = [c for c in benchmark_ds.column_names if c not in bm["columns"]]
            benchmark_ds = benchmark_ds.remove_columns(columns_to_remove)
            benchmark_ds = benchmark_ds.map(
                    function=lambda x, idx: {
                        **hash_content(
                            idx,
                            " ".join(
                                [x[col] if isinstance(x[col], str) else " ".join(x[col]) for col in bm["columns"]]
                            ),
                            num_perm=self.num_perm,
                        ),
                        "__content__": " ".join(
                            [x[col] if isinstance(x[col], str) else " ".join(x[col]) for col in bm["columns"]]
                        ),
                    },
                    num_proc=4,
                    with_indices=True,
                    desc=f"Fingerprinting...",
                )
            with globals()[bm["name"]].insertion_session() as session:
                for record in benchmark_ds:
                    session.insert(record["__id__"], LeanMinHash(seed=MINHASH_SEED, hashvalues=record["__signature__"]))

            # remove unused columns
            hashed_ds = hashed_ds.remove_columns([c for c in hashed_ds.column_names if c not in ["__id__", "__signature__"]])
            queried = hashed_ds.map(
                function=lambda x, y: query_content(x, y, index=globals()[bm["name"]]),
                num_proc=os.cpu_count(),
                input_columns=[
                    "__id__",
                    "__signature__",
                ],
                remove_columns=["__signature__"],
                desc="Querying...",
                features=Features(
                    {
                        "__id__": Value("uint64"),
                        "__neighbors__": Sequence(Value("string")),
                    }
                ),
            ).filter(
                lambda x: len(x["__neighbors__"]) > 0,
                num_proc=os.cpu_count(),
                desc=f"Filtering...",
            )

            for record in tqdm(
                queried,
                desc=f"Checking for false positives...",
            ):
                neighbors = set(record["__neighbors__"])
                curr_text = ds[record["__id__"]][column]
                for neighbor in neighbors:
                    reference = benchmark_ds[int(neighbor)]
                    reference_text = reference["__content__"]
                    if jaccard_similarity(curr_text, reference_text) >= self.threshold:
                        break
                else:
                    continue
                dup_ids.add(record["__id__"])

            duplicates = ds.filter(lambda x: x["__id__"] in dup_ids, num_proc=os.cpu_count())
            final_data = ds.filter(
                lambda idx: idx not in dup_ids,
                input_columns=["__id__"],
                num_proc=os.cpu_count(),
                desc="Filtering duplicates...",
            )

            FINAL_DATA_SIZE = len(final_data)
            DUP_SIZE = DATA_SIZE - FINAL_DATA_SIZE

            logger.info(f"{'Data Number':<30}: {DATA_SIZE}")
            logger.info(f"{'Duplicate Number':<30}: {DUP_SIZE}")
            logger.info(f"{'Duplicate Rate':<30}: {DUP_SIZE / DATA_SIZE:.2%}")
            logger.info(f"{'Total Time':<30}: {time.time() - start_time:.2f} seconds")

            return final_data

In [None]:
DATASETS_TO_CHECK = [
    {
        "name": "openai_humaneval",
        "splits": ["test"],
        "columns": ["prompt", "canonical_solution", "test"],
    },
    {
        "name": "mbpp",
        "splits": ["validation", "test"],
        "columns": ["text", "code", "test_list"],
    },
]
ds = load_dataset("bigcode/the-stack-smol", data_dir="data/python", split="train")
bench_cleaner = BenchmarkCleaner(DATASETS_TO_CHECK, threshold=0.1, num_perm=128)
ds = bench_cleaner.clean(ds, "content")

Using custom data configuration bigcode--the-stack-smol-7b51f8bde3058781
Found cached dataset json (/home/nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-7b51f8bde3058781/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-7b51f8bde3058781/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-5cec16f7c8c87bb2.arrow
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-7b51f8bde3058781/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-63fe0b2405153769.arrow
Loading cached processed dataset at /home/nathan/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-7b51f8bde3058781/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-ae0db36e3a8aac70.arrow
Loading cached processed dataset

In [None]:
ds

Dataset({
    features: ['content', 'avg_line_length', 'max_line_length', 'alphanum_fraction', 'licenses', 'repository_name', 'path', 'size', 'lang', '__id__'],
    num_rows: 9972
})

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()