Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Virtual ports on multiple incoming connections #224

Merged
merged 2 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 26 additions & 18 deletions src/lava/magma/compiler/builders/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
PyInPort,
PyOutPort,
PyRefPort,
PyVarPort
PyVarPort,
VirtualPortTransformer
)
from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel, \
ChannelType
Expand Down Expand Up @@ -168,7 +169,8 @@ def __init__(
self.py_ports: ty.Dict[str, PortInitializer] = {}
self.ref_ports: ty.Dict[str, PortInitializer] = {}
self.var_ports: ty.Dict[str, VarPortInitializer] = {}
self.csp_ports: ty.Dict[str, ty.List[AbstractCspPort]] = {}
self.csp_ports: ty.Dict[str,
ty.Dict[str, AbstractCspPort]] = {}
self.csp_rs_send_port: ty.Dict[str, CspSendPort] = {}
self.csp_rs_recv_port: ty.Dict[str, CspRecvPort] = {}
self.proc_params = proc_params
Expand Down Expand Up @@ -276,24 +278,30 @@ def set_var_ports(self, var_ports: ty.List[VarPortInitializer]):
self._check_not_assigned_yet(self.var_ports, new_ports.keys(), "ports")
self.var_ports.update(new_ports)

def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""Appends the given list of CspPorts to the ProcessModel. Used by the
runtime to configure csp ports during initialization (_build_channels).
def set_csp_ports(self, csp_ports: ty.Dict[str, AbstractCspPort]):
"""Appends the given dictionary of CspPorts to the ProcessModel.
Used by the runtime to configure csp ports during initialization
(_build_channels).

Parameters
----------
csp_ports : ty.List[AbstractCspPort]

csp_ports : ty.Dict[str, AbstractCspPort]
dictionary that associates an ID of the source/destination Port
mathisrichter marked this conversation as resolved.
Show resolved Hide resolved
(constructed from the name of the Process and the name of the Port)
with a CspPort

Raises
------
AssertionError
PyProcessModel has no port of that name
"""
# Create a new dict that maps the name of the port to another dict.
# This in turn maps a string-based ID of the PyPort on the other end
# of the channel to the CSP port: {connected_port_id: csp_port}.
new_ports = {}
for p in csp_ports:
new_ports.setdefault(p.name, []).extend(
p if isinstance(p, list) else [p]
for connected_port_id, port in csp_ports.items():
new_ports.setdefault(port.name, {}).update(
{connected_port_id: port}
)

# Check that there's a PyPort for each new CspPort
Expand All @@ -304,7 +312,7 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
no port named '{}'.".format(proc_name, port_name))

if port_name in self.csp_ports:
self.csp_ports[port_name].extend(new_ports[port_name])
self.csp_ports[port_name].update(new_ports[port_name])
else:
self.csp_ports[port_name] = new_ports[port_name]

Expand Down Expand Up @@ -365,17 +373,17 @@ def build(self):
# Build PyPort
lt = self._get_lava_type(name)
port_cls = ty.cast(ty.Type[AbstractPyIOPort], lt.cls)

csp_ports = []
if name in self.csp_ports:
csp_ports = self.csp_ports[name]
if not isinstance(csp_ports, list):
csp_ports = [csp_ports]
csp_ports = list(self.csp_ports[name].values())

# TODO (MR): This is probably just a temporary hack until the
# interface of PyOutPorts has been adjusted.
if issubclass(port_cls, PyInPort):
port = port_cls(csp_ports, pm, p.shape, lt.d_type,
p.transform_funcs)
transformer = VirtualPortTransformer(
self.csp_ports[name],
p.transform_funcs)
port_cls = ty.cast(ty.Type[PyInPort], lt.cls)
port = port_cls(csp_ports, pm, p.shape, lt.d_type, transformer)
elif issubclass(port_cls, PyOutPort):
port = port_cls(csp_ports, pm, p.shape, lt.d_type)
else:
Expand Down
6 changes: 2 additions & 4 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,11 @@ def _compile_proc_models(
for pt in (list(p.in_ports) + list(p.out_ports)):
# For all InPorts that receive input from
# virtual ports...
transform_funcs = []
transform_funcs = None
if isinstance(pt, InPort):
# ... extract a function pointer to the
# transformation function of each virtual port.
transform_funcs = \
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]
transform_funcs = pt.get_incoming_transform_funcs()

pi = PortInitializer(pt.name,
pt.shape,
Expand Down
2 changes: 1 addition & 1 deletion src/lava/magma/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PortInitializer:
d_type: type
port_type: str
size: int
transform_funcs: ty.List[ft.partial] = None
transform_funcs: ty.Dict[str, ty.List[ft.partial]] = None


# check if can be a subclass of PortInitializer
Expand Down
100 changes: 76 additions & 24 deletions src/lava/magma/core/model/py/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,75 @@ def csp_ports(self) -> ty.List[AbstractCspPort]:
return self._csp_ports


class Transformer(ty.Protocol):
def transform(self,
data: np.ndarray,
csp_port: AbstractCspPort) -> np.ndarray:
...


class IdentityTransformer:
def transform(self,
data: np.ndarray,
_: AbstractCspPort) -> np.ndarray:
return data


class VirtualPortTransformer:
def __init__(self,
csp_ports: ty.Dict[str, AbstractCspPort],
transform_funcs: ty.Dict[str, ty.List[ft.partial]]):
self._csp_port_to_fp = {}

for port_id, csp_port in csp_ports.items():
if port_id not in transform_funcs:
raise AssertionError(
f"no transformation functions found for port "
f"id {port_id}")
self._csp_port_to_fp[csp_port] = transform_funcs[port_id]

def transform(self,
data: np.ndarray,
csp_port: AbstractCspPort) -> np.ndarray:
return self._get_transform(csp_port)(data)

def _get_transform(self,
csp_port: AbstractCspPort) -> ty.Callable[[np.ndarray],
np.ndarray]:
"""For a given CSP port, returns a function that applies, in sequence,
all the function pointers associated with the incoming virtual
ports.

Example:
Let the current PyPort be called C. It receives input from
PyPorts A and B, and the connection from A to C goes through a
sequence of virtual ports V1, V2, V3. Within PyPort C, there is a CSP
port 'csp_port_a', that receives data from a CSP port in PyPort A.
Then, the following call
>>> csp_port_a : AbstractCspPort
>>> data : np.ndarray
>>> self._get_transform(csp_port_a)(data)
takes the data 'data' and applies the function pointers associated
with V1, V2, and V3.

Parameters
----------
csp_port : AbstractCspPort
the CSP port on which the data is received, which is supposed
to be transformed

Returns
-------
transformation_function : ty.Callable
function that transforms a given numpy array, e.g. by calling the
returned function f(data)
"""
return ft.reduce(
lambda f, g: lambda data: g(f(data)),
self._csp_port_to_fp[csp_port]
)


class PyInPort(AbstractPyIOPort):
"""Python implementation of InPort used within AbstractPyProcessModel.

Expand Down Expand Up @@ -143,9 +212,9 @@ def __init__(self,
process_model: AbstractProcessModel,
shape: ty.Tuple[int, ...],
d_type: type,
transform_funcs: ty.Optional[ty.List[ft.partial]] = None):
transformer: Transformer = IdentityTransformer()):

self._transform_funcs = transform_funcs
self._transformer = transformer
super().__init__(csp_ports, process_model, shape, d_type)

@abstractmethod
Expand Down Expand Up @@ -193,25 +262,6 @@ def probe(self) -> bool:
True,
)

def _transform(self, recv_data: np.array) -> np.array:
"""Applies all transformation function pointers to the input data.

Parameters
----------
recv_data : numpy.ndarray
data received on the port that shall be transformed

Returns
-------
recv_data : numpy.ndarray
received data, transformed by the incoming virtual ports
"""
if self._transform_funcs:
# apply all transformation functions to the received data
for f in self._transform_funcs:
recv_data = f(recv_data)
return recv_data


class PyInPortVectorDense(PyInPort):
"""Python implementation of PyInPort for dense vector data."""
Expand All @@ -229,9 +279,10 @@ def recv(self) -> np.ndarray:
fashion.
"""
return ft.reduce(
lambda acc, csp_port: acc + self._transform(csp_port.recv()),
lambda acc, port: acc + self._transformer.transform(port.recv(),
port),
self.csp_ports,
np.zeros(self._shape, self._d_type),
np.zeros(self._shape, self._d_type)
)

def peek(self) -> np.ndarray:
Expand All @@ -246,7 +297,8 @@ def peek(self) -> np.ndarray:
fashion.
"""
return ft.reduce(
lambda acc, csp_port: acc + csp_port.peek(),
lambda acc, port: acc + self._transformer.transform(port.recv(),
port),
self.csp_ports,
np.zeros(self._shape, self._d_type),
)
Expand Down
74 changes: 55 additions & 19 deletions src/lava/magma/core/process/ports/ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ def is_disjoint(a: ty.List, b: ty.List):
return set(a).isdisjoint(set(b))


def create_port_id(proc_name: str, port_name: str) -> str:
"""Generates a string-based ID for a port that makes it identifiable
within a network of Processes.

Parameters
----------
proc_name : str
name of the Process that the Port is associated with
port_name : str
name of the Port

Returns
-------
port_id : str
ID of a port
"""
return proc_name + "." + port_name
mathisrichter marked this conversation as resolved.
Show resolved Hide resolved


class AbstractPort(AbstractProcessMember):
"""Abstract base class for any type of port of a Lava Process.

Expand Down Expand Up @@ -145,7 +164,25 @@ def get_src_ports(self, _include_self=False) -> ty.List["AbstractPort"]:
ports += p.get_src_ports(True)
return ports

def get_incoming_virtual_ports(self) -> ty.List["AbstractVirtualPort"]:
def get_incoming_transform_funcs(self) -> ty.Dict[str, ty.List[ft.partial]]:
"""Returns the list of all incoming transformation functions for the
list of all incoming connections.

Returns
-------
transform_funcs : list(list(functools.partial))
the list of all incoming transformation functions, sorted from
source to destination port, for all incoming connections
"""
transform_funcs = {}
for p in self.in_connections:
src_port_id, vps = p.get_incoming_virtual_ports()
transform_funcs[src_port_id] = \
[vp.get_transform_func_fwd() for vp in vps]
return transform_funcs

def get_incoming_virtual_ports(self) \
-> ty.Tuple[str, ty.List["AbstractVirtualPort"]]:
"""Returns the list of all incoming virtual ports in order from
source to the current port.

Expand All @@ -156,27 +193,26 @@ def get_incoming_virtual_ports(self) -> ty.List["AbstractVirtualPort"]:
destination port
"""
if len(self.in_connections) == 0:
return []
src_port_id = create_port_id(self.process.name, self.name)
return src_port_id, []
else:
virtual_ports = []
num_virtual_ports = 0
src_port_id = None
for p in self.in_connections:
virtual_ports += p.get_incoming_virtual_ports()
if isinstance(p, AbstractVirtualPort):
# TODO (MR): ConcatPorts are not yet supported by the
# compiler - until then, an exception is raised.
if isinstance(p, ConcatPort):
raise NotImplementedError("ConcatPorts are not yet "
"supported.")

virtual_ports.append(p)
num_virtual_ports += 1

if num_virtual_ports > 1:
raise NotImplementedError("Joining multiple virtual ports is "
"not yet supported.")

return virtual_ports
p_id, vps = p.get_incoming_virtual_ports()
virtual_ports += vps
if p_id:
src_port_id = p_id

if isinstance(self, AbstractVirtualPort):
# TODO (MR): ConcatPorts are not yet supported by the
# compiler - until then, an exception is raised.
if isinstance(self, ConcatPort):
raise NotImplementedError("ConcatPorts are not yet "
"supported.")
virtual_ports.append(self)

return src_port_id, virtual_ports

def get_outgoing_virtual_ports(self) -> ty.List["AbstractVirtualPort"]:
"""Returns the list of all outgoing virtual ports in order from
Expand Down
14 changes: 12 additions & 2 deletions src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lava.magma.core.run_conditions import RunSteps, RunContinuous
from lava.magma.compiler.executable import Executable
from lava.magma.compiler.node import NodeConfig
from lava.magma.core.process.ports.ports import create_port_id
from lava.magma.core.run_conditions import AbstractRunCondition

"""Defines a Runtime which takes a lava executable and a pluggable message
Expand Down Expand Up @@ -184,12 +185,21 @@ def _build_channels(self):
channel = channel_builder.build(
self._messaging_infrastructure
)

src_port_id = create_port_id(
channel_builder.src_process.name,
channel_builder.src_port_initializer.name)

dst_port_id = create_port_id(
channel_builder.dst_process.name,
channel_builder.dst_port_initializer.name)

self._get_process_builder_for_process(
channel_builder.src_process).set_csp_ports(
[channel.src_port])
{dst_port_id: channel.src_port})
self._get_process_builder_for_process(
channel_builder.dst_process).set_csp_ports(
[channel.dst_port])
{src_port_id: channel.dst_port})

def _build_sync_channels(self):
"""Builds the channels needed for synchronization between runtime
Expand Down
2 changes: 1 addition & 1 deletion tests/lava/magma/core/model/py/test_ports.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def probe_test_routine(self, cls):
# Create PyInPort with current implementation
recv_py_port: PyInPort = \
cls([recv_csp_port_1, recv_csp_port_2], None, data.shape,
data.dtype, None)
data.dtype)

recv_py_port.start()
send_py_port_1.start()
Expand Down