In [None]:
#|export
from __future__ import annotations
from typing import Callable, TypeVar,  Generic, Union, Optional, Set, Protocol, Any, overload
from typing_extensions import Annotated, TypeAlias
from fastcore.test import test_eq, test_fail

In [None]:
#| default_exp stores

# stores



## The Svelte Store contract

1. A store must contain a `.subscribe` method, which must accept as its argument a `subscription function`(aka Subscriber or Callback). This `subscription function` must be immediately and synchronously called with the store's current value upon calling `subscribe`. All of a store's active subscription functions must later be synchronously called whenever the store's value changes.

1. The `.subscribe` method must return an `unsubscribe function`(aka Unsubscriber). Calling an `unsubscribe function` must `stop` its subscription, and its corresponding `subscription function` must not be called again by the store.

1. A store may optionally contain a `.set` method, which must accept as its argument a new value for the store, and which synchronously calls all of the store's active subscription functions. Such a store is called a writable store.


~~For interoperability with RxJS Observables, the .subscribe method is also allowed to return an object with an .unsubscribe method, rather than return the unsubscription function directly. Note however that unless .subscribe synchronously calls the subscription (which is not required by the Observable spec), Svelte will see the value of the store as undefined until it does.~~

[Store Contract Documentation](https://svelte.dev/docs#component-format-script-4-prefix-stores-with-$-to-access-their-values-store-contract)


#### Types Definition

In [None]:
#| exports

T = TypeVar("T")
covT = TypeVar("covT", covariant=True)
Subscriber = Callable[[T], None] # a callback
Unsubscriber = Callable[[], None] # a callback to be used upon termination of the subscription
Updater = Callable[[T], T]
Notifier = Callable[[Subscriber], Union[Unsubscriber, None]]

class StoreProtocol(Protocol, Generic[covT]):
    ''' The Svelte Store ~~contract~~ protocol. '''
    def subscribe(self, subscriber: Subscriber[T]) -> Unsubscriber: ...

Readable: TypeAlias = StoreProtocol[T]

class Writable(Readable[T]):
    ''' Writable protocol'''
    def set(self, value: T) -> None: ...
    def update(self, updater: Updater[T]) -> None: ...

## Implementation

#### Writable Store

In [None]:
#|export
import sveltish.utils as utils

In [None]:
#| export
class Store(Readable[T]):
    ''' A Writable Store.'''
    def __init__(self:Writable,
                initial_value: Any = None, # initial value of the store
                start: Notifier = utils.noop # A Notifier (Optional)
                ) -> None:
        self.value = initial_value
        self.subscribers: Set[Subscriber] = set() # callbacks to be called when the value changes
        self.start: Notifier = start # function called when the first subscriber is added
        self.stop: Optional[Unsubscriber] = None  # functional called when the last subscriber is removed

    def get(self) -> T: return self.value
    __call__ = get

    def subscribe(self:Writable,
                  callback: Subscriber # callback to be called when the store value changes
                  ) -> Unsubscriber:
        ''' Adds callback to the list of subscribers.'''
        self.subscribers.add(callback)
        if (len(self.subscribers) == 1):
            self.stop = self.start(self.__set) or (lambda: None) #type: ignore
        callback(self.value)
        def unsubscribe() -> None:
            ''' Removes callback from the list of subscribers.'''
            self.subscribers.remove(callback) if callback in self.subscribers else None
            if (len(self.subscribers) == 0):
                self.stop() if self.stop else None #type: ignore
                self.stop = None #type: ignore
        return unsubscribe

    def __set(self,
            new_value: T # The new value of the store
            ) -> None:
        ''' Internal implementation of set used inside Readable Store, which does not exposes set.'''
        if (utils.safe_not_equal(self.value, new_value)):
            self.value = new_value
            for subscriber in self.subscribers:
                subscriber(new_value)

    def set(self,
            new_value: T # The new value of the store
            ) -> None:
        ''' Set value of store.'''
        self.__set(new_value)

    def update(self,
               fn: Callable[[T], T] # a callback that takes the existing store value and updates it
               ) -> None:
        ''' Update the store value by applying `fn` to the existing value.'''
        self.set(fn(self.value))

    def __len__(self) -> int:
        ''' The length of the store is the number of subscribers.'''
        return len(self.subscribers)

    def __repr__(self) -> str:
        return f"w<{len(self)}> ${self.value.__class__.__name__}: {self.value}"


#### Writable Factory

In [None]:
#|export
def writable(value: T = None, # initial value of the store
             start: Notifier = utils.noop # Optional Notifier, a function called when the first subscriber is added
             ) -> Writable[T]: # Writable Store
    ''' Creates a new Writable Store (A Writable factory).'''
    return Store(value, start)

In [None]:
#| hide
class Bunch:
    __init__ = lambda self, **kw: setattr(self, '__dict__', kw)

count = writable(0)
values = []
unsubscribe = count.subscribe(lambda x: values.append(x))
count.set(1)
count.update(lambda x: x+1)
unsubscribe()
count.set(3)
count.update(lambda x: x+1)
test_eq(values, [0,1,2])
store = writable()
values = []
unsubscribe = store.subscribe(lambda x: values.append(x))
unsubscribe()
test_eq(values, [None])
unsubscribe()
test_eq(unsubscribe(), None)
obj = Bunch()
called = 0
store = writable(obj)
def callback(x):
    global called
    called += 1
store.subscribe(callback)
obj.a = 1
store.set(obj)
test_eq(called, 2)

#### Readable Store

In [None]:
#|export
class ReadableStore(Store[T]):
    ''' A Readable Store.'''
    def __init__(self,
                 initial_value: T, # initial value of the store
                 start: Notifier # function called when the first subscriber is added
                ) -> None:
        super().__init__(initial_value, start)
    def set(self, *args, **kwargs): raise Exception("Cannot set a Readable Store.")
    def update(self, *args, **kwargs): raise Exception("Cannot update a Readable Store.")
    def __repr__(self) -> str: return "r" + super().__repr__()[1:]

#### Readable Factory

In [None]:
#|export
def readable(value: T, # initial value of the store
             start: Notifier  # function called when the first subscriber is added
             ) -> Readable[T]:  # Readable Store
    ''' Creates a new Readable Store (A Readable factory).'''
    return ReadableStore(value, start)

In [None]:
#|hide
try:
    a:Readable = readable(0) # shoud fail
except Exception as error:
    print(error)

readable() missing 1 required positional argument: 'start'


In [None]:
#|hide
test_fail(lambda: Readable(0))
test_fail(lambda: readable(0))
class Publisher:
    def __init__(self): self.set = lambda x: None
    def set_set(self, set): 
        self.set = set
        return lambda: None
    def use_set(self, value): self.set(value)
p = Publisher()
r:Readable = readable(0, p.set_set)
test_eq(r.get(), 0)
p.use_set(1) # lost forever
test_eq(r.get(), 0) # a Readable Store only updates when it has subscribers
stop = r.subscribe(utils.noop)
test_eq(r.get(), 0)
p.set(1)
test_eq(r.get(), 1)
stop()

#### Derived Store


In [None]:
#|export
from sveltish.utils import compose
from fastcore.foundation import L

In [None]:
#|export
class DerivedStore(Store[T]):
    ''' A Derived Store.'''
    def __init__(self,
                 s: Union[Store, list[Store]], # source store(s)
                 *functions: Callable, # a callback that takes the source store(s) values and returns the derived value
             ) -> None:
        self.sources = L(s)
        if not all(isinstance(x, Store) for x in self.sources):
            raise Exception("s must be a Store or a list of Stores")
        self.fn = compose(*functions)

        def start(set_fn: Subscriber):
            def sync(x=None): # x is ignored
                values = self.sources.map(lambda x: x.get())
                set_fn(self.fn(*values))
            sync() # sync target with source values, they can have changed since Derived creation
            unsubscribers = self.sources.map(lambda s: s.subscribe(sync))
            def stop():
                for unsubscribe in unsubscribers: unsubscribe()
            return stop
        values = self.sources.map(lambda x: x.get())
        self.target = readable(self.fn(*values), start)

    def get(self): return self.target.get()
    def set(self, *args, **kwargs): raise Exception("Cannot set a Derived Store.")
    def update(self, *args, **kwargs): raise Exception("Cannot update a Derived Store.")
    def subscribe(self,
                  callback: Subscriber # callback to be called when any of the source stores change
                  ) -> Unsubscriber:
        ''' Adds callback to the list of subscribers.'''
        return self.target.subscribe(callback)

#### Derived Factory

In [None]:
#| export
def derived(s: Union[Store, list[Store]], # source store(s)
            *functions: list(Callable[...,T]) # a callback that takes the source store(s) values and returns the derived value
            ) -> Readable: # Derived Store
    ''' Creates a new Derived Store (A Derived factory).'''
    return DerivedStore(s, *functions).target

In [None]:
#| hide
a:Writable = writable('foo')
b = writable('bar')
d = derived([a,b], lambda a,b: f"{a}_{b}")
test_eq(d.get(), "foo_bar")
a.set('fonzie')
test_eq(d.get(), "foo_bar") #won't change if derived has no subscribers
u = d.subscribe(lambda x: None)
b.set('bach')
test_eq(d.get(), "fonzie_bach")
test_fail(lambda: d.set('baz'))
test_fail(lambda: d.update(lambda x: x))
u()

#### Pipe Operator

In [None]:
#|export
import fastcore.all as fc

In [None]:
#|export
@fc.patch
def pipe(self:Store, # source store
         *functions: list(Callable[...,T]) # functions that transform the source store
         )->Readable[T]: # returned store
     ''' Unix-like Pipe operator.'''
     return derived(self, *functions)

In [None]:
#|hide
test_eq(writable(1).pipe(lambda x: x+1).pipe(lambda x: x*2).get(), 4)
test_eq(writable(1).pipe(lambda x: x+1, lambda x: x*2).get(), 4)

In [None]:
#|export
@fc.patch
def __or__(self:Store, # source store
           other: Callable[...,T] # function that transforms the source store
           ) -> Readable[T]: # returned store
    ''' self | other  works like Unix pipes. It returns a Derived Store that is the result of applying other to self.'''
    return self.pipe(other)

In [None]:
#|hide
test_eq((writable(1) | (lambda x: x+1) | (lambda x: x*2)).get(), 4)

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()