---
title: How to use azureopenai embeddings in lancedb
author: fastdaima
date: '2024-12-05'
description: azureopenai embeddings working solution

categories:
- embeddings
- vectordb
- lancedb


toc: true
hide: false
search: true

---




## Introduction 
- I was unable to use azure openai embeddings inside lancedb 
- However, lancedb has fixed in their recent [commit](https://github.com/lancedb/lancedb/commit/bd82e1f66d184bc13869030f03783cb4cb71da18)
- Still I cannot able to install the dev version in my project, hence copied the changes from the commit, did an minor change on top of that, like added attribute api_version which worked for me.
- let's go.

## References 
- https://github.com/lancedb/lancedb/commit/bd82e1f66d184bc13869030f03783cb4cb71da18

In [1]:
# !pip install lancedb openai -Uq

## Implementation

In [25]:
from functools import cached_property
from typing import TYPE_CHECKING, List, Optional, Union
import logging
import os 
from lancedb.util import attempt_import_or_raise
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import register, EmbeddingFunctionRegistry

if TYPE_CHECKING:
    import numpy as np


@register("azure-openai")
class AzureOpenAIEmbeddings(TextEmbeddingFunction):
    """
    An embedding function that uses the OpenAI API
    https://platform.openai.com/docs/guides/embeddings
    This can also be used for open source models that
    are compatible with the OpenAI API.
    Notes
    -----
    If you're running an Ollama server locally,
    you can just override the `base_url` parameter
    and provide the Ollama embedding model you want
    to use (https://ollama.com/library):
    ```python
    from lancedb.embeddings import get_registry
    openai = get_registry().get("openai")
    embedding_function = openai.create(
        name="<ollama-embedding-model-name>",
        base_url="http://localhost:11434",
        )
    ```
    """

    name: str = "text-embedding-ada-002"
    dim: Optional[int] = None
    azure_endpoint: Optional[str] = None
    default_headers: Optional[dict] = None
    azure_deployment: Optional[str] = None
    api_key: Optional[str] = None
    api_version: Optional[str] = None 
    use_azure: bool = False
    
    def ndims(self):
        return self._ndims

    @staticmethod
    def model_names():
        return [
            "text-embedding-ada-002",
            "text-embedding-3-large",
            "text-embedding-3-small",
        ]

    @cached_property
    def _ndims(self):
        if self.name == "text-embedding-ada-002":
            return 1536
        elif self.name == "text-embedding-3-large":
            return self.dim or 3072
        elif self.name == "text-embedding-3-small":
            return self.dim or 1536
        else:
            raise ValueError(f"Unknown model name {self.name}")

    def generate_embeddings(
        self, texts: Union[List[str], "np.ndarray"]
    ) -> List["np.array"]:
        """
        Get the embeddings for the given texts
        Parameters
        ----------
        texts: list[str] or np.ndarray (of str)
            The texts to embed
        """
        openai = attempt_import_or_raise("openai")

        valid_texts = []
        valid_indices = []
        for idx, text in enumerate(texts):
            if text:
                valid_texts.append(text)
                valid_indices.append(idx)

        # TODO retry, rate limit, token limit
        try:
            kwargs = {
                "input": valid_texts,
                "model": self.name,
            }
            if self.name != "text-embedding-ada-002":
                kwargs["dimensions"] = self.ndims()

            rs = self._openai_client.embeddings.create(**kwargs)
            valid_embeddings = {
                idx: v.embedding for v, idx in zip(rs.data, valid_indices)
            }
        except openai.BadRequestError:
            logging.exception("Bad request: %s", texts)
            return [None] * len(texts)
        except Exception:
            logging.exception("OpenAI embeddings error")
            raise
        return [valid_embeddings.get(idx, None) for idx in range(len(texts))]

    @cached_property
    def _openai_client(self):
        openai = attempt_import_or_raise("openai")
        kwargs = {}
        if self.azure_endpoint:
            kwargs["azure_endpoint"] = self.azure_endpoint
        if self.default_headers:
            kwargs["default_headers"] = self.default_headers
        if self.azure_deployment:
            kwargs["azure_deployment"] = self.azure_deployment
        if self.api_key:
            kwargs["api_key"] = self.api_key
        if self.api_version: 
            kwargs['api_version'] = self.api_version 
        if self.use_azure:
            return openai.AzureOpenAI(**kwargs)
        else:
            return openai.OpenAI(**kwargs)

In [26]:
# REQUIRED ENV VARIABLES, ALWAYS LOAD IT FROM .env

AZURE_EMBEDDING_MODEL= os.environ['AZURE_EMBEDDING_MODEL']
AZURE_EMBEDDING_ENDPOINT=os.environ['AZURE_EMBEDDING_ENDPOINT']
AZURE_EMBEDDING_DEPLOYMENT=os.environ['AZURE_EMBEDDING_DEPLOYMENT']
AZURE_EMBEDDING_API_KEY=os.environ['AZURE_EMBEDDING_API_KEY']
AZURE_API_VERSION=os.environ['AZURE_API_VERSION']

In [27]:
reg = EmbeddingFunctionRegistry.get_instance() 
azure_openai = reg.get('azure-openai')

- Here by default, organization, base_url parameter is not working as expected, so I am passing azure_endpoint, azure_deployment and api_version parameters which are required for openai.AzureOpenAI API's

In [28]:
embed_fn = azure_openai.create(
    api_version=AZURE_API_VERSION,
    name=AZURE_EMBEDDING_MODEL,
    azure_endpoint=AZURE_EMBEDDING_ENDPOINT,
    azure_deployment=AZURE_EMBEDDING_DEPLOYMENT,
    api_key=AZURE_EMBEDDING_API_KEY,
    use_azure=True
)

In [29]:
embed_fn.generate_embeddings('what is 2+3')

[[-0.011524682864546776,
  -0.012060285545885563,
  0.008892844431102276,
  0.023437215015292168,
  0.02672470547258854,
  0.010296491906046867,
  0.044251829385757446,
  0.05115925148129463,
  -0.0017534048529341817,
  0.04366081953048706,
  -0.002740344265475869,
  0.057808104902505875,
  -0.030270762741565704,
  0.007553838659077883,
  -0.023197118192911148,
  0.02506249211728573,
  -0.03551597148180008,
  -0.0005330050480552018,
  -0.028534671291708946,
  0.026780113577842712,
  0.017739515751600266,
  -0.028405388817191124,
  -0.012217272073030472,
  0.03211766481399536,
  0.013944127596914768,
  0.003534513060003519,
  0.03867417573928833,
  -0.009520791471004486,
  -0.030123010277748108,
  -0.021036241203546524,
  -2.9561291739810258e-05,
  0.016492854803800583,
  -0.014018003828823566,
  0.033114995807409286,
  -0.05585038661956787,
  0.03331815451383591,
  0.055222440510988235,
  0.002488703466951847,
  0.00489891367033124,
  -0.016178881749510765,
  0.01125688198953867,
  -0.