In [1]:
import sys
sys.path.append("py_src/near_duplicates")

In [3]:
from utils import merge_results, bytes_to_ints
import os
import pickle
from tqdm.notebook import tqdm
from transformers import GPT2Tokenizer
import matplotlib.pyplot as plt
from collections import defaultdict
import traceback
import copy
import numpy as np

In [4]:
ROOT = "/working/dir"
DS_PATH = os.path.join(ROOT, "tokenized/slimpajama_0_of_20.train")
DS_SIZE_PATH = os.path.join(ROOT, "tokenized/slimpajama_0_of_20.train.size")
RESULTS_DIR = os.path.join(ROOT, "scan")

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [None]:
with open(DS_SIZE_PATH, "rb") as f:
    ds_size_raw = f.read()
    ds_size = bytes_to_ints(ds_size_raw, 8)

print("loaded sizes")

with open(DS_PATH, "rb") as f:
    ds = f.read()

print("loaded full dataset")

In [None]:
total = 0
results = None

for filename in tqdm(sorted(os.listdir(RESULTS_DIR))):
    with open(os.path.join(RESULTS_DIR, filename), "rb") as f:
        new_results = pickle.load(f)    
    if results is None:
        results = new_results
    else:
        results = merge_results(results, new_results)
    total += 1

In [16]:
all_buckets = sorted(set([r.bucket for r in results]))

In [17]:
cum_counts = {}
for bucket in all_buckets:
    results_bucket = [x for x in results if x.bucket == bucket]
    cum_counts_tmp = defaultdict(list)
    for metric in ("hamming", "edit"):
        for dist in range(51):
            cnt = 0
            for r in results_bucket:
                near_dups_dist = [x for x in r.near_duplicates[metric] if x[1] <= dist]
                cnt += len(near_dups_dist)
            
            # extrapolating total count from 5%
            cnt = (cnt / 100) * 20
            cum_counts_tmp[metric].append(cnt)
    
    cum_counts[bucket] = cum_counts_tmp

In [None]:
plt.figure(figsize=(6, 6))
colors = ['#74c476',
          '#31a354',
          '#006d2c',
          "green",]

for idx, bucket in enumerate(all_buckets):
    plt.plot(range(51), cum_counts[bucket]['edit'],
             #  linewidth=2.5,
             color=colors[idx],
             marker='o',
             markersize=5,
             markeredgecolor='white',
             markerfacecolor=colors[idx],
             markeredgewidth=0.5,
             label=f"bucket {bucket}")


plt.grid(True, linestyle='-', color='#E5E7EB', alpha=0.9, zorder=0)
plt.xlabel('Levenstein distance', fontsize=10)
plt.tick_params(labelsize=9)

plt.ylabel('Cumulative number of near duplicates', fontsize=10)
plt.legend(fontsize=10, loc="upper left")
plt.tight_layout()
plt.yscale('log')
plt.ylim(10, 1000000)
plt.yticks(plt.yticks()[0].tolist() + [100])
plt.show()

In [None]:
fig, axes = plt.subplots(1, len(all_buckets), figsize=(15, 5))

for idx, bucket in enumerate(all_buckets):
    axes[idx].plot(range(51), cum_counts[bucket]['edit'], 
             color='#3B82F6',
             marker='o',
             markersize=5,
             markeredgecolor='white',
             markerfacecolor='#3B82F6',
             markeredgewidth=0.5,
             label="Edit distance")

    axes[idx].plot(range(51), cum_counts[bucket]['hamming'],
             color='#F97316',
             marker='o',
             markersize=5,
             markeredgecolor='white',
             markerfacecolor='#F97316',
             markeredgewidth=0.5,
             label="Hamming distance")

    axes[idx].grid(True, linestyle='-', color='#E5E7EB', alpha=0.9, zorder=0)
    axes[idx].set_title(f'Bucket: {bucket}', fontsize=12, pad=10, fontweight='bold')
    axes[idx].set_xlabel('Distance', fontsize=10)
    axes[idx].spines['top'].set_visible(False)
    axes[idx].spines['right'].set_visible(False)
    axes[idx].tick_params(labelsize=9)

    if bucket > 0:
        axes[idx].set_ylim(0, 22*bucket)

axes[0].set_ylabel('Cumulative number of Near Duplicates', fontsize=10)
axes[0].legend(fontsize=9, loc="upper left")

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 6))
bucket = 1000

plt.plot(range(51), cum_counts[bucket]['edit'], 
    #  linewidth=2.5,
     color='#3B82F6',
     marker='o',
     markersize=5,
     markeredgecolor='white',
     markerfacecolor='#3B82F6',
     markeredgewidth=0.5,
     label="Levenstein distance")

