In [None]:
#|hide
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2
%load_ext autoreload
%autoreload 2

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2


In [None]:
#|export
from __future__ import annotations
from typing import Callable, Optional, Set, Any, NamedTuple, Protocol, Tuple, Dict, TypeVar
from fastcore.test import test_eq, test_fail
from sveltish.utils import Bunch

In [None]:
#| default_exp signals

In [None]:
#|export
    
T = Optional[TypeVar("T")]
Getter = Callable[[],T]
Setter = Callable[[T], None]
Subscriber = Callable[[T], None] # a callback
Unsubscriber = Callable[[], None] # a callback to be used upon termination of the subscription
class Observable(Protocol):
    def subscribe(self, 
                  observer: Callable[[Any], None]
                  ) -> None: ...
class Observer(Protocol):
    def run(self) -> None: ...
    def cancel(self) -> None: ...
class Readable(Bunch):
    ''' An observable that can be read.'''
    def __init__(self, 
                 read: Getter
                 ): self.read = read
    def __repr__(self): 
        return f'${(self.read()).__class__.__name__}: {self.read().__repr__()}'
    def asTuple(self):
        return self.__dict__.values()
    
class Writable(Readable):
    ''' An observable that can be read and set.'''
    def __init__(self, read, 
                 write:Setter): 
        super().__init__(read)
        self.write = write

class Signal(Writable, Observable):
    ''' A signal (aka store) stores a value and run callbacks when the value changes.
    '''
    def __init__(self, 
                 read: Getter, 
                 write:Setter,
                 subscribe: Callable): 
        super().__init__(read, write)
        self.subscribe = subscribe
    subscribe = lambda self, observer: self.subscribe(observer)


# Signal = Tuple[Getter, Setter, Callable[[Subscriber], None]]
class Callback(Bunch, Observer):
    ''' A callback is a function that is called when a signal changes.'''
    def __init__(self, 
                 run:Callable[[], None], # the function to be called
                 cancel:Callable[[], None], # cancels the subscription
                 subscriptions:Set[Any] # observables the callback is subscribed to
                 ): 
        self.run = run
        self.cancel = cancel
        self.subscriptions = subscriptions
    run = lambda self: self.run()
    cancel = lambda self: self.cancel()

In [None]:
# |export
context = [] # a stack of dependencies of the current reaction


def signal(
    value: T = None # initial value
    ) -> Signal:
    ''' Signal factory.'''
    subscribers = set()

    def subscribe(
        callback: Subscriber # callback to be called when the store value changes
        ) -> None:
        '''Add callback to list of subscribers to be executed when the signal value changes.
        Also adds unsubscribe function to the callback's subscriptions.
        '''
        subscribers.add(callback.run)
        callback.subscriptions.add(lambda: subscribers.discard(callback.run))

    def read() -> T: # signal getter
        callback = context[-1] if context else None
        if callback: subscribe(callback)
        return value

    def write(newValue: T) -> None: # signal setter
        nonlocal value
        value = newValue
        # run can change the subscribers set, so we need to copy it
        for run in subscribers.copy():run()

    return Signal(read, write, subscribe)


def reaction(fn: Callable) -> Callback:
    ''' Reaction factory. A reaction is a callback that is called when a signal changes.\n
    Also known as: effect, observer, callback, computed, formula, derived.'''
    def cancel():
        nonlocal callback
        # unsubscribe function can change the set, so we need to copy it
        for u in callback.subscriptions.copy(): u()
        callback = Callback(callback.run, callback.cancel, set())

    def run():
        cancel()
        context.append(callback)
        try: fn()
        finally: context.pop()

    callback = Callback(run, cancel, set())
    run()
    return callback

def writable(value:T=None) -> Writable:
    ''' Writable factory. A writable is an interface to a signal.'''
    s = signal(value)
    return Writable(s.read, s.write)

from sveltish.utils import compose
from functools import reduce

def pipe(*fns):
    ''' Composes a list of functions.'''
    w = writable()
    fn = compose(*fns)
    _ = reduce(lambda fn: w.write(fn()), [*fns])
    # _ = reaction(lambda: w.write(fn()))
    return w.read

