In [3]:
import pandas as pd
import sqlalchemy as sa

from transformers import T5Tokenizer
from IPython.utils import io

import dask
from dask import dataframe as dd
from dask.diagnostics import ProgressBar
import dask.distributed as distributed

import warnings
import logging

warnings.filterwarnings('ignore')

%run /home/ubuntu/work/therapeutic_accelerator/scripts/base.py

max_sequence_length = 1200
embedding_size = 200

# Create tokenizer for T5 model
T5tokens = T5Tokenizer.from_pretrained('t5-base', model_max_length = max_sequence_length)

In [None]:
# from IPython.utils import io

# with io.capture_output() as captured:

In [None]:
# Create dask cluster
dask.config.set(scheduler='processes')  # overwrite default with multiprocessing scheduler

cluster = distributed.LocalCluster(name='local', n_workers=7, memory_limit = '4GiB', threads_per_worker=2)  # Launches a scheduler and workers locally
client = distributed.client._get_global_client() or distributed.Client(cluster)
client

# Create Embeddings

In [None]:
from langchain.text_splitter import CharacterTextSplitter
# import tiktoken

# @dask.delayed
def token_len(text): 
    """ Get the length of tokens from text"""
    tokens = T5tokens.encode(text)
    return len(tokens)
    
chunk_size = 512

# create text splitters for processing the texts
text_splitter = CharacterTextSplitter(
    separator = "\n\n",
    chunk_size = chunk_size,
    chunk_overlap  = 20,
    length_function = token_len
)

## Now with Dask

Functions to clean up dataframes

In [None]:
def split_text(df):
    """ Split text into chunks """
    df = df.assign(split_text = df['text'].apply(text_splitter.split_text))
    df = df.drop(columns = 'text')
    return df

# ------------------------------------------------------------------------------------
def create_doc(split_text, corpusid):
    """ Create documents for each chunk """
        
    try:
        docs = {
            "documents": split_text, # list of all documents [doc1, doc2, doc3, ...]
            'ids': [f'{corpusid}-{i}' for i in range(len(split_text))], # list of all ids [id1, id2, id3, ...]
            'metadatas': [{'corpusid': int(corpusid), 'chunk': i} for i in range(len(split_text))] # list of dictionaries with metadata for each document
        }
        return docs

    except Exception as e:
        logging.error(e)
        
def df_create_doc(df):
    """ Used for mapping partitions
    
    Takes dataframe
    
    Returns a series
    
    """
    return df.apply(lambda x: create_doc(x['split_text'], x['corpusid']), axis=1)

def mp_create_doc(ddf): 
    """ Used for mapping partitions"""
    return ddf.apply(df_create_doc, axis = 1)

# ------------------------------------------------------------------------------------
# Add documents to collection
def add_to_collection(docs):
    """ Add documents to collection """
    try:
        collection.add(**docs)
    except Exception as e:
        logging.error(e)
        
def ddf_add_to_collection(series):
    """ Add documents to collection """
    return series.apply(add_to_collection)

In [None]:
# Read in fulltext from csvs for dask
ft = dd.read_parquet('/home/ubuntu/work/data/fulltext_parquets/fulltext-*.parquet', sample=10000000,
                     sample_rows=10,
                     lineterminator=None,
                     dtype={'corpusid': 'int', 'text': 'object'})

# Cleanup dataframes
ft = ft.map_partitions(pd.DataFrame.dropna, subset='text')

ft = ft.map_partitions(pd.DataFrame.drop_duplicates, subset='text')

ft = ft.map_partitions(pd.DataFrame.reset_index, drop=True)

ft = ft.persist()


In [None]:
# creates futures to then act on to create tree of dependencies
results = ft.compute(optimize_graph = True, scheduler='processes', num_workers=7)

In [None]:
futures_split_text = [client.submit(split_text, f) for f in results]

In [None]:
docs = [client.submit(df_create_doc, f) for f in futures_split_text]

In [None]:
import os

# get number of files in a folder
def get_num_files(path):
    """ Get number of files in a folder """
    return len([name for name in os.listdir(path) if os.path.isfile(os.path.join(path, name))])

