Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ dev = [
"pyright>=1.1.408",
"pytest-testmon>=2.2.0",
"ty>=0.0.18",
"cst-lsp>=0.1.3",
"libcst>=1.8.6",
]

[tool.hatch.version]
Expand All @@ -115,6 +117,7 @@ ignore = ["test/"]
defineConstant = { DEBUG = true }
reportMissingImports = "error"
reportMissingTypeStubs = false
reportUnusedImport = "none"
pythonVersion = "3.12"


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Rename entity_type column to note_type

Revision ID: j3d4e5f6g7h8
Revises: i2c3d4e5f6g7
Create Date: 2026-02-22 12:00:00.000000

"""

from typing import Sequence, Union

from alembic import op
from sqlalchemy import text

# revision identifiers, used by Alembic.
revision: str = "j3d4e5f6g7h8"
down_revision: Union[str, None] = "i2c3d4e5f6g7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def table_exists(connection, table_name: str) -> bool:
"""Check if a table exists (idempotent migration support)."""
if connection.dialect.name == "postgresql":
result = connection.execute(
text(
"SELECT 1 FROM information_schema.tables "
"WHERE table_name = :table_name"
),
{"table_name": table_name},
)
return result.fetchone() is not None
# SQLite
result = connection.execute(
text("SELECT 1 FROM sqlite_master WHERE type='table' AND name = :table_name"),
{"table_name": table_name},
)
return result.fetchone() is not None


def index_exists(connection, index_name: str) -> bool:
"""Check if an index exists (idempotent migration support)."""
if connection.dialect.name == "postgresql":
result = connection.execute(
text("SELECT 1 FROM pg_indexes WHERE indexname = :index_name"),
{"index_name": index_name},
)
return result.fetchone() is not None
# SQLite
result = connection.execute(
text("SELECT 1 FROM sqlite_master WHERE type='index' AND name = :index_name"),
{"index_name": index_name},
)
return result.fetchone() is not None


def column_exists(connection, table: str, column: str) -> bool:
"""Check if a column exists in a table (idempotent migration support)."""
if connection.dialect.name == "postgresql":
result = connection.execute(
text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :table AND column_name = :column"
),
{"table": table, "column": column},
)
return result.fetchone() is not None
# SQLite
result = connection.execute(text(f"PRAGMA table_info({table})"))
columns = [row[1] for row in result]
return column in columns


def upgrade() -> None:
"""Rename entity_type → note_type on the entity table."""
connection = op.get_bind()
dialect = connection.dialect.name

# Skip if already migrated (idempotent)
if column_exists(connection, "entity", "note_type"):
return

if dialect == "postgresql":
# Postgres supports direct column rename
op.execute("ALTER TABLE entity RENAME COLUMN entity_type TO note_type")

# Recreate the index with new name
op.execute("DROP INDEX IF EXISTS ix_entity_type")
op.execute("CREATE INDEX ix_note_type ON entity (note_type)")
else:
# SQLite 3.25.0+ supports ALTER TABLE RENAME COLUMN directly.
# Avoids batch_alter_table which fails on tables with generated columns
# (duplicate column name error when recreating the table).
op.execute("ALTER TABLE entity RENAME COLUMN entity_type TO note_type")

# Recreate the index with new name
if index_exists(connection, "ix_entity_type"):
op.drop_index("ix_entity_type", table_name="entity")
op.create_index("ix_note_type", "entity", ["note_type"])

# Update search index metadata: rename entity_type → note_type in JSON
# This updates the stored metadata so search results use the new field name
# Guard: search_index may not exist on a fresh DB (created by an earlier migration)
if not table_exists(connection, "search_index"):
return

if dialect == "postgresql":
op.execute(
text("""
UPDATE search_index
SET metadata = metadata - 'entity_type' || jsonb_build_object('note_type', metadata->'entity_type')
WHERE metadata ? 'entity_type'
""")
)
else:
op.execute(
text("""
UPDATE search_index
SET metadata = json_set(
json_remove(metadata, '$.entity_type'),
'$.note_type',
json_extract(metadata, '$.entity_type')
)
WHERE json_extract(metadata, '$.entity_type') IS NOT NULL
""")
)


def downgrade() -> None:
"""Rename note_type → entity_type on the entity table."""
connection = op.get_bind()
dialect = connection.dialect.name

if dialect == "postgresql":
op.execute("ALTER TABLE entity RENAME COLUMN note_type TO entity_type")
op.execute("DROP INDEX IF EXISTS ix_note_type")
op.execute("CREATE INDEX ix_entity_type ON entity (entity_type)")
else:
op.execute("ALTER TABLE entity RENAME COLUMN note_type TO entity_type")

if index_exists(connection, "ix_note_type"):
op.drop_index("ix_note_type", table_name="entity")
op.create_index("ix_entity_type", "entity", ["entity_type"])

# Revert search index metadata
if not table_exists(connection, "search_index"):
return

if dialect == "postgresql":
op.execute(
text("""
UPDATE search_index
SET metadata = metadata - 'note_type' || jsonb_build_object('entity_type', metadata->'note_type')
WHERE metadata ? 'note_type'
""")
)
else:
op.execute(
text("""
UPDATE search_index
SET metadata = json_set(
json_remove(metadata, '$.note_type'),
'$.entity_type',
json_extract(metadata, '$.note_type')
)
WHERE json_extract(metadata, '$.note_type') IS NOT NULL
""")
)
2 changes: 1 addition & 1 deletion src/basic_memory/api/v2/routers/knowledge_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ async def create_entity(
Created entity with generated external_id (UUID) and file content
"""
logger.info(
"API v2 request", endpoint="create_entity", entity_type=data.entity_type, title=data.title
"API v2 request", endpoint="create_entity", note_type=data.note_type, title=data.title
)

