### Copyright (c) 2024 James Litsios

## Code V14: Give State a comonadic api¶

In [1]:
from abc import abstractmethod
from typing import TypeVar, Generic, ClassVar, Tuple, Any, Dict, List
from collections.abc import Callable
from dataclasses import dataclass, field, fields

@dataclass(frozen=True)
class State:
    kv: dict
    counter: int

    class Access:
        
        @dataclass(frozen=True)
        class kv:
            counter: int
        
            @staticmethod
            def lower(path): 
                state = path[0]
                return (state.kv, (State.Access.kv(state.counter), path))
                        
        @dataclass(frozen=True)
        class counter:
            kv: dict
    
            @staticmethod
            def lower(path):
                state = path[0]
                return (state.counter, (State.Access.counter(state.kv), path))
    
            
def cobind(path, f):
    return State(
        f(State.Access.kv.lower(path)),
        f(State.Access.counter.lower(path)))

def current(focus):
    return focus[0]

state1 = State({'a':1, 'b':2}, 100)

 
print(cobind((state1, None), lambda t2: {
    State.Access.kv:(lambda kv_: (kv:=current(kv_))|{'c': kv['a']+kv['b']}),
    State.Access.counter:(lambda counter_: current(counter_)+1)
    }[t2[1][0].__class__](t2)))


State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101)


## Code V15: Use right sided division as cobind operator and prepare choices

In [2]:
@dataclass(frozen=True)
class Choice:
    f: Callable[[Any, ...], Any]

    def __rtruediv__(self, cobindable):
        return cobind(cobindable, self.f)

    def __rfloordiv__(self, cobindable):
        return cobind((cobindable, None), self.f)

print(state1 // Choice(lambda t2: {
    State.Access.kv:(lambda kv_: (kv:=current(kv_))|{'c': kv['a']+kv['b']}),
    State.Access.counter:(lambda counter_: current(counter_)+1)
    }[t2[1][0].__class__](t2)))

print((state1, None) / Choice(lambda t2: {
    State.Access.kv:(lambda kv_: (kv:=current(kv_))|{'c': kv['a']+kv['b']}),
    State.Access.counter:(lambda counter_: current(counter_)+1)
    }[t2[1][0].__class__](t2)))

State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101)
State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101)


## Code V16: Bring in conditional cobinding of dataclass fields

In [3]:
class CoBindable:
    @abstractmethod
    def cobind(self, path: Tuple[Any, ...], f: Callable[[Tuple[Any, ...]], Any]) -> Any:
        pass

@dataclass(frozen=True)
class Choice:
    f: Callable[[Any, ...], Any]

    @abstractmethod
    def mk_filter(self):
        pass

    @abstractmethod
    def mk_choices(self, cs):
        pass
    
    def __or__(self, other: 'Choice'):
        return self.mk_choices((self, other))
    
    def __rtruediv__(self, cobindable):
        return cobind(cobindable, self.mk_filter())

    def __rfloordiv__(self, cobindable):
        return cobind((cobindable, None), self.mk_filter())

@dataclass(frozen=True)
class Choices:
    cs: Tuple[Choice, ...]

    @abstractmethod
    def mk_chooser(self):
        pass
    
    def __or__(self, other: 'Choice'):
        return self.__class__(self.cs+(other,))

    def __rtruediv__(self, cobindable):
        return cobind(cobindable, self.mk_chooser())

    def __rfloordiv__(self, cobindable):
        return cobind((cobindable, None), self.mk_chooser())


@dataclass(frozen=True)
class FieldChoice(Choice):

    def mk_filter(self):
        return (lambda focus:
            self.f(focus) 
                if self.__class__.__name__ == focus[1][0].__class__.__name__ 
                else focus[0])
    
    def mk_choices(self, cs):
        return FieldChoices(cs)

                
@dataclass(frozen=True)
class FieldChoices(Choices):

    def mk_chooser(self):
     return (lambda focus:
             opt_choice[0].f(focus) 
                 if (opt_choice := tuple(filter((lambda choice: choice.__class__.__name__ == focus[1][0].__class__.__name__), self.cs)))
                 else focus[0])

@dataclass(frozen=True)
class State(CoBindable):
    kv: dict
    counter: int

    def cobind(self, path: Tuple[Any, ...], f: Callable[[Tuple[Any, ...]], Any]) -> Any:
        xpath = (self, path)
        return self.__class__(*[f(access.lower(xpath)) for access in self.__class__.accessors])

    
    class Lower:
        
        @dataclass(frozen=True)
        class kv:
            counter: int
        
            @staticmethod
            def lower(path): 
                state = path[0]
                return (state.kv, (State.Lower.kv(state.counter), path))
                        
        @dataclass(frozen=True)
        class counter:
            kv: dict
    
            @staticmethod
            def lower(path):
                state = path[0]
                return (state.counter, (State.Lower.counter(state.kv), path))

    class K:
        
        @dataclass(frozen=True)
        class kv(FieldChoice):
            pass
            
        @dataclass(frozen=True)
        class counter(FieldChoice):
            pass

State.accessors=[State.Lower.kv, State.Lower.counter]

            
def cobind(focus: Tuple[Any, ...], f: Callable[[Tuple[Any, ...]], Any]) -> Any:
    return focus[0].cobind(focus[1], f)



state1 = State({'a':1, 'b':2}, 100)

print(state1 // (
    State.K.kv(lambda kv_: (kv:=kv_[0])|{'c': kv['a']+kv['b']})
    ))

print(state1 // (
    State.K.counter(lambda counter_: counter_[0]+1)
    ))

