Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions src/basic_memory/services/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,6 @@ async def initialize_file_sync(
)
project_repository = ProjectRepository(session_maker)

# Initialize watch service
watch_service = WatchService(
app_config=app_config,
project_repository=project_repository,
quiet=True,
)

# Get active projects
active_projects = await project_repository.get_active_projects()

Expand Down Expand Up @@ -196,14 +189,26 @@ async def initialize_file_sync(
except Exception as e: # pragma: no cover
logger.warning(f"Could not update migration status: {e}")

# Then start the watch service in the background
# Then start the watch service in the background with restart capability
logger.info("Starting watch service for all projects")
# run the watch service
try:
await watch_service.run()
logger.info("Watch service started")
except Exception as e: # pragma: no cover
logger.error(f"Error starting watch service: {e}")

while True:
try:
# Create a fresh watch service instance to pick up new projects
watch_service = WatchService(
app_config=app_config,
project_repository=project_repository,
quiet=True,
)

await watch_service.run()
logger.info("Watch service exited normally")
break # Normal exit

except Exception as e: # pragma: no cover
logger.error(f"Error in watch service: {e}")
# Don't restart on unexpected errors
break

return None

Expand Down
24 changes: 24 additions & 0 deletions src/basic_memory/services/project_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ async def add_project(self, name: str, path: str, set_default: bool = False) ->

logger.info(f"Project '{name}' added at {resolved_path}")

# Signal watch service restart since project list changed
self._signal_watch_service_restart()

async def remove_project(self, name: str) -> None:
"""Remove a project from configuration and database.

Expand All @@ -130,6 +133,9 @@ async def remove_project(self, name: str) -> None:

logger.info(f"Project '{name}' removed from configuration and database")

# Signal watch service restart since project list changed
self._signal_watch_service_restart()

async def set_default_project(self, name: str) -> None:
"""Set the default project in configuration and database.

Expand Down Expand Up @@ -283,6 +289,14 @@ async def synchronize_projects(self) -> None: # pragma: no cover

logger.info("Project synchronization complete")

# Signal watch service restart if projects changed
projects_changed = len(config_projects) != len(db_projects_by_permalink) or set(
config_projects.keys()
) != set(db_projects_by_permalink.keys())

if projects_changed:
self._signal_watch_service_restart()

# Refresh MCP session to ensure it's in sync with current config
try:
from basic_memory.mcp.project_session import session
Expand All @@ -292,6 +306,16 @@ async def synchronize_projects(self) -> None: # pragma: no cover
# MCP components might not be available in all contexts
logger.debug("MCP session not available, skipping session refresh")

def _signal_watch_service_restart(self) -> None:
"""Signal the watch service to restart by creating a restart signal file."""
restart_signal_path = Path.home() / ".basic-memory" / "restart-watch-service"
try:
restart_signal_path.parent.mkdir(parents=True, exist_ok=True)
restart_signal_path.write_text(str(datetime.now()))
logger.info("Signaled watch service restart")
except Exception as e:
logger.warning(f"Failed to signal watch service restart: {e}")

async def update_project( # pragma: no cover
self, name: str, updated_path: Optional[str] = None, is_active: Optional[bool] = None
) -> None:
Expand Down
20 changes: 20 additions & 0 deletions src/basic_memory/sync/watch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(
# quiet mode for mcp so it doesn't mess up stdout
self.console = Console(quiet=quiet)

# Restart signal file path
self.restart_signal_path = Path.home() / ".basic-memory" / "restart-watch-service"

async def run(self): # pragma: no cover
"""Watch for file changes and sync them"""

Expand All @@ -109,6 +112,11 @@ async def run(self): # pragma: no cover
watch_filter=self.filter_changes,
recursive=True,
):
# Check for restart signal
if await self.check_restart_signal():
logger.info("Restart signal detected, exiting watch service...")
return

# group changes by project
project_changes = defaultdict(list)
for change, path in changes:
Expand Down Expand Up @@ -161,6 +169,18 @@ def filter_changes(self, change: Change, path: str) -> bool: # pragma: no cover

return True

async def check_restart_signal(self) -> bool:
"""Check if a restart signal file exists and remove it if found."""
try:
if self.restart_signal_path.exists():
# Remove the signal file
self.restart_signal_path.unlink()
logger.info("Found and removed restart signal file")
return True
except Exception as e:
logger.warning(f"Error checking restart signal: {e}")
return False

async def write_status(self):
"""Write current state to status file"""
self.status_path.write_text(WatchServiceState.model_dump_json(self.state, indent=2))
Expand Down
93 changes: 93 additions & 0 deletions tests/services/test_project_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,96 @@ async def test_synchronize_projects_handles_case_sensitivity_bug(
db_project = await project_service.repository.get_by_name(name)
if db_project:
await project_service.repository.delete(db_project.id)


@pytest.mark.asyncio
async def test_signal_watch_service_restart(project_service: ProjectService, tmp_path):
"""Test that _signal_watch_service_restart creates the correct signal file."""
# Call the signal method
project_service._signal_watch_service_restart()

# Check that the signal file was created
restart_signal_path = tmp_path.parent / ".basic-memory" / "restart-watch-service"
# Since it uses Path.home(), we need to check the actual location
from pathlib import Path

actual_signal_path = Path.home() / ".basic-memory" / "restart-watch-service"

assert actual_signal_path.exists()

# Verify it contains a timestamp
content = actual_signal_path.read_text()
assert content # Should have some content (timestamp)

# Clean up
actual_signal_path.unlink()


@pytest.mark.asyncio
async def test_add_project_signals_restart(project_service: ProjectService, tmp_path):
"""Test that adding a project signals watch service restart."""
test_project_name = f"test-signal-restart-{os.urandom(4).hex()}"
test_project_path = str(tmp_path / "test-signal-restart")

# Make sure the test directory exists
os.makedirs(test_project_path, exist_ok=True)

# Remove any existing signal file
from pathlib import Path

signal_path = Path.home() / ".basic-memory" / "restart-watch-service"
if signal_path.exists():
signal_path.unlink()

try:
# Add the project - this should create the signal file
await project_service.add_project(test_project_name, test_project_path)

# Verify signal file was created
assert signal_path.exists()

finally:
# Clean up
if signal_path.exists():
signal_path.unlink()
if test_project_name in project_service.projects:
await project_service.remove_project(test_project_name)


@pytest.mark.asyncio
async def test_remove_project_signals_restart(project_service: ProjectService, tmp_path):
"""Test that removing a project signals watch service restart."""
test_project_name = f"test-remove-signal-{os.urandom(4).hex()}"
test_project_path = str(tmp_path / "test-remove-signal")

# Make sure the test directory exists
os.makedirs(test_project_path, exist_ok=True)

from pathlib import Path

signal_path = Path.home() / ".basic-memory" / "restart-watch-service"

try:
# Add the project first
await project_service.add_project(test_project_name, test_project_path)

# Remove any existing signal file (from the add)
if signal_path.exists():
signal_path.unlink()

# Remove the project - this should create the signal file
await project_service.remove_project(test_project_name)

# Verify signal file was created
assert signal_path.exists()

finally:
# Clean up
if signal_path.exists():
signal_path.unlink()
# Project should already be removed, but double-check
if test_project_name in project_service.projects:
try:
await project_service.remove_project(test_project_name)
except:
pass
60 changes: 60 additions & 0 deletions tests/sync/test_watch_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,63 @@ def test_is_project_path(watch_service, tmp_path):

# Test the project path itself
assert watch_service.is_project_path(project, project_path) is False


@pytest.mark.asyncio
async def test_check_restart_signal_no_signal(app_config, project_repository):
"""Test check_restart_signal when no signal file exists."""
watch_service = WatchService(app_config, project_repository)

# Ensure no signal file exists
if watch_service.restart_signal_path.exists():
watch_service.restart_signal_path.unlink()

# Should return False when no signal file exists
result = await watch_service.check_restart_signal()
assert result is False


@pytest.mark.asyncio
async def test_check_restart_signal_with_signal(app_config, project_repository):
"""Test check_restart_signal when signal file exists."""
watch_service = WatchService(app_config, project_repository)

try:
# Create signal file
watch_service.restart_signal_path.parent.mkdir(parents=True, exist_ok=True)
watch_service.restart_signal_path.write_text("test signal")

# Should return True and remove the signal file
result = await watch_service.check_restart_signal()
assert result is True
assert not watch_service.restart_signal_path.exists()

finally:
# Cleanup
if watch_service.restart_signal_path.exists():
watch_service.restart_signal_path.unlink()


@pytest.mark.asyncio
async def test_check_restart_signal_error_handling(app_config, project_repository):
"""Test check_restart_signal handles errors gracefully."""
watch_service = WatchService(app_config, project_repository)

# Create a directory where the signal file should be (to cause an error)
watch_service.restart_signal_path.parent.mkdir(parents=True, exist_ok=True)
if watch_service.restart_signal_path.exists():
watch_service.restart_signal_path.unlink()
watch_service.restart_signal_path.mkdir()

try:
# Should handle the error and return False
result = await watch_service.check_restart_signal()
assert result is False

finally:
# Cleanup
if watch_service.restart_signal_path.exists():
if watch_service.restart_signal_path.is_dir():
watch_service.restart_signal_path.rmdir()
else:
watch_service.restart_signal_path.unlink()
Loading