In [2]:
import pandas as pd
import numpy as np
import dask
import dask.dataframe as dd

import logging

from dask import distributed, dataframe as dd
import re

# with open("/home/ubuntu/work/therapeutic_accelerator/scripts/base.py") as f:
#     exec(f.read())

In [3]:
# Create embeddings function with specter model
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

class specter_ef(EmbeddingFunction):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def embed_documents(self, texts: Documents) -> Embeddings:
        
        text_list = [re.sub("\n", " ", p) for p in texts]
        texts = [re.sub("\s\s+", " ", t) for t in text_list]
        
        # embed the documents somehow
        embeddings = []
        
        for text in texts:
            inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
            result = model(**inputs)
            embeddings.append(result.last_hidden_state[:, 0, :])
        
        return embeddings
    
specter_embeder = specter_ef(model, tokenizer)

In [4]:
from langchain.text_splitter import CharacterTextSplitter
# import tiktoken

# @dask.delayed
def token_len(text): 
    """ Get the length of tokens from text"""
    tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)['input_ids'][0]
    return len(tokens)
    
chunk_size = 2000

# create text splitters for processing the texts
text_splitter = CharacterTextSplitter(
    separator = "\n\n",
    chunk_size = chunk_size,
    chunk_overlap  = 20,
    length_function = token_len
)

# Chroma DB Connection

In [6]:
import chromadb
from chromadb.config import Settings

# Create chroma client
chroma = chromadb.Client(Settings(chroma_api_impl="rest",
                                  chroma_server_host="34.238.51.66", # EC2 instance public IPv4
                                  chroma_server_http_port=8000))

print("Nanosecond heartbeat on server", chroma.heartbeat()) # returns a nanosecond heartbeat. Useful for making sure the client remains connected.

# Check Existing connections
chroma.list_collections()

Nanosecond heartbeat on server 1689443485320180857000


[Collection(name=langchain_store),
 Collection(name=abstracts),
 Collection(name=fulltext),
 Collection(name=specter_abstracts)]

In [7]:
collection = chroma.get_or_create_collection("specter_abstracts", embedding_function=specter_ef(model, tokenizer))

In [8]:
collection.count()

259

# Text Processing Functions

In [None]:
def split_text(row):
    """ Split text into chunks """
    return text_splitter.split_text(row['abstract'])

def create_doc(splited_text, corpusid):
    """ Create documents for each chunk """
        
    try:
        docs = {
            "documents": splited_text, # list of all documents [doc1, doc2, doc3, ...]
            'ids': [f'{corpusid}-{i}' for i in range(len(splited_text))], # list of all ids [id1, id2, id3, ...]
            'metadatas': [{'corpusid': int(corpusid), 'chunk': i} for i in range(len(splited_text))] # list of dictionaries with metadata for each document
        }
        return docs

    except Exception as e:
        logging.error(e)
        
def add_to_collection(docs, collection):
    """ Add documents to collection """
    
    try:
        collection.add(**docs)
        return True
    
    except Exception as e:
        logging.error(e)
        return False

# Dask functions

In [None]:
# Create dask cluster
dask.config.set(scheduler='processes')  # overwrite default with multiprocessing scheduler

cluster = distributed.LocalCluster(name='local', n_workers=7, memory_limit = '4GiB', threads_per_worker=4)  # Launches a scheduler and workers locally
client = distributed.client._get_global_client() or distributed.Client(cluster)

client

In [None]:
def df_main(row, collection): 
    """ Main workflow """
    
    splited_text = split_text(row)
    
    docs = create_doc(splited_text, row['corpusId'])
    
    docs['embeddings'] = specter_embeder.embed_documents(docs['documents'])[0][0].tolist()
    
    addition_results = add_to_collection(docs, collection)
    
    return docs

In [None]:
def over_ddf(ddf, collection): 
    return ddf.apply(lambda x: x.apply(main), axis=1, meta=('x', 'object'), collection=collection)

# Abstract Processing

In [None]:
abstracts = dd.read_csv("/home/ubuntu/work/data/abstracts.csv")
abstracts.shape

In [None]:
abstracts = abstracts.map_partitions(over_ddf, args = (collection), meta=('docs', 'object'))

In [None]:
results = abstracts.apply(main, axis=1, args=(collection,))

In [12]:
collection.get(
    include=['documents']
)

{'ids': ['149108350-0',
  '71997817-0',
  '86179664-0',
  '7882437-0',
  '23708908-0',
  '13232625-0',
  '73484844-0',
  '229159752-0',
  '219603447-0',
  '15826244-0',
  '235536812-0',
  '233722185-0',
  '28537746-0',
  '207236403-0',
  '12757505-0',
  '214215182-0',
  '6350414-0',
  '25071297-0',
  '19367812-0',
  '205991753-0',
  '108660208-0',
  '129942654-0',
  '37599676-0',
  '41722145-0',
  '25730773-0',
  '943031-0',
  '248218154-0',
  '145831977-0',
  '143827095-0',
  '206670982-0',
  '10756415-0',
  '10554268-0',
  '2359269-0',
  '101135827-0',
  '20881238-0',
  '281441-0',
  '20003423-0',
  '43408092-0',
  '27094053-0',
  '32849717-0',
  '235698891-0',
  '86827093-0',
  '55088074-0',
  '54562693-0',
  '50775657-0',
  '11858060-0',
  '97910232-0',
  '67846863-0',
  '94343189-0',
  '85067485-0',
  '12200250-0',
  '88552083-0',
  '75269537-0',
  '149129283-0',
  '31998843-0',
  '22174058-0',
  '28827379-0',
  '32186116-0',
  '23611602-0',
  '35468332-0',
  '93713199-0',
  '2592

In [None]:
import boto3

In [None]:
s3 = boto3.resource('s3')

# Print out bucket names
for bucket in s3.buckets.all():
    print(bucket.name)

In [None]:
import torch
torch.save(test, '/home/ubuntu/work/bucket/tensors_abstracts/tensor0-0.pt')

In [None]:
# import dask processingbar
from dask.diagnostics import ProgressBar

with ProgressBar():
    tokens = df['abstract'].apply(tokenize_abstracts, meta=('abstract', 'object')).compute()

In [None]:
tokenized = client.map(tokenize_abstracts, df['abstract'])
inputs = client.map(run_inputs, tokenized)
embeddings = client.submit(get_embeddings, inputs)

In [None]:
# import dask processingbar
from dask.diagnostics import ProgressBar

with ProgressBar():
    abstract_embeddings = ddf['abstract'].apply(get_embeddings, meta=('abstract', 'object')).compute()