In [1]:
# first - to start ScyllaDB run the following in a terminal:
# sudo mkdir -p ./data/scylla/data ./scylla/commitlog
# docker run --rm -ti -p 127.0.0.1:9042:9042 -v ${PWD}/data/scylla:/var/lib/scylla scylladb/scylla --overprovisioned 1 --listen-address 0.0.0.0 --broadcast-rpc-address 127.0.0.1


# then wait for the following message to appear before proceeding:
# init - Scylla version 5.1.3-0.20230112.addc4666d502 initialization completed.

In [1]:
import time
from pathlib import Path
import narrow_down as nd
import ray
from ray.util import ActorPool
from datasets import load_from_disk
from narrow_down import _tokenize
from narrow_down.similarity_store import SimilarityStore
from narrow_down.storage import StorageLevel
from narrow_down.scylladb import ScyllaDBStore

import cassandra

ray.init(ignore_reinit_error=True)

similarity_threshold = 0.85
word_ngrams = 5

@ray.remote(num_cpus=1)
class Worker(object):
    def __init__(self):
        return

    # initialize the similarity store - we'll call first and only once per worker
    async def initialize(self):
        try:
            cassandra_cluster = cassandra.cluster.Cluster(contact_points=["localhost"], port=9042) # type: ignore
            session = cassandra_cluster.connect()
            session.execute(
                "CREATE KEYSPACE IF NOT EXISTS dedupe "
                "WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1} "
                "AND durable_writes = False"
            )
            cassandra_storage = ScyllaDBStore(session, keyspace="dedupe")
            self.similarity_store = await nd.similarity_store.SimilarityStore.create(
                storage_level=StorageLevel.Minimal,
                similarity_threshold=similarity_threshold,
                tokenize=lambda s: _tokenize.word_ngrams(s, word_ngrams),
                storage=cassandra_storage
            )
        except Exception as e:
            print(e)

    # process a document - we'll call this for each document
    async def process(self, document_text, document_id):
        try:            
            result = await self.similarity_store.insert(document=document_text, document_id=document_id)
            return result
        except Exception as e:
            print(e)

    # find duplicates for a document - we'll call this for each document
    async def find_duplicates(self, document_text, document_id):
        try:
            result = await self.similarity_store.query(document=document_text)
            return {"query_doc_id": document_id, "duplicate_indices": [r.id_ for r in result if r.id_ != document_id], 'has_duplicates': len(result) > 1}
        except Exception as e:
            print(e)




2023-01-31 05:15:54,642	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m


In [2]:
# connect to ScyllaDB and create a keyspace
cassandra_cluster = cassandra.cluster.Cluster(contact_points=["localhost"], port=9042)
session = cassandra_cluster.connect()
session.execute(
    "CREATE KEYSPACE IF NOT EXISTS dedupe "
    "WITH replication = {'class': 'SimpleStrategy', 'replication_factor' : 1} "
    "AND durable_writes = False"
)

# create a similarity store - this configures ScyllaDB and allows us to ad-hoc query later on
cassandra_storage = ScyllaDBStore(session, keyspace="dedupe")
similarity_store = await nd.similarity_store.SimilarityStore.create(
    storage_level=StorageLevel.Minimal,
    similarity_threshold=similarity_threshold,
    tokenize=lambda s: _tokenize.word_ngrams(s, word_ngrams),
    storage=cassandra_storage,
)

In [3]:
# create the workers - allocate a worker per CPU core
actors = [Worker.remote() for _ in range(int(ray.available_resources()['CPU']))]
[actor.initialize.remote() for actor in actors]

