In [1]:
from tqdm import tqdm
from typing import List, Dict, Any, Union
from pydantic import BaseModel
from lancedb.pydantic import LanceModel, Vector
import pandas as pd
from lancedb.embeddings.utils import api_key_not_found_help
from lancedb.embeddings.registry import register
from lancedb.embeddings import TextEmbeddingFunction, get_registry
import numpy as np
from functools import cached_property
from openai import AzureOpenAI
from azure_openai.setup import AzureOpenAiConfig
import nest_asyncio
import lancedb
from azure.identity import DefaultAzureCredential
import logging
import asyncio
import sys
import os
import hashlib

# Add parent directory to path
current_dir = os.path.dirname(os.path.abspath(__name__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

## LanceDBManager

IT is responsible for all interactions with the blob storage


before that we need a lancedb embedder.

### LanceDBEmbedder

It is responsible for all interactions with the tables.
there is a built in one but it had issues and could not be used. luckily we could make a custom one.

the problem with the built in one is that it does not support the following:

- microsoft ad authentication for enterprises

it worked fine for a normal user if you dint have to go through azure ad.


In [2]:

openai_config = AzureOpenAiConfig()
credentials = DefaultAzureCredential()
token_provdier = openai_config.get_token_provider(DefaultAzureCredential())

# make the custom embedder function caus ethe one provided by lancedb is not working


@register("azure_openai")
class AzureOpenAIEmbeddings(TextEmbeddingFunction):
    """
    An embedding function that uses the Azure OpenAI API
    """

    name: str = openai_config.text_embedder_lagre.deployment_name
    azure_api_key: str = openai_config.get_openai_api_key(
        credential=credentials)
    azure_endpoint: str = openai_config.endpoint
    azure_deployment: str = openai_config.text_embedder_lagre.deployment_name
    azure_api_version: str = openai_config.text_embedder_lagre.api_version

    def ndims(self):
        return self._ndims

    @cached_property
    def _ndims(self):
        if self.name == openai_config.text_embedder_lagre.deployment_name:
            return openai_config.text_embedder_lagre.ndims
        else:
            raise ValueError(f"Unknown model name {self.name}")

    def ndims(self):
        """
        Return the dimensionality of the embeddings.
        """
        if self.name == openai_config.text_embedder_lagre.deployment_name:
            return self._ndims
        else:
            raise ValueError(f"Unknown model name {self.name}")

    def generate_embeddings(
        self, texts: Union[List[str], np.ndarray]
    ) -> List[np.array]:
        """
        Get the embeddings for the given texts

        Parameters
        ----------
        texts: list[str] or np.ndarray (of str)
            The texts to embed
        """
        # TODO retry, rate limit, token limit
        if self.name == openai_config.text_embedder_lagre.deployment_name:
            rs = self._azure_openai_client.embeddings.create(
                input=texts, model=self.name)
        else:
            rs = self._azure_openai_client.embeddings.create(
                input=texts, model=self.name, dimensions=self.ndims()
            )
        return [v.embedding for v in rs.data]

    @cached_property
    def _azure_openai_client(self):
        if not os.environ.get("OPENAI_API_KEY") and not self.azure_api_key:
            api_key_not_found_help("openai")
        return AzureOpenAI(
            azure_endpoint=openai_config.endpoint,
            # azure_ad_token_provider=token_provdier,
            api_version=openai_config.api_version,
            api_key=openai_config.get_openai_api_key(DefaultAzureCredential()),
            max_retries=5,
        )


embedder = get_registry().get("azure_openai").create()

def get_embedder():
    return get_registry().get("azure_openai").create()


to check how to use and what you get back we have a test bellow


In [3]:
# Generate embeddings for a list of texts TEST

embeddings = embedder.generate_embeddings(["hello world", "goodbye world"])
# Convert embeddings into a pandas DataFrame
embeddings_df = pd.DataFrame(
    embeddings, columns=[f"dim_{i}" for i in range(len(embeddings[0]))])
embeddings_df.head()

Unnamed: 0,dim_0,dim_1,dim_2,dim_3,dim_4,dim_5,dim_6,dim_7,dim_8,dim_9,...,dim_3062,dim_3063,dim_3064,dim_3065,dim_3066,dim_3067,dim_3068,dim_3069,dim_3070,dim_3071
0,-0.002786,-0.022649,0.005107,0.02694,0.015299,0.009778,-0.005585,0.085468,0.011564,0.047623,...,-0.001351,0.007141,-0.017521,0.003406,0.003008,0.002577,0.001279,0.015119,0.001032,-0.003605
1,-0.040522,-0.031269,0.004365,0.023705,-0.010782,0.034536,0.012536,0.066818,0.004377,0.02591,...,-0.0048,0.036563,0.008642,0.005443,-0.023126,-0.002327,-0.00308,0.001875,0.012593,0.001872


### Model

we need them to store the data in the tables. it checks if the data is valid and if it is not it will throw an error.

the important one was the vector dimensions. lance db will block you if you try to store a vector with a different dimension than the one you have specified when creating the table.

this means if we move to a different model we might have to change the table. possibly recreate it.

!!! note
for some reason lancedb couldnt accecpt a pydantic model with a list of dictionaries. so we had to use a list of strings instead.

    also how you handle dictionaries is they have to be provided as a BaseModel. so you have to create a model for the dictionary and then use that model in the main model.


In [22]:
class User(LanceModel):
    email: str
    name: str = embedder.SourceField()
    # Embedding vector for semantic searches
    vector: Vector(embedder.ndims()) = embedder.VectorField()

In [23]:

# nest_asyncio.apply()
def get_lance_db_azure_credentials(credentials: DefaultAzureCredential = DefaultAzureCredential()):
    return {
        "azure_storage_account_name": "some_account_name",
        "azure_tenant_id": "some_tenant_id",
        "azure_storage_token": credentials.get_token(
            "https://storage.azure.com/.default"
        ).token,
    }


class LanceDBManager:
    def __init__(
        self,
        database_url: str = "az://dev-vector-db/",
        azure_credentials: Dict[str, Any] = get_lance_db_azure_credentials(),
        embedder: TextEmbeddingFunction = embedder,
        async_mode: bool = True,
    ):
        """
        Initialize LanceDBManager with the database URL, Azure credentials, and an embedder.

        Args:
            database_url (str): The path or URL to the LanceDB database.
            azure_credentials (dict): Azure storage credentials.
            embedder (TextEmbeddingFunction): Embedder instance for generating embeddings.
        """
        self.database_url = database_url
        self.azure_credentials = azure_credentials
        self.embedder = embedder
        self.async_mode = async_mode

    def get_sync_manager(self):
        return lancedb.connect(self.database_url, storage_options=self.azure_credentials)

    @property
    async def connection(self):
        """
        Establish a connection to the LanceDB database.

        Returns:
            An asynchronous connection object to LanceDB.
        """
        import lancedb

        return await lancedb.connect_async(
            uri=self.database_url,
            storage_options=self.azure_credentials
        )

    @property
    async def table_names(self):
        """
        Retrieve a list of table names in the database.

        Returns:
            List of table names.
        """
        try:
            connection = await self.connection
            return await connection.table_names()
        finally:
            connection.close()

    async def get_table(self, table_name: str):
        """
        Get a table object from the database.

        Args:
            table_name (str): Name of the table to retrieve.

        Returns:
            Table: LanceDB table object.
        """
        try:
            connection = await self.connection
            return await connection.open_table(table_name)
        except Exception as e:
            logging.error(f"Error getting table '{table_name}': {e}")
            raise

    async def create_schema(self, table_name: str, schema: Any):
        """
        Create a schema-based table in LanceDB.

        Args:
            table_name (str): Name of the table.
            schema (Any): Schema of the table.
        """
        try:
            connection = await self.connection
            table = await connection.create_table(
                table_name, schema=schema, exist_ok=True
            )

            # Example: Adding sample data for testing
            sample_data = [{"text": "hello world"}, {"text": "goodbye world"}]
            existing_rows = await table.to_pandas()
            existing_texts = (
                set(existing_rows["text"]
                    ) if "text" in existing_rows.columns else set()
            )

            # Avoid adding duplicates
            new_data = [row for row in sample_data if row["text"]
                        not in existing_texts]
            if new_data:
                await table.add(new_data)

            query = "greetings"
            results = await self.vector_search(table_name, query)
            print(f"Search results for query '{query}': {results}")

        except Exception as e:
            logging.error(f"Error creating schema for table '{
                          table_name}': {e}")
            raise
        finally:
            connection.close()

    async def create_table(self, table_name: str, schema: Any, overwrite: bool = False):
        """
        Create a new table in LanceDB.

        Args:
            table_name (str): Name of the table.
            schema (Any): Schema of the table.
            overwrite (bool): Whether to overwrite the table if it exists.
        """
        try:
            connection = await self.connection
            mode = "overwrite" if overwrite else "create"
            await connection.create_table(table_name, schema=schema, mode=mode)
            logging.info(f"Table '{table_name}' created successfully.")
        except Exception as e:
            logging.error(f"Error creating table '{table_name}': {e}")
        finally:
            connection.close()

    async def add_data(
        self, table_name: str, data: List[Dict[str, Any]], unique_field: str
    ):
        """
        Add data to a LanceDB table, avoiding duplicates based on specified unique fields.

        Args:
            table_name (str): Name of the table.
            data (List[Dict[str, Any]]): List of data entries to add.
            unique_fields (List[str]): List of fields to use for uniqueness checks.

        Returns:
            int: Number of rows added.
        """
        if not unique_field:
            raise ValueError(
                "Unique field must be specified to check for duplicates.")

        try:
            connection = await self.connection
            async_table = await connection.open_table(table_name)

            # The merge_insert function fails if the table is empty
            if await async_table.count_rows() == 0:
                await async_table.add(data)
                logging.info(f"Added {len(data)} entries to table '{
                             table_name}'.")
                return len(data)

            else:
                rows_before = await async_table.count_rows()
                await async_table.merge_insert(on=unique_field).when_not_matched_insert_all().execute(data)
                rows_after = await async_table.count_rows()
                new_rows = rows_after - rows_before
                logging.info(f"Added {new_rows} new entries to table '{
                             table_name}'.")
            return new_rows
        except Exception as e:
            logging.error(f"Error adding data to table '{table_name}': {e}")
            raise
        finally:
            connection.close()

    async def update_data(
        self, table_name: str, data: List[Dict[str, Any]], unique_field: str
    ):
        """
        Update data in a LanceDB table based on specified unique fields. Using the merge_insert with the when_matched_update_all() function
        """
        if not unique_field:
            raise ValueError(
                "Unique field must be specified to check for duplicates.")

        try:
            connection = await self.connection
            async_table = await connection.open_table(table_name)

            if await async_table.count_rows() == 0:
                await async_table.add(data)
                logging.info(f"Added {len(data)} entries to table '{
                             table_name}'.")
                return len(data)

            else:
                await async_table.merge_insert(on=unique_field).when_matched_update_all().execute(data)
                logging.info(f"Updated {len(data)} entries in table '{
                             table_name}'.")
                return len(data)

        except Exception as e:
            logging.error(f"Error updating data in table '{table_name}': {e}")
            raise
        finally:
            connection.close()

    async def fetch_data(
        self,
        table_name: str,
        as_pandas: bool = True,
        page: int = 1,
        per_page: int = 10,
        filter: str = None,
        columns_to_exclude: List[str] = [],
    ):
        """
        Fetch data from a LanceDB table with pagination and optional filtering.

        Args:
            table_name (str): Name of the table.
            as_pandas (bool): Whether to return data as a pandas DataFrame.
            page (int): Page number for pagination.
            per_page (int): Number of items per page. Use -1 to fetch all data.
            filter (str): SQL filter expression. these are the filters that can be used - https://lancedb.github.io/lancedb/sql/#sql-filters
            columns_to_exclude (List[str]): List of columns to exclude from the results.

        Returns:
            DataFrame or List[Dict]: Fetched data.
            List[Dict]: Fetched data as a list of dictionaries if as_pandas is set to False.
        """
        # docs used to make this function: https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering
        try:
            connection = await self.connection
            async_table = await connection.open_table(table_name)
            query = async_table.query()

            # dont include the vector column in the results .select(["title", "text", "_distance"]) is used to define the columns to be returned
            # !DANGER The paranthesis around async_table.to_pandas() is used to make sure that the head function is called on the dataframe and not coroutine
            table_df = (await async_table.to_pandas()).head(1).columns

            columns_to_include = [
                col for col in table_df if col not in columns_to_exclude
            ]

            query = query.select(columns_to_include).with_row_id()

            # the "await async_table.count_rows()" is done like that beacause there is a bug in lancedb v0.17.0 that does not respect the limit(-1) when used with where clause
            # https://github.com/lancedb/lancedb/issues/1852
            if filter:
                query = (
                    query.where(filter).limit(
                        per_page).offset((page - 1) * per_page)
                    if per_page != -1
                    else query.where(filter).limit(await async_table.count_rows())
                )
            else:
                query = (
                    query.limit(per_page).offset((page - 1) * per_page)
                    if per_page != -1
                    else query.limit(await async_table.count_rows())
                )
            df = await query.to_pandas()
            return df if as_pandas else df.to_dict(orient="records")
        except Exception as e:
            logging.error(f"Error fetching data from table '{
                          table_name}': {e}")
            raise
        finally:
            connection.close()

    async def vector_search(
        self,
        table_name: str,
        query: str,
        limit: int = 5,
        as_pandas: bool = True,
        columns_to_exclude: List[str] = [],
    ):
        """
        Perform a vector search on a LanceDB table.

        Args:
            table_name (str): Name of the table.
            query (str): Query text to search for.
            limit (int): Number of search results to return.
            as_pandas (bool): Whether to return data as a pandas DataFrame.
            columns_to_exclude (List[str]): List of columns to exclude from the results.

        Returns:
            DataFrame: Search results.
            List[Dict]: Search results as a list of dictionaries. if as_pandas is set to False
        """
        try:
            connection = await self.connection
            async_table = await connection.open_table(table_name)

            # dont include the vector column in the results .select(["title", "text", "_distance"]) is used to define the columns to be returned
            # !DANGER The paranthesis around async_table.to_pandas() is used to make sure that the head function is called on the dataframe and not coroutine
            table_df = (await async_table.to_pandas()).head(3).columns

            columns_to_include = [
                col for col in table_df if col not in columns_to_exclude
            ]

            # Generate embedding for the query
            embedding = self.embedder.generate_embeddings([query])[0]

            # Perform vector search
            # results = await async_table.vector_search(embedding).limit(limit).to_pandas()
            results = (
                await async_table.query()
                .select(columns_to_include)
                .with_row_id()
                .nearest_to(embedding)
                .limit(limit)
                .to_pandas()
            )

            return results if as_pandas else results.to_dict(orient="records")
        except Exception as e:
            logging.error(
                f"Error performing vector search on table '{table_name}': {e}"
            )
            raise
        finally:
            connection.close()

    async def delete_table(self, table_name: str):
        """
        Delete a table from LanceDB.

        Args:
            table_name (str): Name of the table to delete.
        """
        try:
            connection = await self.connection
            await connection.drop_table(table_name)
            logging.info(f"Table '{table_name}' deleted successfully.")
            return True
        except Exception as e:
            logging.error(f"Error deleting table '{table_name}': {e}")
            raise
        finally:
            connection.close()

    async def delete_rows(self, table_name: str, condition: str):
        """
        Delete rows from a table based on a condition.

        Args:
            table_name (str): Name of the table.
            condition (str): Condition to match rows for deletion.
        """
        try:
            connection = await self.connection
            async_table = await connection.open_table(table_name)
            await async_table.delete(where=condition)
            logging.info(
                f"Rows matching condition '{condition}' deleted from table '{
                    table_name
                }'."
            )
        except Exception as e:
            logging.error(f"Error deleting rows from table '{
                          table_name}': {e}")
            raise
        finally:
            connection.close()

    async def delete_duplicates(self, table_name: str, subset: List[str]):
        """
        Remove duplicate rows from a LanceDB table based on specified columns.

        Args:
            table_name (str): Name of the table.
            subset (List[str]): List of column names to check for duplicates.

        Returns:
            int: Number of duplicate rows removed.
        """
        try:
            connection = await self.connection
            async_table = await connection.open_table(table_name)

            # Fetch the table's data
            df = await async_table.to_pandas()

            # Drop duplicates based on the specified subset
            df_unique = df.drop_duplicates(subset=subset)
            duplicates_removed = len(df) - len(df_unique)

            if duplicates_removed > 0:
                # Overwrite the table with the unique data
                await async_table.add(
                    df_unique.to_dict(orient="records"), mode="overwrite"
                )
                logging.info(
                    f"Removed {duplicates_removed} duplicate rows from table '{
                        table_name
                    }'."
                )
            else:
                logging.info(f"No duplicates found in table '{table_name}'.")

            return duplicates_removed
        except Exception as e:
            logging.error(f"Error deleting duplicates from table '{
                          table_name}': {e}")
            raise
        finally:
            connection.close()


