From 5bc9033114a0ab011de3fcc5225cd2397450e58a Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 10 Nov 2025 08:27:44 +0000 Subject: [PATCH] Fix for max_chunk_size; restructure databases (server); update mxbai to 512 --- src/client/content/config/tabs/models.py | 13 +- src/server/api/core/databases.py | 58 -- src/server/api/utils/databases.py | 54 +- src/server/api/utils/models.py | 20 +- src/server/api/v1/databases.py | 3 +- src/server/bootstrap/models.py | 2 +- .../integration/test_endpoints_models.py | 46 ++ .../unit/api/core/test_core_databases.py | 410 -------------- .../unit/api/utils/test_utils_databases.py | 502 ++++++++++++++++-- .../unit/api/utils/test_utils_models.py | 63 +++ 10 files changed, 630 insertions(+), 541 deletions(-) delete mode 100644 src/server/api/core/databases.py delete mode 100644 tests/server/unit/api/core/test_core_databases.py diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py index 65bcbaab..ae71185e 100644 --- a/src/client/content/config/tabs/models.py +++ b/src/client/content/config/tabs/models.py @@ -200,16 +200,19 @@ def _render_model_specific_config(model: dict, model_type: str, provider_models: value=max_tokens, ) else: - output_vector_size = next( - (m.get("output_vector_size", 8191) for m in provider_models if m.get("key") == model["id"]), - model.get("output_vector_size", 8191), - ) + # First try to get max_chunk_size from the model, then fall back to output_vector_size from provider + max_chunk_size = model.get("max_chunk_size") + if max_chunk_size is None: + max_chunk_size = next( + (m.get("max_chunk_size", 8192) for m in provider_models if m.get("key") == model["id"]), + 8192, + ) model["max_chunk_size"] = st.number_input( "Max Chunk Size:", help=help_text.help_dict["chunk_size"], min_value=0, key="add_model_max_chunk_size", - value=output_vector_size, + value=max_chunk_size, ) return model diff --git a/src/server/api/core/databases.py b/src/server/api/core/databases.py deleted file mode 100644 index e74ec78b..00000000 --- a/src/server/api/core/databases.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" - -from typing import Optional, Union -from server.bootstrap import bootstrap - -from common.schema import Database, DatabaseNameType -from common import logging_config - -logger = logging_config.logging.getLogger("api.core.database") - - -##################################################### -# Functions -##################################################### -def get_database(name: Optional[DatabaseNameType] = None) -> Union[list[Database], None]: - """ - Return all Database objects if `name` is not provided, - or the single Database if `name` is provided. - If a `name` is provided and not found, raise exception - """ - database_objects = bootstrap.DATABASE_OBJECTS - - logger.debug("%i databases are defined", len(database_objects)) - database_filtered = [db for db in database_objects if (name is None or db.name == name)] - logger.debug("%i databases after filtering", len(database_filtered)) - - if name and not database_filtered: - raise ValueError(f"{name} not found") - - return database_filtered - - -def create_database(database: Database) -> Database: - """Create a new Database definition""" - database_objects = bootstrap.DATABASE_OBJECTS - - try: - existing = get_database(name=database.name) - if existing: - raise ValueError(f"Database {database.name} already exists") - except ValueError as ex: - if "not found" not in str(ex): - raise - - if any(not getattr(database, key) for key in ("user", "password", "dsn")): - raise ValueError("'user', 'password', and 'dsn' are required") - - database_objects.append(database) - return get_database(name=database.name) - - -def delete_database(name: DatabaseNameType) -> None: - """Remove database from database objects""" - database_objects = bootstrap.DATABASE_OBJECTS - bootstrap.DATABASE_OBJECTS = [d for d in database_objects if d.name != name] diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 3b85da05..29304f22 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -9,8 +9,8 @@ import oracledb from langchain_community.vectorstores import oraclevs as LangchainVS -import server.api.core.databases as core_databases import server.api.core.settings as core_settings +from server.bootstrap.bootstrap import DATABASE_OBJECTS from common.schema import ( Database, @@ -38,6 +38,56 @@ def __init__(self, status_code: int, detail: str): super().__init__(detail) +class ExistsDatabaseError(ValueError): + """Raised when the database already exist.""" + + +class UnknownDatabaseError(ValueError): + """Raised when the database doesn't exist.""" + + +##################################################### +# CRUD Functions +##################################################### +def create(database: Database) -> Database: + """Create a new Database definition""" + + try: + _ = get(name=database.name) + raise ExistsDatabaseError(f"Database: {database.name} already exists") + except UnknownDatabaseError: + pass + + if any(not getattr(database, key) for key in ("user", "password", "dsn")): + raise ValueError("'user', 'password', and 'dsn' are required") + + DATABASE_OBJECTS.append(database) + return get(name=database.name) + + +def get(name: Optional[DatabaseNameType] = None) -> Union[list[Database], None]: + """ + Return all Database objects if `name` is not provided, + or the single Database if `name` is provided. + If a `name` is provided and not found, raise exception + """ + database_objects = DATABASE_OBJECTS + + logger.debug("%i databases are defined", len(database_objects)) + database_filtered = [db for db in database_objects if (name is None or db.name == name)] + logger.debug("%i databases after filtering", len(database_filtered)) + + if name and not database_filtered: + raise UnknownDatabaseError(f"{name} not found") + + return database_filtered + + +def delete(name: DatabaseNameType) -> None: + """Remove database from database objects""" + DATABASE_OBJECTS[:] = [d for d in DATABASE_OBJECTS if d.name != name] + + ##################################################### # Protected Functions ##################################################### @@ -231,7 +281,7 @@ def get_databases( db_name: Optional[DatabaseNameType] = None, validate: bool = False ) -> Union[list[Database], Database, None]: """Return list of Database Objects""" - databases = core_databases.get_database(db_name) + databases = get(db_name) if validate: for db in databases: try: diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index ab9f70d7..c01f02b7 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -89,18 +89,18 @@ def get( def update(payload: schema.Model) -> schema.Model: """Update an existing Model definition""" - (model_update,) = get(model_provider=payload.provider, model_id=payload.id) - if payload.enabled and model_update.api_base and not is_url_accessible(model_update.api_base)[0]: - model_update.enabled = False - raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.") + # Get the existing model from MODEL_OBJECTS (this is a reference to the object in the list) + (model_existing,) = get(model_provider=payload.provider, model_id=payload.id) - for key, value in payload: - if hasattr(model_update, key): - setattr(model_update, key, value) - else: - raise InvalidModelError(f"Model: Invalid setting - {key}.") + # Check URL accessibility if enabling the model + if payload.enabled and payload.api_base and not is_url_accessible(payload.api_base)[0]: + model_existing.enabled = False + raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.") - return model_update + # Update all fields from payload in place + for key, value in payload.model_dump().items(): + setattr(model_existing, key, value) + return model_existing def delete(model_provider: schema.ModelProviderType, model_id: schema.ModelIdType) -> None: diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index 1032187b..d17ffcf4 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -4,6 +4,7 @@ """ from fastapi import APIRouter, HTTPException +import oracledb import server.api.utils.databases as utils_databases @@ -15,7 +16,7 @@ # Validate the DEFAULT Databases try: _ = utils_databases.get_databases(db_name="DEFAULT", validate=True) -except Exception: +except (ValueError, PermissionError, ConnectionError, LookupError, oracledb.DatabaseError): pass auth = APIRouter() diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index 9551638e..c73268c1 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -144,7 +144,7 @@ def _get_base_models_list() -> list[dict]: "provider": "ollama", "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "api_key": "", - "max_chunk_size": 8192, + "max_chunk_size": 512, }, ] diff --git a/tests/server/integration/test_endpoints_models.py b/tests/server/integration/test_endpoints_models.py index 472b2402..48c6e737 100644 --- a/tests/server/integration/test_endpoints_models.py +++ b/tests/server/integration/test_endpoints_models.py @@ -321,6 +321,52 @@ def test_models_update_edge_cases(self, client, auth_headers): ) assert response.status_code == 404 + def test_models_update_max_chunk_size(self, client, auth_headers): + """Test updating max_chunk_size for embedding models (regression test)""" + # Create an embedding model with default max_chunk_size + payload = { + "id": "test-embed-chunk-size", + "enabled": False, + "type": "embed", + "provider": "test_provider", + "api_base": "http://127.0.0.1:11434", + "max_chunk_size": 8192, + } + + # Create the model + response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) + assert response.status_code == 201 + assert response.json()["max_chunk_size"] == 8192 + + # Update the max_chunk_size to 512 + payload["max_chunk_size"] = 512 + response = client.patch( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload + ) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 512 + + # Verify the update persists by fetching the model again + response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 512 + + # Update to a different value to ensure it's not cached + payload["max_chunk_size"] = 1024 + response = client.patch( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload + ) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 1024 + + # Verify again + response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 1024 + + # Clean up + client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) + def test_models_response_schema_validation(self, client, auth_headers): """Test response schema validation for all endpoints""" # Test /v1/models response schema diff --git a/tests/server/unit/api/core/test_core_databases.py b/tests/server/unit/api/core/test_core_databases.py deleted file mode 100644 index 9aacd384..00000000 --- a/tests/server/unit/api/core/test_core_databases.py +++ /dev/null @@ -1,410 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=attribute-defined-outside-init - -from unittest.mock import patch, MagicMock -import pytest - -from server.api.core import databases -from server.bootstrap import bootstrap -from common.schema import Database - - -class TestDatabases: - """Test databases module functionality""" - - def setup_method(self): - """Setup test data before each test""" - self.sample_database = Database(name="test_db", user="test_user", password="test_password", dsn="test_dsn") - self.sample_database_2 = Database( - name="test_db_2", user="test_user_2", password="test_password_2", dsn="test_dsn_2" - ) - - @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") - def test_get_database_all(self, mock_database_objects): - """Test getting all databases when no name is provided""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) - mock_database_objects.__len__ = MagicMock(return_value=2) - - result = databases.get_database() - - assert result == [self.sample_database, self.sample_database_2] - assert len(result) == 2 - - @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") - def test_get_database_by_name_found(self, mock_database_objects): - """Test getting database by name when it exists""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) - mock_database_objects.__len__ = MagicMock(return_value=2) - - result = databases.get_database(name="test_db") - - assert result == [self.sample_database] - assert len(result) == 1 - - @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") - def test_get_database_by_name_not_found(self, mock_database_objects): - """Test getting database by name when it doesn't exist""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database])) - mock_database_objects.__len__ = MagicMock(return_value=1) - - with pytest.raises(ValueError, match="nonexistent not found"): - databases.get_database(name="nonexistent") - - @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") - def test_get_database_empty_list(self, mock_database_objects): - """Test getting databases when list is empty""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([])) - mock_database_objects.__len__ = MagicMock(return_value=0) - - result = databases.get_database() - - assert result == [] - - @patch("server.api.core.databases.bootstrap.DATABASE_OBJECTS") - def test_get_database_empty_list_with_name(self, mock_database_objects): - """Test getting database by name when list is empty""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([])) - mock_database_objects.__len__ = MagicMock(return_value=0) - - with pytest.raises(ValueError, match="test_db not found"): - databases.get_database(name="test_db") - - def test_create_database_success(self, db_container): - """Test successful database creation when database doesn't exist""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - # Clear the list to start fresh - bootstrap.DATABASE_OBJECTS.clear() - - # Create a new database - new_database = Database(name="new_test_db", user="test_user", password="test_password", dsn="test_dsn") - - result = databases.create_database(new_database) - - # Verify database was added - assert len(bootstrap.DATABASE_OBJECTS) == 1 - assert bootstrap.DATABASE_OBJECTS[0].name == "new_test_db" - assert result == [new_database] - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_create_database_already_exists(self, db_container): - """Test database creation when database already exists""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - # Add a database to the list - bootstrap.DATABASE_OBJECTS.clear() - existing_db = Database(name="existing_db", user="test_user", password="test_password", dsn="test_dsn") - bootstrap.DATABASE_OBJECTS.append(existing_db) - - # Try to create a database with the same name - duplicate_db = Database(name="existing_db", user="other_user", password="other_password", dsn="other_dsn") - - # Should raise an error for duplicate database - with pytest.raises(ValueError, match="Database existing_db already exists"): - databases.create_database(duplicate_db) - - # Verify only original database exists - assert len(bootstrap.DATABASE_OBJECTS) == 1 - assert bootstrap.DATABASE_OBJECTS[0] == existing_db - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_create_database_missing_user(self, db_container): - """Test database creation with missing user field""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - bootstrap.DATABASE_OBJECTS.clear() - - # Create database with missing user - incomplete_db = Database(name="incomplete_db", password="test_password", dsn="test_dsn") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create_database(incomplete_db) - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_create_database_missing_password(self, db_container): - """Test database creation with missing password field""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - bootstrap.DATABASE_OBJECTS.clear() - - # Create database with missing password - incomplete_db = Database(name="incomplete_db", user="test_user", dsn="test_dsn") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create_database(incomplete_db) - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_create_database_missing_dsn(self, db_container): - """Test database creation with missing dsn field""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - bootstrap.DATABASE_OBJECTS.clear() - - # Create database with missing dsn - incomplete_db = Database(name="incomplete_db", user="test_user", password="test_password") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create_database(incomplete_db) - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_create_database_multiple_missing_fields(self, db_container): - """Test database creation with multiple missing required fields""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - bootstrap.DATABASE_OBJECTS.clear() - - # Create database with multiple missing fields - incomplete_db = Database(name="incomplete_db") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create_database(incomplete_db) - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_delete_database(self, db_container): - """Test database deletion""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - # Setup test data - db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") - db2 = Database(name="test_db_2", user="user2", password="pass2", dsn="dsn2") - db3 = Database(name="test_db_3", user="user3", password="pass3", dsn="dsn3") - - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend([db1, db2, db3]) - - # Delete middle database - databases.delete_database("test_db_2") - - # Verify deletion - assert len(bootstrap.DATABASE_OBJECTS) == 2 - names = [db.name for db in bootstrap.DATABASE_OBJECTS] - assert "test_db_1" in names - assert "test_db_2" not in names - assert "test_db_3" in names - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_delete_database_nonexistent(self, db_container): - """Test deleting non-existent database""" - assert db_container is not None - - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - # Setup test data - db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.append(db1) - - original_length = len(bootstrap.DATABASE_OBJECTS) - - # Try to delete non-existent database (should not raise error) - databases.delete_database("nonexistent") - - # Verify no change - assert len(bootstrap.DATABASE_OBJECTS) == original_length - assert bootstrap.DATABASE_OBJECTS[0].name == "test_db_1" - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_delete_database_empty_list(self, db_container): - """Test deleting from empty database list""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - bootstrap.DATABASE_OBJECTS.clear() - - # Try to delete from empty list (should not raise error) - databases.delete_database("any_name") - - # Verify still empty - assert len(bootstrap.DATABASE_OBJECTS) == 0 - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_delete_database_multiple_same_name(self, db_container): - """Test deleting when multiple databases have the same name""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - # Setup test data with duplicate names - db1 = Database(name="duplicate", user="user1", password="pass1", dsn="dsn1") - db2 = Database(name="duplicate", user="user2", password="pass2", dsn="dsn2") - db3 = Database(name="other", user="user3", password="pass3", dsn="dsn3") - - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend([db1, db2, db3]) - - # Delete databases with duplicate name - databases.delete_database("duplicate") - - # Verify all duplicates are removed - assert len(bootstrap.DATABASE_OBJECTS) == 1 - assert bootstrap.DATABASE_OBJECTS[0].name == "other" - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(databases, "logger") - assert databases.logger.name == "api.core.database" - - def test_get_database_filters_correctly(self, db_container): - """Test that get_database correctly filters by name""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - # Setup test data - db1 = Database(name="alpha", user="user1", password="pass1", dsn="dsn1") - db2 = Database(name="beta", user="user2", password="pass2", dsn="dsn2") - db3 = Database(name="alpha", user="user3", password="pass3", dsn="dsn3") # Duplicate name - - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend([db1, db2, db3]) - - # Test getting all - all_dbs = databases.get_database() - assert len(all_dbs) == 3 - - # Test getting by specific name - alpha_dbs = databases.get_database(name="alpha") - assert len(alpha_dbs) == 2 - assert all(db.name == "alpha" for db in alpha_dbs) - - beta_dbs = databases.get_database(name="beta") - assert len(beta_dbs) == 1 - assert beta_dbs[0].name == "beta" - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) - - def test_database_model_validation(self, db_container): - """Test Database model validation and optional fields""" - assert db_container is not None - # Test with all required fields - complete_db = Database(name="complete", user="test_user", password="test_password", dsn="test_dsn") - assert complete_db.name == "complete" - assert complete_db.user == "test_user" - assert complete_db.password == "test_password" - assert complete_db.dsn == "test_dsn" - assert complete_db.connected is False # Default value - assert complete_db.tcp_connect_timeout == 5 # Default value - assert complete_db.vector_stores == [] # Default value - - # Test with optional fields - complete_db_with_options = Database( - name="complete_with_options", - user="test_user", - password="test_password", - dsn="test_dsn", - wallet_location="/path/to/wallet", - wallet_password="wallet_pass", - tcp_connect_timeout=10, - ) - assert complete_db_with_options.wallet_location == "/path/to/wallet" - assert complete_db_with_options.wallet_password == "wallet_pass" - assert complete_db_with_options.tcp_connect_timeout == 10 - - def test_create_database_real_scenario(self, db_container): - """Test create_database with realistic data using container DB""" - assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() - - try: - bootstrap.DATABASE_OBJECTS.clear() - - # Create database with realistic configuration - test_db = Database( - name="container_test", - user="PYTEST", - password="OrA_41_3xPl0d3r", - dsn="//localhost:1525/FREEPDB1", - tcp_connect_timeout=10, - ) - - result = databases.create_database(test_db) - - # Verify creation - assert len(bootstrap.DATABASE_OBJECTS) == 1 - created_db = bootstrap.DATABASE_OBJECTS[0] - assert created_db.name == "container_test" - assert created_db.user == "PYTEST" - assert created_db.dsn == "//localhost:1525/FREEPDB1" - assert created_db.tcp_connect_timeout == 10 - assert result == [test_db] - - finally: - # Restore original state - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) diff --git a/tests/server/unit/api/utils/test_utils_databases.py b/tests/server/unit/api/utils/test_utils_databases.py index c6db7c1d..786d2ef0 100644 --- a/tests/server/unit/api/utils/test_utils_databases.py +++ b/tests/server/unit/api/utils/test_utils_databases.py @@ -14,10 +14,404 @@ from server.api.utils import databases from server.api.utils.databases import DbException -from server.bootstrap import bootstrap from common.schema import Database +class TestDatabases: + """Test databases module functionality""" + def setup_method(self): + """Setup test data before each test""" + self.sample_database = Database(name="test_db", user="test_user", password="test_password", dsn="test_dsn") + self.sample_database_2 = Database( + name="test_db_2", user="test_user_2", password="test_password_2", dsn="test_dsn_2" + ) + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_all(self, mock_database_objects): + """Test getting all databases when no name is provided""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) + mock_database_objects.__len__ = MagicMock(return_value=2) + + result = databases.get() + + assert result == [self.sample_database, self.sample_database_2] + assert len(result) == 2 + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_by_name_found(self, mock_database_objects): + """Test getting database by name when it exists""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database, self.sample_database_2])) + mock_database_objects.__len__ = MagicMock(return_value=2) + + result = databases.get(name="test_db") + + assert result == [self.sample_database] + assert len(result) == 1 + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_by_name_not_found(self, mock_database_objects): + """Test getting database by name when it doesn't exist""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([self.sample_database])) + mock_database_objects.__len__ = MagicMock(return_value=1) + + with pytest.raises(ValueError, match="nonexistent not found"): + databases.get(name="nonexistent") + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_empty_list(self, mock_database_objects): + """Test getting databases when list is empty""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([])) + mock_database_objects.__len__ = MagicMock(return_value=0) + + result = databases.get() + + assert result == [] + + @patch("server.api.utils.databases.DATABASE_OBJECTS") + def test_get_empty_list_with_name(self, mock_database_objects): + """Test getting database by name when list is empty""" + mock_database_objects.__iter__ = MagicMock(return_value=iter([])) + mock_database_objects.__len__ = MagicMock(return_value=0) + + with pytest.raises(ValueError, match="test_db not found"): + databases.get(name="test_db") + + def test_create_success(self, db_container): + """Test successful database creation when database doesn't exist""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + # Clear the list to start fresh + databases.DATABASE_OBJECTS.clear() + + # Create a new database + new_database = Database(name="new_test_db", user="test_user", password="test_password", dsn="test_dsn") + + result = databases.create(new_database) + + # Verify database was added + assert len(databases.DATABASE_OBJECTS) == 1 + assert databases.DATABASE_OBJECTS[0].name == "new_test_db" + assert result == [new_database] + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_already_exists(self, db_container): + """Test database creation when database already exists""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + # Add a database to the list + databases.DATABASE_OBJECTS.clear() + existing_db = Database(name="existing_db", user="test_user", password="test_password", dsn="test_dsn") + databases.DATABASE_OBJECTS.append(existing_db) + + # Try to create a database with the same name + duplicate_db = Database(name="existing_db", user="other_user", password="other_password", dsn="other_dsn") + + # Should raise an error for duplicate database + with pytest.raises(ValueError, match="Database: existing_db already exists"): + databases.create(duplicate_db) + + # Verify only original database exists + assert len(databases.DATABASE_OBJECTS) == 1 + assert databases.DATABASE_OBJECTS[0] == existing_db + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_missing_user(self, db_container): + """Test database creation with missing user field""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + databases.DATABASE_OBJECTS.clear() + + # Create database with missing user + incomplete_db = Database(name="incomplete_db", password="test_password", dsn="test_dsn") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create(incomplete_db) + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_missing_password(self, db_container): + """Test database creation with missing password field""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + databases.DATABASE_OBJECTS.clear() + + # Create database with missing password + incomplete_db = Database(name="incomplete_db", user="test_user", dsn="test_dsn") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create(incomplete_db) + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_missing_dsn(self, db_container): + """Test database creation with missing dsn field""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + databases.DATABASE_OBJECTS.clear() + + # Create database with missing dsn + incomplete_db = Database(name="incomplete_db", user="test_user", password="test_password") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create(incomplete_db) + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_create_multiple_missing_fields(self, db_container): + """Test database creation with multiple missing required fields""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + databases.DATABASE_OBJECTS.clear() + + # Create database with multiple missing fields + incomplete_db = Database(name="incomplete_db") + + with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): + databases.create(incomplete_db) + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete(self, db_container): + """Test database deletion""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + # Setup test data + db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="test_db_2", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="test_db_3", user="user3", password="pass3", dsn="dsn3") + + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Delete middle database + databases.delete("test_db_2") + + # Verify deletion + assert len(databases.DATABASE_OBJECTS) == 2 + names = [db.name for db in databases.DATABASE_OBJECTS] + assert "test_db_1" in names + assert "test_db_2" not in names + assert "test_db_3" in names + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_nonexistent(self, db_container): + """Test deleting non-existent database""" + assert db_container is not None + + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + # Setup test data + db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.append(db1) + + original_length = len(databases.DATABASE_OBJECTS) + + # Try to delete non-existent database (should not raise error) + databases.delete("nonexistent") + + # Verify no change + assert len(databases.DATABASE_OBJECTS) == original_length + assert databases.DATABASE_OBJECTS[0].name == "test_db_1" + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_empty_list(self, db_container): + """Test deleting from empty database list""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + databases.DATABASE_OBJECTS.clear() + + # Try to delete from empty list (should not raise error) + databases.delete("any_name") + + # Verify still empty + assert len(databases.DATABASE_OBJECTS) == 0 + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_delete_multiple_same_name(self, db_container): + """Test deleting when multiple databases have the same name""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + # Setup test data with duplicate names + db1 = Database(name="duplicate", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="duplicate", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="other", user="user3", password="pass3", dsn="dsn3") + + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Delete databases with duplicate name + databases.delete("duplicate") + + # Verify all duplicates are removed + assert len(databases.DATABASE_OBJECTS) == 1 + assert databases.DATABASE_OBJECTS[0].name == "other" + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_logger_exists(self): + """Test that logger is properly configured""" + assert hasattr(databases, "logger") + assert databases.logger.name == "api.utils.database" + + def test_get_filters_correctly(self, db_container): + """Test that get correctly filters by name""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + # Setup test data + db1 = Database(name="alpha", user="user1", password="pass1", dsn="dsn1") + db2 = Database(name="beta", user="user2", password="pass2", dsn="dsn2") + db3 = Database(name="alpha", user="user3", password="pass3", dsn="dsn3") # Duplicate name + + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend([db1, db2, db3]) + + # Test getting all + all_dbs = databases.get() + assert len(all_dbs) == 3 + + # Test getting by specific name + alpha_dbs = databases.get(name="alpha") + assert len(alpha_dbs) == 2 + assert all(db.name == "alpha" for db in alpha_dbs) + + beta_dbs = databases.get(name="beta") + assert len(beta_dbs) == 1 + assert beta_dbs[0].name == "beta" + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + + def test_database_model_validation(self, db_container): + """Test Database model validation and optional fields""" + assert db_container is not None + # Test with all required fields + complete_db = Database(name="complete", user="test_user", password="test_password", dsn="test_dsn") + assert complete_db.name == "complete" + assert complete_db.user == "test_user" + assert complete_db.password == "test_password" + assert complete_db.dsn == "test_dsn" + assert complete_db.connected is False # Default value + assert complete_db.tcp_connect_timeout == 5 # Default value + assert complete_db.vector_stores == [] # Default value + + # Test with optional fields + complete_db_with_options = Database( + name="complete_with_options", + user="test_user", + password="test_password", + dsn="test_dsn", + wallet_location="/path/to/wallet", + wallet_password="wallet_pass", + tcp_connect_timeout=10, + ) + assert complete_db_with_options.wallet_location == "/path/to/wallet" + assert complete_db_with_options.wallet_password == "wallet_pass" + assert complete_db_with_options.tcp_connect_timeout == 10 + + def test_create_real_scenario(self, db_container): + """Test create with realistic data using container DB""" + assert db_container is not None + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() + + try: + databases.DATABASE_OBJECTS.clear() + + # Create database with realistic configuration + test_db = Database( + name="container_test", + user="PYTEST", + password="OrA_41_3xPl0d3r", + dsn="//localhost:1525/FREEPDB1", + tcp_connect_timeout=10, + ) + + result = databases.create(test_db) + + # Verify creation + assert len(databases.DATABASE_OBJECTS) == 1 + created_db = databases.DATABASE_OBJECTS[0] + assert created_db.name == "container_test" + assert created_db.user == "PYTEST" + assert created_db.dsn == "//localhost:1525/FREEPDB1" + assert created_db.tcp_connect_timeout == 10 + assert result == [test_db] + + finally: + # Restore original state + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) + class TestDbException: """Test custom database exception class""" @@ -551,36 +945,36 @@ def test_drop_vs_calls_langchain(self, mock_drop_table): mock_drop_table.assert_called_once_with(mock_connection, vs_name) - def test_get_databases_without_validation(self, db_container): - """Test get_databases without validation""" + def test_get_without_validation(self, db_container): + """Test get without validation""" assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.append(self.sample_database) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.append(self.sample_database) # Test getting all databases - result = databases.get_databases() + result = databases.get() assert isinstance(result, list) assert len(result) == 1 assert result[0].name == "test_db" assert result[0].connected is False # No validation, so not connected finally: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) - def test_get_databases_with_validation(self, db_container): - """Test get_databases with validation using real database""" + def test_get_with_validation(self, db_container): + """Test get with validation using real database""" assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.append(self.sample_database) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.append(self.sample_database) # Test getting all databases with validation result = databases.get_databases(validate=True) @@ -592,24 +986,24 @@ def test_get_databases_with_validation(self, db_container): finally: # Clean up connections - for db in bootstrap.DATABASE_OBJECTS: + for db in databases.DATABASE_OBJECTS: if db.connection: databases.disconnect(db.connection) - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) - def test_get_databases_by_name(self, db_container): - """Test get_databases by specific name""" + def test_get_by_name(self, db_container): + """Test get by specific name""" assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.clear() db1 = Database(name="db1", user="user1", password="pass1", dsn="dsn1") db2 = Database(name="db2", user=TEST_CONFIG["db_username"], password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) - bootstrap.DATABASE_OBJECTS.extend([db1, db2]) + databases.DATABASE_OBJECTS.extend([db1, db2]) # Test getting specific database result = databases.get_databases(db_name="db2") @@ -617,20 +1011,20 @@ def test_get_databases_by_name(self, db_container): assert result.name == "db2" finally: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) - def test_get_databases_validation_failure(self, db_container): - """Test get_databases with validation when connection fails""" + def test_get_validation_failure(self, db_container): + """Test get with validation when connection fails""" assert db_container is not None - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.clear() # Add database with invalid credentials invalid_db = Database(name="invalid", user="invalid", password="invalid", dsn="invalid") - bootstrap.DATABASE_OBJECTS.append(invalid_db) + databases.DATABASE_OBJECTS.append(invalid_db) # Test validation with invalid database (should continue without error) result = databases.get_databases(validate=True) @@ -639,8 +1033,8 @@ def test_get_databases_validation_failure(self, db_container): assert result[0].connected is False # Should remain False due to connection failure finally: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) @patch("server.api.core.settings.get_client_settings") def test_get_client_database_default(self, mock_get_settings, db_container): @@ -652,22 +1046,22 @@ def test_get_client_database_default(self, mock_get_settings, db_container): mock_settings.selectai = None mock_get_settings.return_value = mock_settings - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.clear() default_db = Database(name="DEFAULT", user=TEST_CONFIG["db_username"], password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) - bootstrap.DATABASE_OBJECTS.append(default_db) + databases.DATABASE_OBJECTS.append(default_db) result = databases.get_client_database("test_client") assert isinstance(result, Database) assert result.name == "DEFAULT" finally: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) @patch("server.api.core.settings.get_client_settings") def test_get_client_database_with_vector_search(self, mock_get_settings, db_container): @@ -681,22 +1075,22 @@ def test_get_client_database_with_vector_search(self, mock_get_settings, db_cont mock_settings.selectai = None mock_get_settings.return_value = mock_settings - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.clear() vector_db = Database(name="VECTOR_DB", user=TEST_CONFIG["db_username"], password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) - bootstrap.DATABASE_OBJECTS.append(vector_db) + databases.DATABASE_OBJECTS.append(vector_db) result = databases.get_client_database("test_client") assert isinstance(result, Database) assert result.name == "VECTOR_DB" finally: - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) @patch("server.api.core.settings.get_client_settings") def test_get_client_database_with_validation(self, mock_get_settings, db_container): @@ -708,14 +1102,14 @@ def test_get_client_database_with_validation(self, mock_get_settings, db_contain mock_settings.selectai = None mock_get_settings.return_value = mock_settings - # Use real bootstrap DATABASE_OBJECTS - original_db_objects = bootstrap.DATABASE_OBJECTS.copy() + # Use real DATABASE_OBJECTS + original_db_objects = databases.DATABASE_OBJECTS.copy() try: - bootstrap.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.clear() default_db = Database(name="DEFAULT", user=TEST_CONFIG["db_username"], password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"]) - bootstrap.DATABASE_OBJECTS.append(default_db) + databases.DATABASE_OBJECTS.append(default_db) result = databases.get_client_database("test_client", validate=True) assert isinstance(result, Database) @@ -725,11 +1119,11 @@ def test_get_client_database_with_validation(self, mock_get_settings, db_contain finally: # Clean up connections - for db in bootstrap.DATABASE_OBJECTS: + for db in databases.DATABASE_OBJECTS: if db.connection: databases.disconnect(db.connection) - bootstrap.DATABASE_OBJECTS.clear() - bootstrap.DATABASE_OBJECTS.extend(original_db_objects) + databases.DATABASE_OBJECTS.clear() + databases.DATABASE_OBJECTS.extend(original_db_objects) def test_logger_exists(self): """Test that logger is properly configured""" diff --git a/tests/server/unit/api/utils/test_utils_models.py b/tests/server/unit/api/utils/test_utils_models.py index cd3331a2..822fb1f4 100644 --- a/tests/server/unit/api/utils/test_utils_models.py +++ b/tests/server/unit/api/utils/test_utils_models.py @@ -230,6 +230,69 @@ def test_update_success(self, mock_url_check): assert result.temperature == 0.8 + @patch("server.api.utils.models.MODEL_OBJECTS", []) + @patch("server.api.utils.models.is_url_accessible") + def test_update_embedding_model_max_chunk_size(self, mock_url_check): + """Test updating max_chunk_size for embedding model (regression test for bug)""" + # Create an embedding model with default max_chunk_size + embed_model = Model( + id="test-embed-model", + provider="ollama", + type="embed", + enabled=True, + api_base="http://127.0.0.1:11434", + max_chunk_size=8192, + ) + models.MODEL_OBJECTS.append(embed_model) + mock_url_check.return_value = (True, None) + + # Update the max_chunk_size to 512 + update_payload = Model( + id="test-embed-model", + provider="ollama", + type="embed", + enabled=True, + api_base="http://127.0.0.1:11434", + max_chunk_size=512, + ) + + result = models.update(update_payload) + + # Verify the update was successful + assert result.max_chunk_size == 512 + assert result.id == "test-embed-model" + assert result.provider == "ollama" + + # Verify the model in MODEL_OBJECTS was updated + (updated_model,) = models.get(model_provider="ollama", model_id="test-embed-model") + assert updated_model.max_chunk_size == 512 + + @patch("server.api.utils.models.MODEL_OBJECTS", []) + @patch("server.api.utils.models.is_url_accessible") + def test_update_multiple_fields(self, mock_url_check): + """Test updating multiple fields at once""" + # Create a model + models.MODEL_OBJECTS.append(self.sample_model) + mock_url_check.return_value = (True, None) + + # Update multiple fields + update_payload = Model( + id="test-model", + provider="openai", + type="ll", + enabled=False, # Changed from True + api_base="https://api.openai.com/v2", # Changed + temperature=0.5, # Changed + max_tokens=2048, # Changed + ) + + result = models.update(update_payload) + + assert result.enabled is False + assert result.api_base == "https://api.openai.com/v2" + assert result.temperature == 0.5 + assert result.max_tokens == 2048 + @patch("server.api.utils.models.get") def test_get_full_config_success(self, mock_get_model): """Test successful full config retrieval"""