Skip to content

Commit

Permalink
Run black changes on the code
Browse files Browse the repository at this point in the history
  • Loading branch information
jborean93 committed Jun 5, 2024
1 parent 6120cda commit ecc55db
Show file tree
Hide file tree
Showing 11 changed files with 682 additions and 747 deletions.
47 changes: 21 additions & 26 deletions winrm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,35 @@

from winrm.protocol import Protocol

__version__ = '0.5.0'
__version__ = "0.5.0"

# Feature support attributes for multi-version clients.
# These values can be easily checked for with hasattr(winrm, "FEATURE_X"),
# "'auth_type' in winrm.FEATURE_SUPPORTED_AUTHTYPES", etc for clients to sniff features
# supported by a particular version of pywinrm
FEATURE_SUPPORTED_AUTHTYPES = ['basic', 'certificate', 'ntlm', 'kerberos', 'plaintext', 'ssl', 'credssp']
FEATURE_SUPPORTED_AUTHTYPES = ["basic", "certificate", "ntlm", "kerberos", "plaintext", "ssl", "credssp"]
FEATURE_READ_TIMEOUT = True
FEATURE_OPERATION_TIMEOUT = True
FEATURE_PROXY_SUPPORT = True


class Response(object):
"""Response from a remote command execution"""

def __init__(self, args):
self.std_out, self.std_err, self.status_code = args

def __repr__(self):
# TODO put tree dots at the end if out/err was truncated
return '<Response code {0}, out "{1}", err "{2}">'.format(
self.status_code, self.std_out[:20], self.std_err[:20])
return '<Response code {0}, out "{1}", err "{2}">'.format(self.status_code, self.std_out[:20], self.std_err[:20])


class Session(object):
# TODO implement context manager methods
def __init__(self, target, auth, **kwargs):
username, password = auth
self.url = self._build_url(target, kwargs.get('transport', 'plaintext'))
self.protocol = Protocol(self.url,
username=username, password=password, **kwargs)
self.url = self._build_url(target, kwargs.get("transport", "plaintext"))
self.protocol = Protocol(self.url, username=username, password=password, **kwargs)

def run_cmd(self, command, args=()):
# TODO optimize perf. Do not call open/close shell every time
Expand All @@ -52,17 +51,16 @@ def run_ps(self, script):
encoded script command
"""
# must use utf16 little endian on windows
encoded_ps = b64encode(script.encode('utf_16_le')).decode('ascii')
rs = self.run_cmd('powershell -encodedcommand {0}'.format(encoded_ps))
encoded_ps = b64encode(script.encode("utf_16_le")).decode("ascii")
rs = self.run_cmd("powershell -encodedcommand {0}".format(encoded_ps))
if len(rs.std_err):
# if there was an error message, clean it it up and make it human
# readable
rs.std_err = self._clean_error_msg(rs.std_err)
return rs

def _clean_error_msg(self, msg):
"""converts a Powershell CLIXML message to a more human readable string
"""
"""converts a Powershell CLIXML message to a more human readable string"""
# TODO prepare unit test, beautify code
# if the msg does not start with this, return it as is
if msg.startswith(b"#< CLIXML\r\n"):
Expand All @@ -83,41 +81,38 @@ def _clean_error_msg(self, msg):
except Exception as e:
# if any of the above fails, the msg was not true xml
# print a warning and return the original string
warnings.warn(
"There was a problem converting the Powershell error "
"message: %s" % (e))
warnings.warn("There was a problem converting the Powershell error " "message: %s" % (e))
else:
# if new_msg was populated, that's our error message
# otherwise the original error message will be used
if len(new_msg):
# remove leading and trailing whitespace while we are here
return new_msg.strip().encode('utf-8')
return new_msg.strip().encode("utf-8")

# either failed to decode CLIXML or there was nothing to decode
# just return the original message
return msg

def _strip_namespace(self, xml):
"""strips any namespaces from an xml string"""
p = re.compile(b"xmlns=*[\"\"][^\"\"]*[\"\"]")
p = re.compile(b'xmlns=*[""][^""]*[""]')
allmatches = p.finditer(xml)
for match in allmatches:
xml = xml.replace(match.group(), b"")
return xml

@staticmethod
def _build_url(target, transport):
match = re.match(
r'(?i)^((?P<scheme>http[s]?)://)?(?P<host>[0-9a-z-_.]+)(:(?P<port>\d+))?(?P<path>(/)?(wsman)?)?', target) # NOQA
scheme = match.group('scheme')
match = re.match(r"(?i)^((?P<scheme>http[s]?)://)?(?P<host>[0-9a-z-_.]+)(:(?P<port>\d+))?(?P<path>(/)?(wsman)?)?", target) # NOQA
scheme = match.group("scheme")
if not scheme:
# TODO do we have anything other than HTTP/HTTPS
scheme = 'https' if transport == 'ssl' else 'http'
host = match.group('host')
port = match.group('port')
scheme = "https" if transport == "ssl" else "http"
host = match.group("host")
port = match.group("port")
if not port:
port = 5986 if transport == 'ssl' else 5985
path = match.group('path')
port = 5986 if transport == "ssl" else 5985
path = match.group("path")
if not path:
path = 'wsman'
return '{0}://{1}:{2}/{3}'.format(scheme, host, port, path.lstrip('/'))
path = "wsman"
return "{0}://{1}:{2}/{3}".format(scheme, host, port, path.lstrip("/"))
74 changes: 36 additions & 38 deletions winrm/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Encryption(object):

SIXTEN_KB = 16384
MIME_BOUNDARY = b'--Encrypted Boundary'
MIME_BOUNDARY = b"--Encrypted Boundary"

def __init__(self, session, protocol):
"""
Expand All @@ -36,15 +36,15 @@ def __init__(self, session, protocol):
self.protocol = protocol
self.session = session