### Example usage

below is an example of how to use the manager. the manager will use the embedder to interact with the tables.

- credentials and initialisation
- create a table
  - will silently fail if the table already exists
- make the data fit the model
- insert the data


### !!! Danger !!! Delete all tables the lancedb database

This deletes all the tables in the database. This is a dangerous operation and should only be used in development environments.

- keep it commented out in in commit history


In [24]:
async def delete_db_schema():
    credentials = DefaultAzureCredential()
    azure_credentials = {
        "azure_storage_account_name": "some_storage_account_name",
        "azure_tenant_id": "some_tenant_id",
        "azure_storage_token": credentials.get_token("https://storage.azure.com/.default").token,
    }
    database_url = "az://dev-vector-db/"
    db_manager = LanceDBManager(database_url=database_url)

    deleted_user = await db_manager.delete_table("user")
    print(deleted_user)

await delete_db_schema()

True
True
True


## Create all tables the lancedb database

uses the model to create all the tables in the database. it will silently fail if the table already exists.



In [26]:
async def create_db():
    """
    Create the database schema for the local news AI application
    """
    # Initialize manager
    database_url = "az://dev-vector-db/"
    # Use default database URL and Azure credentials and embedder
    manager = LanceDBManager(database_url=database_url)
    # Create tables
    await manager.create_table("user", User)
