# Populating Embedding Vectors in MongoDB Atlas

In this Python notebook, we'll be using the embedding models we've downloaded to our local device to create embedding attributes for our movies dataset. Once we've done that, we'll be using `pymongo` to add these new embedding attributes to our dataset in MongoDB.

We're generating all the embeddings locally (i.e. no external API calls needed).

In [2]:
# Load settings from .env file
import sys
from dotenv import find_dotenv, dotenv_values

# Change system path to root directory
sys.path.insert(0, '../')

# _ = load_dotenv(find_dotenv()) # read local .env file
config = dotenv_values(find_dotenv())

# For debugging purposes
# print (config)

ATLAS_URI = config.get('ATLAS_URI')

if not ATLAS_URI:
    raise Exception ("'ATLAS_URI' is not set.")
else:
    print("ATLAS_URI Connection string found:")

ATLAS_URI Connection string found:


In [4]:
DB_NAME = 'sample_mflix'
COLLECTION_NAME = 'embedded_movies'

##  Initialize Mongo Atlas Client


In [7]:
from AtlasClient import AtlasClient

atlas_client = AtlasClient (ATLAS_URI, DB_NAME)
print("Connected to the Mongo Atlas database!")


collection = atlas_client.get_collection(COLLECTION_NAME)
document_count = collection.count_documents({})

print (f"Document count = {document_count:,} movies")

Connected to the Mongo Atlas database!
Document count = 3,483 movies


## Step 3: Generate Embeddings

Now for the fun part - we're going to generate all embeddings locally on our computer, using open source models. No API calls or API KEYS needed! 😄

As mentioned, we'll be using the following models:

| model name                              | overall score | model params | model size | embedding length | url                                                            |
|-----------------------------------------|---------------|--------------|------------|------------------|----------------------------------------------------------------|
| BAAI/bge-small-en-v1.5                  | 62.x          | 33.5 M       | 133 MB     | 384              | https://huggingface.co/BAAI/bge-small-en-v1.5                  |
| sentence-transformers/all-mpnet-base-v2 | 57.8          |              | 438 MB     | 768              | https://huggingface.co/sentence-transformers/all-mpnet-base-v2 |
| sentence-transformers/all-MiniLM-L6-v2  | 56.x          |              | 91 MB      | 384              | https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2  |

In [8]:
import os
# Set llamaindex cache dir to ../cache dir here (Default is system tmp)
# This way, we can easily see downloaded artifacts
os.environ['LLAMA_INDEX_CACHE_DIR'] = os.path.join(os.path.abspath('../'), 'cache')

In [9]:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import time

# Helper function to calculate embeddings, given a model
def create_embeddings (movies, embedding_model, embedding_attr):
    embed_model = HuggingFaceEmbedding(model_name=embedding_model)

    t2a = time.perf_counter()
    for movie in movies:
        movie[embedding_attr] = embed_model.get_text_embedding(movie['plot'])

    t2b = time.perf_counter()
    # print (f'Embeddings generated for {len(movies):,} movies  in {(t2b-t2a)*1000:,.0f} ms')

In [None]:
# Fetch all movies
t1a = time.perf_counter()
movies = [m for m in atlas_client.find (collection_name=COLLECTION_NAME, filter={'plot':{"$exists": True}}, limit=0)]
t1b = time.perf_counter()

print (f'Fetched {len(movies):,} from Atlas in {(t1b-t1a)*1000:,.0f} ms')

In [None]:
# Embedding models we want to use
model_mappings = {
    'BAAI/bge-small-en-v1.5' : {'embedding_attr' : 'plot_embedding_bge_small', 'index_name' : 'idx_plot_embedding_bge_small'},
    'sentence-transformers/all-mpnet-base-v2' : {'embedding_attr' : 'plot_embedding_mpnet_base_v2', 'index_name' : 'idx_plot_embedding_mpnet_base_v2'},
    'sentence-transformers/all-MiniLM-L6-v2' : {'embedding_attr' : 'plot_embedding_minilm_l6_v2', 'index_name' : 'idx_plot_embedding_minilm_l6_v2'},
}

In [None]:
# For selected embedding models above, we are going to create vectors for the plot field
# Each embedding model will have its own 'plot_embedding' attribute (i.e. we don't want to mix them up)


for key in model_mappings.keys():
    embedding_model = key
    embedding_attr = model_mappings[key]['embedding_attr']

    print (f'\n------- Embedding Model = {embedding_model} ---------')
    t1a = time.perf_counter()
    create_embeddings(movies=movies, embedding_model=embedding_model, embedding_attr=embedding_attr)
    t1b = time.perf_counter()
    avg_time_per_movie = (t1b-t1a)*1000 / len(movies)
    print (f'model={embedding_model}, created embeddings for {len(movies):,} movies in {(t1b-t1a)*1000:,.0f} ms, avg_time_per_movie={avg_time_per_movie:,.0f} ms')

## Step 4: Inspect Generated Embeddings
We  have succesfully generated 3 sets of embeddings, 1 set for each embedding model we used.

In [None]:
import random

movie = random.choice(movies)
print ('Randomly selected movie: ', movie['title'])
print ('Movie plot: ', movie['plot'], '\n')
print (f'plot_embeddings (existing openAI generated), len={len(movie["plot_embedding"])} , {movie["plot_embedding"][:5]}...\n')
print (f'plot_embedding_bge_small , len={len(movie["plot_embedding_bge_small"])} , {movie["plot_embedding_bge_small"][:5]}...\n')
print (f'plot_embedding_mpnet_base_v2 , len={len(movie["plot_embedding_mpnet_base_v2"])} , {movie["plot_embedding_mpnet_base_v2"][:5]}...\n')
print (f'plot_embedding_minilm_l6_v2 , len={len(movie["plot_embedding_minilm_l6_v2"])} , {movie["plot_embedding_minilm_l6_v2"][:5]}...')

## Step 5: Add Embeddings to MongoDB Atlas


In [None]:
from pymongo import ReplaceOne

collection = atlas_client.get_collection(COLLECTION_NAME)
replacements = [ReplaceOne ({"_id" : movie["_id"]}, movie) for movie in movies]

# Perform bulk replacement
print (f'About to update {len(replacements)} movies in Atlas...')
t1a = time.perf_counter()
result = collection.bulk_write(replacements)
t1b = time.perf_counter()

# Print result
print(f"Update matched count: {result.matched_count}")
print(f"Update modified count: {result.modified_count}")
print (f'Updated {len(movies):,} in Atlas in {(t1b-t1a)*1000:,.0f} ms')