get_num_files('/home/ubuntu/work/data/fulltext_docs_csvs')


In [None]:
def create_csv(d):
    """ Create csvs for each chunk """
    df = pd.DataFrame(d)
    i = get_num_files('/home/ubuntu/work/data/fulltext_docs_csvs') + 1
    
    df.to_csv(f'/home/ubuntu/work/data/fulltext_docs_csvs/fulltext_doc_{i}.csv', index=False)
    
    del d

In [None]:
# Write out documents for easier uploading later
csvs  = [client.submit(create_csv, d) for d in docs]

In [None]:
from dask.distributed import as_completed

# Construct Chroma Collection

In [4]:
# Create dask cluster
dask.config.set(scheduler='processes')  # overwrite default with multiprocessing scheduler

cluster = distributed.LocalCluster(name='local', n_workers=7, memory_limit = '4GiB', threads_per_worker=2)  # Launches a scheduler and workers locally
client = distributed.client._get_global_client() or distributed.Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 7
Total threads: 14,Total memory: 28.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:35907,Workers: 7
Dashboard: http://127.0.0.1:8787/status,Total threads: 14
Started: Just now,Total memory: 28.00 GiB

0,1
Comm: tcp://127.0.0.1:37149,Total threads: 2
Dashboard: http://127.0.0.1:39429/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:35257,
Local directory: /tmp/dask-scratch-space/worker-xvdjq980,Local directory: /tmp/dask-scratch-space/worker-xvdjq980

0,1
Comm: tcp://127.0.0.1:41085,Total threads: 2
Dashboard: http://127.0.0.1:43747/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:42651,
Local directory: /tmp/dask-scratch-space/worker-pw970hnp,Local directory: /tmp/dask-scratch-space/worker-pw970hnp

0,1
Comm: tcp://127.0.0.1:36017,Total threads: 2
Dashboard: http://127.0.0.1:35705/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:46085,
Local directory: /tmp/dask-scratch-space/worker-0zhgesh3,Local directory: /tmp/dask-scratch-space/worker-0zhgesh3

0,1
Comm: tcp://127.0.0.1:43001,Total threads: 2
Dashboard: http://127.0.0.1:37969/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:35499,
Local directory: /tmp/dask-scratch-space/worker-5m9b8422,Local directory: /tmp/dask-scratch-space/worker-5m9b8422

0,1
Comm: tcp://127.0.0.1:34675,Total threads: 2
Dashboard: http://127.0.0.1:34241/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:40277,
Local directory: /tmp/dask-scratch-space/worker-guhbqmn9,Local directory: /tmp/dask-scratch-space/worker-guhbqmn9

0,1
Comm: tcp://127.0.0.1:39769,Total threads: 2
Dashboard: http://127.0.0.1:43697/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:35739,
Local directory: /tmp/dask-scratch-space/worker-tgr91n4n,Local directory: /tmp/dask-scratch-space/worker-tgr91n4n

0,1
Comm: tcp://127.0.0.1:42233,Total threads: 2
Dashboard: http://127.0.0.1:40235/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:37165,
Local directory: /tmp/dask-scratch-space/worker-an3fy8ly,Local directory: /tmp/dask-scratch-space/worker-an3fy8ly


In [10]:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

chroma_client = chromadb.Client(Settings(chroma_api_impl="rest",
                                chroma_server_host="18.233.156.143",  # EC2 instance public IPv4
                                chroma_server_http_port=8000))


collection = chroma_client.get_or_create_collection("fulltext")

In [5]:
ddf = dd.read_csv('/home/ubuntu/work/data/fulltext_docs_csvs_cleaned/fulltext_doc_*.csv')
ddf = client.scatter(ddf)

In [6]:
ddf

In [8]:
# ------------------------------------------------------------------------------------
# Add documents to collection
def add_to_collection(docs):
    """ Add documents to collection """
    try:
        collection.add(**docs)
    except Exception as e:
        logging.error(e)
        
def ddf_add_to_collection(series):
    """ Add documents to collection """
    return series.apply(add_to_collection, axis = 1)

In [11]:
test = client.submit(ddf_add_to_collection, ddf)

In [None]:
ddf

In [13]:
collection.count()

83675