## Copyright (c) 2024 James Litsios

## Code V12: Make access of monadic variables pointfree to help JIT¶

In [9]:
from abc import abstractmethod
from typing import TypeVar, Generic, ClassVar, Tuple, Any, Dict
from collections.abc import Callable
from dataclasses import dataclass, field, InitVar

R = TypeVar('R')
S = TypeVar('S')
M = TypeVar('M')
ML = TypeVar('ML')
MR = TypeVar('MR')

@dataclass(frozen=True)
class State:
    kv: dict
    counter: int
        
    @dataclass(frozen=True)
    class _KV:
        counter: int
    
        @classmethod
        def toField(cls, state): 
            return (state.kv, cls(state.counter))
            
        @classmethod
        def fromField(cls, kv, rest):
            return State(kv, rest.counter)    
    
    @dataclass(frozen=True)
    class _Counter:
        kv: dict

        @classmethod
        def toField(cls, state): 
            return (state.counter, cls(state.kv))

        @classmethod
        def fromField(cls, counter, rest):
            return State(counter, rest.kv)
    
state = State(kv={'a':1, 'b':2}, counter=100) 

IDSTACK = Tuple[str | Tuple[str, ...], 'IDSTACK'] | None
VARSTACK = Tuple[Any | Tuple[Any, ...], 'VARSTACK'] | None
ACCESSOR = Callable[[VARSTACK], Any]

@dataclass(init=True)
class NamedFSM(Generic[ML]):
    m: ML
    ids: str | Tuple[str, ...]
    
    def __getattr__(self, id_) -> 'NamedFSM[MR]':
        return NamedFSM(self.m, (self.ids+(id_,)) if isinstance(self.ids, tuple) else (self.ids, id_)) 
        
    def __pow__(self, mr: MR) -> 'Bind[ML, Lambda[MR]]':
        return Bind(self.m, Lambda(self.ids, mr))

def build_access(id_:str, id_stack: IDSTACK) -> ACCESSOR:
    assert id_stack, f"{id_} not defined)"
    h_stack, t_stack = id_stack
    if id_ == h_stack:
        return lambda var_stack: var_stack[0]
    elif id_ in h_stack:
        idx = h_stack.index(id_)
        return lambda var_stack: var_stack[0][idx]
    else:
        t_access = build_access(id_, t_stack)
        return lambda var_stack: t_access(var_stack[1])


class InVarMeta(type):
    @classmethod
    def __getattr__(cls, invar_id):
        return M(invar_id)

class M(metaclass=InVarMeta):
    def __init__(self, invar_ids):
        self.invar_ids = invar_ids
        
    def __getattr__(self, invar_id):
        return M((self.invar_ids+(invar_id,)) if isinstance(self.invar_ids, tuple) else (self.invar_ids, invar_id)) 

    def Expr(self, expr_lambda: Callable[[Any, ...], Any]):
        return ExprFSM(self.invar_ids, expr_lambda)


class FSM(Generic[R, S]):

    def __getattr__(self, id_) -> NamedFSM['FSM']:
        return NamedFSM(self, id_)

    def __pow__(self, mr) -> 'Bind[ML, Lambda[MR]]':
        return Bind(self, Lambda(None, mr))
    
    @abstractmethod
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        pass

class ExprMeta(type):
    @classmethod
    def __getattr__(cls, id_):
        print(cls.__name__, id_)
        return Expr({id_:None})

        
@dataclass(init=True)
class ExprFSM(FSM):
    ids: Dict[str, Any]
    expr_lambda: Any
    
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        accessors = tuple([build_access(id_, id_stack) for id_ in self.ids])
        return (lambda state, var_stack:
                    (self.expr_lambda(*[access(var_stack) for access in accessors]), state))
        
        
@dataclass(init=True)
class Bind(FSM):
    m: FSM
    fm: 'Lambda'
    
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        # first try always push result to var_stack
        m_runfs = self.m.mksm(id_stack)
        fm_runfs = self.fm.body.mksm((self.fm.ids, id_stack))
        return (lambda state, var_stack:
                    fm_runfs((r_state := m_runfs(state, var_stack))[1], (r_state[0], var_stack)))


@dataclass(init=True)
class Lambda:
    ids: str | Tuple[str, ...] 
    body: FSM

