From cea2410b15c0c34951f7635a9e47374d3c2766df Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 25 Jun 2025 23:51:14 -0500 Subject: [PATCH] fix: restart watch service when project configuration changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves issue where new projects created via MCP tools weren't being watched by the file synchronization service. The watch service now automatically restarts when projects are added, removed, or when the project configuration changes. Implementation: - Add file-based signaling mechanism for watch service restarts - ProjectService signals restart when projects are added/removed - WatchService checks for restart signal on each file change batch - Watch loop recreates service instances to pick up new projects - Add comprehensive tests for the restart signal mechanism Note: New projects will only be picked up for watching after the next file change occurs in any existing project, as the restart signal is checked during file change processing. Fixes #156 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/basic_memory/services/initialization.py | 33 ++++--- src/basic_memory/services/project_service.py | 24 +++++ src/basic_memory/sync/watch_service.py | 20 +++++ tests/services/test_project_service.py | 93 ++++++++++++++++++++ tests/sync/test_watch_service.py | 60 +++++++++++++ 5 files changed, 216 insertions(+), 14 deletions(-) diff --git a/src/basic_memory/services/initialization.py b/src/basic_memory/services/initialization.py index 0d4c48b89..bcb04e73e 100644 --- a/src/basic_memory/services/initialization.py +++ b/src/basic_memory/services/initialization.py @@ -150,13 +150,6 @@ async def initialize_file_sync( ) project_repository = ProjectRepository(session_maker) - # Initialize watch service - watch_service = WatchService( - app_config=app_config, - project_repository=project_repository, - quiet=True, - ) - # Get active projects active_projects = await project_repository.get_active_projects() @@ -196,14 +189,26 @@ async def initialize_file_sync( except Exception as e: # pragma: no cover logger.warning(f"Could not update migration status: {e}") - # Then start the watch service in the background + # Then start the watch service in the background with restart capability logger.info("Starting watch service for all projects") - # run the watch service - try: - await watch_service.run() - logger.info("Watch service started") - except Exception as e: # pragma: no cover - logger.error(f"Error starting watch service: {e}") + + while True: + try: + # Create a fresh watch service instance to pick up new projects + watch_service = WatchService( + app_config=app_config, + project_repository=project_repository, + quiet=True, + ) + + await watch_service.run() + logger.info("Watch service exited normally") + break # Normal exit + + except Exception as e: # pragma: no cover + logger.error(f"Error in watch service: {e}") + # Don't restart on unexpected errors + break return None diff --git a/src/basic_memory/services/project_service.py b/src/basic_memory/services/project_service.py index df6b6311d..be39c0677 100644 --- a/src/basic_memory/services/project_service.py +++ b/src/basic_memory/services/project_service.py @@ -108,6 +108,9 @@ async def add_project(self, name: str, path: str, set_default: bool = False) -> logger.info(f"Project '{name}' added at {resolved_path}") + # Signal watch service restart since project list changed + self._signal_watch_service_restart() + async def remove_project(self, name: str) -> None: """Remove a project from configuration and database. @@ -130,6 +133,9 @@ async def remove_project(self, name: str) -> None: logger.info(f"Project '{name}' removed from configuration and database") + # Signal watch service restart since project list changed + self._signal_watch_service_restart() + async def set_default_project(self, name: str) -> None: """Set the default project in configuration and database. @@ -283,6 +289,14 @@ async def synchronize_projects(self) -> None: # pragma: no cover logger.info("Project synchronization complete") + # Signal watch service restart if projects changed + projects_changed = len(config_projects) != len(db_projects_by_permalink) or set( + config_projects.keys() + ) != set(db_projects_by_permalink.keys()) + + if projects_changed: + self._signal_watch_service_restart() + # Refresh MCP session to ensure it's in sync with current config try: from basic_memory.mcp.project_session import session @@ -292,6 +306,16 @@ async def synchronize_projects(self) -> None: # pragma: no cover # MCP components might not be available in all contexts logger.debug("MCP session not available, skipping session refresh") + def _signal_watch_service_restart(self) -> None: + """Signal the watch service to restart by creating a restart signal file.""" + restart_signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + try: + restart_signal_path.parent.mkdir(parents=True, exist_ok=True) + restart_signal_path.write_text(str(datetime.now())) + logger.info("Signaled watch service restart") + except Exception as e: + logger.warning(f"Failed to signal watch service restart: {e}") + async def update_project( # pragma: no cover self, name: str, updated_path: Optional[str] = None, is_active: Optional[bool] = None ) -> None: diff --git a/src/basic_memory/sync/watch_service.py b/src/basic_memory/sync/watch_service.py index 022ef3ac3..e38cb39a6 100644 --- a/src/basic_memory/sync/watch_service.py +++ b/src/basic_memory/sync/watch_service.py @@ -85,6 +85,9 @@ def __init__( # quiet mode for mcp so it doesn't mess up stdout self.console = Console(quiet=quiet) + # Restart signal file path + self.restart_signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + async def run(self): # pragma: no cover """Watch for file changes and sync them""" @@ -109,6 +112,11 @@ async def run(self): # pragma: no cover watch_filter=self.filter_changes, recursive=True, ): + # Check for restart signal + if await self.check_restart_signal(): + logger.info("Restart signal detected, exiting watch service...") + return + # group changes by project project_changes = defaultdict(list) for change, path in changes: @@ -161,6 +169,18 @@ def filter_changes(self, change: Change, path: str) -> bool: # pragma: no cover return True + async def check_restart_signal(self) -> bool: + """Check if a restart signal file exists and remove it if found.""" + try: + if self.restart_signal_path.exists(): + # Remove the signal file + self.restart_signal_path.unlink() + logger.info("Found and removed restart signal file") + return True + except Exception as e: + logger.warning(f"Error checking restart signal: {e}") + return False + async def write_status(self): """Write current state to status file""" self.status_path.write_text(WatchServiceState.model_dump_json(self.state, indent=2)) diff --git a/tests/services/test_project_service.py b/tests/services/test_project_service.py index 6695c4ee4..7d8ff328f 100644 --- a/tests/services/test_project_service.py +++ b/tests/services/test_project_service.py @@ -601,3 +601,96 @@ async def test_synchronize_projects_handles_case_sensitivity_bug( db_project = await project_service.repository.get_by_name(name) if db_project: await project_service.repository.delete(db_project.id) + + +@pytest.mark.asyncio +async def test_signal_watch_service_restart(project_service: ProjectService, tmp_path): + """Test that _signal_watch_service_restart creates the correct signal file.""" + # Call the signal method + project_service._signal_watch_service_restart() + + # Check that the signal file was created + restart_signal_path = tmp_path.parent / ".basic-memory" / "restart-watch-service" + # Since it uses Path.home(), we need to check the actual location + from pathlib import Path + + actual_signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + + assert actual_signal_path.exists() + + # Verify it contains a timestamp + content = actual_signal_path.read_text() + assert content # Should have some content (timestamp) + + # Clean up + actual_signal_path.unlink() + + +@pytest.mark.asyncio +async def test_add_project_signals_restart(project_service: ProjectService, tmp_path): + """Test that adding a project signals watch service restart.""" + test_project_name = f"test-signal-restart-{os.urandom(4).hex()}" + test_project_path = str(tmp_path / "test-signal-restart") + + # Make sure the test directory exists + os.makedirs(test_project_path, exist_ok=True) + + # Remove any existing signal file + from pathlib import Path + + signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + if signal_path.exists(): + signal_path.unlink() + + try: + # Add the project - this should create the signal file + await project_service.add_project(test_project_name, test_project_path) + + # Verify signal file was created + assert signal_path.exists() + + finally: + # Clean up + if signal_path.exists(): + signal_path.unlink() + if test_project_name in project_service.projects: + await project_service.remove_project(test_project_name) + + +@pytest.mark.asyncio +async def test_remove_project_signals_restart(project_service: ProjectService, tmp_path): + """Test that removing a project signals watch service restart.""" + test_project_name = f"test-remove-signal-{os.urandom(4).hex()}" + test_project_path = str(tmp_path / "test-remove-signal") + + # Make sure the test directory exists + os.makedirs(test_project_path, exist_ok=True) + + from pathlib import Path + + signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + + try: + # Add the project first + await project_service.add_project(test_project_name, test_project_path) + + # Remove any existing signal file (from the add) + if signal_path.exists(): + signal_path.unlink() + + # Remove the project - this should create the signal file + await project_service.remove_project(test_project_name) + + # Verify signal file was created + assert signal_path.exists() + + finally: + # Clean up + if signal_path.exists(): + signal_path.unlink() + # Project should already be removed, but double-check + if test_project_name in project_service.projects: + try: + await project_service.remove_project(test_project_name) + except: + pass diff --git a/tests/sync/test_watch_service.py b/tests/sync/test_watch_service.py index e70ca87b8..08d6b7397 100644 --- a/tests/sync/test_watch_service.py +++ b/tests/sync/test_watch_service.py @@ -448,3 +448,63 @@ def test_is_project_path(watch_service, tmp_path): # Test the project path itself assert watch_service.is_project_path(project, project_path) is False + + +@pytest.mark.asyncio +async def test_check_restart_signal_no_signal(app_config, project_repository): + """Test check_restart_signal when no signal file exists.""" + watch_service = WatchService(app_config, project_repository) + + # Ensure no signal file exists + if watch_service.restart_signal_path.exists(): + watch_service.restart_signal_path.unlink() + + # Should return False when no signal file exists + result = await watch_service.check_restart_signal() + assert result is False + + +@pytest.mark.asyncio +async def test_check_restart_signal_with_signal(app_config, project_repository): + """Test check_restart_signal when signal file exists.""" + watch_service = WatchService(app_config, project_repository) + + try: + # Create signal file + watch_service.restart_signal_path.parent.mkdir(parents=True, exist_ok=True) + watch_service.restart_signal_path.write_text("test signal") + + # Should return True and remove the signal file + result = await watch_service.check_restart_signal() + assert result is True + assert not watch_service.restart_signal_path.exists() + + finally: + # Cleanup + if watch_service.restart_signal_path.exists(): + watch_service.restart_signal_path.unlink() + + +@pytest.mark.asyncio +async def test_check_restart_signal_error_handling(app_config, project_repository): + """Test check_restart_signal handles errors gracefully.""" + watch_service = WatchService(app_config, project_repository) + + # Create a directory where the signal file should be (to cause an error) + watch_service.restart_signal_path.parent.mkdir(parents=True, exist_ok=True) + if watch_service.restart_signal_path.exists(): + watch_service.restart_signal_path.unlink() + watch_service.restart_signal_path.mkdir() + + try: + # Should handle the error and return False + result = await watch_service.check_restart_signal() + assert result is False + + finally: + # Cleanup + if watch_service.restart_signal_path.exists(): + if watch_service.restart_signal_path.is_dir(): + watch_service.restart_signal_path.rmdir() + else: + watch_service.restart_signal_path.unlink()