Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduced the number of channels between service and process (#1) #157

Merged
merged 4 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 8 additions & 24 deletions src/lava/magma/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def check_lava_py_types(self):

Any Py{In/Out/Ref}Ports must be strict sub-types of Py{In/Out/Ref}Ports.
"""

for name, port_init in self.py_ports.items():
lt = self._get_lava_type(name)
if not isinstance(lt.cls, type):
Expand Down Expand Up @@ -421,22 +420,13 @@ def build(self):
setattr(pm, name, port)

for port in self.csp_rs_recv_port.values():
if "service_to_process_cmd" in port.name:
pm.service_to_process_cmd = port
continue
if "service_to_process_req" in port.name:
pm.service_to_process_req = port
continue
if "service_to_process_data" in port.name:
pm.service_to_process_data = port
if "service_to_process" in port.name:
pm.service_to_process = port
continue

for port in self.csp_rs_send_port.values():
if "process_to_service_ack" in port.name:
pm.process_to_service_ack = port
continue
if "process_to_service_data" in port.name:
pm.process_to_service_data = port
if "process_to_service" in port.name:
pm.process_to_service = port
continue

# Initialize Vars
Expand Down Expand Up @@ -535,18 +525,12 @@ def build(self) -> PyRuntimeService:
rs.model_ids = self._model_ids

for port in self.csp_proc_send_port.values():
if "service_to_process_cmd" in port.name:
rs.service_to_process_cmd.append(port)
if "service_to_process_req" in port.name:
rs.service_to_process_req.append(port)
if "service_to_process_data" in port.name:
rs.service_to_process_data.append(port)
if "service_to_process" in port.name:
rs.service_to_process.append(port)

for port in self.csp_proc_recv_port.values():
if "process_to_service_ack" in port.name:
rs.process_to_service_ack.append(port)
if "process_to_service_data" in port.name:
rs.process_to_service_data.append(port)
if "process_to_service" in port.name:
rs.process_to_service.append(port)

for port in self.csp_send_port.values():
if "service_to_runtime_ack" in port.name:
Expand Down
39 changes: 6 additions & 33 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,50 +768,23 @@ def _create_sync_channel_builders(
sync_channel_builders.append(runtime_to_service_data)

for process in sync_domain.processes:
service_to_process_cmd = \
service_to_process = \
ServiceChannelBuilderMp(ChannelType.PyPy,
rsb[sync_domain],
process,
self._create_mgmt_port_initializer(
f"service_to_process_cmd_"
f"service_to_process_"
f"{process.id}"))
sync_channel_builders.append(service_to_process_cmd)
sync_channel_builders.append(service_to_process)

process_to_service_ack = \
process_to_service = \
ServiceChannelBuilderMp(ChannelType.PyPy,
process,
rsb[sync_domain],
self._create_mgmt_port_initializer(
f"process_to_service_ack_"
f"process_to_service_"
f"{process.id}"))
sync_channel_builders.append(process_to_service_ack)

service_to_process_req = \
ServiceChannelBuilderMp(ChannelType.PyPy,
rsb[sync_domain],
process,
self._create_mgmt_port_initializer(
f"service_to_process_req_"
f"{process.id}"))
sync_channel_builders.append(service_to_process_req)

process_to_service_data = \
ServiceChannelBuilderMp(ChannelType.PyPy,
process,
rsb[sync_domain],
self._create_mgmt_port_initializer(
f"process_to_service_data_"
f"{process.id}"))
sync_channel_builders.append(process_to_service_data)

service_to_process_data = \
ServiceChannelBuilderMp(ChannelType.PyPy,
rsb[sync_domain],
process,
self._create_mgmt_port_initializer(
f"service_to_process_data_"
f"{process.id}"))
sync_channel_builders.append(service_to_process_data)
sync_channel_builders.append(process_to_service)
return sync_channel_builders

def _create_exec_vars(self,
Expand Down
107 changes: 47 additions & 60 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

import numpy as np

from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort,\
from lava.magma.compiler.channels.pypychannel import CspSendPort, CspRecvPort, \
CspSelector
from lava.magma.core.model.model import AbstractProcessModel
from lava.magma.core.model.py.ports import AbstractPyPort, PyVarPort
from lava.magma.runtime.mgmt_token_enums import (
enum_to_np,
enum_equal,
MGMT_COMMAND,
MGMT_RESPONSE, REQ_TYPE,
)
MGMT_RESPONSE, )


class AbstractPyProcessModel(AbstractProcessModel, ABC):
Expand All @@ -33,11 +32,8 @@ class AbstractPyProcessModel(AbstractProcessModel, ABC):
def __init__(self):
super().__init__()
self.model_id: ty.Optional[int] = None
self.service_to_process_cmd: ty.Optional[CspRecvPort] = None
self.process_to_service_ack: ty.Optional[CspSendPort] = None
self.service_to_process_req: ty.Optional[CspRecvPort] = None
self.process_to_service_data: ty.Optional[CspSendPort] = None
self.service_to_process_data: ty.Optional[CspRecvPort] = None
self.service_to_process: ty.Optional[CspRecvPort] = None
self.process_to_service: ty.Optional[CspSendPort] = None
self.py_ports: ty.List[AbstractPyPort] = []
self.var_ports: ty.List[PyVarPort] = []
self.var_id_to_var_map: ty.Dict[int, ty.Any] = {}
Expand All @@ -52,11 +48,8 @@ def __setattr__(self, key: str, value: ty.Any):
self.var_ports.append(value)

def start(self):
self.service_to_process_cmd.start()
self.process_to_service_ack.start()
self.service_to_process_req.start()
self.process_to_service_data.start()
self.service_to_process_data.start()
self.service_to_process.start()
self.process_to_service.start()
for p in self.py_ports:
p.start()
self.run()
Expand All @@ -66,11 +59,8 @@ def run(self):
pass

def join(self):
self.service_to_process_cmd.join()
self.process_to_service_ack.join()
self.service_to_process_req.join()
self.process_to_service_data.join()
self.service_to_process_data.join()
self.service_to_process.join()
self.process_to_service.join()
for p in self.py_ports:
p.join()

Expand Down Expand Up @@ -113,96 +103,93 @@ def run(self):
"""Retrieves commands from the runtime service to iterate through the
phases of Loihi and calls their corresponding methods of the
ProcessModels. The phase is retrieved from runtime service
(service_to_process_cmd). After calling the method of a phase of all
(service_to_process). After calling the method of a phase of all
ProcessModels the runtime service is informed about completion. The
loop ends when the STOP command is received."""
selector = CspSelector()
action = 'cmd'
phase = PyLoihiProcessModel.Phase.SPK
while True:
if action == 'cmd':
phase = self.service_to_process_cmd.recv()
if enum_equal(phase, MGMT_COMMAND.STOP):
self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED)
cmd = self.service_to_process.recv()
if enum_equal(cmd, MGMT_COMMAND.STOP):
self.process_to_service.send(MGMT_RESPONSE.TERMINATED)
self.join()
return
try:
# Spiking phase - increase time step
if enum_equal(phase, PyLoihiProcessModel.Phase.SPK):
if enum_equal(cmd, PyLoihiProcessModel.Phase.SPK):
self.current_ts += 1
phase = PyLoihiProcessModel.Phase.SPK
self.run_spk()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
self.process_to_service.send(MGMT_RESPONSE.DONE)
# Pre-management phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
elif enum_equal(cmd,
PyLoihiProcessModel.Phase.PRE_MGMT):
# Enable via guard method
phase = PyLoihiProcessModel.Phase.PRE_MGMT
if self.pre_guard():
self.run_pre_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
self.process_to_service.send(MGMT_RESPONSE.DONE)
# Learning phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN):
elif enum_equal(cmd, PyLoihiProcessModel.Phase.LRN):
# Enable via guard method
phase = PyLoihiProcessModel.Phase.LRN
if self.lrn_guard():
self.run_lrn()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
self.process_to_service.send(MGMT_RESPONSE.DONE)
# Post-management phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
elif enum_equal(cmd,
PyLoihiProcessModel.Phase.POST_MGMT):
# Enable via guard method
phase = PyLoihiProcessModel.Phase.POST_MGMT
if self.post_guard():
self.run_post_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
self.process_to_service.send(MGMT_RESPONSE.DONE)
# Host phase - called at the last time step before STOP
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
elif enum_equal(cmd, PyLoihiProcessModel.Phase.HOST):
phase = PyLoihiProcessModel.Phase.HOST
pass
elif enum_equal(cmd, MGMT_COMMAND.GET_DATA) and \
enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
# Handle get/set Var requests from runtime service
self._handle_get_var()
elif enum_equal(cmd,
MGMT_COMMAND.SET_DATA) and enum_equal(phase,
PyLoihiProcessModel.Phase.HOST):
# Handle get/set Var requests from runtime service
self._handle_set_var()
else:
raise ValueError(f"Wrong Phase Info Received : {phase}")
raise ValueError(
f"Wrong Phase Info Received : {cmd}")
except Exception as inst:
print("Exception happened")
# Inform runtime service about termination
self.process_to_service_ack.send(MGMT_RESPONSE.ERROR)
self.process_to_service.send(MGMT_RESPONSE.ERROR)
self.join()
raise inst

elif action == 'req':
# Handle get/set Var requests from runtime service
self._handle_get_set_var()
else:
# Handle VarPort requests from RefPorts
self._handle_var_port(action)

channel_actions = [(self.service_to_process_cmd, lambda: 'cmd')]
channel_actions = [(self.service_to_process, lambda: 'cmd')]
if enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT) or \
enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
for var_port in self.var_ports:
for csp_port in var_port.csp_ports:
if isinstance(csp_port, CspRecvPort):
channel_actions.append((csp_port, lambda: var_port))
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
channel_actions.append((self.service_to_process_req,
lambda: 'req'))
action = selector.select(*channel_actions)

# FIXME: (PP) might not be able to perform get/set during pause
def _handle_get_set_var(self):
"""Handles all get/set Var requests from the runtime service and calls
the corresponding handling methods. The loop ends upon a
new command from runtime service after all get/set Var requests have
been handled."""
# Get the type of the request
request = self.service_to_process_req.recv()
if enum_equal(request, REQ_TYPE.GET):
self._handle_get_var()
elif enum_equal(request, REQ_TYPE.SET):
self._handle_set_var()
else:
raise RuntimeError(f"Unknown request type {request}")

def _handle_get_var(self):
"""Handles the get Var command from runtime service."""
# 1. Receive Var ID and retrieve the Var
var_id = self.service_to_process_req.recv()[0].item()
var_id = int(self.service_to_process.recv()[0].item())
var_name = self.var_id_to_var_map[var_id]
var = getattr(self, var_name)

# 2. Send Var data
data_port = self.process_to_service_data
data_port = self.process_to_service
# Header corresponds to number of values
# Data is either send once (for int) or one by one (array)
if isinstance(var, int) or isinstance(var, np.integer):
Expand All @@ -219,12 +206,12 @@ def _handle_get_var(self):
def _handle_set_var(self):
"""Handles the set Var command from runtime service."""
# 1. Receive Var ID and retrieve the Var
var_id = self.service_to_process_req.recv()[0].item()
var_id = int(self.service_to_process.recv()[0].item())
var_name = self.var_id_to_var_map[var_id]
var = getattr(self, var_name)

# 2. Receive Var data
data_port = self.service_to_process_data
data_port = self.service_to_process
if isinstance(var, int) or isinstance(var, np.integer):
# First item is number of items (1) - not needed
data_port.recv()
Expand Down
16 changes: 8 additions & 8 deletions src/lava/magma/core/process/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
import typing as ty
from _collections import OrderedDict

from lava.magma.compiler.executable import Executable
from lava.magma.core.process.interfaces import \
AbstractProcessMember, IdGeneratorSingleton
from lava.magma.core.process.message_interface_enum import ActorType
from lava.magma.core.run_conditions import AbstractRunCondition
from lava.magma.core.run_configs import RunConfig
from lava.magma.core.process.ports.ports import \
InPort, OutPort, RefPort, VarPort
from lava.magma.core.process.variable import Var
from lava.magma.core.process.interfaces import \
AbstractProcessMember, IdGeneratorSingleton
from lava.magma.compiler.executable import Executable
from lava.magma.core.run_conditions import AbstractRunCondition
from lava.magma.core.run_configs import RunConfig
from lava.magma.runtime.runtime import Runtime


# Abbreviation for type annotation in Collection class
mem_type = ty.Union[InPort, OutPort, RefPort, VarPort, Var, "AbstractProcess"]

Expand Down Expand Up @@ -358,7 +357,8 @@ def load(self, path: str):

# TODO: (PP) Remove if condition on blocking as soon as non-blocking
# execution is completely implemented
def run(self, condition: AbstractRunCondition, run_cfg: RunConfig):
def run(self, condition: AbstractRunCondition = None, run_cfg:
RunConfig = None):
"""Runs process given RunConfig and RunCondition.

run(..) compiles this and any process connected to this process
Expand Down Expand Up @@ -393,7 +393,7 @@ def run(self, condition: AbstractRunCondition, run_cfg: RunConfig):

if not self._runtime:
executable = self.compile(run_cfg)
self._runtime = Runtime(condition,
self._runtime = Runtime(
executable,
ActorType.MultiProcessing)
self._runtime.initialize()
Expand Down