if protocol == 'ntlm': # Details under Negotiate [2.2.9.1.1] in MS-WSMV
if protocol == "ntlm": # Details under Negotiate [2.2.9.1.1] in MS-WSMV
self.protocol_string = b"application/HTTP-SPNEGO-session-encrypted"
self._build_message = self._build_ntlm_message
self._decrypt_message = self._decrypt_ntlm_message
elif protocol == 'credssp': # Details under CredSSP [2.2.9.1.3] in MS-WSMV
elif protocol == "credssp": # Details under CredSSP [2.2.9.1.3] in MS-WSMV
self.protocol_string = b"application/HTTP-CredSSP-session-encrypted"
self._build_message = self._build_credssp_message
self._decrypt_message = self._decrypt_credssp_message
elif protocol == 'kerberos':
elif protocol == "kerberos":
self.protocol_string = b"application/HTTP-SPNEGO-session-encrypted"
self._build_message = self._build_kerberos_message
self._decrypt_message = self._decrypt_kerberos_message
Expand All @@ -63,23 +63,22 @@ def prepare_encrypted_request(self, session, endpoint, message):
"""
host = urlsplit(endpoint).hostname

if self.protocol == 'credssp' and len(message) > self.SIXTEN_KB:
content_type = 'multipart/x-multi-encrypted'
encrypted_message = b''
message_chunks = [message[i:i+self.SIXTEN_KB] for i in range(0, len(message), self.SIXTEN_KB)]
if self.protocol == "credssp" and len(message) > self.SIXTEN_KB:
content_type = "multipart/x-multi-encrypted"
encrypted_message = b""
message_chunks = [message[i : i + self.SIXTEN_KB] for i in range(0, len(message), self.SIXTEN_KB)]
for message_chunk in message_chunks:
encrypted_chunk = self._encrypt_message(message_chunk, host)
encrypted_message += encrypted_chunk
else:
content_type = 'multipart/encrypted'
content_type = "multipart/encrypted"
encrypted_message = self._encrypt_message(message, host)
encrypted_message += self.MIME_BOUNDARY + b"--\r\n"

request = requests.Request('POST', endpoint, data=encrypted_message)
request = requests.Request("POST", endpoint, data=encrypted_message)
prepared_request = session.prepare_request(request)
prepared_request.headers['Content-Length'] = str(len(prepared_request.body))
prepared_request.headers['Content-Type'] = '{0};protocol="{1}";boundary="Encrypted Boundary"'\
.format(content_type, self.protocol_string.decode())
prepared_request.headers["Content-Length"] = str(len(prepared_request.body))
prepared_request.headers["Content-Type"] = '{0};protocol="{1}";boundary="Encrypted Boundary"'.format(content_type, self.protocol_string.decode())

return prepared_request

Expand All @@ -90,7 +89,7 @@ def parse_encrypted_response(self, response):
:param response: The response that needs to be decrypted
:return: The unencrypted message from the server
"""
content_type = response.headers['Content-Type']
content_type = response.headers["Content-Type"]
if 'protocol="{0}"'.format(self.protocol_string.decode()) in content_type:
host = urlsplit(response.request.url).hostname
msg = self._decrypt_response(response, host)
Expand All @@ -103,19 +102,19 @@ def _encrypt_message(self, message, host):
message_length = str(len(message)).encode()
encrypted_stream = self._build_message(message, host)

message_payload = self.MIME_BOUNDARY + b"\r\n" \
b"\tContent-Type: " + self.protocol_string + b"\r\n" \
b"\tOriginalContent: type=application/soap+xml;charset=UTF-8;Length=" + message_length + b"\r\n" + \
self.MIME_BOUNDARY + b"\r\n" \
b"\tContent-Type: application/octet-stream\r\n" + \
encrypted_stream
message_payload = (
self.MIME_BOUNDARY + b"\r\n"
b"\tContent-Type: " + self.protocol_string + b"\r\n"
b"\tOriginalContent: type=application/soap+xml;charset=UTF-8;Length=" + message_length + b"\r\n" + self.MIME_BOUNDARY + b"\r\n"
b"\tContent-Type: application/octet-stream\r\n" + encrypted_stream
)

