Skip to content

Commit

Permalink
Add tests for uncovered cases
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <widdis@gmail.com>
  • Loading branch information
dbwiddis committed Sep 9, 2023
1 parent c31b5b8 commit 14343b4
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 15 deletions.
1 change: 1 addition & 0 deletions tests/rest/test_extension_rest_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_initialize_extension_request(self) -> None:
self.assertEqual(err.content, bytes("{}", "ascii"))
self.assertEqual(err.principal_identifier_token, "token")
self.assertEqual(err.http_version, HttpVersion.HTTP_1_1)
self.assertIn("method=RestMethod.GET, uri=/hello?v, path=/hello", str(err))

output = StreamOutput()
err.write_to(output)
Expand Down
2 changes: 2 additions & 0 deletions tests/transport/test_acknowledged_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ class TestAcknowledgedResponse(unittest.TestCase):
def test_initialize_extension_response(self) -> None:
ar = AcknowledgedResponse()
self.assertFalse(ar.status)
self.assertEqual(str(ar), "status=False")

ar = AcknowledgedResponse(True)
self.assertTrue(ar.status)
self.assertEqual(str(ar), "status=True")

out = StreamOutput()
ar.write_to(out)
Expand Down
1 change: 1 addition & 0 deletions tests/transport/test_dicovery_extension_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_discovery_extension_node(self) -> None:
for i in range(2):
self.assertEqual(dependencies[i].unique_id, den.dependencies[i].unique_id)
self.assertEqual(dependencies[i].version.id, den.dependencies[i].version.id)
self.assertIn("id=id, version=0.0.0.0, name=, host=foo.bar, addr=1.2.3.4", str(den))

output = StreamOutput()
den.write_to(output)
Expand Down
2 changes: 2 additions & 0 deletions tests/transport/test_initialize_extension_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def test_initialize_extension_request(self) -> None:
ier = request.read_from(input)
self.assertEqual(ier.source_node.node_id, "opensearch_node")
self.assertEqual(ier.extension.node_id, "extension_node")
self.assertIn("opensearch_node", str(ier))
self.assertIn("extension_node", str(ier))

def test_read_write(self) -> None:
data = NettyTraceData.load("tests/transport/data/initialize_extension_request.txt").data
Expand Down
1 change: 1 addition & 0 deletions tests/transport/test_initialize_extension_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_initialize_extension_response(self) -> None:
ier = InitializeExtensionResponse("test", ["Extension", "ActionExtension"])
self.assertEqual(ier.name, "test")
self.assertListEqual(ier.implemented_interfaces, ["Extension", "ActionExtension"])
self.assertEqual(str(ier), "name=test, interfaces=['Extension', 'ActionExtension']")

out = StreamOutput()
ier.write_to(out)
Expand Down
12 changes: 3 additions & 9 deletions tests/transport/test_outbound_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from opensearch_sdk_py.transport.stream_input import StreamInput
from opensearch_sdk_py.transport.stream_output import StreamOutput
from opensearch_sdk_py.transport.tcp_header import TcpHeader
from opensearch_sdk_py.transport.transport_request import TransportRequest
from opensearch_sdk_py.transport.transport_status import TransportStatus
from opensearch_sdk_py.transport.version import Version

Expand All @@ -31,11 +32,7 @@ def test_outbound_message(self) -> None:
self.assertFalse(om.is_handshake)

def test_outbound_message_stream(self) -> None:
om = OutboundMessage(
request_id=2,
version=Version(3000099),
status=TransportStatus.STATUS_HANDSHAKE,
)
om = OutboundMessage(request_id=2, version=Version(3000099), status=TransportStatus.STATUS_HANDSHAKE, message=TransportRequest())
out = StreamOutput()
om.variable_bytes = b"\x01\x02\x03"
om.write_to(out)
Expand All @@ -54,7 +51,4 @@ def test_outbound_message_stream(self) -> None:
len(out.getvalue()) - TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE,
)
self.assertEqual(om.tcp_header.variable_header_size, 5) # 2 for context, 3 for subclass
self.assertEqual(
om.tcp_header.variable_header_size,
om.tcp_header.size + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE - TcpHeader.HEADER_SIZE,
)
self.assertEqual(om.tcp_header.size + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE - TcpHeader.HEADER_SIZE, om.tcp_header.variable_header_size + len(bytes(TransportRequest())))
1 change: 1 addition & 0 deletions tests/transport/test_outbound_message_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_outbound_message_request_stream(self) -> None:
len(out.getvalue()),
omr.tcp_header.size + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE,
)
self.assertIn("internal:test/handshake", str(omr))

omr = OutboundMessageRequest()
omr.read_from(StreamInput(out.getvalue()))
Expand Down
14 changes: 8 additions & 6 deletions tests/transport/test_stream_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,18 @@ def test_read_optional_string_array(self) -> None:
def test_read_string_to_string_dict(self) -> None:
input = StreamInput(b"\x02\x03foo\x03bar\x03baz\x03qux")
dict = input.read_string_to_string_dict()
self.assertEqual(len(dict), 2)
self.assertEqual(dict["foo"], "bar")
self.assertEqual(dict["baz"], "qux")
self.assertDictEqual(dict, {"foo": "bar", "baz": "qux"})
input = StreamInput(b"\x00")
dict = input.read_string_to_string_dict()
self.assertDictEqual(dict, {})

def test_read_string_to_string_array_dict(self) -> None:
input = StreamInput(b"\x02\x03foo\x02\x03bar\x03baz\x03qux\x00")
dict = input.read_string_to_string_array_dict()
self.assertEqual(len(dict), 2)
self.assertEqual(dict["foo"], ["bar", "baz"])
self.assertEqual(dict["qux"], [])
self.assertDictEqual(dict, {"foo": ["bar", "baz"], "qux": []})
input = StreamInput(b"\x00")
dict = input.read_string_to_string_array_dict()
self.assertDictEqual(dict, {})