plt.plot(range(51), cum_counts[bucket]['hamming'],
        #  linewidth=2.5, 
         color='#F97316',
         marker='o',
         markersize=5,
         markeredgecolor='white',
         markerfacecolor='#F97316',
         markeredgewidth=0.5,
         label="Hamming distance")

plt.axhline(y=1000, xmin=0, xmax=1, color='dimgray', linestyle='--', linewidth=1, label="Exact duplicates")
# print(list(plt.yticks())[0])
plt.yticks(list(plt.yticks()[0]) + [1000])
plt.gca().get_yaxis().get_major_ticks()[-1].label1.set_color('dimgray')
# plt.gca().get_yaxis().get_major_ticks()[-1].label1.set_fontsize(12) 


plt.grid(True, linestyle='-', color='#E5E7EB', alpha=0.9, zorder=0)
plt.xlabel('Distance', fontsize=13)
plt.ylim(0, 23*bucket)
plt.ylabel('Cumulative number of near duplicates', fontsize=13)
plt.legend(fontsize=13, loc="upper left")
plt.tight_layout()
plt.show()

## Dedup

In [141]:
def has_ngram_overlap(arr1, arr2, n):
    ngrams1 = set([tuple(arr1[i:i+n]) for i in range(len(arr1)-n+1)])
    ngrams2 = set([tuple(arr2[i:i+n]) for i in range(len(arr2)-n+1)])

    return len(ngrams1 & ngrams2) > 0

In [None]:
words = [len(tokenizer.decode(r.tokens).split(" ")) for r in results]
np.mean(words)

In [None]:
dedup_results = []
ngrams = range(10,110,10)
bucket = 1000

for r in tqdm(results):
    if r.bucket != bucket: 
        continue

    new_r = copy.deepcopy(r)
    new_r.near_duplicates_dedup = {}

    for ngram_size in ngrams:
        new_r.near_duplicates_dedup[ngram_size] = defaultdict(list)

    for metric in ("hamming", "edit"):
        for pos, dist in r.near_duplicates[metric]:
            nd_tokens = bytes_to_ints(ds[pos:pos+200],2)

            for ngram_size in ngrams:
                if not has_ngram_overlap(r.tokens, nd_tokens, n=ngram_size):
                    new_r.near_duplicates_dedup[ngram_size][metric].append((pos, dist))

    dedup_results.append(new_r)

In [235]:
cum_counts_dedup = {}

metric="edit"

for ngram_size in ngrams:
    cum_counts_dedup[ngram_size] = defaultdict(list)
    for bucket in (100, 1000, 10000):
        results_bucket = [x for x in dedup_results if x.bucket == bucket]
        for dist in range(51):
            cnt = 0
            for r in results_bucket:
                cnt += len([x for x in r.near_duplicates_dedup[ngram_size][metric] if x[1] <= dist])
                
            # correction for average
            cnt = (cnt / 100) * 20

            cum_counts_dedup[ngram_size][bucket].append(cnt)

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))
colors = plt.cm.winter(np.linspace(0, 1, 10))
bucket = 1000

plt.plot(range(51), cum_counts[bucket]['edit'], 
     color='dimgrey',
     markersize=5,
     markeredgecolor='white',
     markerfacecolor='dimgrey',
     markeredgewidth=0.5,
     label="No dedup",
     alpha=1.,
     linewidth=3,
     )
    
for i, ngram_size in enumerate(reversed(ngrams)):
    linewidth = 1. if ngram_size not in (50,20) else 3
    label = None
    if ngram_size == 20:
        label = "$n$ = 20"

    if ngram_size == 50:
        label = "$n$ = 50"

    plt.plot(range(51), cum_counts_dedup[ngram_size][bucket], 
             linewidth=linewidth,
             color=colors[len(ngrams) - i - 1],
             markersize=5,
             markeredgecolor='white',
             markerfacecolor=colors[i],
             markeredgewidth=0.5,
             label=label,
    )
plt.grid(True, linestyle='-', color='#E5E7EB', alpha=0.9, zorder=0)
plt.xlabel('Levenstein distance', fontsize=13)
plt.tick_params(labelsize=9)
plt.ylim(0, 22*bucket)

plt.ylabel('Cumulative number of near duplicates', fontsize=13)
plt.tight_layout()

# Add colorbar
sm = plt.cm.ScalarMappable(cmap=plt.cm.winter)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label('n-gram size', fontsize=13)
cbar.set_ticks(np.linspace(0, 1, 10))
cbar.set_ticklabels([str(x) for x in range(10,110,10)])
cbar.ax.set_position([cbar.ax.get_position().x0, cbar.ax.get_position().y0 , 
                      cbar.ax.get_position().width, cbar.ax.get_position().height * 0.85])
plt.legend(loc='center left', bbox_to_anchor=(1.0,0.94), fontsize=9)
plt.show()