Skip to content

Commit

Permalink
types: Add types to api/client.py
Browse files Browse the repository at this point in the history
Signed-off-by: Victorien Plot <65306057+Viicos@users.noreply.github.com>
  • Loading branch information
Viicos committed Jan 3, 2023
1 parent ff1e743 commit 3a11fb0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 32 deletions.
110 changes: 78 additions & 32 deletions docker/api/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import json
import struct
import urllib
import ssl
from functools import partial
from typing import Any, AnyStr, Optional, Union, Dict, overload, NoReturn, Iterator

from typing_extensions import Literal

import requests
import requests.exceptions
import requests.adapters
import websocket

from .. import auth
Expand All @@ -20,6 +25,7 @@
from ..utils.json_stream import json_stream
from ..utils.proxy import ProxyConfig
from ..utils.socket import consume_socket_output, demux_adaptor, frames_iter
from ..utils.typing import BytesOrDict
from .build import BuildApiMixin
from .config import ConfigApiMixin
from .container import ContainerApiMixin
Expand Down Expand Up @@ -102,11 +108,11 @@ class APIClient(
'base_url',
'timeout']

def __init__(self, base_url=None, version=None,
timeout=DEFAULT_TIMEOUT_SECONDS, tls=False,
user_agent=DEFAULT_USER_AGENT, num_pools=None,
credstore_env=None, use_ssh_client=False,
max_pool_size=DEFAULT_MAX_POOL_SIZE):
def __init__(self, base_url: Optional[str] = None, version: Optional[str] = None,
timeout: int = DEFAULT_TIMEOUT_SECONDS, tls: Optional[Union[bool, TLSConfig]] = False,
user_agent: str = DEFAULT_USER_AGENT, num_pools: Optional[int] = None,
credstore_env: Optional[Dict[str, Any]] = None, use_ssh_client: bool = False,
max_pool_size: int = DEFAULT_MAX_POOL_SIZE) -> None:
super().__init__()

if tls and not base_url:
Expand Down Expand Up @@ -209,7 +215,7 @@ def __init__(self, base_url=None, version=None,
'library.'.format(MINIMUM_DOCKER_API_VERSION)
)

def _retrieve_server_version(self):
def _retrieve_server_version(self) -> str:
try:
return self.version(api_version=False)["ApiVersion"]
except KeyError:
Expand All @@ -222,29 +228,29 @@ def _retrieve_server_version(self):
f'Error while fetching server API version: {e}'
)

