Skip to content

Commit

Permalink
Mock embedding model to address flaky vector tests (#1822) (#1823)
Browse files Browse the repository at this point in the history
(cherry picked from commit b5435a8)

Co-authored-by: Miguel Grinberg <miguel.grinberg@gmail.com>
  • Loading branch information
github-actions[bot] and miguelgrinberg committed May 10, 2024
1 parent ab70d6f commit fcb1fe3
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 12 deletions.
23 changes: 23 additions & 0 deletions tests/async_sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import asyncio


async def sleep(secs):
"""Tests can use this function to sleep."""
await asyncio.sleep(secs)
23 changes: 23 additions & 0 deletions tests/sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time


def sleep(secs):
"""Tests can use this function to sleep."""
time.sleep(secs)
30 changes: 24 additions & 6 deletions tests/test_integration/test_examples/_async/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,38 @@
# specific language governing permissions and limitations
# under the License.

from hashlib import md5
from unittest import SkipTest

import pytest

from ..async_examples.vectors import create, search
from tests.async_sleep import sleep

from ..async_examples import vectors


@pytest.mark.asyncio
async def test_vector_search(async_write_client, es_version):
async def test_vector_search(async_write_client, es_version, mocker):
# this test only runs on Elasticsearch >= 8.11 because the example uses
# a dense vector without giving them an explicit size
# a dense vector without specifying an explicit size
if es_version < (8, 11):
raise SkipTest("This test requires Elasticsearch 8.11 or newer")

await create()
results = await (await search("work from home")).execute()
assert results[0].name == "Work From Home Policy"
class MockModel:
def __init__(self, model):
pass

def encode(self, text):
vector = [int(ch) for ch in md5(text.encode()).digest()]
total = sum(vector)
return [float(v) / total for v in vector]

mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)

await vectors.create()
for i in range(10):
results = await (await vectors.search("Welcome to our team!")).execute()
if len(results.hits) > 0:
break
await sleep(0.1)
assert results[0].name == "New Employee Onboarding Guide"
30 changes: 24 additions & 6 deletions tests/test_integration/test_examples/_sync/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,38 @@
# specific language governing permissions and limitations
# under the License.

from hashlib import md5
from unittest import SkipTest

import pytest

from ..examples.vectors import create, search
from tests.sleep import sleep

from ..examples import vectors


@pytest.mark.sync
def test_vector_search(write_client, es_version):
def test_vector_search(write_client, es_version, mocker):
# this test only runs on Elasticsearch >= 8.11 because the example uses
# a dense vector without giving them an explicit size
# a dense vector without specifying an explicit size
if es_version < (8, 11):
raise SkipTest("This test requires Elasticsearch 8.11 or newer")

create()
results = (search("work from home")).execute()
assert results[0].name == "Work From Home Policy"
class MockModel:
def __init__(self, model):
pass

def encode(self, text):
vector = [int(ch) for ch in md5(text.encode()).digest()]
total = sum(vector)
return [float(v) / total for v in vector]

mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)

vectors.create()
for i in range(10):
results = (vectors.search("Welcome to our team!")).execute()
if len(results.hits) > 0:
break
sleep(0.1)
assert results[0].name == "New Employee Onboarding Guide"
1 change: 1 addition & 0 deletions utils/run-unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main(check=False):
"async_write_client": "write_client",
"async_pull_request": "pull_request",
"async_examples": "examples",
"async_sleep": "sleep",
"assert_awaited_once_with": "assert_called_once_with",
"pytest_asyncio": "pytest",
}
Expand Down

0 comments on commit fcb1fe3

Please sign in to comment.