## Copyright (c) 2024 James Litsios

In [1]:
from dataclasses import dataclass, field, InitVar
from typing import Tuple

## State V1: mutable

In [2]:
@dataclass(frozen=False)
class State:
    kv: dict
    counter: int

state = State(kv={'a':1, 'b':2}, counter=100) 
state

State(kv={'a': 1, 'b': 2}, counter=100)

## Code V1: mutable, no signature constraints

In [3]:
def get(state: State, k: str) -> str:
    return state.kv[k]

def set_(state: State, k: str, v: int) -> None:
    state.kv[k] = v

def increment(state: State) -> int:
    state.counter += 1
    return state.counter

def test1(state: State) -> Tuple[int, State]:
    av = get(state, 'a')
    bv = get(state, 'b')
    set_(state, 'c', av+bv)
    incremented = increment(state)
    return incremented, state

test1(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## State V2: immutable

In [4]:
@dataclass(frozen=True)
class State:
    kv: dict
    counter: int

state = State(kv={'a':1, 'b':2}, counter=100) 
state

State(kv={'a': 1, 'b': 2}, counter=100)

## Code V2: immutable state is passed around

In [5]:
def get(state: State, k: str) -> str:
    return state.kv[k]

def set_(state: State, k: str, v: int) -> State:
    return State(dict(state.kv | {k:v}), state.counter)

def increment(state: State) -> Tuple[int, State]:
    new_counter = state.counter + 1
    return new_counter, State(state.kv, new_counter)

def test(state: State) -> Tuple[int, State]:
    av = get(state, 'a')
    bv = get(state, 'b')
    state = set_(state, 'c', av+bv)
    incremented, state = increment(state)
    return incremented, state

test(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V3: immutable state is always passed around (even when not changed)

In [6]:
from typing import Tuple, Any

In [7]:
def get(state: State, k: str) -> Tuple[str, State]:
    return state.kv[k], state

def set_(state: State, k: str, v: int) -> Tuple[None, State]:
    return None, State(dict(state.kv | {k:v}), state.counter)

def increment(state: State) -> Tuple[int, State]:
    new_counter = state.counter + 1
    return new_counter, State(state.kv, new_counter)

def test(state: State) -> Tuple[int, State]:
    av, state = get(state, 'a')
    bv, state = get(state, 'b')
    _, state = set_(state, 'c', av+bv)
    incremented, state = increment(state)
    return incremented, state

test(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V4: State "handling" is normalized and explicitly given its own function that can be called later, always in the same manner

In [8]:
from collections.abc import Callable

def get(k: str) -> Callable[[State], Tuple[str, State]]:
    return lambda state: (state.kv[k], state)

def set_(k: str, v: int) -> Callable[[State], Tuple[None, State]]:
    return lambda state: (None, State(dict(state.kv | {k:v}), state.counter))

def increment() -> Callable[[State], Tuple[int, State]]:
    return (lambda state:
        (new_counter := state.counter + 1, State(state.kv, new_counter)))

def test(state: State):
    av, state = get('a')(state)
    bv, state = get('b')(state)
    _, state = set_('c', av+bv)(state)
    incremented, state = increment()(state)
    return incremented, state

test(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V5: Make it monadic by "binding" statements into one big expression with the help of a bind function

In [9]:
def bind(m, fm):
    return (lambda state:
        fm((r_state := m(state))[0])(r_state[1]))
    
def test():
    return bind(get('a'), lambda av:
        bind(get('b'), lambda bv:
            bind(set_('c', av+bv), lambda _:
                increment())))

test()(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Case V6: Use the ">>" operator to bind, and explicitly define objects to capture the monadic functions

In [10]:
class SM:
    def __rshift__(self, fm):
        return (lambda state:
            fm((r_state := self.__call__(state))[0])(r_state[1]))
        

class Get(SM):
    def __init__(self, k:str):
        self.k = k
        
    def __call__(self, state) -> Tuple[str, State]:
        return (state.kv[self.k], state)

class Set(SM):
    def __init__(self, k: str, v: int):
        self.k = k
        self.v = v
        
    def __call__(self, state) -> Tuple[None, State]:
        return (None, State(dict(state.kv | {self.k:self.v}), state.counter))

class Increment(SM):
    def __init__(self): 
        pass
        
    def __call__(self, state) -> Tuple[int, State]:
        return (new_counter := state.counter + 1, State(state.kv, new_counter))

def test():
    return Get('a') >> (lambda av:
        Get('b') >> (lambda bv:
            Set('c', av+bv) >> (lambda _:
                Increment())))

test()(state) 

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V7: Change bind logic to make a maybe monad

In [11]:
class MSM:
    def __rshift__(self, fm):
        return (lambda state:
            (r_state2 if (r_state2 := fm(r_state1[0])(r_state1[1]))[0] is not None else (None, state)
               ) if (r_state1 := self.__call__(state))[0] is not None else (None, state))
        

class Get(MSM):
    def __init__(self, k:str):
        self.k = k
        
    def __call__(self, state) -> Tuple[str, State]:
        return (state.kv.get(self.k, None), state)

class Set(MSM):
    def __init__(self, kv: Tuple[str, int]):
        self.kv = kv
        
    def __call__(self, state) -> Tuple[None, State]:
        return (self.kv[1], State(dict(state.kv | {self.kv[0]:self.kv[1]}), state.counter))

class Increment(MSM):
    def __init__(self): 
        pass
        
    def __call__(self, state) -> Tuple[int, State]:
        return (new_counter := state.counter + 1, State(state.kv, new_counter))

def test_a():
    return  Get('a') >> (lambda av:
            Get('b') >> (lambda bv:
            Set(('c', av+bv)) >> (lambda _:
            Increment())))

def test_b():
    return  Get('a') >> (lambda av:
            Get('b') >> (lambda bv:
            Set(('c', av+bv)) >> (lambda _:
            Increment() >> (lambda _:
            Get('x')))))

test_a()(state), test_b()(state)

((101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101)),
 (None, State(kv={'a': 1, 'b': 2}, counter=100)))

## Code V8: "Wrap" expressions and lambda as data to make a free monad

In [12]:
from abc import ABC, abstractmethod

class FSM:
    def __rshift__(self, fm):
        return Bind(self, fm)

    @abstractmethod
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        pass
        
class Expr:
    @abstractmethod
    def eval(self, var_list=None) -> Any:
        pass

@dataclass(init=True)
class Bind(FSM):
    m: FSM
    fm: 'Lambda'

    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return self.fm.body.runfs((r_state := self.m.runfs(state, var_list))[1], ((self.fm.var, (r_state[0])),var_list))

@dataclass(init=True)
class Lambda(Expr):
    var: 'Var'
    body: Expr

    def eval(self, var_list=None) -> Any:
        # must be applied to change
        return self

@dataclass(init=True)
class Add(Expr):
    lhs: Expr
    rhs: Expr

    def eval(self, var_list=None) -> Any:
        return self.lhs.eval(var_list)+self.rhs.eval(var_list)

@dataclass(init=False, eq=True)
class Var(Expr):
    name: str

    def __init__(self, name=None):
        self.name = name

    def eval(self, var_list=None) -> Any:
        while var_list is not None:
            if var_list[0][0] == self:
                return var_list[0][1]
            else:
                var_list = var_list[1]
        assert(False)
    
    def __call__(self, body):
        return Lambda(self, body)

    def __add__(self, other):
        return Add(self, other)        

@dataclass(init=True)
class Get(FSM):
    k: str
            
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return (state.kv.get(self.k, None), state)

@dataclass(init=True)
class Set(FSM):
    k: str
    v: int
        
    def runfs(self, state, var_list=None) -> Tuple[None, State]:
        return (v_eval:=self.v.eval(var_list), State(dict(state.kv | {self.k:v_eval}), state.counter))

class Increment(FSM):
    def __init__(self): 
        pass
        
    def runfs(self, state, var_list=None) -> Tuple[int, State]:
        return (new_counter := state.counter + 1, State(state.kv, new_counter))

def test():
    return  Get('a') >> Var("av")(
            Get('b') >> Var("bv")(
            Set('c', Var("av")+Var("bv")) >> Var()(
            Increment())))

test().runfs(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V9: Use the "**" operator to bind, use \__getattr__ to capture the managed lambda variables

In [13]:
from abc import ABC, abstractmethod

@dataclass(init=True)
class NamedFSM:
    m: 'FSM'
    var_id: str
    
    def __pow__(self, lm):
        return Bind(self.m, Lambda(Var(self.var_id), lm))


class FSM:

    def __getattr__(self, id):
        return NamedFSM(self, id)

    def __pow__(self, lm):
        return Bind(self, Lambda(None, lm))
        
    @abstractmethod
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        pass
        
class Expr:
    @abstractmethod
    def eval(self, var_list=None) -> Any:
        pass

@dataclass(init=True)
class Bind(FSM):
    m: FSM
    fm: 'Lambda'

    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return self.fm.body.runfs((r_state := self.m.runfs(state, var_list))[1], ((self.fm.var, (r_state[0])),var_list) if self.fm.var else var_list)


@dataclass(init=True)
class Lambda(Expr):
    var: 'Var'
    body: Expr

    def eval(self, var_list=None) -> Any:
        # must be applied to change
        return self

@dataclass(init=True)
class Add(Expr):
    lhs: Expr
    rhs: Expr

    def eval(self, var_list=None) -> Any:
        return self.lhs.eval(var_list)+self.rhs.eval(var_list)

@dataclass(init=False, eq=True)
class Var(Expr):
    name: str

    def __init__(self, name=None):
        self.name = name

    def eval(self, var_list=None) -> Any:
        while var_list is not None:
            if var_list[0][0] == self:
                return var_list[0][1]
            else:
                var_list = var_list[1]
        assert(False)
    
    def __call__(self, body):
        return Lambda(self, body)

    def __add__(self, other):
        return Add(self, other)        

@dataclass(init=True)
class Get(FSM):
    k: str
            
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return (state.kv.get(self.k, None), state)

@dataclass(init=True)
class Set(FSM):
    k: str
    v: int
        
    def runfs(self, state, var_list=None) -> Tuple[None, State]:
        return (v_eval:=self.v.eval(var_list), State(dict(state.kv | {self.k:v_eval}), state.counter))

class Increment(FSM):
    def __init__(self): 
        pass
        
    def runfs(self, state, var_list=None) -> Tuple[int, State]:
        return (new_counter := state.counter + 1, State(state.kv, new_counter))

def test():
    return (Get('a').av ** 
            Get('b').bv ** 
            Set('c', Var('av')+Var('bv')) ** 
            Increment())

    
test().runfs(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V10: Clean separation between monadic and non-monadic expressions, improved bringing monadic terms into non-monadic expressions

In [14]:
from abc import ABC, abstractmethod

@dataclass(init=True)
class NamedFSM:
    m: 'FSM'
    ids: Tuple[str, ...] 
    
    def __getattr__(self, id_):
        return NamedFSM(self.m, self.ids+(id_,))
        
    def __pow__(self, lm):
        return Bind(self.m, Lambda(self.ids, lm))



def set_(ids, args, var_list):
    if isinstance(args, tuple):
        if ids == () and args == ():
            return var_list
        else:
            set_(ids[1:], args[1:], ((ids[0], args[0]), var_list))
    else:
        assert(len(ids) == 1)
        return ((ids[0], args), var_list)

def lookup(id_, var_list):
    while var_list is not None:
        if id_ == var_list[0][0]:
            return var_list[0][1]
        var_list = var_list[1]
    assert(False)

def get_(var_list, ids):
    if isinstance(ids, tuple):
        if ids != ():
            return (lookup(ids[0], var_list),)+get_(var_list, ids[1:])
        else:
            return ()
    else:
        return lookup(ids, var_list)

class FSM:

    def __getattr__(self, id_):
        return NamedFSM(self, (id_,))

    def __pow__(self, lm):
        return Bind(self, Lambda((), lm))
        
    @abstractmethod
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        pass


class ExprMeta(type):
    def __getattr__(cls, id_):
        return Expr((id_,))

class Expr(metaclass=ExprMeta):
    def __init__(self, ids):
        self.ids = ids
        
    def __getattr__(self, id_):
        return Expr(self.ids+(id_,))

    def __call__(self, expr_lambda):
        return ExprFSM(self.ids, expr_lambda)
        
@dataclass(init=True)
class ExprFSM(FSM):
    ids: Tuple[str, ...]
    expr_lambda: Any

    
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return (self.expr_lambda(*get_(var_list, self.ids)), state)
        
        
@dataclass(init=True)
class Bind(FSM):
    m: FSM
    fm: 'Lambda'

    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return self.fm.body.runfs((r_state := self.m.runfs(state, var_list))[1], set_(self.fm.ids, r_state[0], var_list))


@dataclass(init=True)
class Lambda:
    ids: Tuple[str, ...] 
    body: Expr


@dataclass(init=True)
class Get(FSM):
    k: str
            
    def runfs(self, state, var_list=None) -> Tuple[Any, State]:
        return (state.kv.get(self.k, None), state)

@dataclass(init=True)
class Set(FSM):
    k: str
    v: int
        
    def runfs(self, state, var_list=None) -> Tuple[None, State]:
        return ((), State(dict((v_runfs:=self.v.runfs(state, var_list))[1].kv | {self.k:v_runfs[0]}), v_runfs[1].counter))

class Increment(FSM):
    def __init__(self, _=None): 
        pass
        
    def runfs(self, state, var_list=None) -> Tuple[int, State]:
        return (new_counter := state.counter + 1, State(state.kv, new_counter))

import operator as op

def test():
    return (Get('a').av ** 
            Get('b').bv ** 
            Set('c', Expr.av.bv(op.add)) ** 
            Increment())

    
test().runfs(state)

(101, State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101))

## Code V11: Define basic monadic functions and "lift" them to work on a specific state

In [15]:
from abc import abstractmethod
from typing import TypeVar, Generic, ClassVar
from collections.abc import Callable

T = TypeVar('T')
S = TypeVar('S')
M = TypeVar('M')
ML = TypeVar('ML')
MR = TypeVar('MR')


@dataclass(frozen=True)
class State:
    kv: dict
    counter: int

        
    @dataclass(frozen=True)
    class _KV:
        counter: int
    
        @classmethod
        def toField(cls, state): 
            return (state.kv, cls(state.counter))
            
        @classmethod
        def fromField(cls, kv, rest):
            return State(kv, rest.counter)
    
        
    
    @dataclass(frozen=True)
    class _Counter:
        kv: dict

        @classmethod
        def toField(cls, state): 
            return (state.counter, cls(state.kv))

        @classmethod
        def fromField(cls, counter, rest):
            return State(counter, rest.kv)
    

state = State(kv={'a':1, 'b':2}, counter=100) 

@dataclass(init=True)
class NamedFSM(Generic[ML]):
    m: ML
    ids: Tuple[str, ...] 
    
    def __getattr__(self, id_) -> 'NamedFSM[MR]':
        return NamedFSM(self.m, self.ids+(id_,))
        
    def __pow__(self, mr: MR) -> 'Bind[ML, Lambda[MR]]':
        return Bind(self.m, Lambda(self.ids, mr))



def set_(ids, args, var_list):
    if isinstance(args, tuple):
        if ids == () and args == ():
            return var_list
        else:
            set_(ids[1:], args[1:], ((ids[0], args[0]), var_list))
    else:
        assert(len(ids) == 1)
        return ((ids[0], args), var_list)

def lookup(id_, var_list):
    while var_list is not None:
        if id_ == var_list[0][0]:
            return var_list[0][1]
        var_list = var_list[1]
    assert(False)

def get_(var_list, ids):
    if isinstance(ids, tuple):
        if ids != ():
            return (lookup(ids[0], var_list),)+get_(var_list, ids[1:])
        else:
            return ()
    else:
        return lookup(ids, var_list)

class FSM:

    def __getattr__(self, id_) -> NamedFSM['FSM']:
        return NamedFSM(self, (id_,))

    def __pow__(self, mr) -> 'Bind[ML, Lambda[MR]]':
        return Bind(self, Lambda((), mr))
        
    @abstractmethod
    def runfs(self, state: S, var_list=None) -> Tuple[Any, S]:
        pass

class ExprMeta(type):
    def __getattr__(cls, id_):
        return Expr((id_,))

class Expr(metaclass=ExprMeta):
    def __init__(self, ids):
        self.ids = ids
        
    def __getattr__(self, id_):
        return Expr(self.ids+(id_,))

    def __call__(self, expr_lambda: Callable[[Any, ...], Any]):
        return ExprFSM(self.ids, expr_lambda)
        
@dataclass(init=True)
class ExprFSM(FSM):
    ids: Tuple[str, ...]
    expr_lambda: Any

    
    def runfs(self, state: S, var_list=None) -> Tuple[Any, S]:
        return (self.expr_lambda(*get_(var_list, self.ids)), state)
        
        
@dataclass(init=True)
class Bind(FSM):
    m: FSM
    fm: 'Lambda'

    def runfs(self, state: S, var_list=None) -> Tuple[Any, S]:
        return self.fm.body.runfs((r_state := self.m.runfs(state, var_list))[1], set_(self.fm.ids, r_state[0], var_list))


@dataclass(init=True)
class Lambda:
    ids: Tuple[str, ...] 
    body: ExprFSM


@dataclass(init=True)
class LiftS(FSM):
    loweredCls: Any
    m: FSM
    
    def runfs(self, state: S, var_list=None) -> Tuple[Any, S]:
        return ((ret_state := self.m.runfs((lower:=self.loweredCls.toField(state))[0], var_list))[0], self.loweredCls.fromField((ret_state)[1], lower[1]))

@dataclass(init=True)
class LiftSC:
    loweredCls: Any
    mc: object

    def __call__(self, *args):
        return liftS(self.loweredCls, self.mc(*args))

def liftS(cls, mc_or_m: object | FSM) -> LiftSC | LiftS:
    if isinstance(mc_or_m, NamedFSM):
        return NamedFSM(liftS(cls, mc_or_m.m), mc_or_m.ids)
    elif isinstance(mc_or_m, FSM):
        return LiftS(cls, mc_or_m)
    else:
        return LiftSC(cls, mc_or_m)

@dataclass(init=True)
class Get(FSM):
    k: str
            
    def runfs(self, state: S, var_list=None) -> Tuple[Any, S]:
        return (state.get(self.k, None), state)

@dataclass(init=True)
class Set(FSM):
    k: str
    v: int
        
    def runfs(self, state: S, var_list=None) -> Tuple[None, S]:
        return ((), dict((v_runfs:=self.v.runfs(state, var_list))[1] | {self.k:v_runfs[0]}))

class Increment(FSM):
    def __init__(self, _=None): 
        pass
        
    def runfs(self, state: S, var_list=None) -> Tuple[int, S]:
        return (new_counter := state + 1, new_counter)

import operator as op

def test_a():
    return (liftS(State._KV, Get)('a').av ** 
            liftS(State._KV, Get)('b').bv ** 
            liftS(State._KV, Set)('c', Expr.av.bv(op.add)) ** 
            liftS(State._Counter, Increment)())

def test_b():
    return (liftS(State._KV, Get('a').av) ** 
            liftS(State._KV, Get('b').bv) ** 
            liftS(State._KV, Set('c', Expr.av.bv(op.add))) ** 
            liftS(State._Counter, Increment()))

def test_c():
    return (liftS(State._KV, Get('a')).av ** 
            liftS(State._KV, Get('b')).bv ** 
            liftS(State._KV, Set('c', Expr.av.bv(op.add))) ** 
            liftS(State._Counter, Increment()))


def test_d():
    return (liftS(State._KV,  Get('a').av ** 
                              Get('b').bv ** 
                              Set('c', Expr.av.bv(op.add))) ** 
            liftS(State._Counter, Increment()))

(test_a().runfs(state), 
 test_b().runfs(state),
 test_c().runfs(state),
 test_d().runfs(state))

((101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})),
 (101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})),
 (101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})),
 (101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})))