# microKanren
- Lightly edited version of https://github.com/jtauber/pykanren/blob/master/microkanren.py
- Major modifications to jtauber's implementation:
    - `occurs_check` is used in `extend_substitution` now to avoid illegal states
    - `take_all` actually works with a stream now
    - `take_all_generator` returns a generator instead of consing a list together
    - `neq` is a disequality goal constructor

**Data**

1. Substitution = [(Var, Term)].
2. Goal is a function (State -> Stream State).
3. Stream is a lazily generated list of states.
5. State
    - Type Signature: State = (Substitution, Counter)
    - Description: Represents the current state of the logic program, consisting of a substitution list and a counter for generating fresh variables.
    - N.B.: EMPTY_STATE = ({}, 0)
6. Variables (Var)
    - Type Signature: Var = Int
    - Description: Represents logic variables using integers. Var is a class that takes an integer and returns a variable represented by that integer.
8. Terms
    - Type Signature: Term = Var | Value | (Term, Term)
    - Description: A term can be a variable, a value, or a pair of terms.

**Main Functions**

1. walk
    - Type Signature: walk :: Var -> Substitution -> Term
    - Description: Finds the current value of a variable in the substitution list, dereferencing as needed.
2. occurs_check
    - Type Signature: occurs_check :: Var -> Term -> Substitution -> Bool
    - Description: Checks for circular references in substitutions to prevent infinite loops.
3. extend_substitution
    - Type Signature: extend_substitution :: Var -> Term -> Substitution -> Maybe Substitution
    - Description: Extends the substitution list with a new variable-term binding, if no circular reference is detected.
6. unify
    - Type Signature: unify :: Term -> Term -> Substitution -> Maybe Substitution
    - Description: Attempts to unify two terms under a given substitution.
5. mzero
    - Type Signature: mzero :: Stream State
    - Description: Represents an empty stream of states. It's equivalent to an empty list in Haskell and acts as the identity element for the mplus operation.
4. unit
    - Type Signature: unit :: State -> Stream State
    - Description: A function that takes a single state and returns a stream containing only that state. It's akin to the return function in monads, wrapping a value into the monadic context. Lifts the state into the stream monad.
5. eq
    - Type Signature: eq :: Term -> Term -> Goal
    - Description: A goal constructor that takes two terms and returns a goal that attempts to unify u and v, applying unit to the new state of unification (new substitution with old counter), otherwise returning mzero if the unification failed.
