Skip to content

Commit

Permalink
chore(pipelined): Make mypy green (#12950)
Browse files Browse the repository at this point in the history
Fix various type issues for pipelined.
- Correct a few wrong argument or variable types
- Add type stubs for mixins
- Ignore some false positives
- Raise ValueError instead of NPE
- Extract method to handle all types of error in tc_ops_pyroute2
- Add null checks in test code

Signed-off-by: Sebastian Thomas <sebastian.thomas@tngtech.com>
  • Loading branch information
sebathomas committed Jun 10, 2022
1 parent a06a965 commit 1abd2ae
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 32 deletions.
1 change: 1 addition & 0 deletions lte/gateway/python/magma/pipelined/app/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ py_library(
py_library(
name = "restart_mixin",
srcs = ["restart_mixin.py"],
deps = [":startup_flows"],
)

py_library(
Expand Down
19 changes: 10 additions & 9 deletions lte/gateway/python/magma/pipelined/app/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import socket
import subprocess
from collections import namedtuple
from typing import Optional
from typing import Optional, Union

import grpc
from lte.protos.mobilityd_pb2 import IPAddress
Expand Down Expand Up @@ -347,12 +347,12 @@ def _install_downlink_arp_flows(

def add_tunnel_flows(
self, precedence: int, i_teid: int,
o_teid: int, enodeb_ip_addr: str,
ue_ip_adr: IPAddress = None, sid: Optional[int] = None,
o_teid: int, enodeb_ip_addr: Union[str, IPAddress],
ue_ip_adr: Optional[IPAddress] = None, sid: Optional[int] = None,
ng_flag: bool = True,
ue_ipv6_address: IPAddress = None,
unused_apn: str = None, unused_vlan: int = 0,
ip_flow_dl: IPFlowDL = None,
ue_ipv6_address: Optional[IPAddress] = None,
unused_apn: Optional[str] = None, unused_vlan: int = 0,
ip_flow_dl: Optional[IPFlowDL] = None,
) -> bool:

priority = Utils.get_of_priority(precedence)
Expand Down Expand Up @@ -472,9 +472,10 @@ def _add_tunnel_ip_flow_dl(
)

def delete_tunnel_flows(
self, i_teid: int, ue_ip_adr: IPAddress = None,
enodeb_ip_addr: str = None,
ip_flow_dl: IPFlowDL = None, ue_ipv6_adr: IPAddress = None,
self, i_teid: int, ue_ip_adr: Optional[IPAddress] = None,
enodeb_ip_addr: Union[None, str, IPAddress] = None,
ip_flow_dl: Optional[IPFlowDL] = None,
ue_ipv6_adr: Optional[IPAddress] = None,
) -> bool:

# Delete flow for gtp port
Expand Down
2 changes: 1 addition & 1 deletion lte/gateway/python/magma/pipelined/app/he.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, *args, **kwargs):
self._ue_rule_counter = UeProxyRuleCounter()
self.logger.info("Header Enrichment app config: %s", self.config)

def _get_config(self, config_dict, mconfig) -> namedtuple:
def _get_config(self, config_dict, mconfig) -> UplinkHEConfig:
he_enabled = config_dict.get('he_enabled', True)
uplink_port = config_dict.get('uplink_port', None)
proxy_port_name = config_dict.get('proxy_port_name')
Expand Down
6 changes: 5 additions & 1 deletion lte/gateway/python/magma/pipelined/app/policy_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
limitations under the License.
"""
from abc import ABCMeta, abstractmethod
from typing import List
from logging import Logger
from typing import Callable, List

from lte.protos.mobilityd_pb2 import IPAddress
from lte.protos.pipelined_pb2 import (
Expand Down Expand Up @@ -56,6 +57,9 @@ class PolicyMixin(metaclass=ABCMeta):
Mixin class for policy enforcement apps that includes common methods
used for rule activation/deactivation.
"""
logger: Logger
tbl_num: int
_get_default_flow_msgs_for_subscriber: Callable

def __init__(self, *args, **kwargs):
super(PolicyMixin, self).__init__(*args, **kwargs)
Expand Down
20 changes: 19 additions & 1 deletion lte/gateway/python/magma/pipelined/app/restart_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import Dict, List
from asyncio import Future
from logging import Logger
from typing import Callable, Dict, List, Optional

from lte.protos.pipelined_pb2 import SetupFlowsResult
from magma.pipelined.app.base import ControllerNotReadyException
from magma.pipelined.app.startup_flows import StartupFlows
from magma.pipelined.openflow import flows
from magma.pipelined.openflow.messages import MessageHub
from magma.pipelined.policy_converters import ovs_flow_match_to_magma_match
from ryu.controller.controller import Datapath
from ryu.ofproto.ofproto_v1_4_parser import OFPFlowStats

DefaultMsgsMap = Dict[int, List[OFPFlowStats]]
Expand All @@ -28,6 +35,17 @@ class RestartMixin(metaclass=ABCMeta):
Mixin class for controller restart handling
"""
logger: Logger
tbl_num: int
cleanup_state: Callable
delete_all_flows: Callable
finish_init: Callable
_msg_hub: MessageHub
_datapath: Datapath
_startup_flow_controller: Optional[StartupFlows]
_startup_flows_fut: Future[StartupFlows]
_clean_restart: bool
_wait_for_responses: Callable

def handle_restart(self, requests) -> SetupFlowsResult:
"""
Expand Down
2 changes: 1 addition & 1 deletion lte/gateway/python/magma/pipelined/app/uplink_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, *args, **kwargs):
self._sgi_ip_mon = None
self._datapath = None

def _get_config(self, config_dict) -> namedtuple:
def _get_config(self, config_dict) -> UplinkBridgeConfig:

enable_nat = config_dict.get('enable_nat', True)
bridge_name = config_dict.get('uplink_bridge', UPLINK_OVS_BRIDGE_NAME)
Expand Down
10 changes: 6 additions & 4 deletions lte/gateway/python/magma/pipelined/ebpf/ebpf_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
limitations under the License.
"""

from __future__ import annotations

import ctypes
import logging
import socket
Expand Down Expand Up @@ -334,17 +336,17 @@ def _pack_mac_addr(self, mac_addr: str):
mac_bytes = bytes.fromhex(mac_addr.replace(':', ''))
return (ctypes.c_ubyte * 6).from_buffer(bytearray(mac_bytes))

def _unpack_mac_addr(self, mac_addr: ctypes.c_ubyte):
def _unpack_mac_addr(self, mac_addr: ctypes.Array[ctypes.c_ubyte]):
mac_bytes = bytearray(mac_addr)
return mac_bytes.hex(":")

def _pack_user_data(self, imsi: str):
user_data = bytearray(imsi, encoding='utf8')
return (ctypes.c_ubyte * 64)(*user_data)

def _unpack_imsi(self, user_data: ctypes.c_ubyte):
user_data = bytearray(user_data)
imsi_bytes = user_data[0:16]
def _unpack_imsi(self, user_data: ctypes.Array[ctypes.c_ubyte]):
user_data_bytearray = bytearray(user_data)
imsi_bytes = user_data_bytearray[0:16]
return imsi_bytes.decode()


Expand Down
2 changes: 1 addition & 1 deletion lte/gateway/python/magma/pipelined/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_hash(s, hash_function) -> bytes:

def encode_str(s: str, encoding_type) -> str:
if encoding_type == PipelineD.HEConfig.BASE64:
s = codecs.encode(codecs.decode(s, 'hex'), 'base64').decode()
s = codecs.encode(codecs.decode(s, 'hex'), 'base64').decode() # type: ignore
elif encoding_type == PipelineD.HEConfig.HEX2BIN:
bits = len(s) * 4
s = bin(int(s, 16))[2:].zfill(bits)
Expand Down
6 changes: 6 additions & 0 deletions lte/gateway/python/magma/pipelined/qos/qos_tc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class TrafficClass:

@staticmethod
def delete_class(intf: str, qid: int, skip_filter=False) -> int:
if not TrafficClass.tc_ops:
raise ValueError("tc_ops not initialized yet")

qid_hex = hex(qid)

if not skip_filter:
Expand All @@ -51,6 +54,9 @@ def create_class(
intf: str, qid: int, max_bw: int, rate=None,
parent_qid=None, skip_filter=False,
) -> int:
if not TrafficClass.tc_ops:
raise ValueError("tc_ops not initialized yet")

if not rate:
rate = DEFAULT_RATE

Expand Down
31 changes: 17 additions & 14 deletions lte/gateway/python/magma/pipelined/qos/tc_ops_pyroute2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import pprint
from typing import Union

from pyroute2 import IPRoute, NetlinkError

Expand Down Expand Up @@ -57,19 +58,17 @@ def create_htb(
LOG.debug("Create HTB iface %s qid %s max_bw %s rate %s", iface, qid, max_bw, rate)
try:
# API needs ceiling in bytes per sec.
max_bw = max_bw / 8
max_bw_bytes = max_bw / 8
if_index = self._get_if_index(iface)
htb_queue = QUEUE_PREFIX + qid
ret = self._ipr.tc(
"add-class", "htb", if_index,
htb_queue, parent=parent_qid,
rate=str(rate).lower(), ceil=max_bw, prio=1,
rate=str(rate).lower(), ceil=max_bw_bytes, prio=1,
)
LOG.debug("Return: %s", ret)
except (ValueError, NetlinkError) as ex:
LOG.error("create-htb error : %s", ex.code)
LOG.debug(ex, exc_info=True)
return ex.code
return log_error_and_get_code(ex, "create-htb")
return 0

def del_htb(self, iface: str, qid: str) -> int:
Expand All @@ -91,9 +90,7 @@ def del_htb(self, iface: str, qid: str) -> int:
ret = self._ipr.tc("del-class", "htb", if_index, htb_queue)
LOG.debug("Return: %s", ret)
except (ValueError, NetlinkError) as ex:
LOG.error("del-htb error error : %s", ex.code)
LOG.debug(ex, exc_info=True)
return ex.code
return log_error_and_get_code(ex, "del-htb")
return 0

def create_filter(self, iface: str, mark: str, qid: str, proto: int = PROTOCOL) -> int:
Expand All @@ -116,9 +113,7 @@ def create_filter(self, iface: str, mark: str, qid: str, proto: int = PROTOCOL)
LOG.debug("Return: %s", ret)

except (ValueError, NetlinkError) as ex:
LOG.error("create-filter error : %s", ex.code)
LOG.debug(ex, exc_info=True)
return ex.code
return log_error_and_get_code(ex, "create-filter")
return 0

def del_filter(self, iface: str, mark: str, qid: str, proto: int = PROTOCOL) -> int:
Expand All @@ -141,9 +136,7 @@ def del_filter(self, iface: str, mark: str, qid: str, proto: int = PROTOCOL) ->
)
LOG.debug("Return: %s", ret)
except (ValueError, NetlinkError) as ex:
LOG.error("del-filter error : %s", ex.code)
LOG.debug(ex, exc_info=True)
return ex.code
return log_error_and_get_code(ex, "del-filter")
return 0

def create(
Expand Down Expand Up @@ -186,3 +179,13 @@ def _print_filters(self, iface):
if_index = self._get_if_index(iface)

pprint.pprint(self._ipr.get_filters(if_index))


def log_error_and_get_code(
ex: Union[ValueError, NetlinkError],
error_type: str,
) -> int:
code = getattr(ex, 'code', -1)
LOG.error("%s error : %s", error_type, code)
LOG.debug(ex, exc_info=True)
return code
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def fail(
stdout=subprocess.PIPE,
shell=True,
)
assert p.stdout is not None
ofctl_dump = p.stdout.read().decode("utf-8", 'ignore').strip()
logging.error("cmd ofctl_dump: %s", ofctl_dump)

Expand Down Expand Up @@ -698,6 +699,7 @@ def get_ovsdb_port_tag(port_name: str) -> Optional[str]:
["ovsdb-client", "dump", "Port", "name", "tag"],
stdout=subprocess.PIPE,
)
assert dump1.stdout is not None
for port in dump1.stdout.readlines():
if port_name not in str(port):
continue
Expand Down
13 changes: 13 additions & 0 deletions lte/gateway/python/magma/pipelined/tests/test_inout_non_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@

from lte.protos.mobilityd_pb2 import GWInfo, IPAddress, IPBlock
from magma.pipelined.app import egress
from magma.pipelined.app.egress import EgressController
from magma.pipelined.app.ingress import IngressController
from magma.pipelined.app.middle import MiddleController
from magma.pipelined.app.testing import TestingController
from magma.pipelined.bridge_util import BridgeTools
from magma.pipelined.service_manager import ServiceManager
from magma.pipelined.tests.app.start_pipelined import (
PipelinedController,
StartThread,
TestSetup,
)
from magma.pipelined.tests.pipelined_test_util import (
Expand Down Expand Up @@ -95,6 +101,13 @@ class InOutNonNatTest(unittest.TestCase):
DHCP_PORT = "tino_dhcp"
UPLINK_VLAN_SW = "vlan_inout"

service_manager: ServiceManager
thread: StartThread
ingress_controller: IngressController
middle_controller: MiddleController
egress_controller: EgressController
testing_controller: TestingController

@classmethod
def setup_uplink_br(cls):
setup_dhcp_server = cls.SCRIPT_PATH + "scripts/setup-test-dhcp-srv.sh"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ def validate_routing_table(dst: str, dev_name: str) -> Optional[str]:
["ip", "r", "get", dst],
stdout=subprocess.PIPE,
)
assert dump1.stdout is not None
for line in dump1.stdout.readlines():
if "dev" not in str(line):
continue
Expand All @@ -1006,6 +1007,7 @@ def validate_routing_table(dst: str, dev_name: str) -> Optional[str]:
["ovs-ofctl", "dump-flows", dev_name],
stdout=subprocess.PIPE,
)
assert dump1.stdout is not None
for line in dump1.stdout.readlines():
print("pbs: %s", line)
assert 0

0 comments on commit 1abd2ae

Please sign in to comment.