In [23]:
from fastembed import TextEmbedding
from tagmatch.vec_db import Embedder
import numpy as np
from tagmatch.vec_db import VecDB
from pydantic_settings import BaseSettings
import pickle
from sklearn.decomposition import PCA
from create_dummy_tags import tags

class Settings(BaseSettings):
    model_name: str
    cache_dir: str
    qdrant_host: str
    qdrant_port: int
    qdrant_collection: str
    reduced_embed_dim: int

    class Config:
        env_file = ".env"
settings = Settings()





In [51]:
embedder = Embedder(model_name=settings.model_name,cache_dir=settings.cache_dir)
tag_embeddings = [embedder.embed(tag) for tag in tags]
embed_vec=np.asarray(tag_embeddings)

Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 53773.13it/s]


In [50]:
embed_vec = embedder.embed(tags[0])

In [45]:
embed_vec[None,:].shape

(1, 384)

In [54]:
pca = PCA(n_components=settings.reduced_embed_dim)
pca.fit(embed_vec)
sum(pca.explained_variance_ratio_)

0.9513890146045014

In [13]:
red_embed_vecs = pca.transform(embed_vec)

In [14]:
vec_db = VecDB(
    host="http://localhost",
    port=6333,
    collection="pca_train_test_reduced",
    vector_size=settings.reduced_embed_dim,
)

ResponseHandlingException: [Errno 61] Connection refused

In [46]:
for name,vec in zip(tags,red_embed_vecs):
    vec_db.store(vec,{"name":name})

In [47]:
reduced_vector_matches = [vec_db.find_closest(query_vector, 5) for query_vector in red_embed_vecs]

In [60]:
reduced_vector_matches[0]

[ScoredPoint(id=10356360874192234358, version=0, score=0.99999994, payload={'name': 'Apollo 11'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=9793711092478664242, version=46, score=0.3996141, payload={'name': 'Moon Mission'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=12207098595657543863, version=1, score=0.3779359, payload={'name': 'Moon'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=12894050092224864253, version=80, score=0.37239692, payload={'name': 'Lunar Base'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=11172108735014080718, version=49, score=0.28158426, payload={'name': 'Saturn Mission'}, vector=None, shard_key=None, order_value=None)]

In [56]:
vec_db_full = VecDB(
    host="http://localhost",
    port=6333,
    collection="train_test",
    vector_size=embedder.embedding_dim,
)

In [57]:
for name,vec in zip(tags,embed_vec):
    vec_db_full.store(vec,{"name":name})

In [58]:

vector_matches = [vec_db_full.find_closest(query_vector, 5) for query_vector in embed_vec]


In [59]:
vector_matches[0]

[ScoredPoint(id=13583038290980744031, version=0, score=0.9999999, payload={'name': 'Apollo 11'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=9869776390968345825, version=46, score=0.7471009, payload={'name': 'Moon Mission'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=9306135850958137005, version=1, score=0.72500336, payload={'name': 'Moon'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=13627112133842308861, version=80, score=0.7157799, payload={'name': 'Lunar Base'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=13746131591790535275, version=14, score=0.6981933, payload={'name': 'NASA'}, vector=None, shard_key=None, order_value=None)]

In [61]:
pickle.dump(pca,open("pca.pkl","wb"))