-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept Meta class inside entity models to override datastore.Client instance params #4
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,66 @@ | ||
from typing import Callable, Dict, Optional, Union | ||
from functools import partial | ||
|
||
from google.cloud import datastore | ||
from google.cloud.datastore.client import _CLIENT_INFO | ||
from google.cloud.datastore.entity import Entity as GoogleEntity | ||
from google.auth.credentials import Credentials as GoogleCredentials | ||
from google.api_core.gapic_v1.client_info import ClientInfo as GoogleClientIngo | ||
from google.api_core.client_options import ClientOptions as GoogleClientOptions | ||
from google.api_core.retry import Retry as GoogleRetry | ||
from requests import Session | ||
|
||
|
||
class DataStoreClient: | ||
_instance = None | ||
|
||
def __init__(self) -> None: | ||
self._client = datastore.Client() | ||
|
||
def __new__(cls): | ||
if cls._instance is None: | ||
cls._instance = super().__new__(cls) | ||
return cls._instance | ||
def __init__(self, | ||
project:Optional[str]=None, | ||
namespace:Optional[str]=None, | ||
credentials:Optional[GoogleCredentials]=None, | ||
client_info:Optional[GoogleClientIngo]=None, | ||
client_options:Optional[GoogleClientOptions]=None, | ||
_http:Optional[Session]=None, | ||
_use_grpc:Optional[bool]=None | ||
) -> None: | ||
self._project = project | ||
self._namespace = namespace | ||
self._client = datastore.Client( | ||
project=project, | ||
namespace=namespace, | ||
credentials=credentials, | ||
client_info=(client_info or _CLIENT_INFO), | ||
client_options=client_options, | ||
_http=_http, | ||
_use_grpc=_use_grpc | ||
) | ||
|
||
def _get_partial_query(self, kind): | ||
def _get_partial_query(self, kind: Union[str, int]) -> Callable: | ||
return partial( | ||
self._client.query, | ||
kind=kind | ||
) | ||
|
||
def save(self, entity, retry=None, timeout=None): | ||
def save(self, entity:GoogleEntity, retry:Optional[GoogleRetry]=None, timeout:Optional[float]=None) -> Union[str, int]: | ||
self._client.put_multi( | ||
entities=[entity], | ||
retry=retry, | ||
timeout=timeout | ||
) | ||
return entity.key.id | ||
|
||
@staticmethod | ||
def _mount_google_entity(entity_dict: Dict, key_case: Callable) -> GoogleEntity: | ||
entity_key = entity_dict.pop("id") | ||
entity = GoogleEntity(entity_key) | ||
for key, value in entity_dict.items(): | ||
key_name = key_case(key) | ||
entity[key_name] = value | ||
|
||
return entity | ||
|
||
@property | ||
def project(self): | ||
return self._client.project | ||
def project(self) -> str: | ||
return self._project or self._client.project | ||
|
||
@property | ||
def namespace(self): | ||
return self._client.namespace | ||
def namespace(self) -> str: | ||
return self._namespace or self._client.namespace |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,29 +12,63 @@ def __init__(self, name, bases, attrs): | |
super().__init__(name, bases, attrs) | ||
|
||
self.kind = attrs.get("__kind__") or name | ||
self._parent_entity = attrs.get("__parent__") | ||
self._project = attrs.get("__project__") or self._client.project | ||
self._namespace = attrs.get("__namespace__") or self._client.namespace | ||
|
||
_case_style = attrs.get("__case_style__") or {} | ||
self._convert_property_name = CaseStyle( | ||
from_case=_case_style.get("from") or "snake_case", | ||
to_case=_case_style.get("to") or "camel_case", | ||
) | ||
|
||
self.query = Query( | ||
partial_query=self._client._get_partial_query( | ||
kind=self.kind | ||
), | ||
entity_instance=self | ||
) | ||
self._process_meta(attrs) | ||
self._define_datastore_client() | ||
self._define_case_style(attrs) | ||
self._mount_query() | ||
|
||
self._partial_key = KeyField( | ||
entity_kind=self.kind, | ||
project=self._project, | ||
namespace=self._namespace | ||
project=self.project, | ||
namespace=self.namespace | ||
) | ||
|
||
self._handle_properties_validation(attrs) | ||
self._handle_required_properties(attrs) | ||
self._handle_properties_default_value(attrs) | ||
self._handle_parent_key(attrs) | ||
Comment on lines
+16
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've broken the steps/processes into functions, to make it easier to understand |
||
|
||
self._entity_fields = list(self._entity_types.keys()) | ||
|
||
@property | ||
def project(self) -> str: | ||
return self._client.project | ||
|
||
@property | ||
def namespace(self) -> str: | ||
return self._client.namespace | ||
|
||
def _define_datastore_client(self) -> None: | ||
self._client = DataStoreClient( | ||
**self._client_args | ||
) | ||
|
||
def _handle_parent_key(self, attrs) -> None: | ||
self._parent_entity = attrs.get("__parent__") | ||
|
||
if self._parent_entity: | ||
self._required.append("parent_id") | ||
self._entity_types.update({ | ||
"parent_id": self._parent_entity._validate | ||
}) | ||
self._defaults.update({ | ||
"parent_id": None | ||
}) | ||
|
||
def _handle_properties_default_value(self, attrs) -> None: | ||
self._defaults = { | ||
key: value.default_value | ||
for key, value in attrs.items() | ||
if (isinstance(value, BaseField) and | ||
value.default_value is not None) | ||
} | ||
|
||
self._defaults.update({ | ||
"id": None | ||
}) | ||
|
||
def _handle_properties_validation(self, attrs) -> None: | ||
self._entity_types = { | ||
key: value._validate | ||
for key, value in attrs.items() | ||
|
@@ -45,38 +79,49 @@ def __init__(self, name, bases, attrs): | |
"id": self._partial_key._validate | ||
}) | ||
|
||
def _mount_query(self) -> None: | ||
self.query = Query( | ||
partial_query=self._client._get_partial_query( | ||
kind=self.kind | ||
), | ||
entity_instance=self | ||
) | ||
|
||
def _handle_required_properties(self, attrs) -> None: | ||
self._required = [ | ||
key for key, value in attrs.items() | ||
if (isinstance(value, BaseField) and | ||
value.required) | ||
] | ||
|
||
self._defaults = { | ||
key: value.default_value | ||
for key, value in attrs.items() | ||
if (isinstance(value, BaseField) and | ||
value.default_value is not None) | ||
} | ||
def _define_case_style(self, attrs) -> None: | ||
_case_style = attrs.get("__case_style__") or {} | ||
self._convert_property_name = CaseStyle( | ||
from_case=_case_style.get("from") or "snake_case", | ||
to_case=_case_style.get("to") or "camel_case", | ||
) | ||
|
||
self._defaults.update({ | ||
"id": None | ||
}) | ||
def _process_meta(self, attrs) -> None: | ||
_client_args = ( | ||
"project", | ||
"namespace", | ||
"credentials", | ||
"client_info", | ||
"client_options", | ||
"_http", | ||
"_use_grpc" | ||
) | ||
|
||
if self._parent_entity: | ||
self._required.append("parent_id") | ||
self._entity_types.update({ | ||
"parent_id": self._parent_entity._validate | ||
}) | ||
self._defaults.update({ | ||
"parent_id": None | ||
}) | ||
meta_class = attrs.get("Meta") | ||
|
||
self._entity_fields = list(self._entity_types.keys()) | ||
self._client_args = { | ||
arg: meta_attr | ||
for arg in _client_args | ||
if ((meta_attr := getattr(meta_class, arg, None)) is not None) | ||
} if isinstance(meta_class, type) else {} | ||
|
||
|
||
class Entity(metaclass=EntityMetaClass): | ||
_client = DataStoreClient() | ||
|
||
def __init__(self, **kwargs): | ||
self._data = { | ||
"id": None | ||
|
@@ -101,13 +146,6 @@ def __setattr__(self, key: str, value: Any) -> None: | |
if key in super().__getattribute__("_entity_types"): | ||
self._data[key] = self._entity_types[key](value) | ||
|
||
@staticmethod | ||
def _to_camel_case(key: str) -> str: | ||
splited_key = key.split('_') | ||
return splited_key[0] + ''.join( | ||
[x.title() for x in splited_key[1:]] | ||
) | ||
|
||
Comment on lines
-104
to
-110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need it anymore |
||
def save(self): | ||
for required_property in self._required: | ||
if self._data[required_property] is None: | ||
|
@@ -134,7 +172,10 @@ def save(self): | |
del entity["parent_id"] | ||
|
||
self.id = self._client.save( | ||
entity=self._mount_google_entity(entity) | ||
entity=self._client._mount_google_entity( | ||
entity, | ||
self._convert_property_name | ||
) | ||
) | ||
|
||
def to_dict(self): | ||
|
@@ -161,15 +202,6 @@ def _create_from_google_entity(cls, entity: GoogleEntity) -> "Entity": | |
)) | ||
return instance | ||
|
||
def _mount_google_entity(self, entity_dict: Dict) -> GoogleEntity: | ||
entity_key = entity_dict.pop("id") | ||
entity = GoogleEntity(entity_key) | ||
for key, value in entity_dict.items(): | ||
key_name = self._convert_property_name(key) | ||
entity[key_name] = value | ||
|
||
return entity | ||
|
||
Comment on lines
-164
to
-172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved |
||
def __repr__(self) -> str: | ||
return ( | ||
f"<{self.__class__.__name__} - id: {self.id}>" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are the default client settings (from google lib)