# UNFILTERED 
# Deduplicating Text in Common-Crawl for LLM Training. 

In this notebook, we will cover how to perform the minhash deduplication algorithm on html documents from the common crawl dataset.

The Common Crawl corpus contains petabytes of data, with its oldest entries dating back to 2008, including raw web page data, metadata extracts, and text extracts.

LLMs require massive amounts of data to train on. Early foundation models like GPT-3 and T5 saw improvements in model performance due to deduplication efforts. Deduplication makes it far less likely that the model regurgitates memorized text leading to better responses.

*See [Deduplicating Training Data Makes Language Models Better (Lee et. all)](https://aclanthology.org/2022.acl-long.577.pdf)*

---

## The MinHash Deduplication algorithm

If you google "minhash deduplication" you'll find a variety of sources that can walk you through the aglorithm. [Finding Near Duplicates with Jaccard Similarity and MinHash by Nelson Elhage](https://blog.nelhage.com/post/fuzzy-dedup/) is a great place to start, but if you are looking for the canonical reference for the MinHash deduplication algorithm, it originates from the seminal paper by Andrei Z. Broder, published in 1997, titled:

```text
"On the resemblance and containment of documents"
Published in: Proceedings of the Compression and Complexity of Sequences 1997 (SEQUENCES '97)
Publisher: IEEE Computer Society
DOI: 10.1109/SEQUEN.1997.666900
```

A video walkthough of the algorithm is also available through [Mike Mull's presentation on YouTube](https://www.youtube.com/watch?v=KKNPmvELUP4). He even provides a [jupyter notebook](https://github.com/papers-we-love/san-diego/blob/master/presentations/2016-11-03-resemblance-containment-documents/Broder97.ipynb) detailing the core primatives and how they are calculated in pure python. 

In this notebook, we will adopt a more practical approach, leveraging daft primatives and small user-defined-functions to accelerate development and process documents at scale. 


### First we will need to authenticate with AWS to access S3

Crawl data is free to access by anyone from anywhere. The data is hosted by Amazon Web Services’ Open Data Sets Sponsorships program on the bucket s3://commoncrawl/, located in the US-East-1 (Northern Virginia) AWS Region. The most performative means of accessing Common crawl is through s3, so you'll need to authenticate with an `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. 

Common Crawl data can also be accessed without authentication, anonymously via it's http endpoint, but for the purposes of this walkthrough we are going to stick with S3. 


In [None]:
import daft
from daft.io import IOConfig, S3Config
import os
from dotenv import load_dotenv

# Make sure to define your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY in your environment variables or in a .env file
load_dotenv()

s3_config = S3Config(
    region_name="us-east-1",
    requester_pays=True,
    key_id=os.environ["AWS_ACCESS_KEY_ID"],
    access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    anonymous=False,
)

IO_CONFIG = IOConfig(s3=s3_config)
daft.set_planning_config(default_io_config=IO_CONFIG)

## Loading Common Crawl Documents 

We will be accessing Common Crawl through [WARC files](https://commoncrawl.org/blog/navigating-the-warc-file-format) since daft supports the format natively with `daft.read_warc(uri)`

In [None]:
NUM_ROWS = 1000 # We'll limit this demo to a small number of rows for our initial walkthrough

In [None]:
df_warc = daft.read_warc("s3://commoncrawl/crawl-data/CC-MAIN-2018-17/segments/*/warc/*.warc.gz").limit(NUM_ROWS)
df_warc.show(3) # Inspect the first 3 rows

In [None]:
# Lets investigate the different types of payloads we have: 
df_warc.select("WARC-Identified-Payload-Type").distinct().show()

### Step 1: Preprocessing
Since we are primarily concerned with text, we will focus on text/html payloads, extracting text content from html body and normalizing the text itself. 


In [None]:
from daft import col
from daft.functions import monotonically_increasing_id


# Define a UDF to remove http headers from the payload
@daft.func()
def remove_http_headers(x: str) -> str:
    if x is None:
        return ""
    if len(x.split("\r\n\r\n")) > 1:
        return x.split("\r\n\r\n")[1]
    return ""

# Filter the dataframe to only include text/html payloads
df_html = df_warc.where(col("WARC-Identified-Payload-Type")== "text/html")

# Seperate the http headers from the payloads
df_html = (
    df_html
    .with_column("content_raw", remove_http_headers(col("warc_content").try_decode("utf-8")))
    .where(col("content_raw") != "")
)   

# Simplify the dataframe to just the content and add a monotonically increasing int id
df_html = (
    df_html
    .with_columns_renamed({"WARC-Identified-Payload-Type": "content_type",})
    .with_column("id", monotonically_increasing_id())
    .select("WARC-Record-ID", "id", "content_type", "content_raw")
)

df_html.show(3)

### Extracting Text from HTML body

In [None]:
from selectolax.parser import HTMLParser
import re
import hashlib

# Define a UDF to extract text from HTML content, Specifically (article, main, p, h1, h2, h3, li)
@daft.func()
def extract_blocks(html: str) -> list[str]:
    tree = HTMLParser(html)
    for n in tree.css("script,style,noscript"):  n.decompose()
    blocks = []
    for node in tree.css("article, main, p, h1, h2, h3, li"):
        txt = node.text(separator=" ", strip=True)
        if not txt: 
            continue
        blocks.append(txt)
    return blocks

df_text = df_html.with_column("content_text", extract_blocks(col("content_raw")).list.join(" "))
df_text.show(3)

# Minhash

Now that we have extracted the text out of the html we move to normalize the inputs and calculate our minhash vectors using daft's `minhash` expression! No need to build shingles!

In [None]:
K = 64 # Number of Permutations
SEED = 43 # Seed for the hash function
NGRAM_SIZE = 5 # Size of the n-grams
index_col = "id" 
content_col = "content_text"

In [None]:
# Normalize text 
df_norm = df_text.with_column("content_normalized", 
    col(content_col).str.normalize(
        remove_punct=True, 
        lowercase=True, 
        nfd_unicode=True, 
        white_space=True
    ) 
)
df_norm.show(3)


In [None]:
# Calculate the minhash vectors
df_minhash = (
    df_norm
    .with_column("min_hashes", col("content_normalized").minhash(
        num_hashes = K,
        ngram_size = NGRAM_SIZE,
        seed = SEED, 
        hash_function = 'xxhash'
        )
    )
)
df_minhash.show(3)

### Band Generation and Bucketing

Next, we will:
1. Use the optimal_param function to determine the best band (b) and row (r) parameters for our LSH bucketing
2. Split each document's minhash vector into b bands of r rows each
3. Create buckets by hashing each band's signature, grouping similar documents together


In [None]:
from scipy.integrate import quad as integrate

def optimal_param(
    threshold: float,
    num_perm: int,
    false_positive_weight: float = 0.5,
    false_negative_weight: float = 0.5,
):
    """
    Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum
    of probabilities of false positive and false negative, taken from datasketch.

    Parameters
    ----------
    threshold : float
        The threshold for similarity.
    num_perm : int
        The number of permutations.
    false_positive_weight : float
        The weight of false positive.
    false_negative_weight : float
        The weight of false negative.

    Returns
    -------
    Tuple[int, int]
        The optimal `b` and `r` parameters.
        The number of bands, and the number of rows per band respectively.

    Examples
    --------
    >>> optimal_param(0.7, 256)
    (25, 10)
    """

    def false_positive_area(threshold: float, b: int, r: int):
        """Source: `datasketch.lsh`"""

        def area(s):
            return 1 - (1 - s ** float(r)) ** float(b)

        a, _ = integrate(area, 0.0, threshold)
        return a

    def false_negative_area(threshold: float, b: int, r: int):
        """Source: `datasketch.lsh`"""

        def area(s):
            return 1 - (1 - (1 - s ** float(r)) ** float(b))

        a, _ = integrate(area, threshold, 1.0)
        return a

    min_error = float("inf")
    opt = (0, 0)
    for b in range(1, num_perm + 1):
        max_r = int(num_perm / b)
        for r in range(1, max_r + 1):
            fp = false_positive_area(threshold, b, r)
            fn = false_negative_area(threshold, b, r)
            error = fp * false_positive_weight + fn * false_negative_weight
            if error < min_error:
                min_error = error
                opt = (b, r)
    return opt

In [None]:
# Choose B bands and R rows per band such that B · R = num_perm.
B, R = optimal_param(0.717, K)

In [None]:
from daft import lit

# Band Generation
df_bands = (
    df_minhash
    .with_column("bands", col("min_hashes").list.chunk(R))
    .with_column("band_idx", lit(list(range(B))))
    .explode("bands", "band_idx")
    .select(index_col, "band_idx", "bands")
)
df_bands.show(20)    

In [None]:
# Grouping Bands
df_grouped = (
    df_bands
    .groupby(col("band_idx"), col("bands"))
    .agg(col(index_col).agg_list().alias("nodes"))
)
df_grouped.show()

## Connected Components
Every band whose `nodes` have more than one entry are now candidates for consideration. 
Next we are going to leverage a few tricks from graph theory to  @YK 

In [None]:
# Generate Graph Edges
df_edges = (
    df_grouped
    .with_column("left_edge", col("nodes").list.min())
    .explode("nodes")
    .select("left_edge", right_edge=col("nodes"))
    .filter(col("left_edge") != col("right_edge"))
    .distinct()
)
df_edges.show()

### 1. Canonicalize edges to undirected form

- Remove nulls
- remove rows where u == v

In [None]:
# prep edges for connected components
df_edges_clean = (
    df_edges.select(col("left_edge").alias("u"), col("right_edge").alias("v"))
    .where(~col("u").is_null())
    .where(~col("v").is_null())
    .where(col("u") != col("v"))
)


In [None]:
import igraph as ig

df_pd_edges = df_edges_clean.select(col("u").cast(daft.DataType.int64()), col("v").cast(daft.DataType.int64())).to_pandas()

# using igraph
g = ig.Graph.DataFrame(df_pd_edges, directed=False)
strong_components = {frozenset(c) for c in g.connected_components(mode="strong")}
weak_components = {frozenset(c) for c in g.connected_components(mode="weak")}

print(strong_components)
print(weak_components)
assert strong_components == weak_components


In [None]:
import matplotlib.pyplot as plt
import random

fig, ax = plt.subplots()
ig.plot(
    weak_components,
    target=ax,
    palette=ig.RainbowPalette(),
    vertex_size=7,
    vertex_color=list(map(int, ig.rescale(components.membership, (0, 200), clamp=True))),
    edge_width=0.7
)
plt.show()

### Star Contraction with Daft
Now we will iteratively compress the graph using two alternating phases until convergence:
- Large-star: Every node points to the minimum ID in its neighborhood (including itself). This quickly pulls nodes toward low-ID “hubs.”
- Small-star: Re-orient edges to ensure u < v (canonicalize) and repeat contraction, which merges local hubs together.
- Repeat large-star then small-star until nothing changes. The “parent” each node ends up pointing to is its component representative.

### 2. Large-star phase
- Group neighbors by u.
- Compute min_neighbor = min(neighbors).
- Use min(u, min_neighbor) as the node’s “parent.”
- Emit edges (u, parent) but only where parent > u to avoid self-loops and duplicates.

In [None]:
a = (
    b
    # large_star_map
    .select("u", "v") 
    .union_all(b.select(col("v").alias("u"), col("u").alias("v"))) # Include upper and lower triangles
    .groupby("u").agg_list("v") # Group by u and aggregate the list of v's 
    .with_column("min_edge", col("v").list.min()) # Find the minimum v for each u and call it min_edge
    .with_column("min_edge", (col("u") <= col("min_edge")).if_else(col("u"), col("min_edge"))) # If u is less than the min_edge, use u, otherwise use the min_edge... this is just a sanity check to ensure we are always moving towards lower ids. 
    .with_column("v", col("v").explode())
    .where(col("v") > col("u"))
    .where(~col("v").is_null()) #should be a no-op but just in case
    .distinct()
    .select(col("u"), col("v"))
)

In [None]:
# Compare results after 1 large star iteration 
daft_components = {frozenset([d["u"],*d["v"]]) for d in a.to_pylist()}
assert daft_components == strong_components

### 3. Small-star phase
- Re-orient all edges so u < v (canonical).
- Group neighbors by u, compute min_neighbor, connect (u, parent) like above.
- This step merges local minima across previously separate stars.

In [None]:
# TODO: this will fail.
def small_star_phase(df: DataFrame):
    return (
        df
        # small_star_map
        .select((col("u") > col("v")).if_else(ee(col("u"), col("v")), ee(col("v"), col("u"))).alias("e"))
        .select(col("e")["*"])

        .groupby("u").agg_list("v")
        # small_star_reduce
        .with_column("min_edge", col("v").list.min())
        .with_column("min_edge", (col("u") <= col("min_edge")).if_else(col("u"), col("min_edge")))
        .select(col("u").list.map(ee(daft.element(), col("min_edge"))).alias("e"), col("u"), col("min_edge"))
        # TODO: list_append

        .explode("e")
        .where(~col("e").is_null())
        .distinct()
        .select(col("e")["*"])
        .collect() # Materialize
  )

### 4. Convergence check
- Compare a stable summary of edges before/after (hash sum is fine).
- If stable, stop; otherwise repeat.

In [None]:
def check_convergence(a: DataFrame, b: DataFrame):
    a_hash = a.select(col("u").hash().alias("hash")).sum("hash").to_pydict()["hash"][0]
    b_hash = b.select(col("u").hash().alias("hash")).sum("hash").to_pydict()["hash"][0]
    if a_hash == b_hash:
        return True

### Combining Stages

In [None]:
def connected_components(
    edges: DataFrame,     
    left_id_col="left_edge",
    right_id_col="right_edge",
    output_index_col=index_col,
    output_component_col="__component__",
):
    # Convert column names to u, v
    b = (
        edges.select(col(left_id_col).alias("u"), col(right_id_col).alias("v"))
        .where(~col("u").is_null())
        .where(~col("v").is_null())
        .collect()
    )    
    while True:
        a = large_star_phase(b)
        b = small_star_phase(a)
        if check_convergence(a, b):
            break
    
    # Revert column names and return contracted star edges
    return (
        b
        .select(col("u").alias(output_index_col), col("v").alias(output_component_col))
        .collect()
    )
    

In [None]:

def components(
    df: DataFrame,
    left_id_col: str = "u",
    right_id_col: str = "v",
    output_index_col: str = "u",
    output_component_col: str = "component"
) -> DataFrame:
    b = (
        df.select(col(left_id_col).alias("u"), col(right_id_col).alias("v"))
        .where(~col("u").is_null())
        .where(~col("v").is_null())
        .collect()
    )    
    while True:
        a = (b
             # large_star_map
             .select("u", "v")
             .union_all(b.select(col("v").alias("u"), col("u").alias("v")))

             .groupby("u").agg_list("v")
             # large_star_reduce
             .with_column("min_edge", col("v").list.min())
             .with_column("min_edge", (col("u") <= col("min_edge")).if_else(col("u"), col("min_edge")))
             .select(col("u").list.map(ee(daft.element(), col("min_edge"))).alias("e"), col("u"))

             .explode("e")
             .where(col("e")["v"] > col("u")).select("e")
             .where(~col("e").is_null())
             .distinct()
             .select(col("e")["*"])
             .where(col("u") != col("v"))
             .collect()
        )
        b = (a
             # small_star_map
             .select((col("u") > col("v")).if_else(ee(col("u"), col("v")), ee(col("v"), col("u"))).alias("e"))
             .select(col("e")["*"])

             .groupby("u").agg_list("v")
             # small_star_reduce
             .with_column("min_edge", col("v").list.min())
             .with_column("min_edge", (col("u") <= col("min_edge")).if_else(col("u"), col("min_edge")))
             .select(col("u").list.map(ee(daft.element(), col("min_edge"))).alias("e"), col("u"), col("min_edge"))
             # TODO: list_append

             .explode("e")
             .where(~col("e").is_null())
             .distinct()
             .select(col("e")["*"])
             .collect()
        )
        # check convergence
        a_hash = a.select(col("u").hash().alias("hash")).sum("hash").to_pydict()["hash"][0]
        b_hash = b.select(col("u").hash().alias("hash")).sum("hash").to_pydict()["hash"][0]
        if a_hash == b_hash:
            return (
                b
                .select(col("u").alias(output_index_col), col("v").alias(output_component_col))
                .collect()
            )

In [None]:
assignment = components(
    df_edges,
    left_id_col="left_edge",
    right_id_col="right_edge",
    output_index_col=index_col,
    output_component_col="__component__",
)


In [None]:
# Running the Star Contraction
df_star_edges = connected_components(
    edges=df_edges,
    left_id_col="left_edge",
    right_id_col="right_edge",
    output_index_col=index_col,
    output_component_col="__component__",
)
df_star_edges.show(3)


### 5. Final assignment
- Treat the final v for each u as the component representative.
- Join back to your documents and keep the representative per component.

In [None]:
# Keep one per component (the representative equals the index)
 (
    df
    .join(assignment.select(col(index_col), col("__component__")), on=index_col, how="left")
    .filter(col("__component__").is_null() | (col("__component__") == col(index_col)))
    .exclude("__component__")
)