In [1]:
import chromadb
import tiktoken
from pyprojroot import here
from langchain_openai import OpenAIEmbeddings
import pandas as pd

# Embedding and saving as vector db using ChromaDB

In [3]:
chroma_client = chromadb.PersistentClient(path=str(here("data/db/chroma")))


In [5]:
try:
    collection = chroma_client.create_collection(name="imdb_sample")
except:
    collection = chroma_client.get_collection("imdb_sample")

In [6]:
df_movie_data = pd.read_csv(here("data/imdb_movies_100k.csv"))

In [7]:
embedding_client = OpenAIEmbeddings(model="text-embedding-3-small")

In [8]:
def count_tokens(text, model="text-embedding-3-small"):
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))

count_tokens("This is a test.")

5

In [9]:
def format_movie_data(movie):
    return f"""
Title: {movie['primaryTitle']} ({movie['startYear']})
Type: {movie['titleType']}
Runtime: {movie['runtimeMinutes']} minutes
Genres: {movie['genres']}
IMDb Rating: {movie['averageRating']} (Votes: {movie['numVotes']})

Cast & Crew:
{movie['cast_info']}
"""

### Checking how many tokens the data is

In [10]:
token_count = []
for index, row in df_movie_data.iterrows():
    text = format_movie_data(row)
    tokens = count_tokens(text)
    token_count.append(tokens)
    
print("Average per movie:", sum(token_count) / len(token_count))
print("Total tokens:", sum(token_count))

Average per movie: 416.1012658227848
Total tokens: 41583080


#### Embedding one row at the time takes a really long time -> Batch embeddings 
Here, we embed 20 rows as a sample

In [11]:
docs = []
metadatas = []
ids = []
embeddings = []

for index, row in df_movie_data.iloc[:20].iterrows():
    text = format_movie_data(row)
    embedding = embedding_client.embed_query(text)
    
    embeddings.append(embedding)
    docs.append(text)
    metadatas.append({"source": "imdb_movies_100k"})
    ids.append(f"{index}")

In [12]:
collection.add(ids=ids, documents=docs, metadatas=metadatas, embeddings=embeddings)


In [13]:
print("Number of vectors in vectordb:", collection.count())

Number of vectors in vectordb: 20
