In [1]:
import lancedb
from lancedb.embeddings import with_embeddings
from lancedb.pydantic import vector, pydantic_to_schema
import numpy as np
import openai
import pyarrow as pa
import pydantic

In [18]:
from abc import ABC, abstractmethod
from uuid import uuid4

class LanceDBMixin(ABC):
    embedding: vector(1536)
        
    @classmethod
    @abstractmethod
    def uri(cls):
        pass
    
    @classmethod
    @abstractmethod
    def table_name(self):
        pass
        
    @classmethod
    def get_table(cls) -> lancedb.table.LanceTable:
        db = lancedb.connect(cls.uri())
        if cls.table_name() in db:
            return db.open_table(cls.table_name())
        return db.create_table(cls.table_name(), schema=cls.get_arrow_schema())
        
    @classmethod
    def get_arrow_schema(cls) -> pa.Schema:
        return pydantic_to_schema(cls)        
    
    @classmethod
    def generate_embedding(cls, batch: str | list[str]) -> list[np.ndarray]:
        rs = openai.Embedding.create(input=batch, engine="text-embedding-ada-002")
        return [record["embedding"] for record in rs["data"]]
    
    @classmethod
    def create_instances(cls, data: list[dict]) -> list[pydantic.BaseModel]:
        to_embed_name = cls._find_raw_data_column() # could be None
        schema = cls.get_arrow_schema()
        if to_embed_name is not None:
            arrow_table = with_embeddings(cls.generate_embedding, 
                                          pa.Table.from_pylist(data), 
                                          column=to_embed_name)
            arrow_table = arrow_table.rename_columns(
                [name if name != "vector" else "embedding"
                 for name in arrow_table.schema.names])
            arrow_table = arrow_table.select(schema.names).cast(schema)
        else:
            arrow_table = pa.Table.from_pylist(data, schema=schema)
        cls.get_table().add(arrow_table)
        return [cls(**row) for row in arrow_table.to_pylist()]
    
    @classmethod
    def _find_raw_data_column(cls):
        for name, field in cls.model_fields.items():
            if (field.json_schema_extra or {}).get("vector_input_column"):
                return name
        return None
    
    @classmethod
    def retrieve(cls, query, n=3, filter=None) -> list[pydantic.BaseModel]:
        tbl = cls.get_table()
        embedding = cls.generate_embedding(query)[0]
        arrow_table = tbl.search(embedding, vector_column_name="embedding").where(filter).to_arrow()
        return [cls(**row) for row in arrow_table.to_pylist()]

    @classmethod
    def reset(cls):
        db = lancedb.connect(cls.uri())
        db.drop_table(cls.table_name())
        
    
def lancedb_model(uri, table_name):
    return type(f'LanceDBMixin_{uuid4().hex}', 
                (LanceDBMixin,), 
                {
                    "uri": classmethod(lambda cls: uri),
                    "table_name": classmethod(lambda cls: table_name)
                })

In [19]:
import pydantic
from lancedb.pydantic import vector

class MyModel(pydantic.BaseModel, lancedb_model("~/.lancedb", "test")):
    text: str = pydantic.Field("string", vector_input_column=True)


In [None]:
inst = MyModel.create_instances([{"text": "hamlet"}, {"text": "the bible"}, {"text": "lord of the rings"}])

In [21]:
MyModel.retrieve("play written by shakespeare", n=1)[0].text

'hamlet'