Skip to content

Commit

Permalink
fix various issues. some remain.
Browse files Browse the repository at this point in the history
  • Loading branch information
halcy authored and halcy committed Jun 20, 2023
1 parent 826a6f4 commit d891589
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 87 deletions.
2 changes: 1 addition & 1 deletion mastodon/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def admin_trending_links(self) -> NonPaginatableList[PreviewCard]:

@api_version("4.0.0", "4.0.0", _DICT_VERSION_ADMIN_DOMAIN_BLOCK)
def admin_domain_blocks(self, id: Optional[IdType] = None, max_id: Optional[IdType] = None, min_id: Optional[IdType] = None,
since_id: Optional[IdType] = None, limit: Optional[int] = None) -> PaginatableList[AdminDomainBlock]:
since_id: Optional[IdType] = None, limit: Optional[int] = None) -> Union[AdminDomainBlock, PaginatableList[AdminDomainBlock]]:
"""
Fetches a list of blocked domains. Requires scope `admin:read:domain_blocks`.
Expand Down
30 changes: 6 additions & 24 deletions mastodon/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MastodonGatewayTimeoutError, MastodonServerError, MastodonAPIError, MastodonMalformedEventError
from mastodon.compat import urlparse, magic, PurePath, Path
from mastodon.defaults import _DEFAULT_STREAM_TIMEOUT, _DEFAULT_STREAM_RECONNECT_WAIT_SEC
from mastodon.types import AttribAccessDict, try_cast_recurse
from mastodon.types import AttribAccessDict, PaginatableList, try_cast_recurse
from mastodon.types import *

###
Expand Down Expand Up @@ -271,10 +271,9 @@ def __api_request(self, method, endpoint, params={}, files={}, headers={}, acces
response = response_object.content

# Parse link headers
if isinstance(response, list) and \
'Link' in response_object.headers and \
response_object.headers['Link'] != "":
response = AttribAccessList(response)
if isinstance(response, list) and 'Link' in response_object.headers and response_object.headers['Link'] != "":
if not isinstance(response, PaginatableList):
response = PaginatableList(response)
tmp_urls = requests.utils.parse_header_links(
response_object.headers['Link'].rstrip('>').replace('>,<', ',<'))
for url in tmp_urls:
Expand All @@ -301,18 +300,12 @@ def __api_request(self, method, endpoint, params={}, files={}, headers={}, acces
del next_params['min_id']
response._pagination_next = next_params

# Maybe other API users rely on the pagination info in the last item
# Will be removed in future
if isinstance(response[-1], AttribAccessDict):
response[-1]._pagination_next = next_params

if url['rel'] == 'prev':
# Be paranoid and extract since_id or min_id specifically
prev_url = url['url']

# Old and busted (pre-2.6.0): since_id pagination
matchgroups = re.search(
r"[?&]since_id=([^&]+)", prev_url)
matchgroups = re.search(r"[?&]since_id=([^&]+)", prev_url)
if matchgroups:
prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method
Expand All @@ -326,14 +319,8 @@ def __api_request(self, method, endpoint, params={}, files={}, headers={}, acces
del prev_params['max_id']
response._pagination_prev = prev_params

# Maybe other API users rely on the pagination info in the first item
# Will be removed in future
if isinstance(response[0], AttribAccessDict):
response[0]._pagination_prev = prev_params

# New and fantastico (post-2.6.0): min_id pagination
matchgroups = re.search(
r"[?&]min_id=([^&]+)", prev_url)
matchgroups = re.search(r"[?&]min_id=([^&]+)", prev_url)
if matchgroups:
prev_params = copy.deepcopy(params)
prev_params['_pagination_method'] = method
Expand All @@ -346,11 +333,6 @@ def __api_request(self, method, endpoint, params={}, files={}, headers={}, acces
if "max_id" in prev_params:
del prev_params['max_id']
response._pagination_prev = prev_params

# Maybe other API users rely on the pagination info in the first item
# Will be removed in future
if isinstance(response[0], AttribAccessDict):
response[0]._pagination_prev = prev_params

return response

Expand Down
121 changes: 66 additions & 55 deletions mastodon/types_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,72 +137,28 @@ def __repr__(self) -> str:
"""
return str(self.__val)

