Skip to content

Commit 0fca982

Browse files
authored
Pass Standard Tests (#35)
* Make changes to pass standard tests * Add standard tests * Version change * Use values clause in MERGE and add more comments * Add new versions * Fix error handling * Add error messages for input length mismatch * Increase minor version * Update versions
1 parent 6a2315c commit 0fca982

File tree

9 files changed

+4404
-2053
lines changed

9 files changed

+4404
-2053
lines changed

libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py

Lines changed: 397 additions & 144 deletions
Large diffs are not rendered by default.

libs/oracledb/poetry.lock

Lines changed: 3712 additions & 1861 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/oracledb/pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "langchain-oracledb"
3-
version = "1.0.2"
3+
version = "1.1.0"
44
description = "An integration package connecting Oracle Database and LangChain"
55
authors = []
66
readme = "README.md"
@@ -12,7 +12,10 @@ license = "UPL"
1212

1313
[tool.poetry.dependencies]
1414
python = ">=3.9,<4.0"
15-
langchain-core = ">=0.3.15,<0.4"
15+
langchain-core = [
16+
{ version = "^0.3.15", python = "<3.10" },
17+
{ version = "^1.0.0", python = ">=3.10" }
18+
]
1619
langchain-community = ">=0.3.0"
1720
oracledb = ">=2.2.0"
1821
pydantic = ">=2,<3"
@@ -28,6 +31,10 @@ syrupy = "^4.0.2"
2831
pytest-asyncio = "^0.23.2"
2932
pytest-watcher = "^0.3.4"
3033
sentence-transformers = "^5.0.0"
34+
langchain-tests = [
35+
{ version = "^0.3.21", python = "<3.10" },
36+
{ version = "^1.0.0", python = ">=3.10" }
37+
]
3138

3239
[tool.poetry.group.codespell]
3340
optional = true
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Type
2+
3+
import oracledb
4+
import pytest
5+
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
6+
7+
from langchain_oracledb import OracleEmbeddings
8+
9+
username = ""
10+
password = ""
11+
dsn = ""
12+
13+
try:
14+
oracledb.connect(user=username, password=password, dsn=dsn)
15+
except Exception as e:
16+
pytest.skip(
17+
allow_module_level=True,
18+
reason=f"Database connection failed: {e}, skipping tests.",
19+
)
20+
21+
22+
class TestOracleEmbeddingsModelIntegration(EmbeddingsIntegrationTests):
23+
@property
24+
def embeddings_class(self) -> Type[OracleEmbeddings]:
25+
# Return the embeddings model class to test here
26+
return OracleEmbeddings
27+
28+
@property
29+
def embedding_model_params(self) -> dict:
30+
# Return initialization parameters for the model.
31+
conn = oracledb.connect(user=username, password=password, dsn=dsn)
32+
return {"conn": conn, "params": {"provider": "database", "model": "allminilm"}}

libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py

Lines changed: 98 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
import threading
1515
from typing import Union
1616

17+
import numpy as np
1718
import oracledb
1819
import pytest
1920
from langchain_community.embeddings import HuggingFaceEmbeddings
2021
from langchain_community.vectorstores.utils import DistanceStrategy
2122

2223
from langchain_oracledb.embeddings import OracleEmbeddings
2324
from langchain_oracledb.vectorstores.oraclevs import (
25+
INTERNAL_ID_KEY,
2426
OracleVS,
2527
_acreate_table,
2628
_aindex_exists,
@@ -1165,6 +1167,7 @@ def test_add_texts_test() -> None:
11651167
vs_obj = OracleVS(connection, model, "TB7", DistanceStrategy.EUCLIDEAN_DISTANCE)
11661168
ids6 = ['"Good afternoon"', '"India"']
11671169
vs_obj.add_texts(texts2, ids=ids6)
1170+
assert len(vs_obj.add_texts(texts2, ids=ids6)) == 0
11681171
drop_table_purge(connection, "TB7")
11691172

11701173
# 4. Add records with ids and metadatas
@@ -1208,30 +1211,6 @@ def add(val: str) -> None:
12081211
thread_2.join()
12091212
drop_table_purge(connection, "TB10")
12101213

1211-
# 7. Add 2 same record concurrently
1212-
# Expectation:Successful, For one of the insert,get primary key violation error
1213-
def add1(val: str) -> None:
1214-
model = HuggingFaceEmbeddings(
1215-
model_name="sentence-transformers/all-mpnet-base-v2"
1216-
)
1217-
vs_obj = OracleVS(
1218-
connection, model, "TB11", DistanceStrategy.EUCLIDEAN_DISTANCE
1219-
)
1220-
texts = [val]
1221-
ids10 = texts
1222-
vs_obj.add_texts(texts, ids=ids10)
1223-
1224-
try:
1225-
thread_1 = threading.Thread(target=add1, args=("Sri Ram"))
1226-
thread_2 = threading.Thread(target=add1, args=("Sri Ram"))
1227-
thread_1.start()
1228-
thread_2.start()
1229-
thread_1.join()
1230-
thread_2.join()
1231-
except Exception:
1232-
pass
1233-
drop_table_purge(connection, "TB11")
1234-
12351214
# 8. create object with table name of type <schema_name.table_name>
12361215
# Expectation:U1 does not exist
12371216
with pytest.raises(RuntimeError):
@@ -1317,6 +1296,7 @@ async def test_add_texts_test_async() -> None:
13171296
)
13181297
ids6 = ['"Good afternoon"', '"India"']
13191298
await vs_obj.aadd_texts(texts2, ids=ids6)
1299+
assert len(await vs_obj.aadd_texts(texts2, ids=ids6)) == 0
13201300
await adrop_table_purge(connection, "TB7")
13211301

13221302
# 4. Add records with ids and metadatas
@@ -1362,26 +1342,6 @@ async def add(val: str) -> None:
13621342
await asyncio.gather(task_1, task_2)
13631343
await adrop_table_purge(connection, "TB10")
13641344

1365-
# 7. Add 2 same record concurrently
1366-
# Expectation:Successful, For one of the insert,get primary key violation error
1367-
async def add1(val: str) -> None:
1368-
model = HuggingFaceEmbeddings(
1369-
model_name="sentence-transformers/all-mpnet-base-v2"
1370-
)
1371-
vs_obj = await OracleVS.acreate(
1372-
connection, model, "TB11", DistanceStrategy.EUCLIDEAN_DISTANCE
1373-
)
1374-
texts = [val]
1375-
ids10 = texts
1376-
await vs_obj.aadd_texts(texts, ids=ids10)
1377-
1378-
with pytest.raises(RuntimeError):
1379-
task_1 = asyncio.create_task(add1("Sri Ram"))
1380-
task_2 = asyncio.create_task(add1("Sri Ram"))
1381-
await asyncio.gather(task_1, task_2)
1382-
1383-
await adrop_table_purge(connection, "TB11")
1384-
13851345
# 8. create object with table name of type <schema_name.table_name>
13861346
# Expectation:U1 does not exist
13871347
with pytest.raises(RuntimeError):
@@ -1695,7 +1655,9 @@ def test_perform_search_test() -> None:
16951655
vs.similarity_search(query, 2, filter=db_filter)
16961656

16971657
# Similarity search with relevance score
1698-
vs.similarity_search_with_score(query, 2)
1658+
res = vs.similarity_search_with_score(query, 2)
1659+
assert all(isinstance(_r[1], float) for _r in res)
1660+
assert res[0][1] <= res[1][1]
16991661

17001662
# Similarity search with relevance score with filter
17011663
vs.similarity_search_with_score(query, 2, filter=db_filter)
@@ -1787,7 +1749,9 @@ async def test_perform_search_test_async() -> None:
17871749
await vs.asimilarity_search(query, 2, filter=db_filter)
17881750

17891751
# Similarity search with relevance score
1790-
await vs.asimilarity_search_with_score(query, 2)
1752+
res = await vs.asimilarity_search_with_score(query, 2)
1753+
assert all(isinstance(_r[1], float) for _r in res)
1754+
assert res[0][1] <= res[1][1]
17911755

17921756
# Similarity search with relevance score with filter
17931757
await vs.asimilarity_search_with_score(query, 2, filter=db_filter)
@@ -2518,6 +2482,14 @@ def test_oracle_embeddings() -> None:
25182482
res = vs_obj.similarity_search("database", 1)
25192483

25202484
assert "Database" in res[0].page_content
2485+
assert "100" == res[0].id
2486+
2487+
embedding = model.embed_query("Database Document")
2488+
res = vs_obj.similarity_search_by_vector_returning_embeddings(embedding, 1) # type: ignore
2489+
2490+
# distance
2491+
assert all(np.isclose([res[0][1]], [0])) # type: ignore
2492+
assert all(np.isclose(res[0][2], embedding)) # type: ignore
25212493

25222494
drop_table_purge(connection, "TB1")
25232495

@@ -2556,6 +2528,14 @@ async def test_oracle_embeddings_async(caplog: pytest.LogCaptureFixture) -> None
25562528
res = await vs_obj.asimilarity_search("database", 1)
25572529

25582530
assert "Database" in res[0].page_content
2531+
assert "100" == res[0].id
2532+
2533+
embedding = model.embed_query("Database Document")
2534+
res = await vs_obj.asimilarity_search_by_vector_returning_embeddings(embedding, 1) # type: ignore
2535+
2536+
# distance
2537+
assert all(np.isclose([res[0][1]], [0])) # type: ignore
2538+
assert all(np.isclose(res[0][2], embedding)) # type: ignore
25592539

25602540
await adrop_table_purge(connection, "TB1")
25612541

@@ -2987,3 +2967,75 @@ def model1(_) -> list[float]: # type: ignore[no-untyped-def]
29872967
result = await vs.asimilarity_search("Hello", k=3, filter=_f)
29882968

29892969
await adrop_table_purge(connection, "TB10")
2970+
2971+
2972+
##################################
2973+
####### test_reserved ######
2974+
##################################
2975+
2976+
2977+
def test_reserved() -> None:
2978+
try:
2979+
connection = oracledb.connect(user=username, password=password, dsn=dsn)
2980+
except Exception:
2981+
sys.exit(1)
2982+
2983+
drop_table_purge(connection, "TB1")
2984+
2985+
embedder_params = {"provider": "database", "model": "allminilm"}
2986+
proxy = ""
2987+
2988+
# instance
2989+
model = OracleEmbeddings(conn=connection, params=embedder_params, proxy=proxy)
2990+
2991+
vs_obj = OracleVS(connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE)
2992+
2993+
texts = ["Database Document", "Code Document"]
2994+
metadata = [
2995+
{"id": "100", "link": "Document Example Test 1", INTERNAL_ID_KEY: "my_temp_id"},
2996+
{"id": "101", "link": "Document Example Test 2"},
2997+
]
2998+
2999+
with pytest.raises(ValueError, match="reserved"):
3000+
vs_obj.add_texts(texts, metadata, ids=["1", "2"])
3001+
3002+
drop_table_purge(connection, "TB1")
3003+
3004+
connection.close()
3005+
3006+
3007+
@pytest.mark.asyncio
3008+
async def test_reserved_async() -> None:
3009+
try:
3010+
connection = await oracledb.connect_async(
3011+
user=username, password=password, dsn=dsn
3012+
)
3013+
3014+
connection_sync = oracledb.connect(user=username, password=password, dsn=dsn)
3015+
except Exception:
3016+
sys.exit(1)
3017+
3018+
await adrop_table_purge(connection, "TB1")
3019+
3020+
embedder_params = {"provider": "database", "model": "allminilm"}
3021+
proxy = ""
3022+
3023+
# instance
3024+
model = OracleEmbeddings(conn=connection_sync, params=embedder_params, proxy=proxy)
3025+
3026+
vs_obj = await OracleVS.acreate(
3027+
connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE
3028+
)
3029+
3030+
texts = ["Database Document", "Code Document"]
3031+
metadata = [
3032+
{"id": "100", "link": "Document Example Test 1", INTERNAL_ID_KEY: "my_temp_id"},
3033+
{"id": "101", "link": "Document Example Test 2"},
3034+
]
3035+
3036+
with pytest.raises(ValueError, match="reserved"):
3037+
await vs_obj.aadd_texts(texts, metadata, ids=["1", "2"])
3038+
3039+
await adrop_table_purge(connection, "TB1")
3040+
3041+
connection.close()

0 commit comments

Comments
 (0)