diff --git a/src/ai/backend/common/service_ports.py b/src/ai/backend/common/service_ports.py index 9716c881..aeeaa663 100644 --- a/src/ai/backend/common/service_ports.py +++ b/src/ai/backend/common/service_ports.py @@ -1,6 +1,9 @@ import re from typing import ( - Sequence, List, Set, Any + List, + Sequence, + Set, + Type, ) from .types import ServicePort, ServicePortProtocols @@ -13,8 +16,10 @@ r'^(?P[\w-]+):(?P\w+):(?P\[\d+(?:,\d+)*\]|\d+)(?:,|$)') -def parse_service_ports(s: str, exception: Any) -> Sequence[ServicePort]: +def parse_service_ports(s: str, exception_cls: Type[Exception] = None) -> Sequence[ServicePort]: items: List[ServicePort] = [] + if exception_cls is None: + exception_cls = ValueError used_ports: Set[int] = set() while True: match = _rx_service_ports.search(s) @@ -22,22 +27,25 @@ def parse_service_ports(s: str, exception: Any) -> Sequence[ServicePort]: s = s[len(match.group(0)):] name = match.group('name') if not name: - raise exception('Service port name must be not empty.') + raise exception_cls('Service port name must be not empty.') protocol = match.group('proto') if protocol == 'pty': # unsupported, skip continue if protocol not in ('tcp', 'http', 'preopen'): - raise exception(f'Unsupported service port protocol: {protocol}') + raise exception_cls(f'Unsupported service port protocol: {protocol}') ports = tuple(map(int, match.group('ports').strip('[]').split(','))) for p in ports: if p in used_ports: - raise exception(f'The port {p} is already used by another service port.') + raise exception_cls(f'The port {p} is already used by another service port.') + if p <= 1024: + raise exception_cls(f'The service port number {p} must be ' + f'larger than 1024 to run without the root privilege.') if p >= 65535: - raise exception(f'The service port number {p} must be smaller than 65535.') + raise exception_cls(f'The service port number {p} must be smaller than 65535.') if p in (2000, 2001, 2002, 2003, 2200, 7681): - raise exception('The service ports 2000 to 2003, 2200 and 7681 ' - 'are reserved for internal use.') + raise exception_cls('The service ports 2000 to 2003, 2200 and 7681 ' + 'are reserved for internal use.') used_ports.add(p) items.append({ 'name': name, @@ -46,7 +54,7 @@ def parse_service_ports(s: str, exception: Any) -> Sequence[ServicePort]: 'host_ports': (None,) * len(ports), }) else: - break - if not s: + if len(s) > 0: + raise exception_cls('Invalid format') break return items diff --git a/tests/test_service_ports.py b/tests/test_service_ports.py new file mode 100644 index 00000000..ac1b490c --- /dev/null +++ b/tests/test_service_ports.py @@ -0,0 +1,82 @@ +import pytest + +from ai.backend.common.service_ports import parse_service_ports + + +def test_parse_service_ports(): + result = parse_service_ports('') + assert len(result) == 0 + + result = parse_service_ports('a:http:1230') + assert len(result) == 1 + assert result[0] == { + 'name': 'a', 'protocol': 'http', + 'container_ports': (1230,), + 'host_ports': (None,), + } + + result = parse_service_ports('a:tcp:[5000,5005]') + assert len(result) == 1 + assert result[0] == { + 'name': 'a', 'protocol': 'tcp', + 'container_ports': (5000, 5005), + 'host_ports': (None, None), + } + + result = parse_service_ports('a:tcp:[1230,1240,9000],x:http:3000,t:http:[5000,5001]') + assert len(result) == 3 + assert result[0] == { + 'name': 'a', 'protocol': 'tcp', + 'container_ports': (1230, 1240, 9000), + 'host_ports': (None, None, None), + } + assert result[1] == { + 'name': 'x', 'protocol': 'http', + 'container_ports': (3000,), + 'host_ports': (None,), + } + assert result[2] == { + 'name': 't', 'protocol': 'http', + 'container_ports': (5000, 5001), + 'host_ports': (None, None), + } + + +def test_parse_service_ports_invalid_values(): + with pytest.raises(ValueError, match="Unsupported"): + parse_service_ports('x:unsupported:1234') + + with pytest.raises(ValueError, match="smaller than"): + parse_service_ports('x:http:65536') + + with pytest.raises(ValueError, match="larger than"): + parse_service_ports('x:http:1000') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('x:http:-1') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('abcdefg') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('x:tcp:1234,abcdefg') + + with pytest.raises(ValueError, match="Invalid format"): + parse_service_ports('abcdefg,x:tcp:1234') + + with pytest.raises(ValueError, match="already used"): + parse_service_ports('x:tcp:1234,y:tcp:1234') + + with pytest.raises(ValueError, match="reserved"): + parse_service_ports('y:tcp:7711,x:tcp:2200') + + +def test_parse_service_ports_custom_exception(): + with pytest.raises(ZeroDivisionError): + parse_service_ports('x:unsupported:1234', ZeroDivisionError) + + +def test_parse_service_ports_ignore_pty(): + result = parse_service_ports('x:pty:1234,y:tcp:1235') + assert len(result) == 1 + assert result[0]['name'] == 'y'