In [None]:
import pysolr
import uuid
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

In [None]:
class DocumentEmbedder:
    def __init__(self):
        self.model_name = "sentence-transformers/all-MiniLM-L6-v2"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name)
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_embedding(self, text: str) -> np.ndarray:
        encoded_input = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        
        encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}

        with torch.no_grad():
            model_output = self.model(**encoded_input)

        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
        
        return sentence_embeddings.cpu().numpy()

In [None]:
class SolrIndexer:
    def __init__(self, solr_url: str = 'http://localhost:8983/solr/embeddings'):
        self.solr = pysolr.Solr(solr_url, always_commit=True)
        self.embedder = DocumentEmbedder()

    def create_document(self, name: str, tags: str, category: str) -> Dict:
        """Create a document with embeddings."""
        embedding = self.embedder.get_embedding(tags)
        
        doc = {
            'id': str(uuid.uuid4()),
            'name': name,
            'tags': tags.split(', '),  
            'category': category,
            'vector': embedding.flatten().tolist()
        }
        return doc

    def index_documents(self, documents: List[Dict]):
        """Index multiple documents to Solr."""
        try:
            self.solr.add(documents)
            print(f"Successfully indexed {len(documents)} documents")
        except Exception as e:
            print(f"Error indexing documents: {str(e)}")

In [None]:
documents = [
    {"name": "Apple iPhone 13", "tags": "phone, smartphone, screen, iOS", "category": "phone"},
    {"name": "Apple iPhone 14", "tags": "phone, smartphone, screen, iOS", "category": "phone"},
    {"name": "Apple iPhone 15", "tags": "phone, smartphone, screen, iOS", "category": "phone"},
    {"name": "Samsung Galaxy S24", "tags": "phone, smartphone, screen, Android", "category": "phone"},
    {"name": "Apple iPod", "tags": "music, screen, iOS", "category": "music player"},
    {"name": "Samsung Microwave", "tags": "kitchen, cooking, electric", "category": "household"}
]

indexer = SolrIndexer()

solr_documents = []
for doc in documents:
    solrdoc = indexer.create_document(doc['name'], doc['tags'], doc['category'])
    solr_documents.append(solrdoc)
    print(f"Created document: {solrdoc['name']}")
    print(f"Vector length: {len(solrdoc['vector'])}")
    
indexer.index_documents(solr_documents)

In [None]:
embedder = DocumentEmbedder()
embeddings = embedder.get_embedding("song player")
print(f"Embeddings: {str(embeddings)}")