if fast:
Expand Down
8 changes: 4 additions & 4 deletions src/basic_memory/api/v2/routers/resource_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,14 @@ async def create_resource(
# Determine file details
file_name = PathLib(data.file_path).name
content_type = file_service.content_type(data.file_path)
entity_type = "canvas" if data.file_path.endswith(".canvas") else "file"
note_type = "canvas" if data.file_path.endswith(".canvas") else "file"

# Create a new entity model
# Explicitly set external_id to ensure NOT NULL constraint is satisfied (fixes #512)
entity = EntityModel(
external_id=str(uuid.uuid4()),
title=file_name,
entity_type=entity_type,
note_type=note_type,
content_type=content_type,
file_path=data.file_path,
checksum=checksum,
Expand Down Expand Up @@ -253,14 +253,14 @@ async def update_resource(
# Determine file details
file_name = PathLib(target_file_path).name
content_type = file_service.content_type(target_file_path)
entity_type = "canvas" if target_file_path.endswith(".canvas") else "file"
note_type = "canvas" if target_file_path.endswith(".canvas") else "file"

# Update entity using internal ID
updated_entity = await entity_repository.update(
entity.id,
{
"title": file_name,
"entity_type": entity_type,
"note_type": note_type,
"content_type": content_type,
"file_path": target_file_path,
"checksum": checksum,
Expand Down
52 changes: 26 additions & 26 deletions src/basic_memory/api/v2/routers/schema_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _entity_relations(entity: Entity) -> list[RelationData]:
RelationData(
relation_type=rel.relation_type,
target_name=rel.to_name,
target_entity_type=rel.to_entity.entity_type if rel.to_entity else None,
target_note_type=rel.to_entity.note_type if rel.to_entity else None,
)
for rel in entity.outgoing_relations
]
Expand All @@ -69,8 +69,8 @@ def _entity_to_note_data(entity: Entity) -> NoteData:
def _entity_frontmatter(entity: Entity) -> dict:
"""Build a frontmatter dict from an entity for schema resolution."""
frontmatter = dict(entity.entity_metadata) if entity.entity_metadata else {}
if entity.entity_type:
frontmatter.setdefault("type", entity.entity_type)
if entity.note_type:
frontmatter.setdefault("type", entity.note_type)
return frontmatter


Expand All @@ -81,7 +81,7 @@ def _entity_frontmatter(entity: Entity) -> dict:
async def validate_schema(
entity_repository: EntityRepositoryV2ExternalDep,
project_id: str = Path(..., description="Project external UUID"),
entity_type: str | None = Query(None, description="Entity type to validate"),
note_type: str | None = Query(None, description="Note type to validate"),
identifier: str | None = Query(None, description="Specific note identifier"),
):
"""Validate notes against their resolved schemas.
Expand All @@ -95,7 +95,7 @@ async def validate_schema(
if identifier:
entity = await entity_repository.get_by_permalink(identifier)
if not entity:
return ValidationReport(entity_type=entity_type, total_notes=0, results=[])
return ValidationReport(note_type=note_type, total_notes=0, results=[])

frontmatter = _entity_frontmatter(entity)
schema_ref = frontmatter.get("schema")
Expand All @@ -120,16 +120,16 @@ async def search_fn(query: str) -> list[dict]:
results.append(_to_note_validation_response(result))

return ValidationReport(
entity_type=entity_type or entity.entity_type,
note_type=note_type or entity.note_type,
total_notes=1,
valid_count=1 if (results and results[0].passed) else 0,
warning_count=sum(len(r.warnings) for r in results),
error_count=sum(len(r.errors) for r in results),
results=results,
)

# --- Batch validation by entity type ---
entities = await _find_by_entity_type(entity_repository, entity_type) if entity_type else []
# --- Batch validation by note type ---
entities = await _find_by_note_type(entity_repository, note_type) if note_type else []

for entity in entities:
frontmatter = _entity_frontmatter(entity)
Expand All @@ -156,7 +156,7 @@ async def search_fn(query: str) -> list[dict]:

valid = sum(1 for r in results if r.passed)
return ValidationReport(
entity_type=entity_type,
note_type=note_type,
total_notes=len(results),
total_entities=len(entities),
valid_count=valid,
Expand All @@ -173,21 +173,21 @@ async def search_fn(query: str) -> list[dict]:
async def infer_schema_endpoint(
entity_repository: EntityRepositoryV2ExternalDep,
project_id: str = Path(..., description="Project external UUID"),
entity_type: str = Query(..., description="Entity type to analyze"),
note_type: str = Query(..., description="Note type to analyze"),
threshold: float = Query(0.25, description="Minimum frequency for optional fields"),
):
"""Infer a schema from existing notes of a given type.

Examines observation categories and relation types across all notes
of the given type. Returns frequency analysis and suggested Picoschema.
"""
entities = await _find_by_entity_type(entity_repository, entity_type)
entities = await _find_by_note_type(entity_repository, note_type)
notes_data = [_entity_to_note_data(entity) for entity in entities]

result = infer_schema(entity_type, notes_data, optional_threshold=threshold)
result = infer_schema(note_type, notes_data, optional_threshold=threshold)

return InferenceReport(
entity_type=result.entity_type,
note_type=result.note_type,
notes_analyzed=result.notes_analyzed,
field_frequencies=[
FieldFrequencyResponse(
Expand All @@ -212,10 +212,10 @@ async def infer_schema_endpoint(
# --- Drift Detection ---


@router.get("/schema/diff/{entity_type}", response_model=DriftReport)
@router.get("/schema/diff/{note_type}", response_model=DriftReport)
async def diff_schema_endpoint(
entity_repository: EntityRepositoryV2ExternalDep,
entity_type: str = Path(..., description="Entity type to check for drift"),
note_type: str = Path(..., description="Note type to check for drift"),
project_id: str = Path(..., description="Project external UUID"),
):
"""Show drift between a schema definition and actual note usage.
Expand All @@ -229,21 +229,21 @@ async def search_fn(query: str) -> list[dict]:
entities = await _find_schema_entities(entity_repository, query)
return [_entity_frontmatter(e) for e in entities]

# Resolve schema by entity type
schema_frontmatter = {"type": entity_type}
# Resolve schema by note type
schema_frontmatter = {"type": note_type}
schema_def = await resolve_schema(schema_frontmatter, search_fn)

if not schema_def:
return DriftReport(entity_type=entity_type, schema_found=False)
return DriftReport(note_type=note_type, schema_found=False)

# Collect all notes of this type
entities = await _find_by_entity_type(entity_repository, entity_type)
entities = await _find_by_note_type(entity_repository, note_type)
notes_data = [_entity_to_note_data(entity) for entity in entities]

result = diff_schema(schema_def, notes_data)

return DriftReport(
entity_type=entity_type,
note_type=note_type,
new_fields=[
DriftFieldResponse(
name=f.name,
Expand Down Expand Up @@ -271,19 +271,19 @@ async def search_fn(query: str) -> list[dict]:
# --- Helpers ---


async def _find_by_entity_type(
async def _find_by_note_type(
entity_repository: EntityRepositoryV2ExternalDep,
entity_type: str,
note_type: str,
) -> list[Entity]:
"""Find all entities of a given type using the repository's select pattern."""
query = entity_repository.select().where(Entity.entity_type == entity_type)
query = entity_repository.select().where(Entity.note_type == note_type)
result = await entity_repository.execute_query(query)
return list(result.scalars().all())


async def _find_schema_entities(
entity_repository: EntityRepositoryV2ExternalDep,
target_entity_type: str,
target_note_type: str,
*,
allow_reference_match: bool = False,
) -> list[Entity]:
Expand All @@ -295,11 +295,11 @@ async def _find_schema_entities(
2) Only when allow_reference_match=True and no entity match was found, try
exact reference matching by title/permalink (explicit schema references)
"""
query = entity_repository.select().where(Entity.entity_type == "schema")
query = entity_repository.select().where(Entity.note_type == "schema")
result = await entity_repository.execute_query(query)
entities = list(result.scalars().all())

normalized_target = generate_permalink(target_entity_type)
normalized_target = generate_permalink(target_note_type)

entity_matches = [
e
Expand Down
Loading
Loading