def test_read_string_to_string_set_dict(self) -> None:
input = StreamInput(b"\x02\x03foo\x03\x03bar\x03baz\x03bar\x03qux\x00")
Expand Down
8 changes: 8 additions & 0 deletions tests/transport/test_stream_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,38 @@ def test_write_v_int(self) -> None:
out.seek(0, 0)
out.write_v_int(127)
self.assertEqual(out.getvalue(), b"\x7f")
self.assertEqual(StreamOutput.v_int_size(127), 1)
out.seek(0, 0)
out.write_v_int(128)
self.assertEqual(out.getvalue(), b"\x80\x01")
self.assertEqual(StreamOutput.v_int_size(128), 2)
# 14 bit max
out.seek(0, 0)
out.write_v_int(16383)
self.assertEqual(out.getvalue(), b"\xff\x7f")
self.assertEqual(StreamOutput.v_int_size(16383), 2)
out.seek(0, 0)
out.write_v_int(16384)
self.assertEqual(out.getvalue(), b"\x80\x80\x01")
self.assertEqual(StreamOutput.v_int_size(16384), 3)
# 21 bit max
out.seek(0, 0)
out.write_v_int(2097151)
self.assertEqual(out.getvalue(), b"\xff\xff\x7f")
self.assertEqual(StreamOutput.v_int_size(2097151), 3)
out.seek(0, 0)
out.write_v_int(2097152)
self.assertEqual(out.getvalue(), b"\x80\x80\x80\x01")
self.assertEqual(StreamOutput.v_int_size(2097152), 4)
# 28 bit max
out.seek(0, 0)
out.write_v_int(268435455)
self.assertEqual(out.getvalue(), b"\xff\xff\xff\x7f")
self.assertEqual(StreamOutput.v_int_size(268435455), 4)
out.seek(0, 0)
out.write_v_int(268435456)
self.assertEqual(out.getvalue(), b"\x80\x80\x80\x80\x01")
self.assertEqual(StreamOutput.v_int_size(268435456), 5)

def test_write_version(self) -> None:
out = StreamOutput()
Expand Down
2 changes: 2 additions & 0 deletions tests/transport/test_task_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_task_id(self) -> None:
ti2.read_from(input=StreamInput(out.getvalue()))
self.assertEqual(ti.node_id, "test")
self.assertEqual(ti.id, 42)
self.assertEqual(str(ti), "node=test, id=42")

def test_empty_task_id(self) -> None:
ti = TaskId()
Expand All @@ -35,3 +36,4 @@ def test_empty_task_id(self) -> None:
out = StreamOutput()
ti.write_to(out)
self.assertEqual(out.getvalue(), b"\x00")
self.assertEqual(str(ti), "node=, id=None")
5 changes: 5 additions & 0 deletions tests/transport/test_tcp_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@ def test_tcp_header(self) -> None:
header.is_response = True
self.assertFalse(header.is_request)
self.assertTrue(header.is_response)
self.assertIn("response", str(header))
header.is_request = True
self.assertTrue(header.is_request)
self.assertFalse(header.is_response)
self.assertIn("request", str(header))
header.is_error = True
self.assertTrue(header.is_error)
self.assertIn("error", str(header))
header.is_compress = True
self.assertTrue(header.is_compress)
self.assertIn("compressed", str(header))
header.is_handshake = True
self.assertTrue(header.is_handshake)
self.assertIn("handshake", str(header))

def test_tcp_header_stream(self) -> None:
out = StreamOutput()
Expand Down
2 changes: 2 additions & 0 deletions tests/transport/test_transport_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ def test_address_with_host(self) -> None:

def test_address_invalid(self) -> None:
self.assertRaises(AddressValueError, TransportAddress, "1.2.3.4.5", 1234)
ta = TransportAddress()
self.assertRaises(Exception, ta.read_from, input=StreamInput(b"\x05\x01\x02\x03\x04\x05\x09host.name\x00\x00\x04\xd2"))
1 change: 1 addition & 0 deletions tests/transport/test_transport_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ def test_transport_request(self) -> None:
tr.read_from(input=StreamInput(out.getvalue()))
self.assertEqual(tr.parent_task_id.node_id, "test")
self.assertEqual(tr.parent_task_id.id, 42)
self.assertEqual(str(tr), "node=test, id=42")
10 changes: 10 additions & 0 deletions tests/transport/test_transport_service_handshake_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def test_transport_service_handshake_response(self) -> None:
self.assertEqual(tshr.discovery_node.node_id, "id")
self.assertEqual(tshr.cluster_name, "hello-world")
self.assertEqual(tshr.version.id, 136317827)
self.assertIn(str(tshr.discovery_node), str(tshr))
self.assertIn("cluster name=hello-world, version=2.10.0.99", str(tshr))

out = StreamOutput()
tshr.write_to(out)
Expand All @@ -36,6 +38,14 @@ def test_transport_service_handshake_response(self) -> None:
self.assertEqual(tshr.cluster_name, "hello-world")
self.assertEqual(tshr.version.id, 136317827)

tshr.discovery_node = None
out = StreamOutput()
tshr.write_to(out)

input = StreamInput(out.getvalue())
tshr = TransportServiceHandshakeResponse().read_from(input)
self.assertEqual(tshr.discovery_node, None)

def test_read_write_transport_handshake_response(self) -> None:
data = NettyTraceData.load("tests/transport/data/transport_service_handshake_response.txt").data

Expand Down

0 comments on commit 14343b4

Please sign in to comment.