In [1]:
import pandas as pd
import sqlalchemy as sa

from transformers import T5Tokenizer

from dask import dataframe as dd
from dask.diagnostics import ProgressBar

import warnings
warnings.filterwarnings('ignore')

%run /home/ubuntu/work/therapeutic_accelerator/scripts/base.py

max_sequence_length = 1200
embedding_size = 200

# Create tokenizer for T5 model
T5tokens = T5Tokenizer.from_pretrained('t5-base', model_max_length = max_sequence_length)

In [2]:
from dask.distributed import Client, LocalCluster, progress

cluster = LocalCluster(name='local', n_workers=12, memory_limit = '2GiB', threads_per_worker=4)  # Launches a scheduler and workers locally
client = Client(cluster)  # Connect to distributed cluster and override default

# client = Client(threads_per_worker=4, n_workers=10)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 12
Total threads: 48,Total memory: 24.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:34081,Workers: 12
Dashboard: http://127.0.0.1:8787/status,Total threads: 48
Started: Just now,Total memory: 24.00 GiB

0,1
Comm: tcp://127.0.0.1:44047,Total threads: 4
Dashboard: http://127.0.0.1:34033/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:45225,
Local directory: /tmp/dask-scratch-space/worker-_6kqwgi9,Local directory: /tmp/dask-scratch-space/worker-_6kqwgi9

0,1
Comm: tcp://127.0.0.1:44679,Total threads: 4
Dashboard: http://127.0.0.1:40629/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:34805,
Local directory: /tmp/dask-scratch-space/worker-ayrikkl4,Local directory: /tmp/dask-scratch-space/worker-ayrikkl4

0,1
Comm: tcp://127.0.0.1:45779,Total threads: 4
Dashboard: http://127.0.0.1:46313/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:32901,
Local directory: /tmp/dask-scratch-space/worker-qkywnnh8,Local directory: /tmp/dask-scratch-space/worker-qkywnnh8

0,1
Comm: tcp://127.0.0.1:40465,Total threads: 4
Dashboard: http://127.0.0.1:46505/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:37943,
Local directory: /tmp/dask-scratch-space/worker-vmnoo8xq,Local directory: /tmp/dask-scratch-space/worker-vmnoo8xq

0,1
Comm: tcp://127.0.0.1:46501,Total threads: 4
Dashboard: http://127.0.0.1:39247/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:39127,
Local directory: /tmp/dask-scratch-space/worker-oeyu_xls,Local directory: /tmp/dask-scratch-space/worker-oeyu_xls

0,1
Comm: tcp://127.0.0.1:35797,Total threads: 4
Dashboard: http://127.0.0.1:43549/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:36617,
Local directory: /tmp/dask-scratch-space/worker-5iop2qci,Local directory: /tmp/dask-scratch-space/worker-5iop2qci

0,1
Comm: tcp://127.0.0.1:39813,Total threads: 4
Dashboard: http://127.0.0.1:46745/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:42877,
Local directory: /tmp/dask-scratch-space/worker-fgnlx41r,Local directory: /tmp/dask-scratch-space/worker-fgnlx41r

0,1
Comm: tcp://127.0.0.1:38159,Total threads: 4
Dashboard: http://127.0.0.1:42007/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:34217,
Local directory: /tmp/dask-scratch-space/worker-03ztge2i,Local directory: /tmp/dask-scratch-space/worker-03ztge2i

0,1
Comm: tcp://127.0.0.1:41009,Total threads: 4
Dashboard: http://127.0.0.1:39591/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:35745,
Local directory: /tmp/dask-scratch-space/worker-1sc794u6,Local directory: /tmp/dask-scratch-space/worker-1sc794u6

0,1
Comm: tcp://127.0.0.1:36137,Total threads: 4
Dashboard: http://127.0.0.1:35631/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:43503,
Local directory: /tmp/dask-scratch-space/worker-dg197g2y,Local directory: /tmp/dask-scratch-space/worker-dg197g2y

0,1
Comm: tcp://127.0.0.1:36113,Total threads: 4
Dashboard: http://127.0.0.1:46553/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:38397,
Local directory: /tmp/dask-scratch-space/worker-syovp6ea,Local directory: /tmp/dask-scratch-space/worker-syovp6ea

0,1
Comm: tcp://127.0.0.1:42653,Total threads: 4
Dashboard: http://127.0.0.1:43751/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:37507,
Local directory: /tmp/dask-scratch-space/worker-4zepuqam,Local directory: /tmp/dask-scratch-space/worker-4zepuqam


In [None]:
# from dask import dataframe as dd

# ddf = dd.read_sql_table('fulltext', 
#                         con = f'postgresql://postgres:{keys["postgres"]}@{config["database"]["host"]}:5432/postgres',
#                         index_col = 'id',
#                         head_rows = 10,
#                         npartitions = 400)
# ft = ddf.loc[:, ['corpusid', 'text']]
# # write out dask series to parquet
# ft.to_csv(f'/home/ubuntu/work/data/fulltext_csvs/fulltext-*.csv')

# Create Embeddings

In [3]:
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
import tiktoken

def token_len(text): 
    """ Get the length of tokens from text"""
    tokens = T5tokens.encode(text)
    return len(tokens)
    