def _set_request_timeout(self, kwargs):
def _set_request_timeout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare the kwargs for an HTTP request by inserting the timeout
parameter, if not already present."""
kwargs.setdefault('timeout', self.timeout)
return kwargs

@update_headers
def _post(self, url, **kwargs):
def _post(self, url: str, **kwargs: Any) -> requests.Response:
return self.post(url, **self._set_request_timeout(kwargs))

@update_headers
def _get(self, url, **kwargs):
def _get(self, url: str, **kwargs: Any) -> requests.Response:
return self.get(url, **self._set_request_timeout(kwargs))

@update_headers
def _put(self, url, **kwargs):
def _put(self, url: str, **kwargs: Any) -> requests.Response:
return self.put(url, **self._set_request_timeout(kwargs))

@update_headers
def _delete(self, url, **kwargs):
def _delete(self, url: str, **kwargs: Any) -> requests.Response:
return self.delete(url, **self._set_request_timeout(kwargs))

def _url(self, pathfmt, *args, **kwargs):
def _url(self, pathfmt: str, *args: str, **kwargs: Any) -> str:
for arg in args:
if not isinstance(arg, str):
raise ValueError(
Expand All @@ -262,14 +268,30 @@ def _url(self, pathfmt, *args, **kwargs):
else:
return f'{self.base_url}{pathfmt.format(*args)}'

def _raise_for_status(self, response):
def _raise_for_status(self, response: requests.Response) -> None:
"""Raises stored :class:`APIError`, if one occurred."""
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise create_api_error_from_http_exception(e) from e

def _result(self, response, json=False, binary=False):
@overload
def _result(self, response: requests.Response, json: Literal[True], binary: Literal[True]) -> NoReturn:
...

@overload
def _result(self, response: requests.Response, json: Literal[False] = ..., binary: Literal[False] = ...) -> str:
...

@overload
def _result(self, response: requests.Response, json: Literal[True], binary: bool = ...) -> Any:
...

@overload
def _result(self, response: requests.Response, json: bool = ..., binary: Literal[True] = ...) -> bytes:
...

def _result(self, response: requests.Response, json: bool = False, binary: bool = False) -> Any:
assert not (json and binary)
self._raise_for_status(response)

Expand All @@ -279,7 +301,7 @@ def _result(self, response, json=False, binary=False):
return response.content
return response.text

def _post_json(self, url, data, **kwargs):
def _post_json(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
# Go <1.1 can't unserialize null to a string
# so we do this disgusting thing here.
data2 = {}
Expand All @@ -295,26 +317,26 @@ def _post_json(self, url, data, **kwargs):
kwargs['headers']['Content-Type'] = 'application/json'
return self._post(url, data=json.dumps(data2), **kwargs)

def _attach_params(self, override=None):
def _attach_params(self, override: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return override or {
'stdout': 1,
'stderr': 1,
'stream': 1
}

@check_resource('container')
def _attach_websocket(self, container, params=None):
def _attach_websocket(self, container: str, params: Optional[Dict[str, Any]] = None) -> websocket.WebSocket:
url = self._url("/containers/{0}/attach/ws", container)
req = requests.Request("POST", url, params=self._attach_params(params))
full_url = req.prepare().url
full_url = full_url.replace("http://", "ws://", 1)
full_url = full_url.replace("https://", "wss://", 1)
return self._create_websocket_connection(full_url)

def _create_websocket_connection(self, url):
def _create_websocket_connection(self, url: str) -> websocket.WebSocket:
return websocket.create_connection(url)

def _get_raw_response_socket(self, response):
def _get_raw_response_socket(self, response: requests.Response) -> ssl.SSLSocket:
self._raise_for_status(response)
if self.base_url == "http+docker://localnpipe":
sock = response.raw._fp.fp.raw.sock
Expand All @@ -336,7 +358,15 @@ def _get_raw_response_socket(self, response):

return sock

def _stream_helper(self, response, decode=False):
@overload
def _stream_helper(self, response: requests.Response, decode: Literal[True]) -> Iterator[Dict[str, Any]]:
...

@overload
def _stream_helper(self, response: requests.Response, decode: Literal[False] = ...) -> Iterator[bytes]:
...

def _stream_helper(self, response: requests.Response, decode: bool = False) -> Iterator[BytesOrDict]:
"""Generator for data coming from a chunked-encoded HTTP response."""

if response.raw._fp.chunked:
Expand All @@ -357,7 +387,7 @@ def _stream_helper(self, response, decode=False):
# encountered an error immediately
yield self._result(response, json=decode)

def _multiplexed_buffer_helper(self, response):
def _multiplexed_buffer_helper(self, response: requests.Response) -> Iterator[bytes]:
"""A generator of multiplexed data blocks read from a buffered
response."""
buf = self._result(response, binary=True)
Expand All @@ -373,7 +403,7 @@ def _multiplexed_buffer_helper(self, response):
walker = end
yield buf[start:end]

def _multiplexed_response_stream_helper(self, response):
def _multiplexed_response_stream_helper(self, response: requests.Response) -> Iterator[bytes]:
"""A generator of multiplexed data blocks coming from a response
stream."""

Expand All @@ -394,7 +424,15 @@ def _multiplexed_response_stream_helper(self, response):
break
yield data

def _stream_raw_result(self, response, chunk_size=1, decode=True):
@overload
def _stream_raw_result(self, response: requests.Response, chunk_size: int = ..., decode: Literal[False] = ...) -> Iterator[bytes]:
...

@overload
def _stream_raw_result(self, response: requests.Response, chunk_size: int = ..., decode: Literal[True] = ...) -> Iterator[str]:
...

def _stream_raw_result(self, response: requests.Response, chunk_size: int = 1, decode: bool = True) -> Iterator[AnyStr]:
''' Stream result for TTY-enabled container and raw binary data'''
self._raise_for_status(response)

Expand All @@ -405,7 +443,7 @@ def _stream_raw_result(self, response, chunk_size=1, decode=True):

yield from response.iter_content(chunk_size, decode)

def _read_from_socket(self, response, stream, tty=True, demux=False):
def _read_from_socket(self, response: requests.Response, stream: bool, tty: bool = True, demux: bool = False) -> Any:
socket = self._get_raw_response_socket(response)

gen = frames_iter(socket, tty)
Expand All @@ -423,7 +461,7 @@ def _read_from_socket(self, response, stream, tty=True, demux=False):
# Wait for all the frames, concatenate them, and return the result
return consume_socket_output(gen, demux=demux)

def _disable_socket_timeout(self, socket):
def _disable_socket_timeout(self, socket: ssl.SSLSocket) -> None:
""" Depending on the combination of python version and whether we're
connecting over http or https, we might need to access _sock, which
may or may not exist; or we may need to just settimeout on socket
Expand Down Expand Up @@ -452,14 +490,22 @@ def _disable_socket_timeout(self, socket):
s.settimeout(None)

@check_resource('container')
def _check_is_tty(self, container):
def _check_is_tty(self, container: str) -> bool:
cont = self.inspect_container(container)
return cont['Config']['Tty']

def _get_result(self, container, stream, res):
def _get_result(self, container: str, stream: bool, res: requests.Response):
return self._get_result_tty(stream, res, self._check_is_tty(container))

def _get_result_tty(self, stream, res, is_tty):
@overload
def _get_result_tty(self, stream: Literal[True], res: requests.Response, is_tty: bool) -> Iterator[bytes]:
...

@overload
def _get_result_tty(self, stream: Literal[False], res: requests.Response, is_tty: bool) -> bytes:
...

def _get_result_tty(self, stream: bool, res: requests.Response, is_tty: bool):
# We should also use raw streaming (without keep-alives)
# if we're dealing with a tty-enabled container.
if is_tty:
Expand All @@ -475,11 +521,11 @@ def _get_result_tty(self, stream, res, is_tty):
[x for x in self._multiplexed_buffer_helper(res)]
)

def _unmount(self, *args):
def _unmount(self, *args: str) -> None:
for proto in args:
self.adapters.pop(proto)

def get_adapter(self, url):
def get_adapter(self, url: str) -> requests.adapters.BaseAdapter:
try:
return super().get_adapter(url)
except requests.exceptions.InvalidSchema as e:
Expand All @@ -489,10 +535,10 @@ def get_adapter(self, url):
raise e

@property
def api_version(self):
def api_version(self) -> str:
return self._version

def reload_config(self, dockercfg_path=None):
def reload_config(self, dockercfg_path: Optional[str] = None) -> None:
"""
Force a reload of the auth configuration
Expand Down
3 changes: 3 additions & 0 deletions docker/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Any, Dict, TypeVar

BytesOrDict = TypeVar("BytesOrDict", bytes, Dict[str, Any])
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pywin32==304; sys_platform == 'win32'
requests==2.28.1
urllib3==1.26.11
websocket-client==1.3.3
typing_extensions>=3.10.0.0

0 comments on commit 3a11fb0

Please sign in to comment.