In [1]:
import pysolr
import uuid
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

In [2]:
class DocumentEmbedder:
    def __init__(self):
        self.model_name = "answerdotai/ModernBERT-base"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_embedding(self, text: str) -> np.ndarray:
        encoded_input = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=768,
            return_tensors='pt'
        )
        
        encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
        
        with torch.no_grad():
            model_output = self.model(**encoded_input)
            
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
        
        return sentence_embeddings.cpu().numpy()

In [3]:
class SolrIndexer:
    def __init__(self, solr_url: str = 'http://localhost:8983/solr/embeddings'):
        self.solr = pysolr.Solr(solr_url, always_commit=True)
        self.embedder = DocumentEmbedder()

    def create_document(self, name: str, tags: str, category: str) -> Dict:
        embedding = self.embedder.get_embedding(tags)
        
        doc = {
            'id': str(uuid.uuid4()),
            'name': name,
            'tags': tags.split(', '),  
            'category': category,
            'vector': embedding.flatten().tolist()
        }
        return doc

    def index_documents(self, documents: List[Dict]):
        try:
            self.solr.add(documents)
            print(f"Successfully indexed {len(documents)} documents")
        except Exception as e:
            print(f"Error indexing documents: {str(e)}")

In [4]:
documents = [
    {"name": "Apple iPhone 13", "tags": "phone, smartphone, screen, iOS", "category": "phone"},
    {"name": "Apple iPhone 14", "tags": "phone, smartphone, screen, iOS", "category": "phone"},
    {"name": "Apple iPhone 15", "tags": "phone, smartphone, screen, iOS", "category": "phone"},
    {"name": "Samsung Galaxy S24", "tags": "phone, smartphone, screen, Android", "category": "phone"},
    {"name": "Apple iPod", "tags": "music, screen, iOS", "category": "music player"},
    {"name": "Samsung Microwave", "tags": "kitchen, cooking, electric", "category": "household"}
]

indexer = SolrIndexer()

solr_documents = []
for doc in documents:
    solrdoc = indexer.create_document(doc['name'], doc['tags'], doc['category'])
    solr_documents.append(solrdoc)
    print(f"Created document: {solrdoc['name']}")
    print(f"Vector length: {len(solrdoc['vector'])}")
    
indexer.index_documents(solr_documents)

Created document: Apple iPhone 13
Vector length: 768
Created document: Apple iPhone 14
Vector length: 768
Created document: Apple iPhone 15
Vector length: 768
Created document: Samsung Galaxy S24
Vector length: 768
Created document: Apple iPod
Vector length: 768
Created document: Samsung Microwave
Vector length: 768
Successfully indexed 6 documents


In [5]:
embedder = DocumentEmbedder()
embeddings = embedder.get_embedding("song player")
formatted = embeddings.flatten().tolist()
print(f"Embeddings: {str(embeddings)}")

Embeddings: [[ 2.10360941e-02 -6.01991033e-03  1.02332737e-02 -1.43522304e-02
  -1.48830963e-02 -5.99567953e-04  8.11345223e-03 -2.08218452e-02
   4.00092686e-03  3.62962601e-03  2.69588232e-02  3.00253220e-02
   5.63282939e-03  1.53024243e-02  4.71011288e-02 -1.31746279e-02
   5.24408044e-03  3.03235352e-02  1.89648606e-02 -1.29690152e-02
  -5.34556434e-03  2.84999353e-03  8.94962717e-03 -3.18741128e-02
  -6.99078944e-03 -4.15501138e-03  7.65044428e-03  1.12900967e-02
   4.11990890e-03 -3.42138065e-03 -1.10447789e-02 -6.84759095e-02
  -1.39015552e-03 -4.56782756e-03 -7.46989995e-03  1.83845696e-03
  -3.77940461e-02 -3.00570508e-03 -6.83693588e-03 -4.14641835e-02
  -1.23182312e-02  1.72455274e-02 -1.65808350e-02 -7.51126837e-03
   2.68548299e-02 -6.53659599e-03  2.09815186e-02 -2.43661962e-02
  -2.91947671e-03  4.42355359e-03 -1.23623293e-02 -3.28787137e-03
   6.68136263e-03  1.52024627e-02 -1.85656399e-02  4.40689176e-03
   2.11724676e-02  8.32620356e-03 -2.95622624e-03  1.72240566e-0