This file tests the deduplication code from EleutherAI's `janitor.py` file on small section(s) of Dolma to estimate how long full deduplication would take.

To run ``janitor.py`` with C++ on Linux:
1. At ``lm-evaluation-harness/scripts/clean_training_data``, run ``c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix)``
2. Rename the resulting ``.so`` file to ``janitor_util.so``
3. Tell Python the location of ``janitor_util.so`` when it looks for ``janitor_util``: ```sys.path.append(harness_dir + "/scripts/clean_training_data")```

In [2]:
import pyarrow.parquet as pq
from pathlib import Path
import pandas as pd
import sys
import datetime
import os
import pyarrow
from tqdm import tqdm
import copy

harness_dir = str(Path("__file__").resolve().parents[3] / "lm-evaluation-harness")
sys.path.append(harness_dir)

sys.path.append(harness_dir + "/scripts/clean_training_data")
from lm_eval.decontamination.janitor import Janitor

os.environ["NUMEXPR_MAX_THREADS"] = "256"
os.environ["NUMEXPR_NUM_THREADS"] = "200"
import numexpr as ne

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(




Traceback (most recent call last):
  File "/data/tir/projects/tir7/user_data/mchen5/llm-pretraining-behaviours/lm-evaluation-harness/lm_eval/decontamination/janitor.py", line 11, in <module>
    import janitor_util
ModuleNotFoundError: No module named 'janitor_util'


In [3]:
with open("./contaminant.txt", "r") as file:
    contaminant: str = file.read()  # 3.4G
contaminant = contaminant
janitor = Janitor()
janitor.register_contaminant(contaminant)



In [4]:
data_mini: pyarrow.lib.Table = pq.read_table("data_mini.arrow")
print(f"Size of data_mini: {sys.getsizeof(data_mini)} bytes")

Size of data_mini: 112133501 bytes


In [5]:
df: pd.DataFrame = data_mini.to_pandas()
df.head(5)

Unnamed: 0,id,text
0,09c6eceb562caeba5b94489087fb1e8d,"TAMPA, Fla., Nov. 03, 2016 (GLOBE NEWSWIRE) --..."
1,7378e5a823604985555d1d9267827368,"It was brimming with midges. Everywhere, these..."
2,43088e9ab3bdb2236fc493594b99f72f,We encourage all our employees to be ambitious...
3,14b802b07c5b0685470f5c87fc60e394,The first road assignment is coming this weeke...
4,954f973826676c5a9421c0286f964bd3,Course to upgrade skills for experienced Hr pr...


In [6]:
def decontaminate(df: pd.DataFrame, janitor: Janitor) -> pd.DataFrame:
    df["num_contaminated"] = 0
    df["thrown"] = False

    num_thrown = 0
    for index, row in df.iterrows():
        try:
            (cleaned, num_contaminated) = janitor.clean_python(row["text"])
            df.at[index, "num_contaminated"] = num_contaminated
            if num_contaminated != 0:
                df.at[index, "text"] = "".join(cleaned)
        except:
            df.at[index, "thrown"] = True
            num_thrown += 1

    return df

In [7]:
df = decontaminate(df, janitor)
df.to_parquet("./dev_outputs/result.arrow")

In [8]:
result = pq.read_table("./dev_outputs/result.arrow").to_pandas()
print(result.shape)
result.head(5)

(50000, 4)


Unnamed: 0,id,text,num_contaminated,thrown
0,09c6eceb562caeba5b94489087fb1e8d,"TAMPA, Fla., Nov. 03, 2016 (GLOBE NEWSWIRE) --...",0,False
1,7378e5a823604985555d1d9267827368,"It was brimming with midges. Everywhere, these...",0,False
2,43088e9ab3bdb2236fc493594b99f72f,We encourage all our employees to be ambitious...,0,False
3,14b802b07c5b0685470f5c87fc60e394,The first road assignment is coming this weeke...,0,False
4,954f973826676c5a9421c0286f964bd3,Course to upgrade skills for experienced Hr pr...,0,False


In [16]:
import os
import sys
import multiprocessing


contamination_indices = 0

print(f"{multiprocessing.cpu_count()} CPUs go brrr")


# Deduplicates the file at this path and saves the output to dolma_100B_deduped
def process_file(file_path, directory_name, file_name):
    print(f"Processing {file_path}")
    sys.stdout.flush()
    global contamination_indices
    df: pd.DataFrame = pq.read_table(file_path).to_pandas()
    df = decontaminate(df, janitor)
    contamination_indices += df["num_contaminated"].sum()
    df.to_parquet(
        f"/data/tir/projects/tir7/user_data/mchen5/dolma_100B_deduped/{directory_name}/{file_name}"
    )


# Start a new process for each file, so we deduplicate fully in parallel
def process_directory(directory_path, directory_name):
    for root, _, files in os.walk(directory_path):
        for file_name in files:
            file_path = os.path.join(root, file_name)
            p = multiprocessing.Process(
                target=process_file, args=(file_path, directory_name, file_name)
            )
            p.start()


def main():
    global contamination_indices
    base_dir = "/data/tir/projects/tir7/user_data/mchen5/dolma_100B"
    for directory_name in os.listdir(base_dir):
        directory_path = os.path.join(base_dir, directory_name)
        if os.path.isdir(directory_path):
            print(f"Processing {directory_path}")
            process_directory(directory_path, directory_name)
    print("Finished decontamination")
    print(f"{contamination_indices} total contamination indices")


if __name__ == "__main__":
    main()

256 CPUs go brrr
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/peS2o
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/peS2o/part_4.arrow
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/peS2o/part_5.arrow
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/peS2o/part_1.arrow
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/peS2o/part_2.arrow
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/stack-code
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/peS2o/part_3.arrow
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/stack-code/part_9.arrowProcessing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/stack-code/part_12.arrow

Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/stack-code/part_8.arrow
Processing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/stack-code/part_4.arrowProcessing /data/tir/projects/tir7/user_data/mchen5/dolma_100B/stack-code/part_11.