Skip to content

Commit

Permalink
Enable custom context lookups in index
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Mar 21, 2024
1 parent d68b24f commit 5f944ca
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Unreleased

- Enable custom context lookups in index
- Adds `utils.file_hash`
- Adds `Execution().deferred()` to prevent automatic dispatch
- Respect CLI context order
Expand Down
47 changes: 23 additions & 24 deletions src/machinable/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def __init__(self, version: VersionType = None):
self.__mixins__ = {}
self._config: Optional[DictConfig] = None
self._predicate: Optional[DictConfig] = None
self._context: Optional[DictConfig] = None
self._cache = {}
self._kwargs = {}

Expand Down Expand Up @@ -461,6 +462,15 @@ def predicate(self) -> DictConfig:

return self._predicate

@property
def context(self) -> Optional[DictConfig]:
if self._context is None:
if self.__model__.context is None:
return None
self._context = OmegaConf.create(self.__model__.context)

return self._context

@property
def config(self) -> DictConfig:
"""Element configuration"""
Expand Down Expand Up @@ -580,7 +590,7 @@ def model(cls, element: Optional[Any] = None) -> schema.Element:
return getattr(schema, cls.kind)

def matches(self, context: Optional[Dict] = None) -> bool:
if context is None:
if context is None or self.context is None:
# full constraint, match none
return False

Expand All @@ -589,31 +599,20 @@ def matches(self, context: Optional[Dict] = None) -> bool:
return True

for field, value in context.items():
if field in (
"uuid",
"kind",
"module",
):
if not equaljson(getattr(self, field), value):
if field == "predicate":
if hasattr(value, "items") and len(value) > 0:
for p, v in value.items():
if field not in self.context:
return False
if p not in self.context[field]:
return False
if not equaljson(self.context[field][p], v):
return False
else:
if field not in self.context:
return False
elif field == "config":
if not equaljson(
{
k: v
for k, v in self.config.items()
if k not in ["_default_", "_version_", "_update_"]
},
value,
):
if not equaljson(self.context[field], value):
return False
elif field == "predicate":
for p, v in value.items():
if p not in self.predicate:
return False
if not equaljson(self.predicate[p], v):
return False
else:
raise ValueError("Invalid context field: {field}")

return True

Expand Down
24 changes: 19 additions & 5 deletions src/machinable/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def interface_row_factory(cursor, row) -> schema.Interface:
config=json.loads(row[3]),
version=json.loads(row[6]),
predicate=json.loads(row[7]),
context=json.loads(row[11] or "null"),
lineage=json.loads(row[8]),
timestamp=int(row[9]),
**json.loads(row[10]),
Expand Down Expand Up @@ -76,6 +77,12 @@ def migrate(db: sqlite3.Connection) -> None:
db.commit()
version += 1
if version == 1:
# updates
cur.execute("""ALTER TABLE 'index' ADD COLUMN 'context' json""")
cur.execute("PRAGMA user_version = 2;")
db.commit()
version += 1
if version == 2:
# future migrations
...

Expand Down Expand Up @@ -163,10 +170,11 @@ def commit(self, model: schema.Interface) -> bool:
config_default,
version,
predicate,
context,
lineage,
'timestamp',
extra
) VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)""",
(
model.uuid,
model.kind,
Expand All @@ -176,6 +184,7 @@ def commit(self, model: schema.Interface) -> bool:
_jn(default),
_jn(version),
_jn(model.predicate),
_jn(model.context),
_jn(model.lineage),
model.timestamp,
_jn(model.extra()),
Expand Down Expand Up @@ -264,11 +273,16 @@ def find_by_context(self, context: Dict) -> List[schema.Interface]:
equals = []
for field, value in context.items():
if field == "predicate":
for p, v in value.items():
keys.append(f"json_extract(predicate, '$.{p}')=?")
equals.append(v)
if hasattr(value, "items") and len(value) > 0:
# empty dict is a wildcard, so we only add
# condition if len(value) > 0
for p, v in value.items():
keys.append(
f"json_extract(context, '$.{field}.{p}')=?"
)
equals.append(v)
else:
keys.append(f"{field}=?")
keys.append(f"json_extract(context, '$.{field}')=?")
equals.append(value)

if len(keys) > 0:
Expand Down
2 changes: 1 addition & 1 deletion src/machinable/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def commit(self) -> Self:
if index.find_by_id(self.uuid) is not None:
return self

context = self.compute_context()
self.__model__.context = context = self.compute_context()
self.__model__.uuid = update_uuid_payload(self.__model__.uuid, context)

# ensure that configuration and predicate has been computed
Expand Down
1 change: 1 addition & 0 deletions src/machinable/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Element(BaseModel):
version: List[Union[str, Dict]] = []
config: Optional[Dict] = None
predicate: Optional[Dict] = None
context: Optional[Dict] = None
lineage: Tuple[str, ...] = ()

@property
Expand Down
7 changes: 5 additions & 2 deletions tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def _is_migrated(db):
return db.cursor().execute("PRAGMA user_version;").fetchone()[0] == 1
return db.cursor().execute("PRAGMA user_version;").fetchone()[0] == 2


def _matches(q, v):
Expand Down Expand Up @@ -51,6 +51,7 @@ def test_index_commit(tmp_path):
"[]",
v.timestamp,
"{}",
"null",
)
assert i.commit(v) is True
with index.db(i.config.database) as db:
Expand Down Expand Up @@ -98,7 +99,9 @@ def test_index_find(tmp_path):

def test_index_find_by_context(tmp_path):
i = index.Index({"database": str(tmp_path / "index.sqlite")})
v = schema.Interface(module="machinable", predicate={"a": 0, "b": 0})
v = schema.Interface(
context=dict(module="machinable", predicate={"a": 0, "b": 0})
)
i.commit(v)
assert len(i.find_by_context(dict(module="machinable"))) == 1
assert (
Expand Down

0 comments on commit 5f944ca

Please sign in to comment.