@dataclass(init=True)
class LiftS(FSM):
    loweredCls: Any
    m: FSM
    
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        m_runfs = self.m.mksm(id_stack)
        return (lambda state, var_stack:
                    ((ret_state := m_runfs((lower:=self.loweredCls.toField(state))[0], var_stack))[0], self.loweredCls.fromField((ret_state)[1], lower[1])))

@dataclass(init=True)
class LiftSC:
    loweredCls: Any
    mc: object

    def __call__(self, *args):
        return liftS(self.loweredCls, self.mc(*args))

def liftS(cls, mc_or_m: object | FSM) -> LiftSC | LiftS:
    if isinstance(mc_or_m, NamedFSM):
        return NamedFSM(liftS(cls, mc_or_m.m), mc_or_m.ids)
    elif isinstance(mc_or_m, FSM):
        return LiftS(cls, mc_or_m)
    else:
        return LiftSC(cls, mc_or_m)

@dataclass(init=True)
class Get(FSM):
    k: str
            
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        return (lambda state, var_stack:
                    (state.get(self.k, None), state))

@dataclass(init=True)
class Set(FSM):
    k: str
    v: FSM

    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        v_runfs = self.v.mksm(id_stack)
        return (lambda state, var_stack:
                    ((), dict((v_ret:=v_runfs(state, var_stack))[1] | {self.k:v_ret[0]})))


@dataclass(init=True)
class Increment(FSM):
        
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        return (lambda state, var_stack:
                (new_counter := state + 1, new_counter))

import operator as op

def test_a():
    return (liftS(State._KV, Get)('a').av ** 
            liftS(State._KV, Get)('b').bv ** 
            liftS(State._KV, Set)('c', M.av.bv.Expr(op.add)) ** 
            liftS(State._Counter, Increment)())

def test_b():
    return (liftS(State._KV, Get('a').av) ** 
            liftS(State._KV, Get('b').bv) ** 
            liftS(State._KV, Set('c', M.av.bv.Expr(op.add))) ** 
            liftS(State._Counter, Increment()))

def test_c():
    return (liftS(State._KV, Get('a')).av ** 
            liftS(State._KV, Get('b')).bv ** 
            liftS(State._KV, Set('c', M.av.bv.Expr(op.add))) ** 
            liftS(State._Counter, Increment()))


def test_d():
    return (liftS(State._KV,  Get('a').av ** 
                              Get('b').bv ** 
                              Set('c', M.av.bv.Expr(op.add))) ** 
            liftS(State._Counter, Increment()))

tms = [test_a(), test_b(), test_c(), test_d()]
runs = list([m.mksm() for m in tms])
rets = list([run(state, None) for run in runs])
rets

[(101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})),
 (101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})),
 (101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3})),
 (101, State(kv=101, counter={'a': 1, 'b': 2, 'c': 3}))]

## Code V13: Check that JAX's JIT is not confused

In [10]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
from jax import jit
import jax.numpy as jnp
import jax.random as jrnd
from timeit import timeit 

# Example taken from: https://jax.readthedocs.io/en/latest/quickstart.html

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jit(selu)

key = jrnd.key(1701)
x = jrnd.normal(key, (1_00, ))

print(f"Original non-JIT: {timeit(lambda: selu(x).block_until_ready()):.0f}")
print(f"Original JIT: {timeit(lambda: selu_jit(x).block_until_ready()):.0f}")

@dataclass(init=True)
class GetState(FSM):
            
    def mksm(self, id_stack:IDSTACK=None) -> Callable[[S, VARSTACK], Tuple[R, S]]:
        return (lambda state, var_stack:
                    (state, state))


def test():
    return (GetState().x **
            M.x.Expr(selu))

t = test()
run = t.mksm()
run_jit = jit(lambda x: run(x, None))

print(f"Monadic non-JIT: {timeit(lambda: run(x, None)):.0f}")
print(f"Monadic JIT: {timeit(lambda: run_jit(x)):.0f}")

Original non-JIT: 219
Original JIT: 36
Monadic non-JIT: 261
Monadic JIT: 31


In [11]:
print(t)

Bind(m=GetState(), fm=Lambda(ids='x', body=ExprFSM(ids='x', expr_lambda=<function selu at 0x1038ada20>)))
