Skip to content

Commit

Permalink
Setup strict type for state attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
denisenkom committed Nov 23, 2023
1 parent 218a7c6 commit 3d151f0
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions src/twain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@
}


# Corresponding states are defined in TWAIN spec 2.11 paragraph
# Data Source Manager states
_DsmStates = typing.Literal[
"closed", # TWAIN state 2
"open", # TWAIN state 3
]

# Data Source states
_SourceStates = typing.Literal[
"closed", # TWAIN state 2
"open", # TWAIN state 4
"enabled", # TWAIN state 5
"ready", # TWAIN state 6
]


def _is_good_type(type_id: int) -> bool:
return type_id in list(_mapping.keys())

Expand Down Expand Up @@ -149,7 +165,7 @@ class Source:
def __init__(self, sm: SourceManager, ds_id: structs.TW_IDENTITY):
self._sm = sm
self._id = ds_id
self._state = "open"
self._state: _SourceStates = "open"
self._version2 = bool(ds_id.SupportedGroups & constants.DF_DS2)
if self._version2:
self._alloc = sm._alloc
Expand Down Expand Up @@ -253,8 +269,8 @@ def _get_capability(self, cap: int, current: int):
raise exceptions.CapabilityFormatNotSupported(msg)
ctype = _mapping[enum.ItemType]
item_p = ct.cast(
ptr + ct.sizeof(structs.TW_ENUMERATION), ct.POINTER(ctype)
) # type: ignore # needs fixing
ptr + ct.sizeof(structs.TW_ENUMERATION), ct.POINTER(ctype) # type: ignore # needs fixing
)
values = [el for el in item_p[0 : enum.NumItems]]
return enum.ItemType, (enum.CurrentIndex, enum.DefaultIndex, values)
elif twCapability.ConType == constants.TWON_ARRAY:
Expand All @@ -264,8 +280,8 @@ def _get_capability(self, cap: int, current: int):
raise exceptions.CapabilityFormatNotSupported(msg)
ctype = _mapping[arr.ItemType]
item_p = ct.cast(
ptr + ct.sizeof(structs.TW_ARRAY), ct.POINTER(ctype)
) # type: ignore # needs fixing
ptr + ct.sizeof(structs.TW_ARRAY), ct.POINTER(ctype) # type: ignore # needs fixing
)
return arr.ItemType, [el for el in item_p[0 : arr.NumItems]]
else:
msg = (
Expand Down Expand Up @@ -660,7 +676,11 @@ def image_info(self) -> dict:
"Compression": ii.Compression,
}

def _get_native_image(self) -> tuple[int, ct.c_void_p]:
def _get_native_image(self) -> ct.c_void_p | None:
"""
Transfer image via memory. Should only be called when image is ready for transfer.
Returns handle to image or None if transfer was cancelled.
"""
hbitmap = ct.c_void_p()
logger.info("Calling DAT_IMAGENATIVEXFER")
rv = self._call(
Expand All @@ -670,7 +690,11 @@ def _get_native_image(self) -> tuple[int, ct.c_void_p]:
ct.byref(hbitmap),
(constants.TWRC_XFERDONE, constants.TWRC_CANCEL),
)
return rv, hbitmap
if rv == constants.TWRC_XFERDONE:
return hbitmap
if rv == constants.TWRC_CANCEL:
return None
raise RuntimeError(f"Unexpected result returned from DAT_IMAGENATIVEXFER: {rv}")

def _get_file_image(self) -> int:
logger.info("Calling DAT_IMAGEFILEXFER")
Expand All @@ -693,6 +717,13 @@ def _get_file_audio(self) -> int:
)

def _end_xfer(self) -> int:
"""
This method should be called after every image transfer, successful or cancelled
Returns information about additional transfers:
* 0 - no more transfers available
* -1 - more images available but how many is unknown
* >0 - indicates how many more images are available
"""
px = structs.TW_PENDINGXFERS()
logger.info("Calling DAT_PENDINGXFERS/MSG_ENDXFER")
self._call(
Expand Down Expand Up @@ -760,9 +791,11 @@ def xfer_image_natively(self) -> tuple[typing.Any, int]:
Valid states: 6
"""
rv, handle = self._get_native_image()
handle = self._get_native_image()
# _end_xfer should be called even if transfer was cancelled
more = self._end_xfer()
if rv == constants.TWRC_CANCEL:
# get_native_image returns None if transfer was cancelled
if not handle:
raise exceptions.DSTransferCancelled
return handle.value, more

Expand Down Expand Up @@ -860,10 +893,7 @@ def acquire_natively(

def callback() -> int:
before(self.image_info)
rv, handle = self._get_native_image()
more = self._end_xfer()
if rv == constants.TWRC_CANCEL:
raise exceptions.DSTransferCancelled
handle, more = self.xfer_image_natively()
after(_Image(handle), more) # type: ignore # needs fixing
return more

Expand Down Expand Up @@ -1099,7 +1129,7 @@ def __init__(
"""
self._sources: weakref.WeakSet[Source] = weakref.WeakSet()
self._cb: collections.abc.Callable[[int], None] | None = None
self._state = "closed"
self._state : _DsmStates = "closed"
self._parent_window = parent_window
self._hwnd = 0
if utils.is_windows():
Expand Down

0 comments on commit 3d151f0

Please sign in to comment.