Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address WMI hanging #3978

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions monkey/agent_plugins/exploiters/wmi/src/wmi_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Optional

from impacket.dcerpc.v5 import transport
from impacket.dcerpc.v5.dcom import wmi
from impacket.dcerpc.v5.dcomrt import DCOMConnection
from impacket.dcerpc.v5.dtypes import NULL
Expand All @@ -11,6 +12,8 @@

logger = logging.getLogger(__name__)

DCOM_CONNECT_TIMEOUT = 30


def secret_of_type(credentials, type) -> Optional[SecretStr]:
if type is Password and isinstance(credentials.secret, Password):
Expand All @@ -29,6 +32,38 @@ def get_plaintext(secret: Optional[SecretStr]) -> str:
return secret.get_secret_value()


def check_dcom_connection(iInterface, timeout: float):
stringBinding = None
stringBindings = iInterface.get_cinstance().get_string_bindings()
for strBinding in stringBindings:
if strBinding["wTowerId"] == 7:
if strBinding["aNetworkAddr"].find("[") >= 0:
binding, _, bindingPort = strBinding["aNetworkAddr"].partition("[")
bindingPort = "[" + bindingPort
else:
binding = strBinding["aNetworkAddr"]
bindingPort = ""

if binding.upper().find(iInterface.get_target().upper()) >= 0:
stringBinding = "ncacn_ip_tcp:" + strBinding["aNetworkAddr"][:-1]
break
elif (
iInterface.is_fqdn()
and binding.upper().find(iInterface.get_target().upper().partition(".")[0]) >= 0
):
stringBinding = f"ncacn_ip_tcp:{iInterface.get_target()}{bindingPort}"
if stringBinding is None:
raise Exception("Exception occured defining string binding")
try:
rpctransport = transport.DCERPCTransportFactory(stringBinding)
rpctransport.set_connect_timeout(timeout)
rpctransport.connect()
rpctransport.disconnect()
except Exception as err:
logger.debug(f"Exception while connecting to {stringBinding}: {err}")
raise


class WMIClient:
def __init__(self):
self._wbem_services: Optional[wmi.IWbemServices] = None
Expand Down Expand Up @@ -60,6 +95,7 @@ def login(self, host: TargetHost, credentials: Credentials):
iInterface = self._dcom.CoCreateInstanceEx(
wmi.CLSID_WbemLevel1Login, wmi.IID_IWbemLevel1Login
)
check_dcom_connection(iInterface, DCOM_CONNECT_TIMEOUT)
except Exception:
try:
self._dcom.disconnect() # type: ignore[union-attr]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def mock_wbem_login(monkeypatch):
return wbem_login


@pytest.fixture
def mock_dcom_firewall_checker(monkeypatch):
monkeypatch.setattr(
"agent_plugins.exploiters.wmi.src.wmi_client.check_dcom_connection", MagicMock()
)


def test_login__raises_on_dcom_error(mock_dcom_connection, mock_wbem_login):
mock_dcom_connection.CoCreateInstanceEx.side_effect = Exception
wmi_client = WMIClient()
Expand All @@ -41,7 +48,9 @@ def test_login__raises_on_dcom_error(mock_dcom_connection, mock_wbem_login):
assert mock_dcom_connection.CoCreateInstanceEx.called


def test_login__raises_on_wbem_error(mock_dcom_connection, mock_wbem_login, monkeypatch):
def test_login__raises_on_wbem_error(
mock_dcom_connection, mock_wbem_login, monkeypatch, mock_dcom_firewall_checker
):
mock_wbem_login.NTLMLogin.side_effect = Exception

wmi_client = WMIClient()
Expand All @@ -51,7 +60,14 @@ def test_login__raises_on_wbem_error(mock_dcom_connection, mock_wbem_login, monk
assert mock_wbem_login.NTLMLogin.called


def test_login__success(mock_dcom_connection, mock_wbem_login):
def test_login__success(mock_dcom_connection, mock_wbem_login, mock_dcom_firewall_checker):
wmi_client = WMIClient()

wmi_client.login(TARGET_HOST, CREDENTIALS)


def test_login__fail_when_victim_unaccessable(mock_dcom_connection, mock_wbem_login):
wmi_client = WMIClient()

with pytest.raises(Exception):
wmi_client.login(TARGET_HOST, CREDENTIALS)