diff --git a/src/basic_memory/services/initialization.py b/src/basic_memory/services/initialization.py index 0d4c48b89..bcb04e73e 100644 --- a/src/basic_memory/services/initialization.py +++ b/src/basic_memory/services/initialization.py @@ -150,13 +150,6 @@ async def initialize_file_sync( ) project_repository = ProjectRepository(session_maker) - # Initialize watch service - watch_service = WatchService( - app_config=app_config, - project_repository=project_repository, - quiet=True, - ) - # Get active projects active_projects = await project_repository.get_active_projects() @@ -196,14 +189,26 @@ async def initialize_file_sync( except Exception as e: # pragma: no cover logger.warning(f"Could not update migration status: {e}") - # Then start the watch service in the background + # Then start the watch service in the background with restart capability logger.info("Starting watch service for all projects") - # run the watch service - try: - await watch_service.run() - logger.info("Watch service started") - except Exception as e: # pragma: no cover - logger.error(f"Error starting watch service: {e}") + + while True: + try: + # Create a fresh watch service instance to pick up new projects + watch_service = WatchService( + app_config=app_config, + project_repository=project_repository, + quiet=True, + ) + + await watch_service.run() + logger.info("Watch service exited normally") + break # Normal exit + + except Exception as e: # pragma: no cover + logger.error(f"Error in watch service: {e}") + # Don't restart on unexpected errors + break return None diff --git a/src/basic_memory/services/project_service.py b/src/basic_memory/services/project_service.py index df6b6311d..be39c0677 100644 --- a/src/basic_memory/services/project_service.py +++ b/src/basic_memory/services/project_service.py @@ -108,6 +108,9 @@ async def add_project(self, name: str, path: str, set_default: bool = False) -> logger.info(f"Project '{name}' added at {resolved_path}") + # Signal watch service restart since project list changed + self._signal_watch_service_restart() + async def remove_project(self, name: str) -> None: """Remove a project from configuration and database. @@ -130,6 +133,9 @@ async def remove_project(self, name: str) -> None: logger.info(f"Project '{name}' removed from configuration and database") + # Signal watch service restart since project list changed + self._signal_watch_service_restart() + async def set_default_project(self, name: str) -> None: """Set the default project in configuration and database. @@ -283,6 +289,14 @@ async def synchronize_projects(self) -> None: # pragma: no cover logger.info("Project synchronization complete") + # Signal watch service restart if projects changed + projects_changed = len(config_projects) != len(db_projects_by_permalink) or set( + config_projects.keys() + ) != set(db_projects_by_permalink.keys()) + + if projects_changed: + self._signal_watch_service_restart() + # Refresh MCP session to ensure it's in sync with current config try: from basic_memory.mcp.project_session import session @@ -292,6 +306,16 @@ async def synchronize_projects(self) -> None: # pragma: no cover # MCP components might not be available in all contexts logger.debug("MCP session not available, skipping session refresh") + def _signal_watch_service_restart(self) -> None: + """Signal the watch service to restart by creating a restart signal file.""" + restart_signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + try: + restart_signal_path.parent.mkdir(parents=True, exist_ok=True) + restart_signal_path.write_text(str(datetime.now())) + logger.info("Signaled watch service restart") + except Exception as e: + logger.warning(f"Failed to signal watch service restart: {e}") + async def update_project( # pragma: no cover self, name: str, updated_path: Optional[str] = None, is_active: Optional[bool] = None ) -> None: diff --git a/src/basic_memory/sync/watch_service.py b/src/basic_memory/sync/watch_service.py index 022ef3ac3..e38cb39a6 100644 --- a/src/basic_memory/sync/watch_service.py +++ b/src/basic_memory/sync/watch_service.py @@ -85,6 +85,9 @@ def __init__( # quiet mode for mcp so it doesn't mess up stdout self.console = Console(quiet=quiet) + # Restart signal file path + self.restart_signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + async def run(self): # pragma: no cover """Watch for file changes and sync them""" @@ -109,6 +112,11 @@ async def run(self): # pragma: no cover watch_filter=self.filter_changes, recursive=True, ): + # Check for restart signal + if await self.check_restart_signal(): + logger.info("Restart signal detected, exiting watch service...") + return + # group changes by project project_changes = defaultdict(list) for change, path in changes: @@ -161,6 +169,18 @@ def filter_changes(self, change: Change, path: str) -> bool: # pragma: no cover return True + async def check_restart_signal(self) -> bool: + """Check if a restart signal file exists and remove it if found.""" + try: + if self.restart_signal_path.exists(): + # Remove the signal file + self.restart_signal_path.unlink() + logger.info("Found and removed restart signal file") + return True + except Exception as e: + logger.warning(f"Error checking restart signal: {e}") + return False + async def write_status(self): """Write current state to status file""" self.status_path.write_text(WatchServiceState.model_dump_json(self.state, indent=2)) diff --git a/tests/services/test_project_service.py b/tests/services/test_project_service.py index 6695c4ee4..7d8ff328f 100644 --- a/tests/services/test_project_service.py +++ b/tests/services/test_project_service.py @@ -601,3 +601,96 @@ async def test_synchronize_projects_handles_case_sensitivity_bug( db_project = await project_service.repository.get_by_name(name) if db_project: await project_service.repository.delete(db_project.id) + + +@pytest.mark.asyncio +async def test_signal_watch_service_restart(project_service: ProjectService, tmp_path): + """Test that _signal_watch_service_restart creates the correct signal file.""" + # Call the signal method + project_service._signal_watch_service_restart() + + # Check that the signal file was created + restart_signal_path = tmp_path.parent / ".basic-memory" / "restart-watch-service" + # Since it uses Path.home(), we need to check the actual location + from pathlib import Path + + actual_signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + + assert actual_signal_path.exists() + + # Verify it contains a timestamp + content = actual_signal_path.read_text() + assert content # Should have some content (timestamp) + + # Clean up + actual_signal_path.unlink() + + +@pytest.mark.asyncio +async def test_add_project_signals_restart(project_service: ProjectService, tmp_path): + """Test that adding a project signals watch service restart.""" + test_project_name = f"test-signal-restart-{os.urandom(4).hex()}" + test_project_path = str(tmp_path / "test-signal-restart") + + # Make sure the test directory exists + os.makedirs(test_project_path, exist_ok=True) + + # Remove any existing signal file + from pathlib import Path + + signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + if signal_path.exists(): + signal_path.unlink() + + try: + # Add the project - this should create the signal file + await project_service.add_project(test_project_name, test_project_path) + + # Verify signal file was created + assert signal_path.exists() + + finally: + # Clean up + if signal_path.exists(): + signal_path.unlink() + if test_project_name in project_service.projects: + await project_service.remove_project(test_project_name) + + +@pytest.mark.asyncio +async def test_remove_project_signals_restart(project_service: ProjectService, tmp_path): + """Test that removing a project signals watch service restart.""" + test_project_name = f"test-remove-signal-{os.urandom(4).hex()}" + test_project_path = str(tmp_path / "test-remove-signal") + + # Make sure the test directory exists + os.makedirs(test_project_path, exist_ok=True) + + from pathlib import Path + + signal_path = Path.home() / ".basic-memory" / "restart-watch-service" + + try: + # Add the project first + await project_service.add_project(test_project_name, test_project_path) + + # Remove any existing signal file (from the add) + if signal_path.exists(): + signal_path.unlink() + + # Remove the project - this should create the signal file + await project_service.remove_project(test_project_name) + + # Verify signal file was created + assert signal_path.exists() + + finally: + # Clean up + if signal_path.exists(): + signal_path.unlink() + # Project should already be removed, but double-check + if test_project_name in project_service.projects: + try: + await project_service.remove_project(test_project_name) + except: + pass diff --git a/tests/sync/test_watch_service.py b/tests/sync/test_watch_service.py index e70ca87b8..08d6b7397 100644 --- a/tests/sync/test_watch_service.py +++ b/tests/sync/test_watch_service.py @@ -448,3 +448,63 @@ def test_is_project_path(watch_service, tmp_path): # Test the project path itself assert watch_service.is_project_path(project, project_path) is False + + +@pytest.mark.asyncio +async def test_check_restart_signal_no_signal(app_config, project_repository): + """Test check_restart_signal when no signal file exists.""" + watch_service = WatchService(app_config, project_repository) + + # Ensure no signal file exists + if watch_service.restart_signal_path.exists(): + watch_service.restart_signal_path.unlink() + + # Should return False when no signal file exists + result = await watch_service.check_restart_signal() + assert result is False + + +@pytest.mark.asyncio +async def test_check_restart_signal_with_signal(app_config, project_repository): + """Test check_restart_signal when signal file exists.""" + watch_service = WatchService(app_config, project_repository) + + try: + # Create signal file + watch_service.restart_signal_path.parent.mkdir(parents=True, exist_ok=True) + watch_service.restart_signal_path.write_text("test signal") + + # Should return True and remove the signal file + result = await watch_service.check_restart_signal() + assert result is True + assert not watch_service.restart_signal_path.exists() + + finally: + # Cleanup + if watch_service.restart_signal_path.exists(): + watch_service.restart_signal_path.unlink() + + +@pytest.mark.asyncio +async def test_check_restart_signal_error_handling(app_config, project_repository): + """Test check_restart_signal handles errors gracefully.""" + watch_service = WatchService(app_config, project_repository) + + # Create a directory where the signal file should be (to cause an error) + watch_service.restart_signal_path.parent.mkdir(parents=True, exist_ok=True) + if watch_service.restart_signal_path.exists(): + watch_service.restart_signal_path.unlink() + watch_service.restart_signal_path.mkdir() + + try: + # Should handle the error and return False + result = await watch_service.check_restart_signal() + assert result is False + + finally: + # Cleanup + if watch_service.restart_signal_path.exists(): + if watch_service.restart_signal_path.is_dir(): + watch_service.restart_signal_path.rmdir() + else: + watch_service.restart_signal_path.unlink()