5. neq
    - Type Signature: neq :: Term -> Term -> Goal
    - Description: A goal constructor that takes two terms and returns a goal that attempts to unify u and v, applying unit to the new state of unification (old substitution with old counter since the point is that they don't unify), otherwise returning mzero if the unification succeeded.
7. call_fresh
    - Type Signature: call_fresh :: (Var -> Goal) -> Goal
    - Description: Introduces a fresh logic variable and constructs a goal with it.
8. disj
    - Type Signature: disj :: Goal -> Goal -> Goal
    - Description: Represents logical disjunction (OR) of two goals.
9. conj
    - Type Signature: conj :: Goal -> Goal -> Goal
    - Description: Represents logical conjunction (AND) of two goals.
10. mplus
    - Type Signature: mplus :: Stream State -> Stream State -> Stream State
    - Description: Merges two streams of states, used for implementing disjunction.
11. bind
    - Type Signature: bind :: Stream State -> (State -> Stream State) -> Stream State
    - Description: Binds a stream of states to a goal, used for implementing conjunction.
12. run
    - Type Signature: run :: ([Var] -> Goal) -> Maybe Int -> [Substitution]
    - Description: Combines run* and run from original microKanren. Executes a logic program, returning a specified number of solutions or all solutions, depending on if Maybe Int == Int or Nothing.

**Connections to Monads**

The bind and unit (return) functions form a monad for the stream of states, which is key to structuring the execution of logic programs.

```
class Monad m where
    (>>=)  :: m a -> (a -> m b) -> m b
    return :: a -> m a

instance Monad (Stream State) where
    (>>=)  :: Stream State a -> (a -> Stream State b) -> Stream State b
    return :: a -> Stream State a
```

Note that `return` is `unit` in the microKanren code, taking `unit` from category theory (Haskell calls it `return`)

```
class Monad m => MonadPlus m where
    mzero :: m a
    mplus :: m a -> m a -> m a

instance MonadPlus (Stream State) where
    mzero :: Stream State a
    mplus :: Stream State a -> Stream State a -> Stream State a
```

In [179]:
# LISP-like cons structures

def cons(a, b):
    return (a, b)

def is_cons(c):
    return isinstance(c, tuple) and len(c) == 2

def car(c):
    return c[0]

def cdr(c):
    return c[1]

# helper function for creating nested cons out of lists

def l(*lst):
    if len(lst) == 1:
        return cons(lst[0], ()) # base case is a cons with '()
    else:
        return cons(lst[0], l(*lst[1:]))

class Var:
    """
    (var c) becomes Var(c)
    (var? x) becomes isinstance(x, Var)
    (var=? x_1 x_2) becomes x_1 == x_2
    """
    def __init__(self, index):
        self.index = index

    def __eq__(self, other):
        return isinstance(other, Var) and self.index == other.index

    def __hash__(self):
        return hash(self.index)

    def __repr__(self):
        return "<Var %s>" % self.index

def is_var(v):
    return isinstance(v, Var)

def walk(u, s):
    """
    Walking a dictionary for an assignment of the term u to a value, if u is a variable.
    """
    if is_var(u):
        a = s.get(u) # a == empty dict if u not in s
        if a:
            return walk(a, s)
        else:
            return u
    else:
        return u

def occurs_check(var, value, s):
    """
    Check if variable v occurs within value considering substitutions s.
    """
    value = walk(value, s)  # Walk value through the substitutions to get its assignment
    if var == value:
        return True
    else:
        return is_cons(value) and (occurs_check(var, car(value), s) or
                                   occurs_check(var, cdr(value), s))

def extend_substitution(x, v, s):
    """
    Extend substitution s by binding variable x to value v,
    with an occurs check to prevent circular references.
    """
    if occurs_check(x, v, s):
        return False  # Fail if x occurs in v, preventing circular references
    s = s.copy()
    s[x] = v
    return s

def unify(u, v, s):
    u = walk(u, s)
    v = walk(v, s)
    if is_var(u) and is_var(v) and u == v:
        return s
    elif is_var(u):
        return extend_substitution(u, v, s)
    elif is_var(v):
        return extend_substitution(v, u, s)
    elif is_cons(u) and is_cons(v):
        s = unify(car(u), car(v), s)
        t = unify(cdr(u), cdr(v), s)
        return t if t is not False else s
    else:
        return u == v and s

# monad
mzero = ()
def unit(state):
    """
    State = (Substitution, Counter)
    where Substitution = [(Var, Term)] # a dict of { var: term }
    """
    return cons(state, mzero)

def eq(u, v):
    def goal(state):
        new_state = unify(u, v, car(state))
        if new_state is not False:
            return unit(cons(new_state, cdr(state)))
        else:
            return mzero

    return goal

def neq(u, v):
    def goal(state):
        # Attempt to unify the terms u and v with the current state's substitutions.
        new_state = unify(u, v, car(state))
        if new_state is False:
            # If unify fails, it means u and v cannot be made equal under the current substitutions.
            # Hence, neq should succeed, continuing with the original state.
            return unit(state)
        else:
            # If unify succeeds, it means u and v can be made equal, which contradicts the neq goal.
            # Hence, neq should fail, represented by mzero.
            return mzero
    return goal

def call_fresh(f):
    def goal(state):
        c = cdr(state)
        return f(Var(c))(cons(car(state), c + 1))

    return goal

EMPTY_STATE = cons({}, 0)

def mplus(stream1, stream2):
    if stream1 == mzero:
        return stream2
    elif callable(stream1):
        return lambda: mplus(stream2, stream1())
    else:
        return cons(car(stream1), mplus(cdr(stream1), stream2))

def bind(stream, goal):
    if stream == ():
        return ()
    elif callable(stream):
        return lambda: bind(stream(), goal)
    else:
        return mplus(goal(car(stream)), bind(cdr(stream), goal))

def disj(goal_1, goal_2):
    def goal(state):
        return mplus(goal_1(state), goal_2(state))
    return goal

def conj(goal_1, goal_2):
    def goal(state):
        return bind(goal_1(state), goal_2)
    return goal

if __name__ == "__main__":
    s1 = {Var(0): 5, Var(1): True}
    s2 = {Var(1): 5, Var(0): Var(1)}

    assert walk(Var(0), s1) == 5
    assert walk(Var(0), s2) == 5

    assert unify(None, 1, {}) is False
    assert unify(None, Var(0), {}) == {Var(0): None}
    assert unify(None, [1, Var(0)], {}) is False
    assert unify(1, None, {}) is False
    assert unify(1, 1, {}) == {}
    assert unify(1, 2, {}) is False
    assert unify(1, Var(0), {}) == {Var(0): 1}
    assert unify(1, [], {}) is False
    assert unify(Var(0), 1, {}) == {Var(0): 1}
    assert unify(Var(0), Var(1), {}) == {Var(0): Var(1)}
    assert unify(Var(0), [], {}) == {Var(0): []}
    assert unify(Var(0), l(1, 2, 3), {}) == {Var(0): l(1, 2, 3)}
    assert unify(l(1, 2, 3), l(1, 2, 3), {}) == {}
    assert unify(l(1, 2, 3), l(3, 2, 1), {}) is False
    assert unify(l(Var(0), Var(1)), l(1, 2), {}) == {Var(0): 1, Var(1): 2}
    assert unify(l(l(1, 2), l(3, 4)), l(l(1, 2), l(3, 4)), {}) == {}
    assert unify(l(l(Var(0), 2), l(3, 4)), l(l(1, 2), l(3, 4)), {}) == {Var(0): 1}

    assert unify((1, (2, (3, 4))), (1, (2, Var(0))), {}) == {Var(0): (3, 4)}

    assert unify({}, {}, {}) == {}

    assert eq(1, 1)(EMPTY_STATE) == (EMPTY_STATE, ())
    assert eq(1, 2)(EMPTY_STATE) == ()

    # take function

    def pull(stream):
        if callable(stream):
            return pull(stream())
        else:
            return stream
    
    def take(n, stream):
        if n == 0: return ()
        else:
            stream = pull(stream)
            if stream == mzero: return ()
            else:
                return cons(car(stream), take(n - 1, cdr(stream)))

    assert take(0, l(1, 2, 3)) == ()
    assert take(1, l(1, 2, 3)) == (1, ())
    assert take(2, l(1, 2, 3)) == (1, (2, ()))
    assert take(3, l(1, 2, 3)) == (1, (2, (3, ())))

    def take_all(stream):
        stream = pull(stream)
        if stream == mzero: return ()
        else:
            return cons(car(stream), take_all(cdr(stream)))

    def take_all_generator(stream):
        stream = pull(stream)
        if stream == mzero:  # Use mzero to check for the end of the stream
            return
        else:
            yield car(stream)
            # iteratively return every element in the generator returned from
            # the recursive take_all_generator call
            yield from take_all_generator(cdr(stream))
            
    # microKanren test programs

    a_and_b = call_fresh(lambda a:
                         call_fresh(lambda b:
                                    conj(
                                        eq(a, 7),
                                        disj(eq(b, 5), eq(b, 6)))))

    def fives(x):
        return disj(
            eq(x, 5),
            lambda a_c: lambda: fives(x)(a_c)
        )
    
    def appendo(l, s, out):
        return call_fresh(lambda a: call_fresh(lambda d: call_fresh(lambda res:
            disj(
                # Base case: if l is empty, then out must be equal to s
                conj(eq((), l), eq(s, out)),
                # Recursive case
                conj(
                    eq(cons(a, d), l),
                    conj(
                        eq(cons(a, res), out),
                        lambda state:
                            lambda:
                                appendo(d, s, res)(state)))))))

    # run* and run-n
    def run(query, n=None):
        def run_generator():
            gen = take_all_generator(query(EMPTY_STATE))
            count = 0
            for item in gen:
                if n is not None and count >= n:
                    break
                yield item
                count += 1
            if n is not None and count < n:
                raise RuntimeError("Requested more solutions than are available.")

        return list(run_generator())
    
    def reify(s, v):
        """
        Recursively walk through the tree and build up the answer with literals 
        instead of variables, where possible.
        
        Args:
        s (dict): The substitution dictionary containing variable bindings.
        v: The current value to reify, which could be a variable, a literal, or a tuple.
        
        Returns:
        The reified value with variables replaced by their literals, where possible.
        """
        if is_var(v):
            # If v is a variable, look it up in the substitution dictionary
            value = walk(v, s)
            if is_var(value):
                # If the value is still a variable (i.e., not bound), return it as is
                return value
            else:
                # If the value is not a variable, it might need further reification
                return reify(s, value)
        elif isinstance(v, tuple) and len(v) == 2:
            # If v is a cons cell (a tuple of length 2), recursively reify its head and tail
            head, tail = v
            return (reify(s, head), reify(s, tail))
        else:
            # If v is a literal (including the empty tuple, ()), return it as is
            return v

    # microKanren tests

    # second-set t1
    assert car(call_fresh(lambda q: eq(q, 5))(EMPTY_STATE)) == ({Var(0): 5}, 1)

    # second-set t2
    assert cdr(call_fresh(lambda q: eq(q, 5))(EMPTY_STATE)) == mzero

    # second-set t3
    assert car(a_and_b(EMPTY_STATE)) == ({Var(0): 7, Var(1): 5}, 2)

    # run - take first element of a single-element list
    assert run(a_and_b, n=1)[0] == ({Var(0): 7, Var(1): 5}, 2)

    # second-set t3, take
    assert take(1, a_and_b(EMPTY_STATE)) == (({Var(0): 7, Var(1): 5}, 2), mzero)

    # second-set t4
    assert car(cdr(a_and_b(EMPTY_STATE))) == ({Var(0): 7, Var(1): 6}, 2)

    # second-set t5
    assert cdr(cdr(a_and_b(EMPTY_STATE))) == mzero

    # who cards
    assert take(1, call_fresh(lambda q: fives(q))(EMPTY_STATE)) == (({Var(0): 5}, 1), mzero)

    # take 2 a-and-b stream
    assert take(2, a_and_b(EMPTY_STATE)) == (({Var(0): 7, Var(1): 5}, 2), (({Var(0): 7, Var(1): 6}, 2), mzero))

    # take-all a-and-b stream
    assert take_all(a_and_b(EMPTY_STATE)) == (({Var(0): 7, Var(1): 5}, 2), (({Var(0): 7, Var(1): 6}, 2), mzero))

    # infinite stream
    for solution in take_all_generator(call_fresh(fives)(EMPTY_STATE)):
        assert solution == ({Var(0): 5}, 1)
        break

    # disequality
    def goal_neq_ab(state):
        def not_equal_goal(a, b):
            return conj(
                eq(a, 6),
                conj(
                    eq(b, 5),
                    neq(a, b)
                )
            )

        return call_fresh(lambda a:
                        call_fresh(lambda b:
                                    not_equal_goal(a, b)))(state)

    assert run(goal_neq_ab,
            n=1
            #   n=2 # raises an error
            ) == [({Var(0): 6, Var(1): 5}, 2)]

In [227]:
from pprint import pprint

list1 = l(1, 2)
list2 = l(3, 4)

def res_query(list1, list2, result):
    return appendo(list1, list2, result)
    
result = run(call_fresh(lambda res: res_query(list1, list2, res)), n=1)

for res in result:
    subst, _ = res
    pprint(subst)
    print(reify(subst, subst[Var(0)]))

{<Var 2>: (2, ()),
 <Var 4>: 2,
 <Var 1>: 1,
 <Var 0>: (<Var 1>, <Var 3>),
 <Var 5>: (),
 <Var 6>: (3, (4, ())),
 <Var 3>: (<Var 4>, <Var 6>)}
(1, (2, (3, (4, ()))))
