diff --git a/examples/meeting_room/tests/test_dav.py b/examples/meeting_room/tests/test_dav.py index 9d03c74..f131783 100644 --- a/examples/meeting_room/tests/test_dav.py +++ b/examples/meeting_room/tests/test_dav.py @@ -245,7 +245,7 @@ def get_from_user_id(cls, user_id, once=False, **kwargs): **kwargs) def propagate_change(self): - self.user.save() + self.store.user.save() class UserViewMediatorDAV(ViewMediatorDAV): @@ -328,7 +328,7 @@ def test_propagate_change(users, tickets, location, meeting): # time.sleep(3) - user_view.user.last_name = "M." + user_view.store.user.last_name = "M." user_view.propagate_change() user_ref = Context.db.collection("users").document(user_id) diff --git a/examples/meeting_room/view_models/meeting_session.py b/examples/meeting_room/view_models/meeting_session.py index b571316..79ec145 100644 --- a/examples/meeting_room/view_models/meeting_session.py +++ b/examples/meeting_room/view_models/meeting_session.py @@ -4,7 +4,9 @@ from examples.meeting_room.domain_models.location import Location from examples.meeting_room.domain_models.meeting import Meeting from flask_boiler import fields, schema, view_model, view +from flask_boiler.business_property_store import BPSchema from flask_boiler.mutation import Mutation, PatchMutation +from flask_boiler.struct import Struct from flask_boiler.view import DocumentAsView from flask_boiler.view_model import ViewModelMixin @@ -18,12 +20,11 @@ class MeetingSessionSchema(schema.Schema): num_hearing_aid_requested = fields.Raw() -class MeetingSessionBpStoreSchema: - - _users = fields.BusinessPropertyFieldMany(referenced_cls=User) - _tickets = fields.BusinessPropertyFieldMany(referenced_cls=Ticket) - _meeting = fields.BusinessPropertyFieldOne(referenced_cls=Meeting) - _location = fields.BusinessPropertyFieldOne(referenced_cls=Location) +class MeetingSessionBpss(BPSchema): + tickets = fields.StructuralRef(dm_cls=Ticket, many=True) + users = fields.StructuralRef(dm_cls=User, many=True) + meeting = fields.StructuralRef(dm_cls=Meeting) + location = fields.StructuralRef(dm_cls=Location) class MeetingSessionMixin: @@ -35,75 +36,54 @@ def __init__(self, *args, meeting_id=None, **kwargs): super().__init__(*args, **kwargs) self._meeting_id = meeting_id - @property - def _users(self): - user_ids = [user_ref.id for user_ref in self._meeting.users] - return { - user_id: self.business_properties[user_id] for user_id in user_ids - } - @property def _tickets(self): return { - self.business_properties[ticket_ref.id].user.id: - self.business_properties[ticket_ref.id] - for ticket_ref in self._meeting.tickets + ticket.user.id: ticket + for _, ticket in self.store.tickets.items() } - @property - def _meeting(self): - """ - TODO: fix evaluation order in source code (add priority flag to some - TODO: view models to be instantiated first) - :return: - """ - return self.business_properties[self._meeting_id] - @property def meeting_id(self): - return self._meeting.doc_id - - @property - def _location(self): - return self.business_properties[self._meeting.location.id] + return self.store.meeting.doc_id @property def in_session(self): - return self._meeting.status == "in-session" + return self.store.meeting.status == "in-session" @in_session.setter def in_session(self, in_session): - cur_status = self._meeting.status + cur_status = self.store.meeting.status if cur_status == "in-session" and not in_session: - self._meeting.status = "closed" + self.store.meeting.status = "closed" elif cur_status == "closed" and in_session: - self._meeting.status = "in-session" + self.store.meeting.status = "in-session" else: raise ValueError @property def latitude(self): - return self._location.latitude + return self.store.location.latitude @property def longitude(self): - return self._location.longitude + return self.store.location.longitude @property def address(self): - return self._location.address + return self.store.location.address @property def attending(self): - user_ids = [uid for uid in self._users.keys()] + user_ids = [uid for uid in self.store.users.keys()] - if self._meeting.status == "not-started": + if self.store.meeting.status == "not-started": return list() res = list() for user_id in sorted(user_ids): ticket = self._tickets[user_id] - user = self._users[user_id] + user = self.store.users[user_id] if ticket.attendance: d = { "name": user.display_name, @@ -140,33 +120,33 @@ def new(cls, doc_id=None): @classmethod def get_from_meeting_id(cls, meeting_id, once=False, **kwargs): - struct = dict() + struct = Struct(schema_obj=MeetingSessionBpss()) m: Meeting = Meeting.get(doc_id=meeting_id) - struct[m.doc_id] = (Meeting, m.doc_ref.id) + struct["meeting"] = (Meeting, m.doc_ref.id) for user_ref in m.users: obj_type = User doc_id = user_ref.id - struct[doc_id] = (obj_type, user_ref.id) + struct["users"][doc_id] = (obj_type, user_ref.id) for ticket_ref in m.tickets: obj_type = Ticket doc_id = ticket_ref.id + struct["tickets"][doc_id] = (obj_type, ticket_ref.id) - struct[doc_id] = (obj_type, ticket_ref.id) - - struct[m.location.id] = (Location, m.location.id) + struct["location"] = (Location, m.location.id) obj = cls.get(struct_d=struct, once=once, meeting_id=m.doc_ref.id, **kwargs) # time.sleep(2) # TODO: delete after implementing sync + return obj def propagate_change(self): - self._meeting.save() + self.store.propagate_back() class MeetingSession(MeetingSessionMixin, view.FlaskAsView): diff --git a/examples/meeting_room/view_models/user_view.py b/examples/meeting_room/view_models/user_view.py index 0c7c2c6..b475e72 100644 --- a/examples/meeting_room/view_models/user_view.py +++ b/examples/meeting_room/view_models/user_view.py @@ -1,6 +1,8 @@ from examples.meeting_room.domain_models.user import User from examples.meeting_room.domain_models.meeting import Meeting from flask_boiler import fields, schema, view_model, view +from flask_boiler.business_property_store import BPSchema +from flask_boiler.struct import Struct class UserViewSchema(schema.Schema): @@ -13,6 +15,10 @@ class UserViewSchema(schema.Schema): meetings = fields.Relationship(many=True, dump_only=True) +class UserBpss(BPSchema): + user = fields.StructuralRef(dm_cls=User) + + class UserViewMixin: class Meta: @@ -20,51 +26,42 @@ class Meta: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.user = None @classmethod def new(cls, doc_id=None): return cls.get_from_user_id(user_id=doc_id) - def set_user(self, user): - self.user = user - @property def first_name(self): - return self.user.first_name + return self.store.user.first_name @property def last_name(self): - return self.user.last_name + return self.store.user.last_name @last_name.setter def last_name(self, new_last_name): - self.user.last_name = new_last_name + self.store.user.last_name = new_last_name @property def organization(self): - return self.user.organization + return self.store.user.organization @property def hearing_aid_requested(self): - return self.user.hearing_aid_requested + return self.store.user.hearing_aid_requested @property def meetings(self): - return list(Meeting.where(users=("array_contains", self.user.doc_ref))) - - def get_vm_update_callback(self, dm_cls): - def user_update_func(vm: UserView, dm): - vm.set_user(dm) - return user_update_func + return list(Meeting.where(users=("array_contains", self.store.user.doc_ref))) @classmethod def get_from_user_id(cls, user_id, once=False, **kwargs): - struct = dict() + struct = Struct(schema_obj=UserBpss()) u: User = User.get(doc_id=user_id) - struct[u.doc_id] = (User, u.doc_ref.id) + struct["user"] = (User, u.doc_ref.id) return super().get(struct_d=struct, once=once, **kwargs) diff --git a/flask_boiler/business_property_store.py b/flask_boiler/business_property_store.py index 05cc853..299408a 100644 --- a/flask_boiler/business_property_store.py +++ b/flask_boiler/business_property_store.py @@ -29,29 +29,41 @@ def structural_ref_fields(self): return [fd for _, fd in self.fields.items() if isinstance(fd, StructuralRef)] -class BusinessPropertyStore(Schemed): +class BusinessPropertyStore: - def __init__(self, struct): + def __init__(self, struct, snapshot_container, schema_obj): super().__init__() - - self._container = SnapshotContainer() + self._container = snapshot_container self.struct = struct + self.schema_obj = struct.schema_obj self._g, self._gr, self._manifest = \ self._get_manifests(self.struct, self.schema_obj) + self.objs = dict() @property def bprefs(self): return self._manifest.copy() + def refresh(self): + for doc_ref in self._manifest: + self.objs[doc_ref] = snapshot_to_obj(self._container.get(doc_ref)) + def __getattr__(self, item): + if item not in self._g: + raise AttributeError + if isinstance(self._g[item], dict): return { - k: snapshot_to_obj(self._container.get(v)) + k: self.objs[v] for k, v in self._g[item].items() } else: - return snapshot_to_obj(self._container.get(self._g[item])) + return self.objs[self._g[item]] + + def propagate_back(self): + for _, obj in self.objs.items(): + obj.save() @staticmethod def _get_manifests(struct, schema_obj) -> Tuple: @@ -62,18 +74,17 @@ def _get_manifests(struct, schema_obj) -> Tuple: key = fd.attribute val = struct[key] - dm_cls = fd.dm_cls if fd.many: g[key] = dict() for k, v in val.items(): if "." in k: raise ValueError - doc_ref = to_ref(dm_cls, v) + doc_ref = to_ref(*v) g[key][k] = doc_ref gr[doc_ref].append("{}.{}".format(key, k)) manifest.add(doc_ref) else: - doc_ref = to_ref(dm_cls, val) + doc_ref = to_ref(*val) g[key] = doc_ref gr[doc_ref].append(key) manifest.add(doc_ref) diff --git a/flask_boiler/fields.py b/flask_boiler/fields.py index 7667e10..eba55c8 100644 --- a/flask_boiler/fields.py +++ b/flask_boiler/fields.py @@ -250,24 +250,6 @@ class Remainder(fields.Dict, Field): pass -class BusinessPropertyFieldBase(fields.Raw, Field): - pass - - -class BusinessPropertyFieldMany(BusinessPropertyFieldBase): - - @property - def default_value(self): - return set() - - -class BusinessPropertyFieldOne(BusinessPropertyFieldBase): - - @property - def default_value(self): - return None - - # class BpStoreField(fields.Raw, Field): # # def __init__(self, *args, **kwargs): diff --git a/flask_boiler/struct.py b/flask_boiler/struct.py new file mode 100644 index 0000000..6856e4c --- /dev/null +++ b/flask_boiler/struct.py @@ -0,0 +1,30 @@ +from collections import UserDict + + +class Struct(UserDict): + + def __init__(self, schema_obj): + super().__init__() + self.schema_obj = schema_obj + + @property + def vals(self): + for _, val in self.data.items(): + if isinstance(val, dict): + for _, v in val.items(): + yield v + else: + yield val + + def __getitem__(self, key): + """ + Initializes a field to dict if it was not declared before + + :param item: + :return: + """ + + if key not in self.data.keys(): + self.data[key] = dict() + return super().__getitem__(key) + diff --git a/flask_boiler/view_model.py b/flask_boiler/view_model.py index c1e52da..ffa3efd 100644 --- a/flask_boiler/view_model.py +++ b/flask_boiler/view_model.py @@ -3,13 +3,14 @@ from dictdiffer import diff, patch from google.cloud.firestore import DocumentReference -from flask_boiler.business_property_store import BusinessPropertyStore +from flask_boiler.business_property_store import BusinessPropertyStore, to_ref from flask_boiler.snapshot_container import SnapshotContainer from flask_boiler.watch import DataListener from .context import Context as CTX from .domain_model import DomainModel from flask_boiler.referenced_object import ReferencedObject from .serializable import Serializable +from .struct import Struct from .utils import random_id, snapshot_to_obj @@ -67,15 +68,22 @@ def get(cls, struct_d=None, once=False, **kwargs): :return: """ obj = cls(struct_d=struct_d, **kwargs) - for key, val in obj._structure.items(): - obj_type, doc_id = val - obj.bind_to(key=key, obj_type=obj_type, doc_id=doc_id) + obj.bind_all() + if once: obj.listen_once() else: obj.register_listener() return obj + def bind_all(self): + + for obj_type, doc_id in self._struct_d.vals: + self.__subscribe_to( + dm_cls=obj_type, + dm_doc_id=doc_id, + ) + @classmethod def get_many(cls, struct_d_iterable=None, once=False): """ Gets a list of view models from a list of @@ -88,7 +96,7 @@ def get_many(cls, struct_d_iterable=None, once=False): return [cls.get(struct_d=struct_d, once=once) for struct_d in struct_d_iterable] - def __init__(self, f_notify=None, *args, **kwargs): + def __init__(self, struct_d=None, f_notify=None, *args, **kwargs): """ :param f_notify: callback to notify that view model's @@ -99,6 +107,12 @@ def __init__(self, f_notify=None, *args, **kwargs): super().__init__(*args, **kwargs) self.business_properties: Dict[str, DomainModel] = dict() self.snapshot_container = SnapshotContainer() + if not isinstance(struct_d, Struct): + raise ValueError + self._struct_d = struct_d + self.store = BusinessPropertyStore( + struct=self._struct_d, schema_obj=struct_d.schema_obj, + snapshot_container=self.snapshot_container) self._on_update_funcs: Dict[str, Tuple] = dict() self.listener = None self.f_notify = f_notify @@ -120,27 +134,16 @@ def _bind_to_domain_model(self, *, key, obj_type, doc_id): :param doc_id: :return: """ - # obj_cls: DomainModel = Serializable.get_cls_from_name(obj_type) - obj_cls: DomainModel = obj_type - if key in self._structure: - a, b = self._structure[key] - if a != obj_type or b != doc_id: - raise ValueError("Values disagree. ") - else: - # update_func = self.get_update_func(dm_cls=obj_cls) - self._structure[key] = (obj_type, doc_id) + self._struct_d[key] = (obj_type, doc_id) self.__subscribe_to( - key=key, - dm_cls=obj_cls, + dm_cls=obj_type, dm_doc_id=doc_id, ) - # _, doc_watch = self._on_update_funcs[key] - # assert isinstance(doc_watch, Watch) def get_on_update(self, - dm_cls=None, dm_doc_id=None, dm_doc_ref_str=None, key=None): + dm_cls=None, dm_doc_id=None, dm_doc_ref_str=None): # do something with this ViewModel def _on_update(docs, changes, readtime): @@ -151,7 +154,7 @@ def _on_update(docs, changes, readtime): raise NotImplementedError doc = docs[0] - self.snapshot_container.set( (dm_cls, dm_doc_id), doc ) + self.snapshot_container.set(to_ref(dm_cls, dm_doc_id), doc ) return _on_update @@ -162,23 +165,30 @@ def propagate_change(self): """ raise NotImplementedError - def __subscribe_to(self, *, key, dm_cls, - dm_doc_id): + def __subscribe_to(self, *, dm_cls,dm_doc_id): + """ + + :param dm_cls: + :param dm_doc_id: + :return: + """ - # if key in self._on_update_funcs: - # # Release the previous on_snapshot functions - # # https://firebase.google.com/docs/firestore/query-data/listen - # f, doc_watch = self._on_update_funcs[key] - # # TODO: add back, see: - # # https://github.com/googleapis/google-cloud-python/issues/9008 - # # https://github.com/googleapis/google-cloud-python/issues/7826 - # # doc_watch.unsubscribe() + """ + if key in self._on_update_funcs: + # Release the previous on_snapshot functions + # https://firebase.google.com/docs/firestore/query-data/listen + f, doc_watch = self._on_update_funcs[key] + # TODO: add back, see: + # https://github.com/googleapis/google-cloud-python/issues/9008 + # https://github.com/googleapis/google-cloud-python/issues/7826 + # doc_watch.unsubscribe() + """ dm_ref: DocumentReference = dm_cls._get_collection().document(dm_doc_id) on_update = self.get_on_update( dm_cls=dm_cls, dm_doc_id=dm_doc_id, dm_doc_ref_str=dm_ref._document_path, - key=key) + ) # doc_watch = dm_ref.on_snapshot(on_update) self._on_update_funcs[dm_ref._document_path] = on_update @@ -236,7 +246,7 @@ def snapshot_callback(docs, changes, read_time): # TODO: restore parameter "changes" on_update([doc], None, read_time) - self._refresh_business_property() + self.store.refresh() self._invoke_vm_callbacks() self.listener = DataListener( @@ -278,7 +288,7 @@ def snapshot_callback(docs, changes, read_time): # TODO: restore parameter "changes" on_update([doc], None, read_time) - self._refresh_business_property() + self.store.refresh() self._invoke_vm_callbacks() with self.snapshot_container.lock: self._notify() @@ -292,19 +302,27 @@ def snapshot_callback(docs, changes, read_time): self.listener.wait_for_once_done() - def _refresh_business_property(self, ): - with self.snapshot_container.lock: - for key, val in self._structure.items(): - obj_type, doc_id = val - snapshot = self.snapshot_container.get((obj_type, doc_id), ) - self.business_properties[key] = snapshot_to_obj(snapshot=snapshot) + # def _refresh_business_property(self, ): + # with self.snapshot_container.lock: + # for key, val in self._structure.items(): + # obj_type, doc_id = val + # snapshot = self.snapshot_container.get((obj_type, doc_id), ) + # self.business_properties[key] = snapshot_to_obj(snapshot=snapshot) def _invoke_vm_callbacks(self): - for key, val in self._structure.items(): - obj_type, doc_id = val - vm_update_callback = self.get_vm_update_callback(dm_cls=obj_type) - dm = self.business_properties[key] - vm_update_callback(vm=self, dm=dm) + for key, val in self._struct_d.items(): + b = getattr(self.store, key) + if isinstance(val, dict): + for k, v in val.items(): + obj_type, doc_id = v + vm_update_callback = self.get_vm_update_callback( + dm_cls=obj_type) + vm_update_callback(vm=self, dm=b[k]) + else: + obj_type, doc_id = val + vm_update_callback = self.get_vm_update_callback(dm_cls=obj_type) + vm_update_callback(vm=self, dm=b) + def _notify(self): """ Notify that this object has been changed by underlying view models