|
14 | 14 | import threading |
15 | 15 | from typing import Union |
16 | 16 |
|
| 17 | +import numpy as np |
17 | 18 | import oracledb |
18 | 19 | import pytest |
19 | 20 | from langchain_community.embeddings import HuggingFaceEmbeddings |
20 | 21 | from langchain_community.vectorstores.utils import DistanceStrategy |
21 | 22 |
|
22 | 23 | from langchain_oracledb.embeddings import OracleEmbeddings |
23 | 24 | from langchain_oracledb.vectorstores.oraclevs import ( |
| 25 | + INTERNAL_ID_KEY, |
24 | 26 | OracleVS, |
25 | 27 | _acreate_table, |
26 | 28 | _aindex_exists, |
@@ -1165,6 +1167,7 @@ def test_add_texts_test() -> None: |
1165 | 1167 | vs_obj = OracleVS(connection, model, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE) |
1166 | 1168 | ids6 = ['"Good afternoon"', '"India"'] |
1167 | 1169 | vs_obj.add_texts(texts2, ids=ids6) |
| 1170 | + assert len(vs_obj.add_texts(texts2, ids=ids6)) == 0 |
1168 | 1171 | drop_table_purge(connection, "TB7") |
1169 | 1172 |
|
1170 | 1173 | # 4. Add records with ids and metadatas |
@@ -1208,30 +1211,6 @@ def add(val: str) -> None: |
1208 | 1211 | thread_2.join() |
1209 | 1212 | drop_table_purge(connection, "TB10") |
1210 | 1213 |
|
1211 | | - # 7. Add 2 same record concurrently |
1212 | | - # Expectation:Successful, For one of the insert,get primary key violation error |
1213 | | - def add1(val: str) -> None: |
1214 | | - model = HuggingFaceEmbeddings( |
1215 | | - model_name="sentence-transformers/all-mpnet-base-v2" |
1216 | | - ) |
1217 | | - vs_obj = OracleVS( |
1218 | | - connection, model, "TB11", DistanceStrategy.EUCLIDEAN_DISTANCE |
1219 | | - ) |
1220 | | - texts = [val] |
1221 | | - ids10 = texts |
1222 | | - vs_obj.add_texts(texts, ids=ids10) |
1223 | | - |
1224 | | - try: |
1225 | | - thread_1 = threading.Thread(target=add1, args=("Sri Ram")) |
1226 | | - thread_2 = threading.Thread(target=add1, args=("Sri Ram")) |
1227 | | - thread_1.start() |
1228 | | - thread_2.start() |
1229 | | - thread_1.join() |
1230 | | - thread_2.join() |
1231 | | - except Exception: |
1232 | | - pass |
1233 | | - drop_table_purge(connection, "TB11") |
1234 | | - |
1235 | 1214 | # 8. create object with table name of type <schema_name.table_name> |
1236 | 1215 | # Expectation:U1 does not exist |
1237 | 1216 | with pytest.raises(RuntimeError): |
@@ -1317,6 +1296,7 @@ async def test_add_texts_test_async() -> None: |
1317 | 1296 | ) |
1318 | 1297 | ids6 = ['"Good afternoon"', '"India"'] |
1319 | 1298 | await vs_obj.aadd_texts(texts2, ids=ids6) |
| 1299 | + assert len(await vs_obj.aadd_texts(texts2, ids=ids6)) == 0 |
1320 | 1300 | await adrop_table_purge(connection, "TB7") |
1321 | 1301 |
|
1322 | 1302 | # 4. Add records with ids and metadatas |
@@ -1362,26 +1342,6 @@ async def add(val: str) -> None: |
1362 | 1342 | await asyncio.gather(task_1, task_2) |
1363 | 1343 | await adrop_table_purge(connection, "TB10") |
1364 | 1344 |
|
1365 | | - # 7. Add 2 same record concurrently |
1366 | | - # Expectation:Successful, For one of the insert,get primary key violation error |
1367 | | - async def add1(val: str) -> None: |
1368 | | - model = HuggingFaceEmbeddings( |
1369 | | - model_name="sentence-transformers/all-mpnet-base-v2" |
1370 | | - ) |
1371 | | - vs_obj = await OracleVS.acreate( |
1372 | | - connection, model, "TB11", DistanceStrategy.EUCLIDEAN_DISTANCE |
1373 | | - ) |
1374 | | - texts = [val] |
1375 | | - ids10 = texts |
1376 | | - await vs_obj.aadd_texts(texts, ids=ids10) |
1377 | | - |
1378 | | - with pytest.raises(RuntimeError): |
1379 | | - task_1 = asyncio.create_task(add1("Sri Ram")) |
1380 | | - task_2 = asyncio.create_task(add1("Sri Ram")) |
1381 | | - await asyncio.gather(task_1, task_2) |
1382 | | - |
1383 | | - await adrop_table_purge(connection, "TB11") |
1384 | | - |
1385 | 1345 | # 8. create object with table name of type <schema_name.table_name> |
1386 | 1346 | # Expectation:U1 does not exist |
1387 | 1347 | with pytest.raises(RuntimeError): |
@@ -1695,7 +1655,9 @@ def test_perform_search_test() -> None: |
1695 | 1655 | vs.similarity_search(query, 2, filter=db_filter) |
1696 | 1656 |
|
1697 | 1657 | # Similarity search with relevance score |
1698 | | - vs.similarity_search_with_score(query, 2) |
| 1658 | + res = vs.similarity_search_with_score(query, 2) |
| 1659 | + assert all(isinstance(_r[1], float) for _r in res) |
| 1660 | + assert res[0][1] <= res[1][1] |
1699 | 1661 |
|
1700 | 1662 | # Similarity search with relevance score with filter |
1701 | 1663 | vs.similarity_search_with_score(query, 2, filter=db_filter) |
@@ -1787,7 +1749,9 @@ async def test_perform_search_test_async() -> None: |
1787 | 1749 | await vs.asimilarity_search(query, 2, filter=db_filter) |
1788 | 1750 |
|
1789 | 1751 | # Similarity search with relevance score |
1790 | | - await vs.asimilarity_search_with_score(query, 2) |
| 1752 | + res = await vs.asimilarity_search_with_score(query, 2) |
| 1753 | + assert all(isinstance(_r[1], float) for _r in res) |
| 1754 | + assert res[0][1] <= res[1][1] |
1791 | 1755 |
|
1792 | 1756 | # Similarity search with relevance score with filter |
1793 | 1757 | await vs.asimilarity_search_with_score(query, 2, filter=db_filter) |
@@ -2518,6 +2482,14 @@ def test_oracle_embeddings() -> None: |
2518 | 2482 | res = vs_obj.similarity_search("database", 1) |
2519 | 2483 |
|
2520 | 2484 | assert "Database" in res[0].page_content |
| 2485 | + assert "100" == res[0].id |
| 2486 | + |
| 2487 | + embedding = model.embed_query("Database Document") |
| 2488 | + res = vs_obj.similarity_search_by_vector_returning_embeddings(embedding, 1) # type: ignore |
| 2489 | + |
| 2490 | + # distance |
| 2491 | + assert all(np.isclose([res[0][1]], [0])) # type: ignore |
| 2492 | + assert all(np.isclose(res[0][2], embedding)) # type: ignore |
2521 | 2493 |
|
2522 | 2494 | drop_table_purge(connection, "TB1") |
2523 | 2495 |
|
@@ -2556,6 +2528,14 @@ async def test_oracle_embeddings_async(caplog: pytest.LogCaptureFixture) -> None |
2556 | 2528 | res = await vs_obj.asimilarity_search("database", 1) |
2557 | 2529 |
|
2558 | 2530 | assert "Database" in res[0].page_content |
| 2531 | + assert "100" == res[0].id |
| 2532 | + |
| 2533 | + embedding = model.embed_query("Database Document") |
| 2534 | + res = await vs_obj.asimilarity_search_by_vector_returning_embeddings(embedding, 1) # type: ignore |
| 2535 | + |
| 2536 | + # distance |
| 2537 | + assert all(np.isclose([res[0][1]], [0])) # type: ignore |
| 2538 | + assert all(np.isclose(res[0][2], embedding)) # type: ignore |
2559 | 2539 |
|
2560 | 2540 | await adrop_table_purge(connection, "TB1") |
2561 | 2541 |
|
@@ -2987,3 +2967,75 @@ def model1(_) -> list[float]: # type: ignore[no-untyped-def] |
2987 | 2967 | result = await vs.asimilarity_search("Hello", k=3, filter=_f) |
2988 | 2968 |
|
2989 | 2969 | await adrop_table_purge(connection, "TB10") |
| 2970 | + |
| 2971 | + |
| 2972 | +################################## |
| 2973 | +####### test_reserved ###### |
| 2974 | +################################## |
| 2975 | + |
| 2976 | + |
| 2977 | +def test_reserved() -> None: |
| 2978 | + try: |
| 2979 | + connection = oracledb.connect(user=username, password=password, dsn=dsn) |
| 2980 | + except Exception: |
| 2981 | + sys.exit(1) |
| 2982 | + |
| 2983 | + drop_table_purge(connection, "TB1") |
| 2984 | + |
| 2985 | + embedder_params = {"provider": "database", "model": "allminilm"} |
| 2986 | + proxy = "" |
| 2987 | + |
| 2988 | + # instance |
| 2989 | + model = OracleEmbeddings(conn=connection, params=embedder_params, proxy=proxy) |
| 2990 | + |
| 2991 | + vs_obj = OracleVS(connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE) |
| 2992 | + |
| 2993 | + texts = ["Database Document", "Code Document"] |
| 2994 | + metadata = [ |
| 2995 | + {"id": "100", "link": "Document Example Test 1", INTERNAL_ID_KEY: "my_temp_id"}, |
| 2996 | + {"id": "101", "link": "Document Example Test 2"}, |
| 2997 | + ] |
| 2998 | + |
| 2999 | + with pytest.raises(ValueError, match="reserved"): |
| 3000 | + vs_obj.add_texts(texts, metadata, ids=["1", "2"]) |
| 3001 | + |
| 3002 | + drop_table_purge(connection, "TB1") |
| 3003 | + |
| 3004 | + connection.close() |
| 3005 | + |
| 3006 | + |
| 3007 | +@pytest.mark.asyncio |
| 3008 | +async def test_reserved_async() -> None: |
| 3009 | + try: |
| 3010 | + connection = await oracledb.connect_async( |
| 3011 | + user=username, password=password, dsn=dsn |
| 3012 | + ) |
| 3013 | + |
| 3014 | + connection_sync = oracledb.connect(user=username, password=password, dsn=dsn) |
| 3015 | + except Exception: |
| 3016 | + sys.exit(1) |
| 3017 | + |
| 3018 | + await adrop_table_purge(connection, "TB1") |
| 3019 | + |
| 3020 | + embedder_params = {"provider": "database", "model": "allminilm"} |
| 3021 | + proxy = "" |
| 3022 | + |
| 3023 | + # instance |
| 3024 | + model = OracleEmbeddings(conn=connection_sync, params=embedder_params, proxy=proxy) |
| 3025 | + |
| 3026 | + vs_obj = await OracleVS.acreate( |
| 3027 | + connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE |
| 3028 | + ) |
| 3029 | + |
| 3030 | + texts = ["Database Document", "Code Document"] |
| 3031 | + metadata = [ |
| 3032 | + {"id": "100", "link": "Document Example Test 1", INTERNAL_ID_KEY: "my_temp_id"}, |
| 3033 | + {"id": "101", "link": "Document Example Test 2"}, |
| 3034 | + ] |
| 3035 | + |
| 3036 | + with pytest.raises(ValueError, match="reserved"): |
| 3037 | + await vs_obj.aadd_texts(texts, metadata, ids=["1", "2"]) |
| 3038 | + |
| 3039 | + await adrop_table_purge(connection, "TB1") |
| 3040 | + |
| 3041 | + connection.close() |
0 commit comments