Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Virtual ports on multiple incoming connections #224

Merged
merged 2 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 41 additions & 7 deletions src/lava/magma/compiler/builders/builder.py
Expand Up @@ -39,7 +39,9 @@
PyInPort,
PyOutPort,
PyRefPort,
PyVarPort
PyVarPort,
VirtualPortTransformer,
IdentityTransformer
)
from lava.magma.compiler.channels.interfaces import AbstractCspPort, Channel, \
ChannelType
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(
self.ref_ports: ty.Dict[str, PortInitializer] = {}
self.var_ports: ty.Dict[str, VarPortInitializer] = {}
self.csp_ports: ty.Dict[str, ty.List[AbstractCspPort]] = {}
self._csp_port_map: ty.Dict[str, ty.Dict[str, AbstractCspPort]] = {}
self.csp_rs_send_port: ty.Dict[str, CspSendPort] = {}
self.csp_rs_recv_port: ty.Dict[str, CspRecvPort] = {}
self.proc_params = proc_params
Expand Down Expand Up @@ -308,6 +311,24 @@ def set_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
else:
self.csp_ports[port_name] = new_ports[port_name]

def add_csp_port_mapping(self, py_port_id: str, csp_port: AbstractCspPort):
"""Appends a mapping from a PyPort ID to a CSP port. This is used
to associate a CSP port in a PyPort with transformation functions
that implement the behavior of virtual ports.

Parameters
----------
py_port_id : str
ID of the PyPort that contains the CSP on the other side of the
channel of 'csp_port'
csp_port : AbstractCspPort
a CSP port
"""
# Add or update the mapping
self._csp_port_map.setdefault(
csp_port.name, {}
).update({py_port_id: csp_port})

def set_rs_csp_ports(self, csp_ports: ty.List[AbstractCspPort]):
"""Set RS CSP Ports

Expand Down Expand Up @@ -371,11 +392,13 @@ def build(self):
if not isinstance(csp_ports, list):
csp_ports = [csp_ports]

# TODO (MR): This is probably just a temporary hack until the
# interface of PyOutPorts has been adjusted.
if issubclass(port_cls, PyInPort):
port = port_cls(csp_ports, pm, p.shape, lt.d_type,
p.transform_funcs)
transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()
port_cls = ty.cast(ty.Type[PyInPort], lt.cls)
port = port_cls(csp_ports, pm, p.shape, lt.d_type, transformer)
elif issubclass(port_cls, PyOutPort):
port = port_cls(csp_ports, pm, p.shape, lt.d_type)
else:
Expand All @@ -401,8 +424,13 @@ def build(self):
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]

transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()

port = port_cls(csp_send, csp_recv, pm, p.shape, lt.d_type,
p.transform_funcs)
transformer)

# Create dynamic RefPort attribute on ProcModel
setattr(pm, name, port)
Expand All @@ -422,9 +450,15 @@ def build(self):
csp_ports[0], CspRecvPort) else csp_ports[1]
csp_send = csp_ports[0] if isinstance(
csp_ports[0], CspSendPort) else csp_ports[1]

transformer = VirtualPortTransformer(
self._csp_port_map[name],
p.transform_funcs
) if p.transform_funcs else IdentityTransformer()

port = port_cls(
p.var_name, csp_send, csp_recv, pm, p.shape, p.d_type,
p.transform_funcs)
transformer)

# Create dynamic VarPort attribute on ProcModel
setattr(pm, name, port)
Expand Down
15 changes: 5 additions & 10 deletions src/lava/magma/compiler/compiler.py
Expand Up @@ -352,13 +352,11 @@ def _compile_proc_models(
for pt in (list(p.in_ports) + list(p.out_ports)):
# For all InPorts that receive input from
# virtual ports...
transform_funcs = []
transform_funcs = None
if isinstance(pt, InPort):
# ... extract a function pointer to the
# transformation function of each virtual port.
transform_funcs = \
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]
transform_funcs = pt.get_incoming_transform_funcs()

pi = PortInitializer(pt.name,
pt.shape,
Expand All @@ -371,9 +369,7 @@ def _compile_proc_models(
# Create RefPort (also use PortInitializers)
ref_ports = []
for pt in list(p.ref_ports):
transform_funcs = \
[vp.get_transform_func_bwd()
for vp in pt.get_outgoing_virtual_ports()]
transform_funcs = pt.get_outgoing_transform_funcs()

pi = PortInitializer(pt.name,
pt.shape,
Expand All @@ -386,9 +382,8 @@ def _compile_proc_models(
# Create VarPortInitializers (contain also the Var name)
var_ports = []
for pt in list(p.var_ports):
transform_funcs = \
[vp.get_transform_func_fwd()
for vp in pt.get_incoming_virtual_ports()]
transform_funcs = pt.get_incoming_transform_funcs()

pi = VarPortInitializer(
pt.name,
pt.shape,
Expand Down
4 changes: 2 additions & 2 deletions src/lava/magma/compiler/utils.py
Expand Up @@ -18,7 +18,7 @@ class PortInitializer:
d_type: type
port_type: str
size: int
transform_funcs: ty.List[ft.partial] = None
transform_funcs: ty.Dict[str, ty.List[ft.partial]] = None


# check if can be a subclass of PortInitializer
Expand All @@ -31,4 +31,4 @@ class VarPortInitializer:
port_type: str
size: int
port_cls: type
transform_funcs: ty.List[ft.partial] = None
transform_funcs: ty.Dict[str, ty.List[ft.partial]] = None