In [1]:
from dataclasses import dataclass, field

from pycozo.builder import *

In [2]:
class _Query:
    def __lshift__(self, other):
        match other:
            case int() | float() | str() | bool() | None:
                return Const(other)
            case list() if all(
                type(item) in (ConstantRule, InlineRule, FixedRule) for item in other
            ):
                return InputProgram(rules=list(other))
            case tuple() if all(
                type(item) in (ConstantRule, InlineRule, FixedRule) for item in other
            ):
                return InputProgram(rules=list(other))
            case [*items]:
                return InputList([self << item for item in items])
            case _:
                raise Exception("Invalid input type")


Query = _Query()

In [3]:
from functools import wraps
import inspect


def query(spec_fn):
    # Get the signature of spec_fn
    sig = inspect.signature(spec_fn)

    # Get the keyword only arguments from signature
    var_names = [p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY]

    # Create Var objects for each keyword only argument
    vars = {name: Var(name) for name in var_names}

    # Create a wrapper function that takes a single input and returns a query
    @wraps(spec_fn)
    def wrapper(input):
        return Query << spec_fn(input, **vars)

    # Create a wrapper function that takes no input and returns a query
    @wraps(spec_fn)
    def wrapper_no_input():
        return Query << spec_fn(**vars)

    # Handle which variant of the wrapper function to return
    ...

    return wrapper_no_input

In [4]:
@dataclass
class RuleThunk:
    rule_name: str
    vars: list[str] = field(default_factory=list)

    def __call__(self):
        return RuleHead(self.rule_name, self.vars)

    def __le__(self, input_list: InputList | list[list[Any]]):
        match input_list:
            case list() if all(type(item) == list for item in input_list):
                return ConstantRule(self(), Query << input_list)
            case InputList():
                return ConstantRule(self(), input_list)


class _Q:
    def __getitem__(self, items: str | Var | list[str | Var] | tuple[str | Var]):
        vars: list[str] = []
        items = list(items) if isinstance(items, tuple) else items
        items = [items] if not isinstance(items, list) else items

        for var in items:
            if isinstance(var, str):
                vars.append(var)
            else:
                vars.append(var.name)

        return RuleThunk(
            rule_name="?",
            vars=vars,
        )


Q = _Q()

In [7]:
@query
def example(*, artist, last_name):
    return (
        # ?[artist, last_name] <- [["lady", "gaga"], ["michael", "jackson"]]
        Q[artist, last_name] <= [["lady", "gaga"], ["michael", "jackson"]],
    )

In [6]:
str(example())

'?[artist, last_name] <-\n    [["lady", "gaga"], ["michael", "jackson"]]'