print(state1 // (
    State.K.kv(lambda kv_: (kv:=kv_[0])|{'c': kv['a']+kv['b']})
    | State.K.counter(lambda counter_: counter_[0]+1)
    ))


State(kv={'a': 1, 'b': 2, 'c': 3}, counter=100)
State(kv={'a': 1, 'b': 2}, counter=101)
State(kv={'a': 1, 'b': 2, 'c': 3}, counter=101)


## Code V17: We can cobind dictionaries too, and have a choice model

In [4]:
@dataclass(frozen=True, init=False)
class DictChoice(Choice):
    k: Any

    def __init__(self, k, f): 
        object.__setattr__(self, 'k', k)
        super().__init__(f)
    
    def mk_filter(self):
        return (lambda focus:
            self.f(focus) 
                if self.k == focus[1][0] 
                else focus[0])
    
    def mk_choices(self, cs):
        return FieldMappings(cs)

DC = DictChoice

@dataclass(frozen=True)
class DictChoices(Choices):

    def mk_chooser(self):
     return (lambda focus:
             opt_choice[0].f(focus) 
                 if (opt_choice := tuple(filter((lambda choice: choice.k == focus[1][0]), self.cs)))
                 else focus[0])


def dict_cobind(focus, f):
    return {k: f((v, (k, focus))) for k,v in focus[0].items()}


def cobind(focus: Tuple[Any, ...], f: Callable[[Tuple[Any, ...]], Any]) -> Any:
    if isinstance(focus[0], dict):
        return dict_cobind(focus, f)
    elif isinstance(focus[0], CoBindable):
        return focus[0].cobind(focus[1], f)
    else:
        assert(False)


state2 = State({'a':1, 'b':2, 'c':None}, 100)


print(state2 // (
    State.K.kv(lambda kvf: kvf / DC('c', lambda df: kvf[0]['a']+kvf[0]['b']))
    ))



State(kv={'a': 1, 'b': 2, 'c': 3}, counter=100)


## Code V18: Make binding operators composable

In [5]:
setattr(State.K.kv, '__init__', lambda self, f=None: super(State.K.kv, self).__init__(f))
setattr(State.K.kv, '__pow__', lambda self, other: State.K.kv(lambda kv_f: kv_f / other))

setattr(State.K.counter, '__init__', lambda self, f=None: super(State.K.counter, self).__init__(f))
setattr(State.K.counter, '__pow__', lambda self, other: State.K.counter(lambda counter_f: counter_f / other))


def dcinit(self, k, f=None): 
    object.__setattr__(self, 'k', k)
    super(DictChoice, self).__init__(f)
    
setattr(DictChoice, '__init__', dcinit)
setattr(DictChoice, '__pow__', lambda self, other: DictChoice(lambda kvf: kvf / other))

state2 = State({'a':1, 'b':2, 'c':None}, 100)

print(state2 // State.K.kv() ** DC('c', (lambda df: df[1][1][0]['a']+df[1][1][0]['b'])))

State(kv={'a': 1, 'b': 2, 'c': 3}, counter=100)


## Code V19: We can cobind JAX arrays

In [6]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
from jax import lax, jit
import jax.numpy as jnp
import jax.random as jrnd



@dataclass(frozen=True)
class JNP(Choice):
    
    def mk_filter(self):
        return lambda focus: self.f(focus)
    
    def mk_choices(self, cs):
        return JNPChoices(cs)



@dataclass(frozen=True)
class JNPChoices(Choices):

    def mk_chooser(self):
     return None


def jnp_cobind(focus, f):
    return jnp.fromfunction((lambda i:
                                f((focus[0][i], (i, focus)))),
                            shape=focus[0].shape,
                            dtype=jnp.int16)


def cobind(focus: Tuple[Any, ...], f: Callable[[Tuple[Any, ...]], Any]) -> Any:
    if isinstance(focus[0], jnp.ndarray):
        return jnp_cobind(focus, f)
    elif isinstance(focus[0], dict):
        return dict_cobind(focus, f)
    elif isinstance(focus[0], CoBindable):
        return focus[0].cobind(focus[1], f)
    else:
        assert(False)


a = jnp.arange(0, 10, dtype=jnp.float32)
print(a)
print(a // JNP(lambda jef: jef[0]+1))

[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]


## Code V20: Show that shallow JAX JIT performance on comonadic code is still good

In [7]:
from timeit import timeit

def timeit2(f, number=1000000):
    return 1e6*timeit(f, number=number)/number

# Example taken from: https://jax.readthedocs.io/en/latest/quickstart.html

key = jrnd.key(1701)
x = jrnd.normal(key, (1_00, ))

def selu1(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu1_jit = jit(selu1)

print(f"Original non-JIT: {timeit2(lambda: selu1(x).block_until_ready(), number=100000):.0f}")
print(f"Original JIT: {timeit2(lambda: selu1_jit(x).block_until_ready()):.0f}")

def selu2(x, alpha=1.67, lmbda=1.05):
    return lmbda * lax.cond(x > 0, (lambda y: y), (lambda y: alpha * jnp.exp(y) - alpha), x)

def f(x, alpha=1.67, lmbda=1.05):
    return x // (JNP(lambda ef:  selu2(ef[0])))

f_jit = jit(f)

print(f"Comonadic non-JIT: {timeit2((lambda: f(x).block_until_ready()), number=1000):.0f}")
print(f"Comonadic JIT: {timeit2((lambda: f_jit(x).block_until_ready())):.0f}")


Original non-JIT: 108
Original JIT: 11
Comonadic non-JIT: 7752
Comonadic JIT: 11


Next step is ensure that JAX JIT also works for deep cobinds