chunk_size = 1200

# create text splitters for processing the texts
text_splitter = CharacterTextSplitter(
    separator = "\n\n",
    chunk_size = chunk_size,
    chunk_overlap  = 20,
    length_function = token_len
)

In [None]:
# recursive_splitter = RecursiveCharacterTextSplitter(
#     separators = ["\n\n", "\n"],
#     chunk_size = chunk_size,
#     chunk_overlap  = 20,
#     length_function = token_len,
# )

In [None]:
# %%capture --no-stdout --no-stderr

# # Read in fulltext from csvs for dask
# ft = dd.read_csv('/home/ubuntu/work/data/fulltext_csvs/fulltext-000.csv', sample = 10000000,
#                 sample_rows=10,
#                 lineterminator=None,
#                 dtype={'corpusid': 'int', 'text': 'object'})


# ft = ft.dropna(subset='text')
# ft = ft.compute()
# ft_temp = dd.from_pandas(ft, npartitions=30)
# ft_temp['text'] = ft_temp['text'].apply(text_splitter.split_text, meta=('text', 'object'))
# res = ft_scat.compute()
# # df = pd.DataFrame(data = {'corpusid': ft['corpusid'], 'text': res})

# Upload to Chroma

In [4]:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

In [5]:
# Create chroma client
chroma = chromadb.Client(Settings(chroma_api_impl="rest",
                                  chroma_server_host="54.175.241.78", # EC2 instance public IPv4
                                  chroma_server_http_port=8000))

print("Nanosecond heartbeat on server", chroma.heartbeat()) # returns a nanosecond heartbeat. Useful for making sure the client remains connected.

# Check Existing connections
chroma.list_collections()

Nanosecond heartbeat on server 1688718739683432044000


[Collection(name=langchain_store),
 Collection(name=abstracts),
 Collection(name=fulltext)]

In [6]:
# Sentence Transformers all-MiniLM-L6-v2 
default_ef = embedding_functions.DefaultEmbeddingFunction()

# Create collection to store embeddings with T5 sentence transformer
def create_collection(chroma, name, metadata = {"hnsw:space":"cosine"}, embedding_function = default_ef):
    try:
        chroma.create_collection(name=name, metadata=metadata, embedding_function=embedding_function)
    except Exception as e:
        logging.error(e)

collection = chroma.get_or_create_collection("fulltext")

In [7]:
import logging

def create_document(texts, corpusid):
    # create metadatas
    metadatas = [{
        'corpusid': int(corpusid),
        'chunk': i
    } for i in range(len(texts))]
    
    ids = [f'{corpusid}-{i}' for i in range(len(texts))]
    
    try:
        docs = {
            "documents": texts, # list of all documents [doc1, doc2, doc3, ...]
            'ids': ids, # list of all ids [id1, id2, id3, ...]
            'metadatas': metadatas # list of dictionaries with metadata for each document
        }
        return docs
    
    except Exception as e:
        logging.error(e)

In [8]:
# def add_to_collection(text, corpusid):
    
#     doc = create_document(text, corpusid)
    
#     try:
#         collection.add(**doc)
#     except Exception as e:
#         logging.error(e)

In [9]:
# temp = df.apply(lambda x: create_document(x['text'], x['corpusid']), axis=1)
# temp.apply(add_to_collection)

## Now with Dask

In [10]:
import logging
from dask import delayed         
            
@delayed
def add_to_collection(doc):
    
    try:
        collection.add(**doc)
    except Exception as e:
        logging.error(e)

In [11]:
# Read in fulltext from csvs for dask
ft = dd.read_csv('/home/ubuntu/work/data/fulltext_csvs/fulltext-*.csv', sample = 10000000,
                sample_rows=10,
                lineterminator=None,
                dtype={'corpusid': 'int', 'text': 'object'})

ft = ft.map_partitions(pd.DataFrame.drop, columns='id')
ft = ft.map_partitions(pd.DataFrame.dropna, subset='text')
ft = ft.map_partitions(pd.DataFrame.reset_index, drop=True)

# split the text in partitions
ft = ft.map_partitions(lambda df: df.assign(text=df['text'].apply(text_splitter.split_text)),
                       meta={'corpusid': 'int', 'text': 'object'})

# create documents
ft = ft.map_partitions(lambda df: df.assign(docs=df[['text', 'corpusid']].apply(lambda x: create_document(x[0], x[1]), axis=1)),
                       meta={'corpusid': 'int', 'text': 'object', 'docs': 'object'})

In [12]:
upload = ft['docs'].map_partitions(add_to_collection, meta=('docs', 'object'))

In [13]:
upload

Dask Series Structure:
npartitions=400
    object
       ...
     ...  
       ...
       ...
Name: docs, dtype: object
Dask Name: add_to_collection, 8 graph layers

In [14]:
from dask.distributed import progress

result = upload.persist()
progress(result)

VBox()

# Embed the text

In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

# @dask.delayed
def tokenize(text):
    inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
    return inputs

# @dask.delayed
def get_embeddings(inputs):
    result = model(**inputs).last_hidden_state[:, 0, :].tolist()
    return result

In [None]:
inputs = tokenize(res[0])

In [None]:
embed = get_embeddings(inputs)