In [1]:
import pandas as pd
import numpy as np
import dask
import dask.dataframe as dd

import logging

import re

# with open("/home/ubuntu/work/therapeutic_accelerator/scripts/base.py") as f:
#     exec(f.read())

In [24]:
# Create embeddings function with specter model
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

class specter_ef(EmbeddingFunction):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def embed_documents(self, texts: Documents) -> Embeddings:
        
        text_list = [re.sub("\n", " ", p) for p in texts]
        texts = [re.sub("\s\s+", " ", t) for t in text_list]
        
        # embed the documents somehow
        embeddings = []
        
        for text in texts:
            inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
            result = model(**inputs)
            embeddings.append(result.last_hidden_state[:, 0, :])
        
        return embeddings
    
    
specter_embeder = specter_ef(model, tokenizer)

In [None]:
# # Create dask cluster
# dask.config.set(scheduler='processes')  # overwrite default with multiprocessing scheduler

# cluster = distributed.LocalCluster(name='local', n_workers=7, memory_limit = '4GiB', threads_per_worker=4)  # Launches a scheduler and workers locally
# client = distributed.client._get_global_client() or distributed.Client(cluster)

# client

In [4]:
from langchain.text_splitter import CharacterTextSplitter
# import tiktoken

# @dask.delayed
def token_len(text): 
    """ Get the length of tokens from text"""
    tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)['input_ids'][0]
    return len(tokens)
    
chunk_size = 2000

# create text splitters for processing the texts
text_splitter = CharacterTextSplitter(
    separator = "\n\n",
    chunk_size = chunk_size,
    chunk_overlap  = 20,
    length_function = token_len
)

In [5]:
import chromadb
from chromadb.config import Settings

# Create chroma client
chroma = chromadb.Client(Settings(chroma_api_impl="rest",
                                  chroma_server_host="54.210.84.192", # EC2 instance public IPv4
                                  chroma_server_http_port=8000))

print("Nanosecond heartbeat on server", chroma.heartbeat()) # returns a nanosecond heartbeat. Useful for making sure the client remains connected.

# Check Existing connections
chroma.list_collections()

Nanosecond heartbeat on server 1689387218119775774000


[Collection(name=langchain_store),
 Collection(name=abstracts),
 Collection(name=fulltext),
 Collection(name=specter_abstracts)]

In [6]:
collection = chroma.get_or_create_collection("specter_abstracts", embedding_function=specter_ef(model, tokenizer))

In [7]:
def create_doc(splited_text, corpusid):
    """ Create documents for each chunk """
        
    try:
        docs = {
            "documents": splited_text, # list of all documents [doc1, doc2, doc3, ...]
            'ids': [f'{corpusid}-{i}' for i in range(len(splited_text))], # list of all ids [id1, id2, id3, ...]
            'metadatas': [{'corpusid': int(corpusid), 'chunk': i} for i in range(len(splited_text))] # list of dictionaries with metadata for each document
        }
        return docs

    except Exception as e:
        logging.error(e)

In [8]:
def add_to_collection(docs, collection):
    """ Add documents to collection """
    
    try:
        collection.add(**docs)
        return True
    
    except Exception as e:
        logging.error(e)
        return False

In [10]:
def split_text(row):
    """ Split text into chunks """
    return text_splitter.split_text(row['abstract'])

In [31]:
docs['documents']