"""
IDs returned from Mastodon.py ar either primitive (int or str) or snowflake
(still int or str, but potentially convertible to datetime).
"""
IdType = Union[PrimitiveIdType, MaybeSnowflakeIdType]

T = TypeVar('T')
class PaginatableList(List[T]):
"""
This is a list with pagination information attached.
It is returned by the API when a list of items is requested, and the response contains
a Link header with pagination information.
"""
def __getattr__(self, attr):
if attr in self:
return self[attr]
else:
raise AttributeError(f"Attribute not found: {attr}")

def __setattr__(self, attr, val):
if attr in self:
raise AttributeError("Attribute-style access is read only")
super(NonPaginatableList, self).__setattr__(attr, val)
# TODO add the pagination housekeeping stuff

class NonPaginatableList(List[T]):
"""
This is just a list. I am subclassing the regular list out of pure paranoia about
potential oversights that might require me to add things to it later.
"""
pass

# Lists in Mastodon.py are either regular or paginatable
EntityList = Union[NonPaginatableList[T], PaginatableList[T]]

# Backwards compat alias
AttribAccessList = EntityList

# Helper functions for typecasting attempts
def try_cast(t, value, retry = True):
"""
Base case casting function. Handles:
* Casting to any AttribAccessDict subclass (directly, no special handling)
* Casting to MaybeSnowflakeIdType (directly, no special handling)
* Casting to bool (with possible conversion from json bool strings)
* Casting to datetime (with possible conversion from all kinds of funny date formats because unfortunately this is the world we live in)
* Casting to whatever t is
* Trying once again to AttribAccessDict as a fallback
Gives up and returns as-is if none of the above work.
"""
try:
if issubclass(t, AttribAccessDict) or t is MaybeSnowflakeIdType:
try:
value = t(**value)
except:
try:
value = AttribAccessDict(**value)
except:
pass
elif isinstance(t, bool):
if issubclass(t, AttribAccessDict):
value = t(**value)
elif issubclass(t, bool):
if isinstance(value, str):
if value.lower() == 'true':
value = True
elif value.lower() == 'false':
value = False
value = bool(value)
elif isinstance(t, datetime):
elif issubclass(t, datetime):
if isinstance(value, int):
value = datetime.fromtimestamp(value, timezone.utc)
elif isinstance(value, str):
Expand All @@ -211,8 +167,11 @@ def try_cast(t, value, retry = True):
value = datetime.fromtimestamp(value_int, timezone.utc)
except:
value = dateutil.parser.parse(value)
except:
value = try_cast(AttribAccessDict, value, False)
else:
value = t(**value)
except Exception as e:
if retry:
value = try_cast(AttribAccessDict, value, False)
return value

def try_cast_recurse(t, value):
Expand Down Expand Up @@ -241,6 +200,38 @@ def try_cast_recurse(t, value):
pass
return try_cast(t, value)

"""
IDs returned from Mastodon.py ar either primitive (int or str) or snowflake
(still int or str, but potentially convertible to datetime).
"""
IdType = Union[PrimitiveIdType, MaybeSnowflakeIdType]

T = TypeVar('T')
class PaginatableList(List[T]):
"""
This is a list with pagination information attached.
It is returned by the API when a list of items is requested, and the response contains
a Link header with pagination information.
"""
def __init__(self, *args, **kwargs):
"""
Initializes basic list and adds empty pagination information.
"""
super(PaginatableList, self).__init__(*args, **kwargs)
self._pagination_next = None
self._pagination_prev = None

class NonPaginatableList(List[T]):
"""
This is just a list. I am subclassing the regular list out of pure paranoia about
potential oversights that might require me to add things to it later.
"""
pass

"""Lists in Mastodon.py are either regular or paginatable"""
EntityList = Union[NonPaginatableList[T], PaginatableList[T]]