return message_payload

def _decrypt_response(self, response, host):
parts = response.content.split(self.MIME_BOUNDARY + b'\r\n')
parts = response.content.split(self.MIME_BOUNDARY + b"\r\n")
parts = list(filter(None, parts)) # filter out empty parts of the split
message = b''
message = b""

for i in range(0, len(parts)):
if i % 2 == 1:
Expand All @@ -124,27 +123,26 @@ def _decrypt_response(self, response, host):
header = parts[i].strip()
payload = parts[i + 1]

expected_length = int(header.split(b'Length=')[1])
expected_length = int(header.split(b"Length=")[1])

# remove the end MIME block if it exists
if payload.endswith(self.MIME_BOUNDARY + b'--\r\n'):
payload = payload[:len(payload) - 24]
if payload.endswith(self.MIME_BOUNDARY + b"--\r\n"):
payload = payload[: len(payload) - 24]

encrypted_data = payload.replace(b'\tContent-Type: application/octet-stream\r\n', b'')
encrypted_data = payload.replace(b"\tContent-Type: application/octet-stream\r\n", b"")
decrypted_message = self._decrypt_message(encrypted_data, host)
actual_length = len(decrypted_message)

if actual_length != expected_length:
raise WinRMError('Encrypted length from server does not match the '
'expected size, message has been tampered with')
raise WinRMError("Encrypted length from server does not match the " "expected size, message has been tampered with")
message += decrypted_message

return message

def _decrypt_ntlm_message(self, encrypted_data, host):
signature_length = struct.unpack("<i", encrypted_data[:4])[0]
signature = encrypted_data[4:signature_length + 4]
encrypted_message = encrypted_data[signature_length + 4:]
signature = encrypted_data[4 : signature_length + 4]
encrypted_message = encrypted_data[signature_length + 4 :]

message = self.session.auth.session_security.unwrap(encrypted_message, signature)

Expand All @@ -161,8 +159,8 @@ def _decrypt_credssp_message(self, encrypted_data, host):

def _decrypt_kerberos_message(self, encrypted_data, host):
signature_length = struct.unpack("<i", encrypted_data[:4])[0]
signature = encrypted_data[4:signature_length + 4]
encrypted_message = encrypted_data[signature_length + 4:]
signature = encrypted_data[4 : signature_length + 4]
encrypted_message = encrypted_data[signature_length + 4 :]

message = self.session.auth.unwrap_winrm(host, encrypted_message, signature)

Expand Down Expand Up @@ -195,25 +193,25 @@ def _get_credssp_trailer_length(self, message_length, cipher_suite):
# but there is no GSSAPI/OpenSSL equivalent so we need to calculate it
# ourselves

if re.match(r'^.*-GCM-[\w\d]*$', cipher_suite):
if re.match(r"^.*-GCM-[\w\d]*$", cipher_suite):
# We are using GCM for the cipher suite, GCM has a fixed length of 16
# bytes for the TLS trailer making it easy for us
trailer_length = 16
else:
# We are not using GCM so need to calculate the trailer size. The
# trailer length is equal to the length of the hmac + the length of the
# padding required by the block cipher
hash_algorithm = cipher_suite.split('-')[-1]
hash_algorithm = cipher_suite.split("-")[-1]

# while there are other algorithms, SChannel doesn't support them
# as of yet https://msdn.microsoft.com/en-us/library/windows/desktop/aa374757(v=vs.85).aspx
if hash_algorithm == 'MD5':
if hash_algorithm == "MD5":
hash_length = 16
elif hash_algorithm == 'SHA':
elif hash_algorithm == "SHA":
hash_length = 20
elif hash_algorithm == 'SHA256':
elif hash_algorithm == "SHA256":
hash_length = 32
elif hash_algorithm == 'SHA384':
elif hash_algorithm == "SHA384":
hash_length = 48
else:
hash_length = 0
Expand Down
9 changes: 6 additions & 3 deletions winrm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@


class WinRMError(Exception):
""""Generic WinRM error"""
""" "Generic WinRM error"""

code = 500


Expand All @@ -19,7 +20,7 @@ def code(self):

@property
def message(self):
return 'Bad HTTP response returned from server. Code {0}'.format(self.code)
return "Bad HTTP response returned from server. Code {0}".format(self.code)

@property
def response_text(self):
Expand All @@ -35,16 +36,18 @@ class WinRMOperationTimeoutError(Exception):
considered a normal error that should be retried transparently by the client when waiting for output from
a long-running process.
"""

code = 500


class AuthenticationError(WinRMError):
"""Authorization Error"""

code = 401


class BasicAuthDisabledError(AuthenticationError):
message = 'WinRM/HTTP Basic authentication is not enabled on remote host'
message = "WinRM/HTTP Basic authentication is not enabled on remote host"


class InvalidCredentialsError(AuthenticationError):
Expand Down
Loading

0 comments on commit ecc55db

Please sign in to comment.