['Constitutive JAK/STAT3 signaling contributes to disease progression in many lymphoproliferative disorders. Recent genetic analyses have revealed gain-of-function STAT3 mutations in lymphoid cancers leading to hyperactivation of STAT3, which may represent a potential therapeutic target. Using a functional reporter assay, we screened 306 compounds with selective activity against various target molecules to identify drugs capable of inhibiting the cellular activity of STAT3. Top hits were further validated with additional models including STAT3-mutated natural killer (NK)-cell leukemia/lymphoma cell lines and primary large granular lymphocytic (LGL) leukemia cells to assess their ability to inhibit STAT3 phosphorylation and STAT3 dependent cell viability. We identified JAK, mTOR, Hsp90 and CDK inhibitors as potent inhibitors of both WT and mutant STAT3 activity. The Hsp90 inhibitor luminespib was highly effective at reducing the viability of mutant STAT3 NK cell lines and LGL leukemia p

In [32]:
specter_embeder.embed_documents(docs['documents'])

[tensor([[ 3.4218e-01,  6.8138e-02,  5.3174e-02, -4.8539e-01,  6.4126e-01,
           8.8838e-01,  5.6956e-01,  3.4740e-01,  8.1531e-01,  1.0447e+00,
          -1.1157e+00,  6.9248e-02, -1.5451e-01,  5.6292e-01, -3.9001e-01,
          -2.5722e-01, -1.7069e-01,  2.7875e-01,  6.7626e-01, -2.1069e-01,
          -1.5390e+00,  5.4673e-01,  7.2121e-01,  1.4600e+00, -3.4173e-01,
          -3.2670e-01,  1.2596e-01, -2.8560e-01,  2.8500e-01, -1.3940e-01,
          -7.2112e-01,  1.5694e+00, -1.0804e-01,  6.3248e-01, -1.7919e+00,
           6.2874e-01, -2.5977e-01,  1.1278e+00,  8.9014e-01,  3.8433e-01,
          -1.4571e-02, -1.8413e-01, -2.5524e-01, -3.3604e-01, -1.3239e-01,
          -3.4614e-01, -2.9149e-01, -2.0888e-02,  6.4878e-02,  7.6201e-01,
           2.4094e-01,  5.6828e-01, -2.1536e-01, -2.5431e-01, -3.8628e-02,
           5.6838e-01, -5.0036e-02, -6.8892e-01, -7.2793e-02, -1.1216e-01,
           8.6185e-01,  6.2680e-01, -1.8049e-02,  8.4125e-01,  1.2957e+00,
           1.6533e-01,  1

In [51]:
def main(row, collection): 
    """ Main workflow """
    
    splited_text = split_text(row)
    
    docs = create_doc(splited_text, row['corpusId'])
    
    docs['embeddings'] = specter_embeder.embed_documents(docs['documents'])[0][0].tolist()
    
    addition_results = add_to_collection(docs, collection)
    
    return docs

In [12]:
abstracts = pd.read_csv("/home/ubuntu/work/data/abstracts.csv")
abstracts.shape

In [13]:
abstracts.head()

Unnamed: 0,id,paperId,corpusId,abstract
0,1,6ec7c156b4173ad7ca0dbc654da9267474644a41,23708908,Constitutive JAK/STAT3 signaling contributes t...
1,2,c856627242a754d2d756b32843523e6d7a089148,13232625,Summary: The current work characterizes young ...
2,3,d4c9b2fa2b760b5cf90ce8635a7dede5b4cd58a2,73484844,Ionotropic glutamate receptors (iGluRs) mediat...
3,4,7f13abe2c82bf0c66ca423e905d8f5967c4517b1,229159752,Aim The current pandemic of coronavirus diseas...
4,7,9b9659a4e9a4a48d7c532c76dd14ee9ccd696025,219603447,Background The Brain Injury Guidelines provide...


In [54]:
results = abstracts.apply(main, axis=1, args=(collection,))

ERROR:root:IDs ['23708908-0'] already exist in collection ac7095cd-348d-4c16-bb3f-975081402b45


: 

: 

In [53]:
collection.get(
    include=['documents']
)

{'ids': ['23708908-0'],
 'embeddings': None,
 'metadatas': None,
 'documents': ['Constitutive JAK/STAT3 signaling contributes to disease progression in many lymphoproliferative disorders. Recent genetic analyses have revealed gain-of-function STAT3 mutations in lymphoid cancers leading to hyperactivation of STAT3, which may represent a potential therapeutic target. Using a functional reporter assay, we screened 306 compounds with selective activity against various target molecules to identify drugs capable of inhibiting the cellular activity of STAT3. Top hits were further validated with additional models including STAT3-mutated natural killer (NK)-cell leukemia/lymphoma cell lines and primary large granular lymphocytic (LGL) leukemia cells to assess their ability to inhibit STAT3 phosphorylation and STAT3 dependent cell viability. We identified JAK, mTOR, Hsp90 and CDK inhibitors as potent inhibitors of both WT and mutant STAT3 activity. The Hsp90 inhibitor luminespib was highly effec

In [None]:
import boto3

In [None]:
s3 = boto3.resource('s3')

# Print out bucket names
for bucket in s3.buckets.all():
    print(bucket.name)

In [None]:
import torch
torch.save(test, '/home/ubuntu/work/bucket/tensors_abstracts/tensor0-0.pt')

In [None]:
# import dask processingbar
from dask.diagnostics import ProgressBar

with ProgressBar():
    tokens = df['abstract'].apply(tokenize_abstracts, meta=('abstract', 'object')).compute()

In [None]:
tokenized = client.map(tokenize_abstracts, df['abstract'])
inputs = client.map(run_inputs, tokenized)
embeddings = client.submit(get_embeddings, inputs)

In [None]:
# import dask processingbar
from dask.diagnostics import ProgressBar

with ProgressBar():
    abstract_embeddings = ddf['abstract'].apply(get_embeddings, meta=('abstract', 'object')).compute()