def computed(fn) -> Getter:
    ''' A computed is a signal that is derived from other signals. It is a kind of cache of a reaction.'''
    w = writable()
    _ = reaction(lambda: w.write(fn()))
    return w.read



def readonly(value:T=None) -> Getter:
    w = writable(value)
    return w.read
    # return Readable(read)


observable = cell = signal
observer = callback = effect = view = reaction
derived = formula = computed


## Tests

In [None]:
count, setCount = writable(0).asTuple()

In [None]:
history = []  # logging for testing

def record(x): 
    history.append(x)
    print(history)
logger = reaction(lambda: record(count()))

test_eq(history, [0])

[0]


In [None]:
def increment(): setCount(count()+1)
def decrement(): setCount(count()-1)
def reset(): setCount(0)

setCount(3)
increment()
decrement()
decrement()
reset()
setCount(42)

test_eq(history, [0, 3, 4, 3, 2, 0, 42])

[0, 3]
[0, 3, 4]
[0, 3, 4, 3]
[0, 3, 4, 3, 2]
[0, 3, 4, 3, 2, 0]
[0, 3, 4, 3, 2, 0, 42]


In [None]:
logger.cancel()
reset()
setCount(22)
test_eq(history, [0, 3, 4, 3, 2, 0, 42])

In [None]:
effect1 = reaction(lambda: print(f'Count is now {count()}'))
effect2 = reaction(lambda: print(f'double of the count is {count()*2}'))

Count is now 22
double of the count is 44


In [None]:
reset()

Count is now 0
double of the count is 0


In [None]:
effect1.cancel()
effect2.cancel()

In [None]:
history = []
s, set_fn = writable().asTuple()

In [None]:
logger = reaction(lambda: history.append(s()))
logger.cancel()
test_eq(history, [None])

In [None]:
from threading import Event, Thread
import time

In [None]:
def start(set): # the start function is the publisher
    stopped = Event()
    def loop(): # needs to be in a separate thread
        while not stopped.wait(1): # in seconds
            set(time.localtime())
    Thread(target=loop).start()    
    return stopped.set

In [None]:
timer, set_fn = writable(time.localtime()).asTuple()
log = reaction(lambda: print(time.strftime(f"%H:%M:%S", timer())))

15:41:35


In [None]:
stop = start(set_fn)
time.sleep(2)
stop()

15:41:36
15:41:37


In [None]:
count, setCount = writable(0).asTuple()
logCount = reaction(lambda: print(f'count is {count()}'))

count is 0


In [None]:
double = computed(lambda: count()*2) #type: ignore
setCount(7)
test_eq(double(), 14)

count is 7


In [None]:
logCount.cancel()

In [None]:
elapsing = None
def calc_elapsed(now):
    global elapsing
    if not elapsing: 
        elapsing = now
    return time.mktime(now) - time.mktime(elapsing)

In [None]:
timer()

time.struct_time(tm_year=2023, tm_mon=3, tm_mday=15, tm_hour=15, tm_min=41, tm_sec=37, tm_wday=2, tm_yday=74, tm_isdst=0)

In [None]:
elapsed = computed(lambda: calc_elapsed(timer()))
elapsed()

0.0

In [None]:
stop = start(set_fn)
time.sleep(2)
stop()

15:41:39
15:41:40


In [None]:
a = writable([1,2,3,4])
b = writable([5,6,7,8])
zipper = computed(lambda: list(zip(a.read(), b.read())))
test_eq(zipper(), [(1, 5), (2, 6), (3, 7), (4, 8)])
a.write([4,3,2,1])
test_eq(zipper(), [(4, 5), (3, 6), (2, 7), (1, 8)])
b.write([8,7,6,5])
test_eq(zipper(), [(4, 8), (3, 7), (2, 6), (1, 5)])

In [None]:
history = []
firstName, setFirstName = writable("John").asTuple()
lastName, setLastName = writable("Smith").asTuple()
fullName = lambda: f'{firstName()} {lastName()}'
showFullName, setShowFullName = writable(True).asTuple()
displayName = reaction(lambda: history.append(fullName() if showFullName() else firstName()))

In [None]:
test_eq(history, ['John Smith'])
setShowFullName(False)
test_eq(history, ['John Smith', 'John'])
setShowFullName(True)
test_eq(history, ['John Smith', 'John', 'John Smith'])

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()