diff --git a/.gitignore b/.gitignore index 3f03a81..e92aa58 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ adb.egg-info/ .tox/ /adb.zip /fastboot.zip +.idea/ +*.DS_Store* \ No newline at end of file diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 2fc5d46..ddb520b 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -6,3 +6,5 @@ Marc-Antoine Ruel Max Borghino Mohammad Abu-Garbeyyeh Josip Delic +Greg E. + diff --git a/adb/adb_commands.py b/adb/adb_commands.py index eb454ce..c9ca631 100644 --- a/adb/adb_commands.py +++ b/adb/adb_commands.py @@ -25,6 +25,7 @@ import io import os import socket +import posixpath from adb import adb_protocol from adb import common @@ -54,52 +55,127 @@ class AdbCommands(object): protocol_handler = adb_protocol.AdbMessage filesync_handler = filesync_protocol.FilesyncProtocol - @classmethod - def ConnectDevice( - cls, port_path=None, serial=None, default_timeout_ms=None, **kwargs): - """Convenience function to get an adb device from usb path or serial. + def __init__(self): + + self.__reset() + + def __reset(self): + self.build_props = None + self._handle = None + self._device_state = None + + # Connection table tracks each open AdbConnection objects per service type for program functions + # that choose to persist an AdbConnection object for their functionality, using + # self._get_service_connection + self._service_connections = {} + + def _get_service_connection(self, service, service_command=None, create=True, timeout_ms=None): + """ + Based on the service, get the AdbConnection for that service or create one if it doesnt exist + + :param service: + :param service_command: Additional service parameters to append + :param create: If False, dont create a connection if it does not exist + :return: + """ + + connection = self._service_connections.get(service, None) + + if connection: + return connection + + if not connection and not create: + return None + + if service_command: + destination_str = b'%s:%s' % (service, service_command) + else: + destination_str = service + + connection = self.protocol_handler.Open( + self._handle, destination=destination_str, timeout_ms=timeout_ms) + + self._service_connections.update({service: connection}) + + return connection + + def ConnectDevice(self, port_path=None, serial=None, default_timeout_ms=None, **kwargs): + """Convenience function to setup a transport handle for the adb device from + usb path or serial then connect to it. Args: port_path: The filename of usb port to use. serial: The serial number of the device to use. default_timeout_ms: The default timeout in milliseconds to use. + kwargs: handle: Device handle to use (instance of common.TcpHandle or common.UsbHandle) + banner: Connection banner to pass to the remote device + rsa_keys: List of AuthSigner subclass instances to be used for + authentication. The device can either accept one of these via the Sign + method, or we will send the result of GetPublicKey from the first one + if the device doesn't accept any of them. + auth_timeout_ms: Timeout to wait for when sending a new public key. This + is only relevant when we send a new public key. The device shows a + dialog and this timeout is how long to wait for that dialog. If used + in automation, this should be low to catch such a case as a failure + quickly; while in interactive settings it should be high to allow + users to accept the dialog. We default to automation here, so it's low + by default. If serial specifies a TCP address:port, then a TCP connection is used instead of a USB connection. """ - if serial and b':' in serial: - handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms) + + # If there isnt a handle override (used by tests), build one here + if 'handle' in kwargs: + self._handle = kwargs.pop('handle') + elif serial and b':' in serial: + self._handle = common.TcpHandle(serial, timeout_ms=default_timeout_ms) else: - handle = common.UsbHandle.FindAndOpen( + self._handle = common.UsbHandle.FindAndOpen( DeviceIsAvailable, port_path=port_path, serial=serial, timeout_ms=default_timeout_ms) - return cls.Connect(handle, **kwargs) - def __init__(self, handle, device_state): - self.handle = handle - self._device_state = device_state + self._Connect(**kwargs) + + return self def Close(self): - self.handle.Close() + for conn in list(self._service_connections.values()): + if conn: + try: + conn.Close() + except: + pass - @classmethod - def Connect(cls, usb, banner=None, **kwargs): + if self._handle: + self._handle.Close() + + self.__reset() + + def _Connect(self, banner=None, **kwargs): """Connect to the device. Args: - usb: UsbHandle or TcpHandle instance to use. banner: See protocol_handler.Connect. - **kwargs: See protocol_handler.Connect for kwargs. Includes rsa_keys, - and auth_timeout_ms. + **kwargs: See protocol_handler.Connect and adb_commands.ConnectDevice for kwargs. + Includes handle, rsa_keys, and auth_timeout_ms. Returns: An instance of this class if the device connected successfully. """ + if not banner: banner = socket.gethostname().encode() - device_state = cls.protocol_handler.Connect(usb, banner=banner, **kwargs) + + conn_str = self.protocol_handler.Connect(self._handle, banner=banner, **kwargs) + # Remove banner and colons after device state (state::banner) - device_state = device_state.split(b':')[0] - return cls(usb, device_state) + parts = conn_str.split(b'::') + self._device_state = parts[0] + + # Break out the build prop info + self.build_props = str(parts[1].split(b';')) + + return True @classmethod def Devices(cls): @@ -129,13 +205,13 @@ def Install(self, apk_path, destination_dir='', timeout_ms=None, replace_existin if not destination_dir: destination_dir = '/data/local/tmp/' basename = os.path.basename(apk_path) - destination_path = destination_dir + basename + destination_path = posixpath.join(destination_dir, basename) self.Push(apk_path, destination_path, timeout_ms=timeout_ms, progress_callback=transfer_progress_callback) cmd = ['pm install'] if replace_existing: cmd.append('-r') - cmd.append('"%s"' % destination_path) + cmd.append('"{}"'.format(destination_path)) return self.Shell(' '.join(cmd), timeout_ms=timeout_ms) def Uninstall(self, package_name, keep_data=False, timeout_ms=None): @@ -167,7 +243,6 @@ def Push(self, source_file, device_filename, mtime='0', timeout_ms=None, progres progress_callback: callback method that accepts filename, bytes_written and total_bytes, total_bytes will be -1 for file-like objects """ - should_close = False if isinstance(source_file, str): if os.path.isdir(source_file): self.Shell("mkdir " + device_filename) @@ -175,42 +250,49 @@ def Push(self, source_file, device_filename, mtime='0', timeout_ms=None, progres self.Push(os.path.join(source_file, f), device_filename + '/' + f, progress_callback=progress_callback) return source_file = open(source_file, "rb") - should_close = True - connection = self.protocol_handler.Open( - self.handle, destination=b'sync:', timeout_ms=timeout_ms) - self.filesync_handler.Push(connection, source_file, device_filename, + with source_file: + connection = self.protocol_handler.Open( + self._handle, destination=b'sync:', timeout_ms=timeout_ms) + self.filesync_handler.Push(connection, source_file, device_filename, mtime=int(mtime), progress_callback=progress_callback) - if should_close: - source_file.close() connection.Close() - def Pull(self, device_filename, dest_file='', progress_callback=None, timeout_ms=None): + def Pull(self, device_filename, dest_file=None, timeout_ms=None, progress_callback=None): """Pull a file from the device. Args: device_filename: Filename on the device to pull. dest_file: If set, a filename or writable file-like object. timeout_ms: Expected timeout for any part of the pull. + progress_callback: callback method that accepts filename, bytes_written and total_bytes, + total_bytes will be -1 for file-like objects Returns: - The file data if dest_file is not set. + The file data if dest_file is not set. Otherwise, True if the destination file exists """ if not dest_file: dest_file = io.BytesIO() elif isinstance(dest_file, str): - dest_file = open(dest_file, 'wb') - connection = self.protocol_handler.Open( - self.handle, destination=b'sync:', - timeout_ms=timeout_ms) - self.filesync_handler.Pull(connection, device_filename, dest_file, progress_callback) - connection.Close() + dest_file = open(dest_file, 'w') + else: + raise ValueError("destfile is of unknown type") + + conn = self.protocol_handler.Open( + self._handle, destination=b'sync:', timeout_ms=timeout_ms) + + self.filesync_handler.Pull(conn, device_filename, dest_file, progress_callback) + + conn.Close() if isinstance(dest_file, io.BytesIO): return dest_file.getvalue() + else: + dest_file.close() + return os.path.exists(dest_file) def Stat(self, device_filename): """Get a file's stat() information.""" - connection = self.protocol_handler.Open(self.handle, destination=b'sync:') + connection = self.protocol_handler.Open(self._handle, destination=b'sync:') mode, size, mtime = self.filesync_handler.Stat( connection, device_filename) connection.Close() @@ -222,7 +304,7 @@ def List(self, device_path): Args: device_path: Directory to list. """ - connection = self.protocol_handler.Open(self.handle, destination=b'sync:') + connection = self.protocol_handler.Open(self._handle, destination=b'sync:') listing = self.filesync_handler.List(connection, device_path) connection.Close() return listing @@ -233,7 +315,7 @@ def Reboot(self, destination=b''): Args: destination: Specify 'bootloader' for fastboot. """ - self.protocol_handler.Open(self.handle, b'reboot:%s' % destination) + self.protocol_handler.Open(self._handle, b'reboot:%s' % destination) def RebootBootloader(self): """Reboot device into fastboot.""" @@ -241,24 +323,29 @@ def RebootBootloader(self): def Remount(self): """Remount / as read-write.""" - return self.protocol_handler.Command(self.handle, service=b'remount') + return self.protocol_handler.Command(self._handle, service=b'remount') def Root(self): """Restart adbd as root on the device.""" - return self.protocol_handler.Command(self.handle, service=b'root') + return self.protocol_handler.Command(self._handle, service=b'root') def EnableVerity(self): """Re-enable dm-verity checking on userdebug builds""" - return self.protocol_handler.Command(self.handle, service=b'enable-verity') + return self.protocol_handler.Command(self._handle, service=b'enable-verity') def DisableVerity(self): """Disable dm-verity checking on userdebug builds""" - return self.protocol_handler.Command(self.handle, service=b'disable-verity') + return self.protocol_handler.Command(self._handle, service=b'disable-verity') def Shell(self, command, timeout_ms=None): - """Run command on the device, returning the output.""" + """Run command on the device, returning the output. + + Args: + command: Shell command to run + timeout_ms: Maximum time to allow the command to run. + """ return self.protocol_handler.Command( - self.handle, service=b'shell', command=command, + self._handle, service=b'shell', command=command, timeout_ms=timeout_ms) def StreamingShell(self, command, timeout_ms=None): @@ -272,7 +359,7 @@ def StreamingShell(self, command, timeout_ms=None): The responses from the shell command. """ return self.protocol_handler.StreamingCommand( - self.handle, service=b'shell', command=command, + self._handle, service=b'shell', command=command, timeout_ms=timeout_ms) def Logcat(self, options, timeout_ms=None): @@ -280,5 +367,26 @@ def Logcat(self, options, timeout_ms=None): Args: options: Arguments to pass to 'logcat'. + timeout_ms: Maximum time to allow the command to run. """ return self.StreamingShell('logcat %s' % options, timeout_ms) + + def InteractiveShell(self, cmd=None, strip_cmd=True, delim=None, strip_delim=True): + """Get stdout from the currently open interactive shell and optionally run a command + on the device, returning all output. + + Args: + command: Optional. Command to run on the target. + strip_cmd: Optional (default True). Strip command name from stdout. + delim: Optional. Delimiter to look for in the output to know when to stop expecting more output + (usually the shell prompt) + strip_delim: Optional (default True): Strip the provided delimiter from the output + + Returns: + The stdout from the shell command. + """ + conn = self._get_service_connection(b'shell:') + + return self.protocol_handler.InteractiveShellCommand( + conn, cmd=cmd, strip_cmd=strip_cmd, + delim=delim, strip_delim=strip_delim) diff --git a/adb/adb_debug.py b/adb/adb_debug.py old mode 100755 new mode 100644 index 9384b3c..4fa974f --- a/adb/adb_debug.py +++ b/adb/adb_debug.py @@ -34,7 +34,12 @@ from adb import sign_pythonrsa rsa_signer = sign_pythonrsa.PythonRSASigner.FromRSAKeyPath except ImportError: - rsa_signer = None + try: + from adb import sign_pycryptodome + + rsa_signer = sign_pycryptodome.PycryptodomeAuthSigner + except ImportError: + rsa_signer = None def Devices(args): @@ -53,13 +58,13 @@ def Devices(args): return 0 -def List(self, device_path): +def List(device, device_path): """Prints a directory listing. Args: device_path: Directory to list. """ - files = adb_commands.AdbCommands.List(self, device_path) + files = device.List(device_path) files.sort(key=lambda x: x.filename) maxname = max(len(f.filename) for f in files) maxsize = max(len(str(f.size)) for f in files) @@ -83,18 +88,40 @@ def List(self, device_path): @functools.wraps(adb_commands.AdbCommands.Logcat) -def Logcat(self, *options): - return adb_commands.AdbCommands.Logcat( +def Logcat(device, *options): + return device.Logcat( self, ' '.join(options), timeout_ms=0) -def Shell(self, *command): +def Shell(device, *command): """Runs a command on the device and prints the stdout. Args: command: Command to run on the target. """ - return adb_commands.AdbCommands.StreamingShell(self, ' '.join(command)) + if command: + return device.StreamingShell(' '.join(command)) + else: + # Retrieve the initial terminal prompt to use as a delimiter for future reads + terminal_prompt = device.InteractiveShell() + print(terminal_prompt.decode('utf-8')) + + # Accept user input in a loop and write that into the interactive shells stdin, then print output + while True: + cmd = input('> ') + if not cmd: + continue + elif cmd == 'exit': + break + else: + stdout = device.InteractiveShell(cmd, strip_cmd=True, delim=terminal_prompt, strip_delim=True) + if stdout: + if isinstance(stdout, bytes): + stdout = stdout.decode('utf-8') + print(stdout) + + + device.Close() def main(): @@ -159,7 +186,8 @@ def main(): if os.path.isfile(default): args.rsa_key_path = [default] if args.rsa_key_path and not rsa_signer: - parser.error('Please install either M2Crypto or python-rsa') + parser.error('Please install either M2Crypto, python-rsa, or PycryptoDome') + # Hacks so that the generated doc is nicer. if args.command_name == 'devices': return Devices(args) @@ -173,8 +201,8 @@ def main(): return common_cli.StartCli( args, - adb_commands.AdbCommands.ConnectDevice, - auth_timeout_ms=args.auth_timeout_s * 1000, + adb_commands.AdbCommands, + auth_timeout_ms=int(args.auth_timeout_s * 1000), rsa_keys=[rsa_signer(path) for path in args.rsa_key_path]) diff --git a/adb/adb_protocol.py b/adb/adb_protocol.py index 8047f48..32ea651 100644 --- a/adb/adb_protocol.py +++ b/adb/adb_protocol.py @@ -19,10 +19,9 @@ import struct import time - +from io import BytesIO from adb import usb_exceptions - # Maximum amount of data in an ADB packet. MAX_ADB_DATA = 4096 # ADB protocol version. @@ -34,6 +33,23 @@ AUTH_RSAPUBLICKEY = 3 +def find_backspace_runs(stdout_bytes, start_pos): + + first_backspace_pos = stdout_bytes[start_pos:].find(b'\x08') + if first_backspace_pos == -1: + return -1, 0 + + end_backspace_pos = (start_pos + first_backspace_pos) + 1 + while True: + if chr(stdout_bytes[end_backspace_pos]) == '\b': + end_backspace_pos += 1 + else: + break + + num_backspaces = end_backspace_pos - (start_pos + first_backspace_pos) + + return (start_pos + first_backspace_pos), num_backspaces + class InvalidCommandError(Exception): """Got an invalid command over USB.""" @@ -241,6 +257,8 @@ def Read(cls, usb, expected_cmds, timeout_ms=None, total_timeout_ms=None): data = bytearray() while data_length > 0: temp = usb.BulkRead(data_length, timeout_ms) + if len(temp) != data_length: + print("Data_length {} does not match actual number of bytes read: {}".format(data_length, len(temp))) data += temp data_length -= len(temp) @@ -298,7 +316,8 @@ def Connect(cls, usb, banner=b'notadb', rsa_keys=None, auth_timeout_ms=100): raise InvalidResponseError( 'Unknown AUTH response: %s %s %s' % (arg0, arg1, banner)) - signed_token = rsa_key.Sign(str(banner)) + # Do not mangle the banner property here by converting it to a string + signed_token = rsa_key.Sign(banner) msg = cls( command=b'AUTH', arg0=AUTH_SIGNATURE, arg1=0, data=signed_token) msg.Send(usb) @@ -349,7 +368,7 @@ def Open(cls, usb, destination, timeout_ms=None): timeout_ms=timeout_ms) if local_id != their_local_id: raise InvalidResponseError( - 'Expected the local_id to be %s, got %s' % (local_id, their_local_id)) + 'Expected the local_id to be {}, got {}'.format(local_id, their_local_id)) if cmd == b'CLSE': # Some devices seem to be sending CLSE once more after a request, this *should* handle it cmd, remote_id, their_local_id, _ = cls.Read(usb, [b'CLSE', b'OKAY'], @@ -358,7 +377,7 @@ def Open(cls, usb, destination, timeout_ms=None): if cmd == b'CLSE': return None if cmd != b'OKAY': - raise InvalidCommandError('Expected a ready response, got %s' % cmd, + raise InvalidCommandError('Expected a ready response, got {}'.format(cmd), cmd, (remote_id, their_local_id)) return _AdbConnection(usb, local_id, remote_id, timeout_ms) @@ -413,3 +432,128 @@ def StreamingCommand(cls, usb, service, command='', timeout_ms=None): timeout_ms=timeout_ms) for data in connection.ReadUntilClose(): yield data.decode('utf8') + + @classmethod + def InteractiveShellCommand(cls, conn, cmd=None, strip_cmd=True, delim=None, strip_delim=True, clean_stdout=True): + """Retrieves stdout of the current InteractiveShell and sends a shell command if provided + TODO: Should we turn this into a yield based function so we can stream all output? + + Args: + conn: Instance of AdbConnection + cmd: Optional. Command to run on the target. + strip_cmd: Optional (default True). Strip command name from stdout. + delim: Optional. Delimiter to look for in the output to know when to stop expecting more output + (usually the shell prompt) + strip_delim: Optional (default True): Strip the provided delimiter from the output + clean_stdout: Cleanup the stdout stream of any backspaces and the characters that were deleted by the backspace + Returns: + The stdout from the shell command. + """ + + if isinstance(delim, str): + delimiter = delim.encode('utf-8') + + # Delimiter may be shell@hammerhead:/ $ + # The user or directory could change, making the delimiter somthing like root@hammerhead:/data/local/tmp $ + # Handle a partial delimiter to search on and clean up + if delim: + user_pos = delim.find(b'@') + dir_pos = delim.rfind(b':/') + if user_pos != -1 and dir_pos != -1: + partial_delim = delim[user_pos:dir_pos+1] # e.g. @hammerhead: + else: + partial_delim = delim + else: + partial_delim = None + + stdout = '' + stdout_stream = BytesIO() + original_cmd = '' + + try: + + if cmd: + original_cmd = str(cmd) + cmd += '\r' # Required. Send a carriage return right after the cmd + cmd = cmd.encode('utf8') + + # Send the cmd raw + bytes_written = conn.Write(cmd) + + if delim: + # Expect multiple WRTE cmds until the delim (usually terminal prompt) is detected + + data = b'' + while partial_delim not in data: + + cmd, data = conn.ReadUntil(b'WRTE') + stdout_stream.write(data) + + else: + # Otherwise, expect only a single WRTE + cmd, data = conn.ReadUntil(b'WRTE') + + # WRTE cmd from device will follow with stdout data + stdout_stream.write(data) + + else: + + # No cmd provided means we should just expect a single line from the terminal. Use this sparingly + cmd, data = conn.ReadUntil(b'WRTE') + if cmd == b'WRTE': + # WRTE cmd from device will follow with stdout data + stdout_stream.write(data) + else: + print("Unhandled cmd: {}".format(cmd)) + + cleaned_stdout_stream = BytesIO() + if clean_stdout: + stdout_bytes = stdout_stream.getvalue() + + bsruns = {} # Backspace runs tracking + next_start_pos = 0 + last_run_pos, last_run_len = find_backspace_runs(stdout_bytes, next_start_pos) + + if last_run_pos != -1 and last_run_len != 0: + bsruns.update({last_run_pos: last_run_len}) + cleaned_stdout_stream.write(stdout_bytes[next_start_pos:(last_run_pos-last_run_len)]) + next_start_pos += last_run_pos + last_run_len + + while last_run_pos != -1: + last_run_pos, last_run_len = find_backspace_runs(stdout_bytes[next_start_pos:], next_start_pos) + + if last_run_pos != -1: + bsruns.update({last_run_pos: last_run_len}) + cleaned_stdout_stream.write(stdout_bytes[next_start_pos:(last_run_pos - last_run_len)]) + next_start_pos += last_run_pos + last_run_len + + cleaned_stdout_stream.write(stdout_bytes[next_start_pos:]) + + else: + cleaned_stdout_stream.write(stdout_stream.getvalue()) + + stdout = cleaned_stdout_stream.getvalue() + + # Strip original cmd that will come back in stdout + if original_cmd and strip_cmd: + findstr = original_cmd.encode('utf-8') + b'\r\r\n' + pos = stdout.find(findstr) + while pos >= 0: + stdout = stdout.replace(findstr, b'') + pos = stdout.find(findstr) + + if b'\r\r\n' in stdout: + stdout = stdout.split(b'\r\r\n')[1] + + # Strip delim if requested + # TODO: Handling stripping partial delims here - not a deal breaker the way we're handling it now + if delim and strip_delim: + + stdout = stdout.replace(delim, b'') + + stdout = stdout.rstrip() + + except Exception as e: + print("InteractiveShell exception (most likely timeout): {}".format(e)) + + return stdout diff --git a/adb/common.py b/adb/common.py index 9e1fc17..2c95f8e 100644 --- a/adb/common.py +++ b/adb/common.py @@ -27,7 +27,7 @@ from adb import usb_exceptions -DEFAULT_TIMEOUT_MS = 1000 +DEFAULT_TIMEOUT_MS = 10000 _LOG = logging.getLogger('android_usb') @@ -76,7 +76,8 @@ def __init__(self, device, setting, usb_info=None, timeout_ms=None): self._handle = None self._usb_info = usb_info or '' - self._timeout_ms = timeout_ms or DEFAULT_TIMEOUT_MS + self._timeout_ms = timeout_ms if timeout_ms else DEFAULT_TIMEOUT_MS + self._max_read_packet_len = 0 @property def usb_info(self): @@ -192,10 +193,14 @@ def BulkRead(self, length, timeout_ms=None): 'Could not receive data from %s (timeout %sms)' % ( self.usb_info, self.Timeout(timeout_ms)), e) + def BulkReadAsync(self, length, timeout_ms=None): + # See: https://pypi.python.org/pypi/libusb1 "Asynchronous I/O" section + return + @classmethod def PortPathMatcher(cls, port_path): """Returns a device matcher for the given port path.""" - if isinstance(port_path, basestring): + if isinstance(port_path, str): # Convert from sysfs path to port_path. port_path = [int(part) for part in SYSFS_PORT_SPLIT_RE.split(port_path)] return lambda device: device.port_path == port_path @@ -284,8 +289,7 @@ class TcpHandle(object): Provides same interface as UsbHandle. """ def __init__(self, serial, timeout_ms=None): - """Initialize the Tcp Handle. - + """Initialize the TCP Handle. Arguments: serial: Android device serial of the form host or host:port. @@ -314,7 +318,7 @@ def BulkWrite(self, data, timeout=None): return self._connection.send(data) msg = 'Sending data to {} timed out after {}s. No data was sent.'.format( self.serial_number, t) - raise usb_exceptions.TcpTimeoutException(msg) + raise usb_exceptions.TcpTimeoutException(msg) def BulkRead(self, numbytes, timeout=None): t = self.TimeoutSeconds(timeout) @@ -322,7 +326,7 @@ def BulkRead(self, numbytes, timeout=None): if readable: return self._connection.recv(numbytes) msg = 'Reading from {} timed out (Timeout {}s)'.format( - self._serial_number,t) + self._serial_number, t) raise usb_exceptions.TcpTimeoutException(msg) def Timeout(self, timeout_ms): diff --git a/adb/common_cli.py b/adb/common_cli.py index 22dfdd8..f3d2be6 100644 --- a/adb/common_cli.py +++ b/adb/common_cli.py @@ -19,13 +19,13 @@ outputting the results. """ +from __future__ import print_function import argparse import io import inspect import logging import re import sys -import textwrap import types from adb import usb_exceptions @@ -142,17 +142,16 @@ def _RunMethod(dev, args, extra): return 0 -def StartCli(args, device_factory, extra=None, **device_kwargs): +def StartCli(args, adb_commands, extra=None, **device_kwargs): """Starts a common CLI interface for this usb path and protocol.""" try: - dev = device_factory( - port_path=args.port_path, serial=args.serial, - default_timeout_ms=args.timeout_ms, **device_kwargs) + dev = adb_commands() + dev.ConnectDevice(port_path=args.port_path, serial=args.serial, default_timeout_ms=args.timeout_ms, **device_kwargs) except usb_exceptions.DeviceNotFoundError as e: - print >> sys.stderr, 'No device found: %s' % e + print('No device found: {}'.format(e), file=sys.stderr) return 1 except usb_exceptions.CommonUsbError as e: - print >> sys.stderr, 'Could not connect to device: %s' % e + print('Could not connect to device: {}'.format(e), file=sys.stderr) return 1 try: return _RunMethod(dev, args, extra or {}) diff --git a/adb/fastboot.py b/adb/fastboot.py index 5642a3e..e9028a7 100644 --- a/adb/fastboot.py +++ b/adb/fastboot.py @@ -206,30 +206,62 @@ def _Write(self, data, length, progress_callback=None): class FastbootCommands(object): """Encapsulates the fastboot commands.""" - def __init__(self, usb, chunk_kb=1024): + def __init__(self): """Constructs a FastbootCommands instance. Args: usb: UsbHandle instance. """ - self._usb = usb - self._protocol = FastbootProtocol(usb, chunk_kb) + self.__reset() + + def __reset(self): + self._handle = None + self._protocol = None @property def usb_handle(self): - return self._usb + return self._handle def Close(self): - self._usb.Close() + self._handle.Close() - @classmethod - def ConnectDevice( - cls, port_path=None, serial=None, default_timeout_ms=None, chunk_kb=1024): - """Convenience function to get an adb device from usb path or serial.""" - usb = common.UsbHandle.FindAndOpen( - DeviceIsAvailable, port_path=port_path, serial=serial, - timeout_ms=default_timeout_ms) - return cls(usb, chunk_kb=chunk_kb) + def ConnectDevice(self, port_path=None, serial=None, default_timeout_ms=None, chunk_kb=1024, **kwargs): + """Convenience function to get an adb device from usb path or serial. + + Args: + port_path: The filename of usb port to use. + serial: The serial number of the device to use. + default_timeout_ms: The default timeout in milliseconds to use. + chunk_kb: Amount of data, in kilobytes, to break fastboot packets up into + kwargs: handle: Device handle to use (instance of common.TcpHandle or common.UsbHandle) + banner: Connection banner to pass to the remote device + rsa_keys: List of AuthSigner subclass instances to be used for + authentication. The device can either accept one of these via the Sign + method, or we will send the result of GetPublicKey from the first one + if the device doesn't accept any of them. + auth_timeout_ms: Timeout to wait for when sending a new public key. This + is only relevant when we send a new public key. The device shows a + dialog and this timeout is how long to wait for that dialog. If used + in automation, this should be low to catch such a case as a failure + quickly; while in interactive settings it should be high to allow + users to accept the dialog. We default to automation here, so it's low + by default. + + If serial specifies a TCP address:port, then a TCP connection is + used instead of a USB connection. + """ + + if 'handle' in kwargs: + self._handle = kwargs['handle'] + + else: + self._handle = common.UsbHandle.FindAndOpen( + DeviceIsAvailable, port_path=port_path, serial=serial, + timeout_ms=default_timeout_ms) + + self._protocol = FastbootProtocol(self._handle, chunk_kb) + + return self @classmethod def Devices(cls): @@ -285,15 +317,16 @@ def Download(self, source_file, source_len=0, source_len = os.stat(source_file).st_size source_file = open(source_file) - if source_len == 0: - # Fall back to storing it all in memory :( - data = source_file.read() - source_file = io.BytesIO(data.encode('utf8')) - source_len = len(data) - - self._protocol.SendCommand(b'download', b'%08x' % source_len) - return self._protocol.HandleDataSending( - source_file, source_len, info_cb, progress_callback=progress_callback) + with source_file: + if source_len == 0: + # Fall back to storing it all in memory :( + data = source_file.read() + source_file = io.BytesIO(data.encode('utf8')) + source_len = len(data) + + self._protocol.SendCommand(b'download', b'%08x' % source_len) + return self._protocol.HandleDataSending( + source_file, source_len, info_cb, progress_callback=progress_callback) def Flash(self, partition, timeout_ms=0, info_cb=DEFAULT_MESSAGE_CALLBACK): """Flashes the last downloaded file to the given partition. diff --git a/adb/fastboot_debug.py b/adb/fastboot_debug.py index 7f25c44..f904b90 100755 --- a/adb/fastboot_debug.py +++ b/adb/fastboot_debug.py @@ -115,7 +115,9 @@ def SetProgress(current, total): kwargs['progress_callback'] = SetProgress return common_cli.StartCli( - args, fastboot.FastbootCommands.ConnectDevice, chunk_kb=args.chunk_kb, + args, + fastboot.FastbootCommands, + chunk_kb=args.chunk_kb, extra=kwargs) diff --git a/adb/filesync_protocol.py b/adb/filesync_protocol.py index b6a77c1..fdecd4a 100644 --- a/adb/filesync_protocol.py +++ b/adb/filesync_protocol.py @@ -31,7 +31,7 @@ # Default mode for pushed files. DEFAULT_PUSH_MODE = stat.S_IFREG | stat.S_IRWXU | stat.S_IRWXG # Maximum size of a filesync DATA packet. -MAX_PUSH_DATA = 2*1024 +MAX_PUSH_DATA = 2 * 1024 class InvalidChecksumError(Exception): @@ -95,7 +95,12 @@ def Pull(cls, connection, filename, dest_file, progress_callback): @classmethod def _HandleProgress(cls, progress_callback): - """Calls the callback with the current progress and total .""" + """Calls the callback with the current progress and total bytes written/received. + + Args: + progress_callback: callback method that accepts filename, bytes_written and total_bytes, + total_bytes will be -1 for file-like objects + """ current = 0 while True: current += yield @@ -115,14 +120,13 @@ def Push(cls, connection, datafile, filename, filename: Filename to push to st_mode: stat mode for filename mtime: modification time - progress_callback: callback method that accepts filename, bytes_written and total_bytes, + progress_callback: callback method that accepts filename, bytes_written and total_bytes Raises: PushFailedError: Raised on push failure. """ - if not isinstance(filename, bytes): - filename = filename.encode('utf8') - fileinfo = b'%s,%d' % (filename, st_mode) + + fileinfo = ('{},{}'.format(filename, int(st_mode))).encode('utf-8') cnxn = FileSyncConnection(connection, b'<2I') cnxn.Send(b'SEND', fileinfo) @@ -211,7 +215,10 @@ def Read(self, expected_ids, read_data=True): if command_id not in expected_ids: if command_id == b'FAIL': - raise usb_exceptions.AdbCommandFailureException('Command failed.') + reason = '' + if self.recv_buffer: + reason = self.recv_buffer.decode('utf-8', errors='ignore') + raise usb_exceptions.AdbCommandFailureException('Command failed: {}'.format(reason)) raise adb_protocol.InvalidResponseError( 'Expected one of %s, got %s' % (expected_ids, command_id)) @@ -252,4 +259,3 @@ def _ReadBuffered(self, size): result = self.recv_buffer[:size] self.recv_buffer = self.recv_buffer[size:] return result - diff --git a/adb/sign_pycryptodome.py b/adb/sign_pycryptodome.py new file mode 100644 index 0000000..1a56f6a --- /dev/null +++ b/adb/sign_pycryptodome.py @@ -0,0 +1,25 @@ +from adb import adb_protocol + +from Crypto.Hash import SHA256 +from Crypto.PublicKey import RSA +from Crypto.Signature import pkcs1_15 + +class PycryptodomeAuthSigner(adb_protocol.AuthSigner): + + def __init__(self, rsa_key_path=None): + + super(PycryptodomeAuthSigner, self).__init__() + + if rsa_key_path: + with open(rsa_key_path + '.pub', 'rb') as rsa_pub_file: + self.public_key = rsa_pub_file.read() + + with open(rsa_key_path, 'rb') as rsa_priv_file: + self.rsa_key = RSA.import_key(rsa_priv_file.read()) + + def Sign(self, data): + h = SHA256.new(data) + return pkcs1_15.new(self.rsa_key).sign(h) + + def GetPublicKey(self): + return self.public_key diff --git a/setup.py b/setup.py index 47c88f8..3cd715b 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,23 @@ from setuptools import setup +# Figure out if the system already has a supported Crypto library +rsa_signer_library = 'M2Crypto>=0.21.1,<=0.26.4' +try: + import rsa + + rsa_signer_library = 'python-rsa' +except ImportError: + try: + from Crypto.Hash import SHA256 + from Crypto.PublicKey import RSA + from Crypto.Signature import pkcs1_15 + + rsa_signer_library = 'pycryptodome' + except ImportError: + pass + + setup( name = 'adb', packages = ['adb'], @@ -43,7 +60,10 @@ keywords = ['android', 'adb', 'fastboot'], - install_requires = ['libusb1>=1.0.16', 'M2Crypto>=0.21.1,<=0.26.4'], + install_requires = [ + 'libusb1>=1.0.16', + rsa_signer_library + ], extra_requires = { 'fastboot': 'progressbar>=2.3' diff --git a/test/adb_test.py b/test/adb_test.py index 1b8800c..bdbfce5 100755 --- a/test/adb_test.py +++ b/test/adb_test.py @@ -14,7 +14,7 @@ # limitations under the License. """Tests for adb.""" -import io +from io import BytesIO import struct import unittest @@ -94,15 +94,17 @@ def testConnect(self): usb = common_stub.StubUsb() self._ExpectConnection(usb) - adb_commands.AdbCommands.Connect(usb, BANNER) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) def testSmallResponseShell(self): command = b'keepin it real' response = 'word.' usb = self._ExpectCommand(b'shell', command, response) - adb_commands = self._Connect(usb) - self.assertEqual(response, adb_commands.Shell(command)) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + self.assertEqual(response, dev.Shell(command)) def testBigResponseShell(self): command = b'keepin it real big' @@ -112,9 +114,10 @@ def testBigResponseShell(self): usb = self._ExpectCommand(b'shell', command, *responses) - adb_commands = self._Connect(usb) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) self.assertEqual(b''.join(responses).decode('utf8'), - adb_commands.Shell(command)) + dev.Shell(command)) def testUninstall(self): package_name = "com.test.package" @@ -122,8 +125,9 @@ def testUninstall(self): usb = self._ExpectCommand(b'shell', ('pm uninstall "%s"' % package_name).encode('utf8'), response) - adb_commands = self._Connect(usb) - self.assertEquals(response, adb_commands.Uninstall(package_name)) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + self.assertEqual(response, dev.Uninstall(package_name)) def testStreamingResponseShell(self): command = b'keepin it real big' @@ -133,42 +137,49 @@ def testStreamingResponseShell(self): usb = self._ExpectCommand(b'shell', command, *responses) - adb_commands = self._Connect(usb) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) response_count = 0 - for (expected,actual) in zip(responses, adb_commands.StreamingShell(command)): + for (expected,actual) in zip(responses, dev.StreamingShell(command)): self.assertEqual(expected, actual) response_count = response_count + 1 self.assertEqual(len(responses), response_count) def testReboot(self): usb = self._ExpectCommand(b'reboot', b'', b'') - adb_commands = self._Connect(usb) - adb_commands.Reboot() + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.Reboot() def testRebootBootloader(self): usb = self._ExpectCommand(b'reboot', b'bootloader', b'') - adb_commands = self._Connect(usb) - adb_commands.RebootBootloader() + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.RebootBootloader() def testRemount(self): usb = self._ExpectCommand(b'remount', b'', b'') - adb_commands = self._Connect(usb) - adb_commands.Remount() + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.Remount() def testRoot(self): usb = self._ExpectCommand(b'root', b'', b'') - adb_commands = self._Connect(usb) - adb_commands.Root() + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.Root() def testEnableVerity(self): usb = self._ExpectCommand(b'enable-verity', b'', b'') - adb_commands = self._Connect(usb) - adb_commands.EnableVerity() + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.EnableVerity() def testDisableVerity(self): usb = self._ExpectCommand(b'disable-verity', b'', b'') - adb_commands = self._Connect(usb) - adb_commands.DisableVerity() + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.DisableVerity() class FilesyncAdbTest(BaseAdbTest): @@ -202,7 +213,7 @@ def _ExpectSyncCommand(cls, write_commands, read_commands): return usb def testPush(self): - filedata = u'alo there, govnah' + filedata = b'alo there, govnah' mtime = 100 send = [ @@ -213,8 +224,9 @@ def testPush(self): data = b'OKAY\0\0\0\0' usb = self._ExpectSyncCommand([b''.join(send)], [data]) - adb_commands = self._Connect(usb) - adb_commands.Push(io.StringIO(filedata), '/data', mtime=mtime) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + dev.Push(BytesIO(filedata), '/data', mtime=mtime) def testPull(self): filedata = b"g'ddayta, govnah" @@ -225,8 +237,9 @@ def testPull(self): self._MakeWriteSyncPacket(b'DONE'), ] usb = self._ExpectSyncCommand([recv], [b''.join(data)]) - adb_commands = self._Connect(usb) - self.assertEqual(filedata, adb_commands.Pull('/data')) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=usb, banner=BANNER) + self.assertEqual(filedata, dev.Pull('/data')) class TcpTimeoutAdbTest(BaseAdbTest): @@ -244,13 +257,15 @@ def _ExpectCommand(cls, service, command, *responses): def _run_shell(self, cmd, timeout_ms=None): tcp = self._ExpectCommand(b'shell', cmd) - adb_commands = self._Connect(tcp) - adb_commands.Shell(cmd, timeout_ms=timeout_ms) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=tcp, banner=BANNER) + dev.Shell(cmd, timeout_ms=timeout_ms) def testConnect(self): tcp = common_stub.StubTcp() self._ExpectConnection(tcp) - adb_commands.AdbCommands.Connect(tcp, BANNER) + dev = adb_commands.AdbCommands() + dev.ConnectDevice(handle=tcp, banner=BANNER) def testTcpTimeout(self): timeout_ms = 1 diff --git a/test/fastboot_test.py b/test/fastboot_test.py index 6862c2e..32c96fa 100755 --- a/test/fastboot_test.py +++ b/test/fastboot_test.py @@ -61,9 +61,10 @@ def testDownload(self): data = io.StringIO(raw) self.ExpectDownload([raw]) - commands = fastboot.FastbootCommands(self.usb) + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) - response = commands.Download(data) + response = dev.Download(data) self.assertEqual(b'Result', response) def testDownloadFail(self): @@ -71,26 +72,28 @@ def testDownloadFail(self): data = io.StringIO(raw) self.ExpectDownload([raw], succeed=False) - commands = fastboot.FastbootCommands(self.usb) + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) with self.assertRaises(fastboot.FastbootRemoteFailure): - commands.Download(data) + dev.Download(data) data = io.StringIO(raw) self.ExpectDownload([raw], accept_data=False) with self.assertRaises(fastboot.FastbootTransferError): - commands.Download(data) + dev.Download(data) def testFlash(self): partition = b'yarr' self.ExpectFlash(partition) - commands = fastboot.FastbootCommands(self.usb) + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) output = io.BytesIO() def InfoCb(message): if message.header == b'INFO': output.write(message.message) - response = commands.Flash(partition, info_cb=InfoCb) + response = dev.Flash(partition, info_cb=InfoCb) self.assertEqual(b'Done', response) self.assertEqual(b'Random info from the bootloader', output.getvalue()) @@ -98,10 +101,11 @@ def testFlashFail(self): partition = b'matey' self.ExpectFlash(partition, succeed=False) - commands = fastboot.FastbootCommands(self.usb) + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) with self.assertRaises(fastboot.FastbootRemoteFailure): - commands.Flash(partition) + dev.Flash(partition) def testFlashFromFile(self): partition = b'somewhere' @@ -122,51 +126,54 @@ def testFlashFromFile(self): cb = lambda progress, total: progresses.append((progress, total)) - commands = fastboot.FastbootCommands(self.usb) - commands.FlashFromFile( + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) + dev.FlashFromFile( partition, tmp.name, progress_callback=cb) self.assertEqual(len(pieces), len(progresses)) os.remove(tmp.name) def testSimplerCommands(self): - commands = fastboot.FastbootCommands(self.usb) + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) self.usb.ExpectWrite(b'erase:vector') self.usb.ExpectRead(b'OKAY') - commands.Erase('vector') + dev.Erase('vector') self.usb.ExpectWrite(b'getvar:variable') self.usb.ExpectRead(b'OKAYstuff') - self.assertEqual(b'stuff', commands.Getvar('variable')) + self.assertEqual(b'stuff', dev.Getvar('variable')) self.usb.ExpectWrite(b'continue') self.usb.ExpectRead(b'OKAY') - commands.Continue() + dev.Continue() self.usb.ExpectWrite(b'reboot') self.usb.ExpectRead(b'OKAY') - commands.Reboot() + dev.Reboot() self.usb.ExpectWrite(b'reboot-bootloader') self.usb.ExpectRead(b'OKAY') - commands.RebootBootloader() + dev.RebootBootloader() self.usb.ExpectWrite(b'oem a little somethin') self.usb.ExpectRead(b'OKAYsomethin') - self.assertEqual(b'somethin', commands.Oem('a little somethin')) + self.assertEqual(b'somethin', dev.Oem('a little somethin')) def testVariousFailures(self): - commands = fastboot.FastbootCommands(self.usb) + dev = fastboot.FastbootCommands() + dev.ConnectDevice(handle=self.usb) self.usb.ExpectWrite(b'continue') self.usb.ExpectRead(b'BLEH') with self.assertRaises(fastboot.FastbootInvalidResponse): - commands.Continue() + dev.Continue() self.usb.ExpectWrite(b'continue') self.usb.ExpectRead(b'DATA000000') with self.assertRaises(fastboot.FastbootStateMismatch): - commands.Continue() + dev.Continue() if __name__ == '__main__':