class AttribAccessDict(OrderedDict[str, Any]):
"""
Base return object class for Mastodon.py.
Expand All @@ -256,12 +247,12 @@ def __init__(self, **kwargs):
Constructor that calls through to dict constructor and then sets attributes for all keys.
"""
super(AttribAccessDict, self).__init__()
if __annotations__ in self.__class__.__dict__:
if "__annotations__" in self.__class__.__dict__:
for attr, _ in self.__class__.__annotations__.items():
attr_name = attr
if hasattr(self.__class__, "_rename_map"):
attr_name = getattr(self.__class__, "_rename_map").get(attr, attr)
if attr_name in kwargs:
if attr_name in kwargs:
self[attr] = kwargs[attr_name]
assert not attr in kwargs, f"Duplicate attribute {attr}"
elif attr in kwargs:
Expand Down Expand Up @@ -337,11 +328,31 @@ def __setitem__(self, key, val):
super(AttribAccessDict, self).__setattr__(key, val)
super(AttribAccessDict, self).__setitem__(key, val)

def __eq__(self, other):
"""
Equality checker with casting
"""
if isinstance(other, self.__class__):
return super(AttribAccessDict, self).__eq__(other)
else:
try:
casted = try_cast_recurse(self.__class__, other)
if isinstance(casted, self.__class__):
return super(AttribAccessDict, self).__eq__(casted)
else:
return False
except Exception as e:
pass
return False

"""An entity returned by the Mastodon API is either a dict or a list"""
Entity = Union[AttribAccessDict, EntityList]

"""A type containing the parameters for a encrypting webpush data. Considered opaque / implementation detail."""
WebpushCryptoParamsPubkey = Dict[str, str]

"""A type containing the parameters for a derypting webpush data. Considered opaque / implementation detail."""
WebpushCryptoParamsPrivkey = Dict[str, str]
WebpushCryptoParamsPrivkey = Dict[str, str]

"""Backwards compatibility alias"""
AttribAccessList = PaginatableList
37 changes: 30 additions & 7 deletions tests/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
import pytest
from datetime import datetime
from mastodon.types import IdType
import typing

def get_type_class(typ):
try:
return typ.__extra__
except AttributeError:
try:
return typ.__origin__
except AttributeError:
pass
return typ


def real_issubclass(obj1, type2orig):
type1 = get_type_class(type(obj1))
type2 = get_type_class(type2orig)
valid_types = []
if type2 is typing.Union:
valid_types = type2orig.__args__
elif type2 is typing.Generic:
valid_types = [type2orig.__args__[0]]
else:
valid_types = [type2orig]
return issubclass(type1, tuple(valid_types))

@pytest.mark.vcr()
def test_id_hook(status):
assert isinstance(status['id'], int)
assert real_issubclass(status['id'], IdType)


@pytest.mark.vcr()
def test_id_hook_in_reply_to(api, status):
reply = api.status_post('Reply!', in_reply_to_id=status['id'])
try:
assert isinstance(reply['in_reply_to_id'], int)
assert isinstance(reply['in_reply_to_account_id'], int)
assert real_issubclass(reply['in_reply_to_id'], IdType)
assert real_issubclass(reply['in_reply_to_account_id'], IdType)
finally:
api.status_delete(reply['id'])

Expand All @@ -21,18 +45,17 @@ def test_id_hook_in_reply_to(api, status):
def test_id_hook_within_reblog(api, status):
reblog = api.status_reblog(status['id'])
try:
assert isinstance(reblog['reblog']['id'], int)
assert real_issubclass(reblog['reblog']['id'], IdType)
finally:
api.status_delete(reblog['id'])


@pytest.mark.vcr()
def test_date_hook(status):
assert isinstance(status['created_at'], datetime)
assert real_issubclass(status['created_at'], datetime)

@pytest.mark.vcr()
def test_attribute_access(status):
assert status.id is not None
with pytest.raises(AttributeError):
status.id = 420
status.id = 420

0 comments on commit d891589

Please sign in to comment.