await create_db()

ERROR:root:Error creating table 'user': Table 'user' already exists
ERROR:root:Error creating table 'labels': Table 'labels' already exists
ERROR:root:Error creating table 'content_labels': Table 'content_labels' already exists


## Add data to all current tables the lancedb database

example on how data is added to the tables.

it adds fake data to the tables.

## ways to do it

- you could add the data and type check it before adding it to the database manually, as well and making the embedding manually
  - make sure they are the same dimention embedding
- you could just send the array to the database and let it handle the type checking and embedding
  - there are some fields in the embeddings pydantic model which are defined to be the embeddings.SourceField()
  - TODO: check if we can have more than 1 source field in the model


In [23]:
async def add_db_schema():
    """
    Adds sample data to 'user', 'content', 'labels', and 'content_labels' tables using raw data and dynamically generates vectors.
    """
    credentials = DefaultAzureCredential()
    azure_credentials = {
        "azure_storage_account_name": "some_storage_account_name",
        "azure_tenant_id": "some_tenant_id",
        "azure_storage_token": credentials.get_token("https://storage.azure.com/.default").token,
    }
    database_url = "az://dev-vector-db/"
    db_manager = LanceDBManager(
        database_url=database_url, azure_credentials=azure_credentials)
    # Raw data to be added to the 'user' table
    more_user_data = [
        {"email": "lancedb@viewer.com", "username": "Lance Vine",},
    ]

    # Convert raw data to model instances and dynamically generate vectors for the 'query' field
    validated_users = []
    for data in tqdm(more_user_data):
        data["user_id"] = hashlib.sha256(
                    (data["usename"] + data["email"]).encode('utf-8')).hexdigest()
        data["vector"] = embedder.generate_embeddings(
            [data["username"]])  # Generate the embedding for the query
        validated_users.append(User(**data))

    # Add user data to database
    await db_manager.add_data("user", more_user_data, unique_field="user_id")


# Run the schema setup function
asyncio.run(add_db_schema())

100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
100%|██████████| 2/2 [00:00<00:00, 14.94it/s]
100%|██████████| 23/23 [00:00<00:00, 62117.83it/s]


## Fetch data

example on how data is fetched from the tables


In [27]:
async def fetch_all_data():
    # Use default database URL and Azure credentials and embedder
    db_manager = LanceDBManager()
    # Fetch data from 'user' table
    user_df = await db_manager.fetch_data("user")
    return user_df

# Run the async functions
user_df = await fetch_all_data()

In [28]:
user_df

Unnamed: 0,email,name,query,vector,newsletter,sources,saved_content,_rowid


## Vector search

find the most similar vectors in the database


In [None]:
async def vector_search_data():
    # Use default database URL and Azure credentials and embedder
    db_manager = LanceDBManager()

    # Perform a vector search
    results = await db_manager.vector_search("user", query="Vance", limit=50)
    print("Search Results:", results)
await vector_search_data()