## Embedding with BERT

BERT doesn't have a `SentenceTransformers` implementation so we have to manually extract the last hidden state vector. The classification vector will be at position 0

In [1]:
from transformers import AutoModel, AutoTokenizer
import torch

In [2]:
sentences = [
    "The weather is lovely today.",
    "It's so sunny outside!",
    "He drove to the stadium.",
]
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [3]:
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint).to(device)
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [4]:
def get_embeddings(sentences: list[str]):
    inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
    outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :]

In [38]:
embeddings = get_embeddings(sentences)
print(embeddings.shape)

torch.Size([3, 768])


## Set up Chroma db

In [23]:
import torch
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings

client = chromadb.PersistentClient(path='./vector_stores/bert/')
print(client)

# Set up embedding model
class BertChromaEmbedder(EmbeddingFunction):
    def __init__(self, embedding_fn):
        self._encode = embedding_fn

    def __call__(self, input: Documents) -> Embeddings:
        outputs = self._encode(input)
        return outputs.tolist()

DEVICE = "cuda" if torch.cuda.is_available(
) else "mps" if torch.mps.is_available() else "cpu"
print(f"Using device: {DEVICE}")

<chromadb.api.client.Client object at 0x35d857fd0>
Using device: mps


In [24]:
embedder = BertChromaEmbedder(get_embeddings)
result = embedder('Hi there')


In [25]:
result[0].shape

(768,)

In [31]:
# Get all collection names
collections = client.list_collections()

# Delete each collection
for collection in collections:
    client.delete_collection(name=collection.name)

print("All collections have been deleted.")

All collections have been deleted.


In [32]:
collection = client.create_collection(
    name="bert_test",
    embedding_function=embedder,
    metadata={"hnsw:space": "cosine"}
)
print(f"Created {collection}")

Created Collection(name=bert_test)


### Add documents and test that it's using embedding function

In [33]:
import os
import json

PATH_TO_DATA = 'data/json/'
FILENAMES = os.listdir(PATH_TO_DATA)
data = dict()
for filename in FILENAMES:
    with open(f'{PATH_TO_DATA}/{filename}', 'r') as file:
        data[os.path.splitext(os.path.basename(filename))[0]] = json.load(file)

print("Found files:")
for filename in data:
    print(f"  {filename}")

Found files:
  Earth_Science_Reviews
  Earth_Science_Research
  Planetary_Research
  Planetary_Reviews
  Astro_Reviews
  Astro_Research


In [34]:
def preprocess_papers(papers):
    # Convert titles from list to string
    for paper in papers:
        paper['title'] = paper['title'][0]
    return papers


def construct_document(record, fields):
    """
    Construct a document from the specified fields
    """
    return "\n".join([record[field] for field in fields])


def prep_metadata(record):
    """
    JSONify any list or dict fields, as Chroma requires all metadata to be primitive
    """
    return {key: json.dumps(value) if isinstance(value, (list, dict)) else value for key, value in record.items()}

data = {key: preprocess_papers(value) for key, value in data.items()}

all_papers = data['Astro_Research'] + \
    data['Earth_Science_Research'] + data['Planetary_Research']
print(f"Number of records: {len(all_papers)}")

documents = [construct_document(
    paper, ['title', 'abstract', 'body']) for paper in all_papers]
metadatas = [prep_metadata(paper) for paper in all_papers]
ids = [paper['id'] for paper in all_papers]

assert len(documents) == len(metadatas) == len(ids)

Number of records: 3000


In [36]:
result = collection.add(
    documents=documents[:3],
    metadatas=metadatas[:3],
    ids=ids[:3]
)
print(result)

In [37]:
results = collection.query(
    query_texts=["the sun is a star", "black holes are interesting"],
    n_results=3,
    include=["embeddings", "documents", "metadatas"])

results['embeddings'][0].shape

(3, 768)

In [44]:
direct_embedding = get_embeddings(["the sun is a star"])
foo = collection.add(
    documents=["the sun is a star"],
    # metadatas=[{}],
    ids=['foo_id']
)
print(foo)

None


In [45]:
chroma_embedding = collection.query(
    query_texts=["the sun is a star"],
    n_results=1,
    include=["embeddings", "documents"]
)

print(chroma_embedding)

{'ids': [['foo_id']], 'embeddings': [array([[-2.02732235e-01,  1.33627385e-01, -1.02120608e-01,
        -3.99959236e-02, -6.64236486e-01, -2.67020136e-01,
         1.55756801e-01,  6.37282431e-01, -7.70211443e-02,
        -5.36820531e-01, -1.85379744e-01, -2.44763508e-01,
        -1.23746999e-01,  6.37905121e-01,  3.39354932e-01,
        -1.77032262e-01, -2.64158137e-02,  6.09097719e-01,
        -4.06437479e-02, -1.73810557e-01, -1.63882673e-01,
         5.66992834e-02, -1.27703965e-01, -2.82312721e-01,
         3.25120427e-02, -1.52548462e-01, -1.77558005e-01,
         4.34586331e-02,  5.07494211e-01,  2.06840962e-01,
         4.92889173e-02,  2.18657538e-01,  9.32179540e-02,
        -1.54707789e-01,  1.91068366e-01, -1.48803711e-01,
         6.60459772e-02,  3.60294059e-02,  1.59142032e-01,
         1.19473696e-01,  1.57703400e-01,  2.64891416e-01,
         2.58314520e-01,  2.46966839e-01, -1.84186935e-01,
        -4.04319048e-01, -2.05175376e+00, -9.97385196e-03,
         1.43912867

In [48]:
print(type(chroma_embedding['embeddings'][0]))
print(chroma_embedding['embeddings'][0][0].shape)

<class 'numpy.ndarray'>
(768,)


In [57]:
import numpy as np
direct_numpy = direct_embedding.cpu().detach().numpy()
chroma_numpy = chroma_embedding['embeddings'][0]
print(direct_numpy.shape)
print(chroma_numpy.shape)
print(np.array_equal(direct_numpy, chroma_numpy))

(1, 768)
(1, 768)
True
