In [None]:
# Create embeddings function with specter model
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

class specter_ef(EmbeddingFunction):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def embed_documents(self, texts: Documents) -> Embeddings:
        
        text_list = [re.sub("\n", " ", p) for p in texts]
        texts = [re.sub("\s\s+", " ", t) for t in text_list]
        
        # embed the documents somehow
        embeddings = []
        
        for text in texts:
            inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
            result = model(**inputs)
            embeddings.append(result.last_hidden_state[:, 0, :])
        
        return embeddings
    
specter_embeder = specter_ef(model, tokenizer)

In [None]:
import chromadb
from chromadb.config import Settings

# Create chroma client
chroma = chromadb.Client(Settings(chroma_api_impl="rest",
                                  chroma_server_host="34.238.51.66", # EC2 instance public IPv4
                                  chroma_server_http_port=8000))

print("Nanosecond heartbeat on server", chroma.heartbeat()) # returns a nanosecond heartbeat. Useful for making sure the client remains connected.

# Check Existing connections
display(chroma.list_collections())

collection = chroma.get_or_create_collection("specter_abstracts")

In [None]:
def get_question_embeddings(question): 
    # Embed question
    question_embeddings = specter_embeder.embed_documents([question])[0][0].tolist()
    
    return question_embeddings

def query_chroma(question_embeddings):
    # Query ChromaDB with Embeddings
    results = collection.query(
        query_embeddings=[question_embeddings],
        n_results=10
        # where={"metadata_field": "is_equal_to_this"},
        # where_document={"$contains":"search_string"}
    )
    
    return results

In [None]:
question = input("What is your question? ")

question_embeddings = get_question_embeddings(question)

results = query_chroma(question_embeddings)

results