diff --git a/gvm/protocols/gmp.py b/gvm/protocols/gmp.py index 830e27bd4..de252719a 100644 --- a/gvm/protocols/gmp.py +++ b/gvm/protocols/gmp.py @@ -19,7 +19,8 @@ """ Module for communication with gvmd """ -from typing import Any, Optional, Callable, Union +from types import TracebackType +from typing import Any, Optional, Callable, Union, Type from gvm.errors import GvmError @@ -114,8 +115,17 @@ def determine_supported_gmp(self) -> SUPPORTED_GMP_VERSIONS: return gmp_class(self._connection, transform=self._gmp_transform) def __enter__(self): - gmp = self.determine_supported_gmp() + self._gmp = self.determine_supported_gmp() - gmp.connect() + self._gmp.connect() - return gmp + return self._gmp + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Any: + self._gmp.disconnect() + self._gmp = None diff --git a/tests/protocols/gmp/test_context_manager.py b/tests/protocols/gmp/test_context_manager.py index e55d5d6f8..5c01a5093 100644 --- a/tests/protocols/gmp/test_context_manager.py +++ b/tests/protocols/gmp/test_context_manager.py @@ -17,6 +17,7 @@ # along with this program. If not, see . import unittest +from unittest.mock import MagicMock, patch from tests.protocols import GmpTestCase @@ -115,6 +116,21 @@ def test_invalid_response(self): with self.gmp: pass + @patch("gvm.protocols.gmp.Gmpv214") + def test_connect_disconnect(self, gmp_mock: MagicMock): + self.connection.read.return_value( + '' + '21.04' + '' + ) + + with self.gmp: + gmp_mock.assert_called_once() + + mock_instance = gmp_mock.return_value + mock_instance.connect.assert_called_once() + mock_instance.disconnect.assert_called_once() + if __name__ == '__main__': unittest.main()