[ObjectRef(28c7376153a43fb1cb6514c3d1e3b8bf147c52e10100000001000000),
 ObjectRef(7109b8141612f944ab5ba956f2de881a273513690100000001000000),
 ObjectRef(8c4854248414f6335f4e34eb9968d2bc9857d5910100000001000000),
 ObjectRef(9a0afb4ce5b46f16c617878de8cb393a258bc0340100000001000000),
 ObjectRef(261bd10b0466d7e85996dbd01a62d3c2f3baff5d0100000001000000),
 ObjectRef(b9a008d165a7e804f3c88909a29dfcf37bfe30c60100000001000000),
 ObjectRef(a8485d936ac2e7cc009eea7c29774293966f6a6c0100000001000000),
 ObjectRef(b58f0ee91e0a959907e83b7a6552c7f4c013d3fa0100000001000000),
 ObjectRef(0cb7b64917b5af44a3acfab4d57e562cea58e9dc0100000001000000),
 ObjectRef(b5f40f7c7d38fc7935a28c9741da349700a9ca3e0100000001000000),
 ObjectRef(0c025bfe7d0aed89c891d923fce4953d3157f89a0100000001000000),
 ObjectRef(a98b912db1b8ed143fc2d338f84b7b3af25576880100000001000000),
 ObjectRef(ec502c4fdc3aeab09001e053cedf3cdb5d5d11220100000001000000),
 ObjectRef(b1d906d2acc455b1284a17d1eb845536bd897fe90100000001000000),
 ObjectRef(04a86d731

In [4]:
# create an actor pool - this allows us to distribute the work across the workers similar to a multiprocessing pool
pool = ActorPool(actors)

In [5]:
paths = Path("./data/pile-v2-eda/cache_ds").glob('*')

total_len = 0
start = time.time()
for path in paths:
    print('------------------')
    print(path)
    ds = load_from_disk(path)
    # note - this is assuming that the id is unique across datasets, if not we'll want to generate a unique id for each document
    # (e.g. even a simple incrementing counter will do)
    res = pool.map(lambda actor, row: actor.process.remote(row['text'], int(row['id'])), ds)
    resolved = list(res)
    total_len += len(ds)
end = time.time()
print(f"Total time: {end - start}, len: {total_len}")

------------------
data/pile-v2-eda/cache_ds/CodePileReddit2022
------------------
data/pile-v2-eda/cache_ds/EuroParliamentProceedings




------------------
data/pile-v2-eda/cache_ds/ASFPublicMail
------------------
data/pile-v2-eda/cache_ds/CodePilePosts
------------------
data/pile-v2-eda/cache_ds/CodePileReddit2021
------------------
data/pile-v2-eda/cache_ds/S2ORC
------------------
data/pile-v2-eda/cache_ds/arXiv
------------------
data/pile-v2-eda/cache_ds/TED2020
------------------
data/pile-v2-eda/cache_ds/Enwiki
------------------
data/pile-v2-eda/cache_ds/TheStack
------------------
data/pile-v2-eda/cache_ds/AMPS
------------------
data/pile-v2-eda/cache_ds/PileV2Reddit2020
------------------
data/pile-v2-eda/cache_ds/AI4Code
------------------
data/pile-v2-eda/cache_ds/USPTO
------------------
data/pile-v2-eda/cache_ds/PileOfLaw
------------------
data/pile-v2-eda/cache_ds/OtherWiki
------------------
data/pile-v2-eda/cache_ds/USENET
------------------
data/pile-v2-eda/cache_ds/PileV2RedditPosts
------------------
data/pile-v2-eda/cache_ds/DMMath
------------------
data/pile-v2-eda/cache_ds/GNOME
-------------

In [6]:
all_dupes = {}

paths = Path("./data/pile-v2-eda/cache_ds").glob('*')
start = time.time()
for path in paths:
    print('------------------')
    print(path)
    ds = load_from_disk(path)
    res = pool.map(lambda actor, row: actor.find_duplicates.remote(row['text'], int(row['id'])), ds)
    resolved = list(res)
    for r in resolved:
        if r['has_duplicates']:
            all_dupes[r['query_doc_id']] = r['duplicate_indices']
end = time.time()
print(f"Total time: {end - start}, dupes: {len(all_dupes)}")

------------------
data/pile-v2-eda/cache_ds/CodePileReddit2022
------------------
data/pile-v2-eda/cache_ds/EuroParliamentProceedings
------------------
data/pile-v2-eda/cache_ds/ASFPublicMail
------------------
data/pile-v2-eda/cache_ds/CodePilePosts
------------------
data/pile-v2-eda/cache_ds/CodePileReddit2021
------------------
data/pile-v2-eda/cache_ds/S2ORC
------------------
data/pile-v2-eda/cache_ds/arXiv
------------------
data/pile-v2-eda/cache_ds/TED2020
------------------
data/pile-v2-eda/cache_ds/Enwiki
------------------
data/pile-v2-eda/cache_ds/TheStack
------------------
data/pile-v2-eda/cache_ds/AMPS
------------------
data/pile-v2-eda/cache_ds/PileV2Reddit2020
------------------
data/pile-v2-eda/cache_ds/AI4Code
------------------
data/pile-v2-eda/cache_ds/USPTO
------------------
data/pile-v2-eda/cache_ds/PileOfLaw
------------------
data/pile-v2-eda/cache_ds/OtherWiki
------------------
data/pile-v2-eda/cache_ds/USENET
------------------
data/pile-v2-eda/cache_ds

In [7]:
all_dupes

{1624678: [3010697],
 3010697: [1624678],
 19931: [16898,
  14345,
  3594,
  36875,
  33804,
  3085,
  32269,
  32276,
  43031,
  14362,
  32794,
  4636,
  36383,
  6176,
  36897,
  32290,
  14372,
  12837,
  12842,
  32811,
  16428,
  39979,
  43562,
  5680,
  32305,
  2102,
  42551,
  12857,
  32826,
  14399,
  6208,
  3649,
  34368,
  27715,
  36421,
  27719,
  22602,
  12363,
  24650,
  27726,
  37455,
  8272,
  22608,
  11858,
  14422,
  6233,
  5210,
  6237,
  34398,
  597089,
  14438,
  22631,
  4713,
  37481,
  35437,
  22642,
  45172,
  12406,
  30327,
  34934,
  35958,
  13434,
  45176,
  13948,
  6782,
  38019,
  30340,
  37001,
  650,
  32396,
  37521,
  30356,
  33431,
  22169,
  32413,
  22177,
  34979,
  12453,
  30377,
  4779,
  30383,
  36528,
  34995,
  6324,
  32954,
  44730,
  28348,
  34493,
  702,
  4798,
  392897,
  708,
  8902,
  31431,
  11976,
  3275,
  181964,
  13005,
  9422,
  12495,
  12496,
  38095,
  45777,
  45779,
  3284,
  22740,
  33494,
  33501,
  1