diff --git a/.github/workflows/py-workflow.yml b/.github/workflows/py-workflow.yml new file mode 100644 index 000000000..5c38aad78 --- /dev/null +++ b/.github/workflows/py-workflow.yml @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2023 Canonical Ltd. + +name: Python Lint and Test + +on: [push, pull_request] + +jobs: + route-control-tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.11' + + - name: Create virtual environment + run: python -m venv venv + + - name: Install dependencies + run: | + source ./venv/bin/activate + pip install -r ./conf/test-requirements.txt + + - name: Run tests + run: | + source ./venv/bin/activate + python -m unittest ./conf/test_route_control.py diff --git a/.gitignore b/.gitignore index 0244ca37e..2c91cf4d9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright 2019 Intel Corporation +# Copyright 2023 Canonical Ltd. + .idea/ .coverage/ *.pyc @@ -8,3 +10,4 @@ output dpdk-devbind.py dpdk-hugepages.py dictionary.dic +venv/ diff --git a/conf/route_control.py b/conf/route_control.py index 70629c575..ed43e82ad 100755 --- a/conf/route_control.py +++ b/conf/route_control.py @@ -1,530 +1,853 @@ #!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 # Copyright 2019 Intel Corporation +# Copyright 2023 Canonical Ltd. import argparse +import ipaddress +import logging import signal import sys import time -import ipaddress +from collections import defaultdict +from dataclasses import dataclass, field +from threading import Lock, Thread +from typing import Dict, List, Optional, Tuple -# for retrieving neighbor info +from pybess.bess import * from pyroute2 import IPDB, IPRoute +from scapy.all import ICMP, IP, send -from scapy.all import * - -try: - from pybess.bess import * -except ImportError: - print('Cannot import the API module (pybess)') - raise - -MAX_RETRIES = 5 -SLEEP_S = 2 - - -class NeighborEntry: - def __init__(self): - self.neighbor_ip = None - self.iface = None - self.iprange = None - self.prefix_len = None - self.route_count = 0 - self.gate_idx = 0 - self.macstr = None - - def __str__(self): - return ('{neigh: %s, iface: %s, ip-range: %s/%s}' % - (self.neighbor_ip, self.iface, self.iprange, self.prefix_len)) - - -def mac2hex(mac): - return int(mac.replace(':', ''), 16) - - -def send_ping(neighbor_ip): - send(IP(dst=neighbor_ip) / ICMP()) - +LOG_FORMAT = "%(asctime)s %(levelname)s %(message)s" +logging.basicConfig(format=LOG_FORMAT, level=logging.INFO) +logger = logging.getLogger(__name__) -def send_arp(neighbor_ip, src_mac, iface): - pkt = Ether(dst="ff:ff:ff:ff:ff:ff") / ARP(pdst=neighbor_ip, hwsrc=src_mac) - pkt.show() - hexdump(pkt) - sendp(pkt, iface=iface) +KEY_NETWORK_LAYER_DEST_ADDR = "NDA_DST" +KEY_LINK_LAYER_ADDRESS = "NDA_LLADDR" +KEY_NEW_NEIGHBOR_ACTION = "RTM_NEWNEIGH" +KEY_DELETE_ROUTE_ACTION = "RTM_DELROUTE" +KEY_NEW_ROUTE_ACTION = "RTM_NEWROUTE" +KEY_INTERFACE = "RTA_OIF" +KEY_DESTINATION_IP = "RTA_DST" +KEY_DESTINATION_GATEWAY_IP = "RTA_GATEWAY" +KEY_DESTINATION_PREFIX_LENGTH = "dst_len" -def fetch_mac(dip): - ip = '' - _mac = '' - neighbors = ipr.get_neighbours(dst=dip) - for i in range(len(neighbors)): - for att in neighbors[i]['attrs']: - if 'NDA_DST' in att and dip == att[1]: - # ('NDA_DST', dip) - ip = att[1] - if 'NDA_LLADDR' in att: - # ('NDA_LLADDR', _mac) - _mac = att[1] - return _mac +@dataclass +class RouteEntry: + """A representation of a neighbor in route entry.""" + next_hop_ip: str = field(default=None) + interface: str = field(default=None) + dest_prefix: str = field(default=None) + prefix_len: int = field(default=0) -def link_modules(server, module, next_module, ogate=0, igate=0): - print('Linking {} module'.format(next_module)) - - # Pause bess first - bess.pause_all() - # Connect module to next_module - for _ in range(MAX_RETRIES): - try: - server.connect_modules(module, next_module, ogate, igate) - except BESS.Error as e: - bess.resume_all() - if e.code == errno.EBUSY: - break +@dataclass +class NeighborEntry: + """A representation of a neighbor in neighbor cache.""" + gate_idx: int = field(default=0) + mac_address: str = field(default=None) + route_count: int = field(default=0) + + +class BessController: + """Wraps commands from bess client.""" + MAX_RETRIES = 5 + SLEEP_S = 2 + + def __init__(self, bess_ip: str, bess_port: str) -> None: + """Initializes the BESS controller. + + Args: + bess_ip (str): The IP address of the BESS daemon. + bess_port (str): The port of the BESS daemon. + """ + self._bess = self._get_bess(ip=bess_ip, port=bess_port) + + def _get_bess(self, ip: str, port: str) -> "BESS": + """Connects to the BESS daemon.""" + bess = BESS() + for _ in range(self.MAX_RETRIES): + try: + if not bess.is_connected(): + bess.connect(grpc_url=ip + ":" + port) + except BESS.RPCError: + logger.error( + "Error connecting to BESS daemon. Retrying in %s sec...", + self.SLEEP_S, + ) + time.sleep(self.SLEEP_S) + except Exception as e: + logger.exception("Error connecting to BESS daemon") + raise Exception("BESS connection failure.", e) else: - return #raise - except Exception as e: - print( - 'Error connecting module {}:{}->{}:{}: {}. Retrying in {} secs...' - .format(module, ogate, igate, next_module, e, SLEEP_S)) - time.sleep(SLEEP_S) + logger.info("Connected to BESS daemon") + return bess else: - bess.resume_all() - break - else: - bess.resume_all() - print('BESS module connection ({}:{}->{}:{}) failure.'.format( - module, ogate, igate, next_module)) - return - #raise Exception('BESS module connection ({}:{}->{}:{}) failure.'. - # format(module, ogate, igate, next_module)) - - -def link_route_module(server, gateway_mac, item): - iprange = item.iprange - prefix_len = item.prefix_len - route_module = item.iface + 'Routes' - last_module = item.iface + 'Merge' - gateway_mac_str = '{:X}'.format(gateway_mac) - print('Adding route entry {}/{} for {}'.format(iprange, prefix_len, - route_module)) - - print('Trying to retrieve neighbor entry {} from neighbor cache'.format( - item.neighbor_ip)) - neighbor_exists = neighborcache.get(item.neighbor_ip) - - # How many gates does this module have? - # If entry does not exist, then initialize it - if not modgatecnt.get(route_module): - modgatecnt[route_module] = 0 - - # Compute likely index - if neighbor_exists: - # No need to create a new Update module - gate_idx = neighbor_exists.gate_idx - else: - # Need to create a new Update module, - # so get gate_idx from gate count - gate_idx = modgatecnt[route_module] - - # Pause bess first - bess.pause_all() - # Pass routing entry to bessd's route module - for _ in range(MAX_RETRIES): - try: - server.run_module_command(route_module, 'add', - 'IPLookupCommandAddArg', { - 'prefix': iprange, - 'prefix_len': int(prefix_len), - 'gate': gate_idx - }) - except: - print('Error adding route entry {}/{} in {}. Retrying in {}sec...'. - format(iprange, prefix_len, route_module, SLEEP_S)) - time.sleep(SLEEP_S) + raise Exception( + "BESS connection failure after {} attempts.".format(self.MAX_RETRIES) + ) + + def add_route_to_module( + self, route_entry: RouteEntry, gate_idx: int, module_name: str + ) -> None: + """Adds a route entry to BESS. + + Args: + route_entry (RouteEntry): Entry to be added to BESS module. + gate_idx (int): Gate of the module used in the route. + module_name (str): The name of the module. + """ + for _ in range(self.MAX_RETRIES): + try: + self._bess.pause_all() + self._bess.run_module_command( + module_name, + "add", + "IPLookupCommandAddArg", + { + "prefix": route_entry.dest_prefix, + "prefix_len": route_entry.prefix_len, + "gate": gate_idx, + }, + ) + except Exception: + logger.exception( + "Error adding route entry %s/%i in %s. Retrying in %i sec...", + route_entry.dest_prefix, + route_entry.prefix_len, + module_name, + self.SLEEP_S, + ) + time.sleep(self.SLEEP_S) + else: + logger.info( + "Route entry %s/%i added to %s", + route_entry.dest_prefix, + route_entry.prefix_len, + module_name, + ) + break + finally: + self._bess.resume_all() + else: + raise Exception( + "BESS route entry ({}/{}) insertion failure in module {}".format( + route_entry.dest_prefix, + route_entry.prefix_len, + module_name, + ) + ) + + def delete_module_route_entry(self, route_entry: RouteEntry) -> None: + """Deletes a route entry from BESS module. + + Args: + route_entry (RouteEntry): The neighbor entry. + """ + route_module = route_entry.interface + "Routes" + for _ in range(self.MAX_RETRIES): + try: + self._bess.pause_all() + self._bess.run_module_command( + route_module, + "delete", + "IPLookupCommandDeleteArg", + { + "prefix": route_entry.dest_prefix, + "prefix_len": int(route_entry.prefix_len), + }, + ) + except Exception: + logger.exception( + "Error deleting route entry for %s. Retrying in %i sec...", + route_module, + self.SLEEP_S, + ) + time.sleep(self.SLEEP_S) + else: + logger.info("Route entry deleted for %s", route_module) + break + finally: + self._bess.resume_all() else: - bess.resume_all() - break - else: - bess.resume_all() - print('BESS route entry ({}/{}) insertion failure in module {}'.format( - iprange, prefix_len, route_module)) - return - #raise Exception('BESS route entry ({}/{}) insertion failure in module {}'. - # format(iprange, prefix_len, route_module)) - - if not neighbor_exists: - print('Neighbor does not exist') - # Create Update module - update_module = route_module + 'DstMAC' + gateway_mac_str - - # Pause bess first - bess.pause_all() - for _ in range(MAX_RETRIES): + raise Exception( + "BESS route entry ({}/{}) deletion failure in module {}".format( + route_entry.dest_prefix, + route_entry.prefix_len, + route_module, + ) + ) + + def create_module( + self, module_name: str, module_class: str, gateway_mac: int + ) -> None: + """Creates a BESS module. + + Args: + gateway_mac (int): MAC address of the gateway as an int. + update_module_name (str): The name of the module. + module_class (str): The class of the module. + """ + for _ in range(self.MAX_RETRIES): try: - server.create_module('Update', update_module, { - 'fields': [{ - 'offset': 0, - 'size': 6, - 'value': gateway_mac - }] - }) + self._bess.pause_all() + self._bess.create_module( + module_class, + module_name, + {"fields": [{"offset": 0, "size": 6, "value": gateway_mac}]}, + ) except BESS.Error as e: - bess.resume_all() if e.code == errno.EEXIST: + logger.error("Module %s already exists", module_name) break else: - return #raise - except Exception as e: - print( - 'Error creating update module {}: {}. Retrying in {} secs...' - .format(update_module, e, SLEEP_S)) - time.sleep(SLEEP_S) + raise Exception( + "Unknown error when inserting {}: {}".format( + module_name, e + ) + ) + except Exception: + logger.exception( + "Error creating update module %s, retrying in %i secs", + module_name, + self.SLEEP_S, + ) + time.sleep(self.SLEEP_S) else: - bess.resume_all() + logger.info("Add Update module %s successfully", module_name) break + finally: + self._bess.resume_all() else: - bess.resume_all() - print('BESS module {} creation failure.'.format(update_module)) - return #raise Exception('BESS module {} creation failure.'. - # format(update_module)) - - print('Update module created') - - # Connect Update module to route module - link_modules(server, route_module, update_module, gate_idx, 0) - - # Connect Update module to dpdk_out module - link_modules(server, update_module, last_module, 0, 0) - - # Add a new neighbor in neighbor cache - neighborcache[item.neighbor_ip] = item - - # Add a record of the affliated gate id - item.gate_idx = gate_idx - - # Set the mac str - item.macstr = gateway_mac_str - - # Increment global gate count number - modgatecnt[route_module] += 1 - - neighbor_exists = item - - else: - print('Neighbor already exists') - - # Finally increment route count - neighborcache[item.neighbor_ip].route_count += 1 - - -def del_route_entry(server, item): - iprange = item.iprange - prefix_len = item.prefix_len - route_module = item.iface + 'Routes' - - neighbor_exists = neighborcache.get(item.neighbor_ip) - if neighbor_exists: - # Pause bess first - bess.pause_all() - # Delete routing entry from bessd's route module - for i in range(MAX_RETRIES): + raise Exception( + "BESS module {} creation failure.".format(module_name) + ) + + def link_modules(self, module, next_module, ogate, igate) -> None: + """Links two BESS modules together. + + Args: + module (str): The name of the first module. + next_module (str): The name of the second module. + ogate (int, optional): The output gate of the first module. + igate (int, optional): The input gate of the second module. + """ + for _ in range(self.MAX_RETRIES): try: - server.run_module_command(route_module, 'delete', - 'IPLookupCommandDeleteArg', { - 'prefix': iprange, - 'prefix_len': int(prefix_len) - }) - except: - print( - 'Error while deleting route entry for {}. Retrying in {} sec...' - .format(route_module, SLEEP_S)) - time.sleep(SLEEP_S) + self._bess.pause_all() + self._bess.connect_modules(module, next_module, ogate, igate) + except BESS.Error as e: + logger.exception("Got BESS error") + if e.code == errno.EBUSY: + logger.error( + "Got code EBUSY. Retrying in %i secs...", self.SLEEP_S + ) + time.sleep(self.SLEEP_S) + else: + raise Exception( + "Unknown error when linking modules: {}".format(e) + ) + except Exception: + logger.exception( + "Error linking module: %s:%i->%i/%s. Retrying in %s secs...", + module, + ogate, + igate, + next_module, + self.SLEEP_S, + ) + time.sleep(self.SLEEP_S) else: - bess.resume_all() + logger.info( + "Module %s:%i->%i/%s linked", + module, + ogate, + igate, + next_module, + ) break + finally: + self._bess.resume_all() else: - bess.resume_all() - print('Route entry deletion failure.') - return - #raise Exception('Route entry deletion failure.') - - print('Route entry {}/{} deleted from {}'.format( - iprange, prefix_len, route_module)) - - # Decrementing route count for the registered neighbor - neighbor_exists.route_count -= 1 - - # If route count is 0, then delete the whole module - if neighbor_exists.route_count == 0: - update_module = route_module + 'DstMAC' + neighbor_exists.macstr - # Pause bess first - bess.pause_all() - for i in range(MAX_RETRIES): - try: - server.destroy_module(update_module) - except: - print('Error destroying module {}. Retrying in {}sec...'. - format(update_module, SLEEP_S)) - time.sleep(SLEEP_S) - else: - bess.resume_all() - break + raise Exception( + "BESS module connection ({}:{}->{}:{}) failure.".format( + module, ogate, igate, next_module + ) + ) + + def delete_module(self, module_name: str) -> None: + """Deletes a BESS module. + + Args: + update_module (str): The name of the module to delete. + """ + for _ in range(self.MAX_RETRIES): + try: + self._bess.pause_all() + self._bess.destroy_module(module_name) + except Exception: + logger.exception( + "Error destroying module %s. Retrying in %i sec...", + module_name, + self.SLEEP_S, + ) + time.sleep(self.SLEEP_S) else: - bess.resume_all() - print('Module {} deletion failure.'.format(update_module)) - return - #raise Exception('Module {} deletion failure.'. - # format(update_module)) - - print('Module {} destroyed'.format(update_module)) - - # Delete entry from the neighbor cache - del neighborcache[item.neighbor_ip] - print('Deleting item from neighborcache') - del neighbor_exists - else: - print('Route count for {} decremented to {}'.format( - item.neighbor_ip, neighbor_exists.route_count)) - neighborcache[item.neighbor_ip] = neighbor_exists - else: - print('Neighbor {} does not exist'.format(item.neighbor_ip)) - - -def probe_addr(item, src_mac): - # Store entry if entry does not exist in ARP cache - arpcache[item.neighbor_ip] = item - print('Adding entry {} in arp probe table'.format(item)) - - try: - ipb = ipaddress.ip_address(item.neighbor_ip) - if isinstance(ipb, ipaddress.IPv4Address): - print("The IP address {} is valid ipv4 address".format(ipb)) + logger.info("Module %s destroyed", module_name) + break + finally: + self._bess.resume_all() else: - print("The IP address {} is valid ipv6 address. Ignore ".format(ipb)) - return - except: - print("The IP address {} is invalid".format(item.neighbor_ip)) - return - # Probe ARP request by sending ping - send_ping(item.neighbor_ip) - - # Probe ARP request - ##send_arp(neighbor_ip, src_mac, item.iface) - - -def parse_new_route(msg): - item = NeighborEntry() - # Fetch prefix_len - item.prefix_len = msg['dst_len'] - # Default route - if item.prefix_len == 0: - item.iprange = '0.0.0.0' - - for att in msg['attrs']: - if 'RTA_DST' in att: - # Fetch IP range - # ('RTA_DST', iprange) - item.iprange = att[1] - if 'RTA_GATEWAY' in att: - # Fetch gateway MAC address - # ('RTA_GATEWAY', neighbor_ip) - item.neighbor_ip = att[1] - _mac = fetch_mac(att[1]) - if not _mac: - gateway_mac = 0 - else: - gateway_mac = mac2hex(_mac) - if 'RTA_OIF' in att: - # Fetch interface name - # ('RTA_OIF', iface) - item.iface = ipdb.interfaces[int(att[1])].ifname - - if not item.iface in args.i or not item.iprange or not item.neighbor_ip: - # Neighbor info is invalid - del item - return - - # if mac is 0, send ARP request - if gateway_mac == 0: - print('Adding entry {} in arp probe table. Neighbor: {}'.format(item.iface,item.neighbor_ip)) - probe_addr(item, ipdb.interfaces[item.iface].address) - - else: # if gateway_mac is set - print('Linking module {}Routes with {}Merge (Dest MAC: {})'.format( - item.iface, item.iface, _mac)) - - link_route_module(bess, gateway_mac, item) - - -def parse_new_neighbor(msg): - for att in msg['attrs']: - if 'NDA_DST' in att: - # ('NDA_DST', neighbor_ip) - neighbor_ip = att[1] - if 'NDA_LLADDR' in att: - # ('NDA_LLADDR', neighbor_mac) - gateway_mac = att[1] - - item = arpcache.get(neighbor_ip) - if item: - print('Linking module {}Routes with {}Merge (Dest MAC: {})'.format( - item.iface, item.iface, gateway_mac)) - - # Add route entry, and add item in the registered neighbor cache - link_route_module(bess, mac2hex(gateway_mac), item) + raise Exception("Module {} deletion failure.".format(module_name)) + + +class RouteController: + """Provides an interface to manage routes from netlink messages. + + Listens for netlink events and handling them. + Creates BESS modules for route entries.""" + + MAX_GATES = 8192 + def __init__( + self, + bess_controller: BessController, + ipdb: IPDB, + ipr: IPRoute, + interfaces: List[str], + ): + """ + Initializes the route controller. + + Args: + bess_controller (BessController): + Controller for BESS (Berkeley Extensible Software Switch). + route_parser (RouteEntryParser): Parser for route entries. + ipdb (IPDB): IP database to manage IP configurations. + ipr (IPRoute): IP routing control object. + + Attributes: + _unresolved_arp_queries_cache (dict[str, RouteEntry]): + A cache to store unresolved ARP queries. + _neighbor_cache (dict[str, RouteEntry]): + A cache to keep track of entries add in Bess. + _module_gate_count_cache (Dict[str, int]): + A cache for counting module gate occurrences. + """ + self._unresolved_arp_queries_cache: Dict[str, RouteEntry] = {} + self._neighbor_cache: Dict[str, NeighborEntry] = {} + self._module_gate_count_cache: Dict[str, int] = defaultdict(lambda: 0) + + self._lock = Lock() + + self._ipdb = ipdb + self._ipr = ipr + self._bess_controller = bess_controller + self._ping_missing_thread = Thread(target=self._ping_missing_entries, daemon=True) + self._event_callback = None + self._interfaces = interfaces + + def register_callbacks(self) -> None: + """Register callback function.""" + logger.info("Registering netlink event listener callback...") + self._event_callback = self._ipdb.register_callback(self._netlink_event_listener) + + def start_pinging_missing_entries(self) -> None: + """Starts a new thread for ping missing entries.""" + if not self._ping_missing_thread or not self._ping_missing_thread.is_alive(): + self._ping_missing_thread.start() + logger.info("Ping missing entries thread started") + + def bootstrap_routes(self) -> None: + """Goes through all routes and handles new ones..""" + routes = self._ipr.get_routes() + for route in routes: + if route["event"] == KEY_NEW_ROUTE_ACTION: + if route_entry := self._parse_route_entry_msg(route): + with self._lock: + self.add_new_route_entry(route_entry) + + def add_new_route_entry(self, route_entry: RouteEntry) -> None: + """Handles a new route entry. + + Args: + route_entry (RouteEntry): The route entry. + """ + if not (next_hop_mac := fetch_mac(self._ipdb, route_entry.next_hop_ip)): + logger.info( + "mac address of the next hop %s is not stored in ARP table. Probing...", + route_entry.next_hop_ip, + ) + self._probe_addr(route_entry) + return - # Remove entry from unresolved arp cache - del arpcache[neighbor_ip] + self._add_neighbor(route_entry, next_hop_mac) + def _add_neighbor( + self, route_entry: RouteEntry, next_hop_mac: str + ) -> None: + """Adds the route in BESS module. + Creates required BESS modules. -def parse_del_route(msg): - item = NeighborEntry() - for att in msg['attrs']: - if 'RTA_DST' in att: - # Fetch IP range - # ('RTA_DST', iprange) - item.iprange = att[1] - if 'RTA_GATEWAY' in att: - # Fetch gateway MAC address - # ('RTA_GATEWAY', neighbor_ip) - item.neighbor_ip = att[1] - if 'RTA_OIF' in att: - # Fetch interface name - # ('RTA_OIF', iface) - item.iface = ipdb.interfaces[int(att[1])].ifname + Args: + route_entry (RouteEntry) + next_hop_mac (str): The MAC address of the next hop. + """ + route_module_name = self.get_route_module_name(route_entry.interface) + gate_idx = self._get_gate_idx(route_entry, route_module_name) + try: + self._bess_controller.add_route_to_module( + route_entry, + gate_idx=gate_idx, + module_name=route_module_name, + ) + + except Exception: + logger.exception( + "Error adding route entry to BESS: %s", route_entry + ) + return - if not item.iface in args.i or not item.iprange or not item.neighbor_ip: - # Neighbor info is invalid - del item - return + if not self._neighbor_cache.get(route_entry.next_hop_ip): + logger.info("Neighbor entry does not exist, creating modules.") + update_module_name = self.get_update_module_name( + route_entry.interface, + next_hop_mac, + ) + merge_module_name = self.get_merge_module_name( + route_entry.interface + ) + self._create_update_module( + destination_mac=next_hop_mac, + update_module_name=update_module_name, + ) + self._create_module_links( + gate_idx=gate_idx, + update_module_name=update_module_name, + route_module_name=route_module_name, + merge_module_name=merge_module_name, + ) + self._neighbor_cache[route_entry.next_hop_ip] = NeighborEntry( + gate_idx=gate_idx, + mac_address=next_hop_mac, + ) + self._module_gate_count_cache[route_module_name] += 1 + else: + logger.info("Neighbor already exists") - # Fetch prefix_len - item.prefix_len = msg['dst_len'] + self._neighbor_cache[route_entry.next_hop_ip].route_count += 1 - del_route_entry(bess, item) + def _create_update_module( + self, + update_module_name: str, + destination_mac: str, + ) -> None: + """Creates an update module in BESS. - # Delete item - del item + Args: + update_module_name (str): The name of the module. + destination_mac (str): The MAC address of the gateway. + """ + try: + mac_in_hexadecimal = mac_to_int(destination_mac) + self._bess_controller.create_module( + module_name=update_module_name, + module_class="Update", + gateway_mac=mac_in_hexadecimal, + ) + except Exception: + logger.exception( + "Error creating update module %s", update_module_name + ) + return + def add_unresolved_new_neighbor(self, netlink_message: dict) -> None: + """Handle new neighbor event. + + It will add the neighbor if it was in the unresolved ARP queries cache. + + Args: + netlink_message (dict): The netlink message. + """ + attr_dict = dict(netlink_message["attrs"]) + route_entry = self._unresolved_arp_queries_cache.get( + attr_dict[KEY_NETWORK_LAYER_DEST_ADDR] + ) + gateway_mac = attr_dict[KEY_LINK_LAYER_ADDRESS] + if route_entry: + self._add_neighbor( + route_entry, gateway_mac + ) + del self._unresolved_arp_queries_cache[ + route_entry.next_hop_ip + ] + + def _create_module_links( + self, + gate_idx: int, + update_module_name: str, + route_module_name: str, + merge_module_name: str, + ) -> None: + """Create update module and link modules. + + Args: + gate_idx (int): Output gate index. + update_module_name (str): The name of the update module. + route_module_name (str): The name of the route module. + merge_module_name (str): The name of the merge module. + """ + try: + self._bess_controller.link_modules( + route_module_name, update_module_name, gate_idx, 0 + ) + except Exception: + logger.exception( + "Error linking module % s to module % s", + update_module_name, + route_module_name, + ) + return -def netlink_event_listener(ipdb, netlink_message, action): + try: + self._bess_controller.link_modules( + update_module_name, merge_module_name, 0, 0 + ) + except Exception: + logger.exception( + "Error linking module %s to module %s", + update_module_name, + merge_module_name, + ) + return - # If you get a netlink message, parse it - msg = netlink_message + def delete_route_entry(self, route_entry: RouteEntry) -> None: + """Deletes a route entry from BESS and the neighbor cache.""" + next_hop = self._neighbor_cache.get(route_entry.next_hop_ip) - if action == 'RTM_NEWROUTE': - parse_new_route(msg) + if next_hop: + try: + self._bess_controller.delete_module_route_entry(route_entry) + except Exception: + logger.exception( + "Error deleting route entry %s", route_entry + ) + return - if action == 'RTM_NEWNEIGH': - parse_new_neighbor(msg) + next_hop.route_count -= 1 - if action == 'RTM_DELROUTE': - parse_del_route(msg) + if next_hop.route_count == 0: + route_module = self.get_route_module_name( + route_entry.interface + ) + update_module_name = self.get_update_module_name( + route_module_name=route_module, + mac_address=next_hop.mac_address, + ) + try: + self._bess_controller.delete_module(update_module_name) + except Exception: + logger.exception( + "Error deleting update module %s", + update_module_name, + ) + return + + logger.info("Module deleted %s", update_module_name) + + del self._neighbor_cache[route_entry.next_hop_ip] + logger.info("Deleted item from neighbor cache") + else: + logger.info( + "Route count for %s decremented to %i", + route_entry.next_hop_ip, + next_hop.route_count, + ) + self._neighbor_cache[route_entry.next_hop_ip] = next_hop + else: + logger.info("Neighbor %s does not exist", route_entry.next_hop_ip) + + def _ping_missing_entries(self): + """Pings missing entries every 10 seconds. + The goal is to populate the ARP cache. + If the target host does not respond it will be pinged again. + """ + while True: + with self._lock: + missing_arp_entries = list(self._unresolved_arp_queries_cache.keys()) + logger.info("Missing ARP entries: %s", missing_arp_entries) + for ip in missing_arp_entries: + try: + send_ping(ip) + except Exception as e: + logger.exception("Error when pinging %s: %s", ip, e) + logger.info("Finished pinging missing ARP entries. Sleeping...") + time.sleep(10) + + def _probe_addr(self, route_entry: RouteEntry) -> None: + """Probes the MAC address of a neighbor. + Pings the neighbor to trigger the update of the ARP table. + + Args: + neighbor (NeighborEntry): The neighbor entry. + """ + + self._unresolved_arp_queries_cache[ + route_entry.next_hop_ip + ] = route_entry + logger.info("Adding entry %s in arp table by pinging", route_entry) + if not validate_ipv4(route_entry.next_hop_ip): + return + send_ping(route_entry.next_hop_ip) + + def _get_gate_idx(self, route_entry: RouteEntry, module_name: str) -> int: + """Get gate index for a route module. + + If the item is cached, return the cached gate index. + If the item is new, increment the gate count + and return the new gate index. + + Args: + route_entry (RouteEntry) + module_name (str): The name of the module. + Returns: + int: The gate index. + """ + if ( + cached_entry := self._neighbor_cache.get(route_entry.next_hop_ip) + ) is not None: + return cached_entry.gate_idx + return self._module_gate_count_cache[module_name] + + def _netlink_event_listener( + self, ipdb: IPDB, netlink_message: dict, action: str + ) -> None: + """Listens for netlink events and handles them. + + Args: + ipdb (IPDB): The IPDB object. + netlink_message (dict): The netlink message. + action (str): The action. + """ + logger.info("%s netlink event received.", action) + route_entry = self._parse_route_entry_msg(netlink_message) + if action == KEY_NEW_ROUTE_ACTION and route_entry: + with self._lock: + self.add_new_route_entry(route_entry) + + elif action == KEY_DELETE_ROUTE_ACTION and route_entry: + with self._lock: + self.delete_route_entry(route_entry) + + elif action == KEY_NEW_NEIGHBOR_ACTION: + with self._lock: + self.add_unresolved_new_neighbor(netlink_message) + + def cleanup(self, number: int) -> None: + """Unregisters the netlink event listener callback and exits.""" + logger.info("Received: %i Exiting", number) + self._ipdb.unregister_callback(self._event_callback) + logger.info("Unregistered netlink event listener callback") + sys.exit() + + def reconfigure(self, number: int) -> None: + """Reconfigures the route controller. + Clears caches and bootstraps routes. + """ + logger.info("Received: %i Reconfiguring", number) + with self._lock: + self._unresolved_arp_queries_cache.clear() + self._neighbor_cache.clear() + self._module_gate_count_cache.clear() + self.bootstrap_routes() + signal.pause() + + def _parse_route_entry_msg( + self, route_entry: dict + ) -> Optional[RouteEntry]: + """Parses a route entry message. + If the entry passes the checks, it is returned as a RouteEntry object. + + Args: + route_entry (dict): A netlink route entry message. + + Returns: + RouteEntry: A route entry object. + """ + try: + attr_dict = dict(route_entry["attrs"]) + except Exception: + logger.exception("Error parsing route entry message") + return None + + if not (next_hop_ip := attr_dict.get(KEY_DESTINATION_GATEWAY_IP)): + return None + + if not attr_dict.get(KEY_INTERFACE): + return None + interface_index = int(attr_dict.get(KEY_INTERFACE)) + interface = self._ipdb.interfaces[interface_index].ifname + if interface not in self._interfaces: + return None + + dest_prefix = None + if route_entry.get(KEY_DESTINATION_PREFIX_LENGTH) == 0: + dest_prefix = "0.0.0.0" + + if attr_dict.get(KEY_DESTINATION_IP): + dest_prefix = attr_dict.get(KEY_DESTINATION_IP) + + if not dest_prefix: + return None + + return RouteEntry( + dest_prefix=dest_prefix, + next_hop_ip=next_hop_ip, + interface=interface, + prefix_len=route_entry[KEY_DESTINATION_PREFIX_LENGTH], + ) + + def get_route_module_name(self, interface_name: str) -> str: + """Returns the name of the route module. + + Args: + interface_name (str): The name of the interface. + """ + return interface_name + "Routes" + + def get_update_module_name( + self, route_module_name: str, mac_address: str + ) -> str: + """Returns the name of the update module. + + Args: + route_module_name (str): The name of the route module. + gateway_mac_hex (str): The MAC address of the gateway. + """ + return route_module_name + "DstMAC" + mac_to_hex(mac_address) + + def get_merge_module_name(self, interface_name: str) -> str: + """Returns the name of the merge module. + + Args: + interface_name (str): The name of the interface. + """ + return interface_name + "Merge" + + +def validate_ipv4(ip: str) -> bool: + """Validate the given IP address. + + Args: + ip (str): The IP address to validate.""" + try: + return isinstance(ipaddress.ip_address(ip), ipaddress.IPv4Address) + except ValueError: + logger.error( + "The IP address %s is invalid", ip + ) + return False -def bootstrap_routes(): - routes = ipr.get_routes() - for i in routes: - if i['event'] == 'RTM_NEWROUTE': - parse_new_route(i) +def send_ping(neighbor_ip): + """Send an ICMP echo request to neighbor_ip. -def connect_bessd(): - print('Connecting to BESS daemon...'), - # Connect to BESS (assuming host=localhost, port=10514 (default)) - for i in range(MAX_RETRIES): - try: - if not bess.is_connected(): - bess.connect(grpc_url=args.ip + ':' + args.port) - except BESS.RPCError: - print( - 'Error connecting to BESS daemon. Retrying in {}sec...'.format( - SLEEP_S)) - time.sleep(SLEEP_S) - else: - break - else: - raise Exception('BESS connection failure.') - - print('Done.') - - -def reconfigure(number, frame): - print('Received: {} Reloading routes'.format(number)) - # clear arpcache - for ip in list(arpcache): - item = arpcache.get(ip) - del item - arpcache.clear() - for ip in list(neighborcache): - item = neighborcache.get(ip) - del item - neighborcache.clear() - for modname in list(modgatecnt): - item = modgatecnt.get(modname) - del item - modgatecnt.clear() - bootstrap_routes() - signal.pause() + Does not wait for a response. Expected to have the side + effect of populating the arp table entry for neighbor_ip. + """ + logger.info("Sending ping to %s", neighbor_ip) + send(IP(dst=neighbor_ip) / ICMP()) -def cleanup(number, frame): - ipdb.unregister_callback(event_callback) - print('Received: {} Exiting'.format(number)) - sys.exit() +def fetch_mac(ipdb: IPDB, target_ip: str) -> Optional[str]: + """Fetches the MAC address of the target IP from the ARP table using IPDB. + Args: + ipdb (IPDB): The IPDB object. + target_ip (str): The target IP address. -def main(): - global arpcache, neighborcache, modgatecnt, ipdb, event_callback, bess, ipr - # for holding unresolved ARP queries - arpcache = {} - # for holding list of registered neighbors - neighborcache = {} - # for holding gate count per route module - modgatecnt = {} - # for interacting with kernel - ipdb = IPDB() - ipr = IPRoute() - # for bess client - bess = BESS() + Returns: + Optional[str]: The MAC address of the target IP. + """ + neighbors = ipdb.nl.get_neighbours(dst=target_ip) + for neighbor in neighbors: + attrs = dict(neighbor['attrs']) + if attrs.get(KEY_NETWORK_LAYER_DEST_ADDR, '') == target_ip: + logger.info( + "Mac address found for %s, Mac: %s", + target_ip, attrs.get(KEY_LINK_LAYER_ADDRESS, ''), + ) + return attrs.get(KEY_LINK_LAYER_ADDRESS, '') + logger.info("Mac address not found for %s", target_ip) + return None - # connect to bessd - connect_bessd() - # program current routes - bootstrap_routes() +def mac_to_int(mac: str) -> int: + """Converts a MAC address to an integer.""" + try: + return int(mac.replace(":", ""), 16) + except ValueError: + raise ValueError("Invalid MAC address: %s", mac) - # listen for netlink events - print('Registering netlink event listener callback...'), - event_callback = ipdb.register_callback(netlink_event_listener) - print('Done.') - signal.signal(signal.SIGHUP, reconfigure) - signal.signal(signal.SIGINT, cleanup) - signal.signal(signal.SIGTERM, cleanup) - signal.pause() +def mac_to_hex(mac: str) -> str: + """Converts a MAC address to a hexadecimal string.""" + return '{:012X}'.format(mac_to_int(mac)) -if __name__ == '__main__': +def parse_args() -> Tuple[List[str], str, str]: parser = argparse.ArgumentParser( - description='Basic IPv4 Routing Controller') - parser.add_argument('-i', - type=str, - nargs='+', - help='interface(s) to control') - parser.add_argument('--ip', - type=str, - default='localhost', - help='BESSD address') - parser.add_argument('--port', type=str, default='10514', help='BESSD port') - - # for holding command-line arguments - global args + description="Basic IPv4 Routing Controller" + ) + parser.add_argument( + "-i", type=str, nargs="+", help="interface(s) to control" + ) + parser.add_argument( + "--ip", type=str, default="localhost", help="BESSD address" + ) + parser.add_argument( + "--port", type=str, default="10514", help="BESSD port" + ) args = parser.parse_args() - - if args.i: - main() - # if interface list is empty, print help menu and quit - else: - print(parser.print_help()) + if not args.i: + parser.print_help() + raise ValueError("interface must be specified") + return (args.i, args.ip, args.port) + + +def register_signal_handlers(route_controller: RouteController) -> None: + """Register signal handlers for SIGHUP, SIGINT, SIGTERM. + + Args: + controller (RouteController): The route controller. + """ + logger.info("Registering signals handlers.") + signal.signal( + signal.SIGHUP, lambda number, _: route_controller.reconfigure(number) + ) + signal.signal( + signal.SIGINT, lambda number, _: route_controller.cleanup(number) + ) + signal.signal( + signal.SIGTERM, lambda number, _: route_controller.cleanup(number) + ) + + +if __name__ == "__main__": + interface_arg, ip_arg, port_arg = parse_args() + ipr = IPRoute() + ipdb = IPDB() + bess_controller = BessController(ip_arg, port_arg) + route_controller = RouteController( + bess_controller=bess_controller, + ipdb=ipdb, + ipr=ipr, + interfaces=interface_arg, + ) + route_controller.bootstrap_routes() + route_controller.register_callbacks() + route_controller.start_pinging_missing_entries() + register_signal_handlers(route_controller=route_controller) + logger.info("Sleep until a signal is received") + signal.pause() diff --git a/conf/test-requirements.txt b/conf/test-requirements.txt new file mode 100644 index 000000000..18061a9d4 --- /dev/null +++ b/conf/test-requirements.txt @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2023 Canonical Ltd. + +pyroute2 +scapy diff --git a/conf/test_route_control.py b/conf/test_route_control.py new file mode 100644 index 000000000..f3c2f1b1a --- /dev/null +++ b/conf/test_route_control.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2023 Canonical Ltd. + +import sys +import unittest +from unittest.mock import MagicMock, Mock, patch + +from pyroute2 import IPDB # type: ignore[import] + +sys.modules["pybess.bess"] = MagicMock() + +from conf.route_control import ( + NeighborEntry, + RouteController, + RouteEntry, + fetch_mac, + mac_to_hex, + mac_to_int, + validate_ipv4, +) + + +class BessControllerMock(object): + """Mock of BessController to avoid using BESS from pybess.bess""" + + def __init__(self): + pass + + def _get_bess(self, *args, **kwargs) -> None: + pass + + def add_route_to_module(self, *args, **kwargs) -> None: + pass + + def delete_module_route_entry(self, *args, **kwargs) -> None: + pass + + def create_module(self, *args, **kwargs) -> None: + pass + + def delete_module(self, *args, **kwargs) -> None: + pass + + def link_modules(self, *args, **kwargs) -> None: + pass + + +@patch("conf.route_control.BessController", BessControllerMock) +class TestUtilityFunctions(unittest.TestCase): + """Tests utility functions in route_control.py.""" + + def test_given_valid_ip_when_validate_ipv4_then_returns_true(self): + self.assertTrue(validate_ipv4("192.168.1.1")) + + def test_given_invalid_ip_when_validate_ipv4_then_returns_false(self): + self.assertFalse(validate_ipv4("192.168.300.1")) + + def test_given_invalid_ip_when_validate_ipv4_then_returns_false(self): + self.assertFalse(validate_ipv4("::1")) + self.assertFalse(validate_ipv4("")) + + def test_given_valid_mac_when_mac_to_int_then_returns_int_representation( + self + ): + self.assertEqual(mac_to_int("00:1a:2b:3c:4d:5e"), 112394521950) + + def test_given_invalid_mac_when_mac_to_int_then_raises_exception(self): + with self.assertRaises(ValueError): + mac_to_int("not a mac") + + def test_given_valid_mac_when_mac_to_hex_then_return_hex_string_representation( + self + ): + self.assertEqual(mac_to_hex("00:1a:2b:3c:4d:5e"), "001A2B3C4D5E") + + def test_given_known_destination_when_fetch_mac_then_returns_mac(self): + ipdb = IPDB() + ipdb.nl.get_neighbours = lambda dst, **kwargs: [ + {"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]} + ] + self.assertEqual(fetch_mac(ipdb, "192.168.1.1"), "00:1a:2b:3c:4d:5e") + + def test_given_unkonw_destination_when_fetch_mac_then_returns_none(self): + ipdb = IPDB() + ipdb.nl.get_neighbours = lambda dst, **kwargs: [] + self.assertIsNone(fetch_mac(ipdb, "192.168.1.1")) + + +class TestRouteController(unittest.TestCase): + def setUp(self): + self.mock_bess_controller = Mock(BessControllerMock) + self.ipdb = Mock() + self.ipr = Mock() + interfaces = ['access', 'core'] + self.route_controller = RouteController( + self.mock_bess_controller, + self.ipdb, + interfaces=interfaces, + ipr=self.ipr, + ) + + @patch("conf.route_control.fetch_mac") + @patch.object(RouteController, "get_merge_module_name") + @patch.object(RouteController, "get_route_module_name") + @patch.object(RouteController, "get_update_module_name") + def add_route_entry( + self, + route_entry, + mock_get_update_module_name, + mock_get_route_module_name, + mock_get_merge_module_name, + mock_fetch_mac, + ) -> None: + """Adds a new route entry using the route controller.""" + self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [ + {"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]} + ] + mock_get_update_module_name.return_value = "merge_module" + mock_get_route_module_name.return_value = "route_module" + mock_get_merge_module_name.return_value = "update_module" + mock_fetch_mac.return_value = "00:1a:2b:3c:4d:5e" + self.route_controller.add_new_route_entry(route_entry) + return route_entry + + def test_given_valid_route_message_when_parse_message_then_parses_message( + self + ): + self.ipdb.interfaces = {2: Mock(ifname='core')} + example_route_entry = { + "family": 2, + "dst_len": 24, + "flags": 0, + "attrs": [ + ("RTA_TABLE", 254), + ("RTA_PRIORITY", 100), + ("RTA_PREFSRC", "172.31.55.52"), + ("RTA_GATEWAY", "172.31.48.1"), + ("RTA_OIF", 2), + ("RTA_DST", "192.168.1.0"), + ], + "header": { + "length": 68, + "type": 24, + "target": "localhost", + "stats": {"qsize": 0, "delta": 0, "delay": 0}, + }, + "event": "RTM_NEWROUTE", + } + result = self.route_controller._parse_route_entry_msg( + example_route_entry + ) + self.assertIsInstance(result, RouteEntry) + self.assertEqual(result.dest_prefix, "192.168.1.0") + self.assertEqual(result.next_hop_ip, "172.31.48.1") + self.assertEqual(result.interface, self.ipdb.interfaces[2].ifname) + self.assertEqual(result.prefix_len, 24) + + def test_given_valid_route_message_and_dst_len_is_zero_when_parse_message_then_parses_message_as_default_route( + self, + ): + self.ipdb.interfaces = {2: Mock(ifname='core')} + example_route_entry = { + "family": 2, + "dst_len": 0, + "flags": 0, + "attrs": [ + ("RTA_TABLE", 254), + ("RTA_PRIORITY", 100), + ("RTA_PREFSRC", "172.31.55.52"), + ("RTA_GATEWAY", "172.31.48.1"), + ("RTA_OIF", 2), + ], + "header": { + "length": 68, + "type": 24, + "target": "localhost", + "stats": {"qsize": 0, "delta": 0, "delay": 0}, + }, + "event": "RTM_NEWROUTE", + } + result = self.route_controller._parse_route_entry_msg( + example_route_entry + ) + self.assertIsInstance(result, RouteEntry) + self.assertEqual(result.dest_prefix, "0.0.0.0") + self.assertEqual(result.next_hop_ip, "172.31.48.1") + self.assertEqual( + result.interface, self.ipdb.interfaces[2].ifname + ) + self.assertEqual(result.prefix_len, 0) + + def test_given_invalid_route_message_when_parse_message_then_returns_none( + self + ): + self.ipdb.interfaces = {2: Mock(ifname='not the needed interface')} + example_route_entry = { + "family": 2, + "flags": 0, + "attrs": [ + ("RTA_TABLE", 254), + ("RTA_PRIORITY", 100), + ("RTA_PREFSRC", "172.31.55.52"), + ("RTA_GATEWAY", "172.31.48.1"), + ("RTA_OIF", 2), + ], + "header": { + "length": 68, + "type": 24, + "target": "localhost", + "stats": {"qsize": 0, "delta": 0, "delay": 0}, + }, + "event": "RTM_NEWROUTE", + } + result = self.route_controller._parse_route_entry_msg( + example_route_entry + ) + self.assertIsNone(result) + + @patch("conf.route_control.send_ping") + def test_given_new_route_when_add_new_route_entry_and_mac_not_known_then_destination_is_pinged( + self, + mock_send_ping, + ): + self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [] + route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.route_controller.add_new_route_entry(route_entry) + mock_send_ping.assert_called_once() + + def test_given_valid_new_route_when_add_new_route_entry_and_mac_known_then_route_is_added_in_bess( + self, + ): + self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [ + {"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]} + ] + mock_routes = [ + {"event": "RTM_NEWROUTE"}, + {"event": "OTHER_ACTION"} + ] + self.ipr.get_routes.return_value = mock_routes + route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.route_controller.add_new_route_entry(route_entry) + self.mock_bess_controller.add_route_to_module.assert_called_once() + + def test_given_valid_new_route_when_add_new_route_entry_and_mac_known_and_neighbor_not_known_then_update_module_is_created_and_modules_are_linked( + self, + ): + self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [ + {"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]} + ] + mock_routes = [ + {"event": "RTM_NEWROUTE"}, + {"event": "OTHER_ACTION"} + ] + self.ipr.get_routes.return_value = mock_routes + route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.route_controller.add_new_route_entry(route_entry) + self.mock_bess_controller.create_module.assert_called() + self.mock_bess_controller.link_modules.assert_called() + + @patch.object(RouteController, "add_new_route_entry") + def test_given_new_route_when_bootstrap_routes_then_add_new_entry_is_called( + self, + mock_add_new_route_entry, + ): + mock_routes = [ + { + "event": "RTM_NEWROUTE", + "attrs": { + "RTA_OIF": 2, + "RTA_GATEWAY": "1.2.3.4", + "RTA_DST": "1.1.1.1", + }, + "dst_len": 24, + }, + {"event": "OTHER_ACTION"} + ] + self.ipr.get_routes.return_value = mock_routes + self.ipdb.interfaces = {2: Mock(ifname='core')} + valid_route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="core", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.ipr.get_routes.return_value = mock_routes + self.route_controller.bootstrap_routes() + self.ipr.get_routes.assert_called_once() + mock_add_new_route_entry.assert_called_with(valid_route_entry) + + @patch.object(RouteController, "add_new_route_entry") + def test_given_no_new_route_when_bootstrap_routes_then_add_new_entry_is_not_called( + self, + mock_add_new_route_entry, + ): + mock_routes = [ + { + "event": "Not a new route", + "attrs": { + "RTA_OIF": 2, + "RTA_GATEWAY": "1.2.3.4", + "RTA_DST": "1.1.1.1", + }, + "dst_len": 24, + }, + {"event": "OTHER_ACTION"} + ] + self.ipr.get_routes.return_value = mock_routes + self.route_controller._parse_route_entry_msg = Mock() + self.route_controller.bootstrap_routes() + self.ipr.get_routes.assert_called_once() + mock_add_new_route_entry.assert_not_called() + + @patch.object(RouteController, "add_new_route_entry") + def test_given_new_route_and_invalid_message_when_bootstrap_routes_then_add_new_entry_is_not_called( + self, + mock_add_new_route_entry, + ): + mock_routes = [ + { + "event": "RTM_NEWROUTE", + "attrs": {}, + }, + {"event": "OTHER_ACTION"} + ] + self.ipr.get_routes.return_value = mock_routes + self.route_controller.bootstrap_routes() + + self.ipr.get_routes.assert_called_once() + mock_add_new_route_entry.assert_not_called() + + @patch.object(RouteController, "add_new_route_entry") + def test_given_netlink_message_when_rtm_newroute_event_then_add_new_route_entry_is_called( + self, mock_add_new_route_entry + ): + self.ipdb.interfaces = {2: Mock(ifname='core')} + example_route_entry = { + "family": 2, + "dst_len": 24, + "flags": 0, + "attrs": [ + ("RTA_TABLE", 254), + ("RTA_PRIORITY", 100), + ("RTA_PREFSRC", "172.31.55.52"), + ("RTA_GATEWAY", "172.31.48.1"), + ("RTA_OIF", 2), + ("RTA_DST", "192.168.1.0"), + ], + "header": { + "length": 68, + "type": 24, + "target": "localhost", + "stats": {"qsize": 0, "delta": 0, "delay": 0}, + }, + "event": "RTM_NEWROUTE", + } + self.route_controller._netlink_event_listener( + self.ipdb, example_route_entry, "RTM_NEWROUTE" + ) + mock_add_new_route_entry.assert_called() + + def test_given_existing_neighbor_and_route_count_not_zero_when_delete_route_entry_then_route_entry_deleted_in_bess(self): + route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.add_route_entry(route_entry) + self.route_controller.delete_route_entry(route_entry) + self.mock_bess_controller.delete_module_route_entry.assert_called_once() + + def test_given_existing_neighbor_and_route_count_greater_than_one_when_delete_route_entry_then_module_not_deleted(self): + route_entry_1 = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + route_entry_2 = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface_2", + dest_prefix="1.1.1.2", + prefix_len=24, + ) + self.add_route_entry(route_entry_1) + self.add_route_entry(route_entry_2) + self.route_controller.delete_route_entry(route_entry_1) + self.mock_bess_controller.delete_module.assert_not_called() + + def test_given_existing_neighbor_and_route_count_is_one_when_delete_route_entry_then_module_deleted(self): + route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.add_route_entry(route_entry) + self.route_controller.delete_route_entry(route_entry) + self.mock_bess_controller.delete_module.assert_called_once() + + @patch.object(RouteController, "delete_route_entry") + def test_given_netlink_message_when_rtm_delroute_event_then_delete_route_entry_is_called( + self, mock_delete_route_entry + ): + self.ipdb.interfaces = {2: Mock(ifname='core')} + example_route_entry = { + "family": 2, + "dst_len": 24, + "flags": 0, + "attrs": [ + ("RTA_TABLE", 254), + ("RTA_PRIORITY", 100), + ("RTA_PREFSRC", "172.31.55.52"), + ("RTA_GATEWAY", "172.31.48.1"), + ("RTA_OIF", 2), + ("RTA_DST", "192.168.1.0"), + ], + "header": { + "length": 68, + "type": 24, + "target": "localhost", + "stats": {"qsize": 0, "delta": 0, "delay": 0}, + }, + "event": "RTM_DELROUTE", + } + self.route_controller._netlink_event_listener( + self.ipdb, example_route_entry, "RTM_DELROUTE" + ) + mock_delete_route_entry.assert_called() + + @patch("conf.route_control.send_ping") + def test_given_new_neighbor_in_unresolved_when_add_unresolved_new_neighbor_then_route_added_in_bess( + self, + _, + ): + self.ipdb.nl.get_neighbours = lambda dst, **kwargs: [ + {"attrs": [("NDA_DST", dst), ("NDA_LLADDR", "00:1a:2b:3c:4d:5e")]} + ] + mock_netlink_msg = { + "attrs": { + "NDA_DST": "1.2.3.4", + "NDA_LLADDR": "00:1a:2b:3c:4d:5e", + } + } + mock_routes = [ + {"event": "RTM_NEWROUTE"}, + {"event": "OTHER_ACTION"} + ] + self.ipr.get_routes.return_value = mock_routes + route_entry = RouteEntry( + next_hop_ip="1.2.3.4", + interface="random_interface", + dest_prefix="1.1.1.1", + prefix_len=24, + ) + self.route_controller.add_new_route_entry(route_entry) + self.route_controller.add_unresolved_new_neighbor(mock_netlink_msg) + self.mock_bess_controller.add_route_to_module.assert_called_once() + + @patch.object(RouteController, "add_unresolved_new_neighbor") + def test_given_netlink_message_when_rtm_newneigh_event_then_add_unresolved_new_neighbor_is_called( + self, mock_add_unresolved_new_neighbor + ): + self.route_controller._netlink_event_listener( + self.ipdb, "new neighbour message", "RTM_NEWNEIGH" + ) + mock_add_unresolved_new_neighbor.assert_called()