Skip to content

Commit

Permalink
Use Sparse instead of Dense for Connections (#67)
Browse files Browse the repository at this point in the history
* enable sparse and update unittests

* linting

* codacy

* rename to connection_class, update tests and docstrings

* linting

* linting

* update docstring

* update docstring

* linting

* linting

* linting

* Update poetry.lock

* Update test_connect.py

Fix codacy

* Update test_connect.py

* Update test_connect.py

fix codacy

---------

Co-authored-by: PhilippPlank <32519998+PhilippPlank@users.noreply.github.com>
  • Loading branch information
SveaMeyer13 and PhilippPlank authored Aug 23, 2023
1 parent a62aa70 commit 6da459b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 32 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 25 additions & 11 deletions src/lava/lib/dnf/connect/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import typing as ty
import numpy as np

from lava.magma.core.process.process import AbstractProcess

from lava.proc.dense.models import Dense
from lava.proc.sparse.models import Sparse
from lava.magma.core.process.ports.ports import InPort, OutPort

from lava.lib.dnf.operations.operations import AbstractOperation, Weights
Expand All @@ -17,8 +18,10 @@ def connect(
src_op: OutPort,
dst_ip: InPort,
ops: ty.Optional[ty.Union[ty.List[AbstractOperation],
AbstractOperation]] = None
) -> AbstractProcess:
AbstractOperation]] = None,
connection_class: ty.Optional[ty.Union[ty.Type[Sparse], ty.Type[Dense]]]
= None
) -> ty.Union[Sparse, Dense]:
"""
Creates and returns a Connections Process <conn> and connects the source
OutPort <src_op> to the InPort of <conn> and the OutPort of <conn> to the
Expand All @@ -39,12 +42,15 @@ def connect(
ops : list(AbstractOperation), optional
list of operations that describes how the connection between
<src_op> and <dst_ip> will be created
connection_class : type(Sparse) or type(Dense), optional
Class of the process used between src_op and dst_ip. If connection_class
is None the connection process will be defined automatically
(currently a Sparse Process is used in that case).
Returns
-------
connections : AbstractProcess
process containing the connections between <src_op> and <dst_ip>
connections : Sparse or Dense Process
Process containing the connections between <src_op> and <dst_ip>
"""
# validate the list of operations
ops = _validate_ops(ops, src_op.shape, dst_ip.shape)
Expand All @@ -58,7 +64,7 @@ def connect(

# create Connections process and connect it:
# source -> connections -> destination
connections = _make_connections(src_op, dst_ip, weights)
connections = _make_connections(src_op, dst_ip, weights, connection_class)

return connections

Expand Down Expand Up @@ -195,7 +201,10 @@ def _compute_weights(ops: ty.List[AbstractOperation]) -> np.ndarray:

def _make_connections(src_op: OutPort,
dst_ip: InPort,
weights: np.ndarray) -> AbstractProcess:
weights: np.ndarray,
connection_class: ty.Optional[ty.Union[Sparse, Dense]]
= None
) -> ty.Union[Sparse, Dense]:
"""
Creates a Connections Process with the given weights and connects its
ports such that:
Expand All @@ -210,15 +219,20 @@ def _make_connections(src_op: OutPort,
InPort of the destination Process
weights : numpy.ndarray
connectivity weight matrix used for the Connections Process
connection_class : type(Sparse) or type(Dense), optional
Class of the process used between src_op and dst_ip. If connection_class
is None the connection process will be defined automatically
(currently a Sparse Process is used in that case).
Returns
-------
Connections Process : AbstractProcess
connections : Sparse or Dense Process
Process containing the connections between <src_op> and <dst_ip>
"""

# Create the connections process
connections = Dense(weights=weights)
connection_class = connection_class or Sparse
connections = connection_class(weights=weights)

con_ip = connections.s_in
src_op.reshape(new_shape=con_ip.shape).connect(con_ip)
Expand Down
10 changes: 7 additions & 3 deletions tests/lava/lib/dnf/acceptance/test_connecting_with_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg
from lava.proc.lif.process import LIF
from lava.proc.dense.process import Dense

from lava.lib.dnf.connect.connect import connect
from lava.lib.dnf.kernels.kernels import SelectiveKernel, MultiPeakKernel
Expand Down Expand Up @@ -97,7 +98,8 @@ def test_connect_population_3d_to_2d_with_reduce_dims_and_reorder(self)\
reduce_op = ReduceDims(reduce_dims=dims)
computed = connect(source.s_out,
destination.a_in,
ops=[reduce_op, reorder_op])
ops=[reduce_op, reorder_op],
connection_class=Dense)

self.assertTrue(np.array_equal(computed.weights.get(), expected))

Expand Down Expand Up @@ -168,7 +170,8 @@ def test_connect_population_2d_to_3d_with_expand_dims_and_reorder(self)\
expand_op = ExpandDims(new_dims_shape=(2,))
computed = connect(source.s_out,
destination.a_in,
ops=[expand_op, reorder_op])
ops=[expand_op, reorder_op],
connection_class=Dense)

self.assertTrue(np.array_equal(computed.weights.get(), expected))

Expand Down Expand Up @@ -239,7 +242,8 @@ def test_connect_population_1d_to_3d_with_expand_dims_and_reorder(self) \
expand_op = ExpandDims(new_dims_shape=(2, 2))
computed = connect(source.s_out,
destination.a_in,
ops=[expand_op, reorder_op])
ops=[expand_op, reorder_op],
connection_class=Dense)

self.assertTrue(np.array_equal(computed.weights.get(), expected))

Expand Down
78 changes: 61 additions & 17 deletions tests/lava/lib/dnf/connect/test_connect.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
# Copyright (C) 2021 Intel Corporation
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest
import numpy as np
import typing as ty

from lava.magma.core.process.ports.ports import InPort, OutPort
from lava.magma.core.process.process import AbstractProcess

from lava.proc.dense.process import Dense
from lava.proc.sparse.process import Sparse
from lava.lib.dnf.connect.connect import connect
from lava.lib.dnf.connect.exceptions import MisconfiguredConnectError
from lava.lib.dnf.operations.operations import AbstractOperation
from lava.lib.dnf.operations.shape_handlers import AbstractShapeHandler,\
from lava.lib.dnf.operations.shape_handlers import AbstractShapeHandler, \
KeepShapeShapeHandler
from lava.lib.dnf.operations.exceptions import MisconfiguredOpError
from lava.lib.dnf.utils.convenience import num_neurons


class MockProcess(AbstractProcess):
"""Mock Process with an InPort and OutPort"""

def __init__(self, shape: ty.Tuple[int, ...] = (1,)) -> None:
super().__init__()
self.a_in = InPort(shape)
Expand All @@ -28,6 +29,7 @@ def __init__(self, shape: ty.Tuple[int, ...] = (1,)) -> None:

class MockNoChangeOperation(AbstractOperation):
"""Mock Operation that does not change shape"""

def __init__(self) -> None:
super().__init__(shape_handler=KeepShapeShapeHandler())

Expand All @@ -39,6 +41,7 @@ def _compute_weights(self) -> np.ndarray:

class MockShapeHandler(AbstractShapeHandler):
"""Mock ShapeHandler for an operation that changes the shape."""

def __init__(self, output_shape: ty.Tuple[int, ...]) -> None:
super().__init__()
self._output_shape = output_shape
Expand All @@ -56,6 +59,7 @@ def _validate_input_shape(self, input_shape: ty.Tuple[int, ...]) -> None:

class MockChangeOperation(AbstractOperation):
"""Mock Operation that changes shape"""

def __init__(self, output_shape: ty.Tuple[int, ...]) -> None:
super().__init__(MockShapeHandler(output_shape))

Expand All @@ -79,45 +83,46 @@ def _test_connections(self,
"""For a given source, destination, and connections Processes,
tests whether they have been connected."""

# check whether the connect function returns a process
# Check whether the connect function returns a process
self.assertIsInstance(connections, AbstractProcess)

# check whether 'source' is connected to 'connections'
# Check whether 'source' is connected to 'connections'
src_op = source.out_ports.s_out
con_ip = connections.in_ports.s_in
self.assertEqual(src_op.get_dst_ports(), [con_ip])
self.assertEqual(con_ip.get_src_ports(), [src_op])

# check whether 'connections' is connected to 'target'
# Check whether 'connections' is connected to 'target'
con_op = connections.out_ports.a_out
dst_op = destination.in_ports.a_in
self.assertEqual(con_op.get_dst_ports(), [dst_op])
self.assertEqual(dst_op.get_src_ports(), [con_op])

def test_connecting_with_op_that_does_not_change_shape(self) -> None:
"""Tests connecting a source Process to a destination Process."""
# create mock processes and an operation to connect
source = MockProcess(shape=(1, 2, 3))
# Create mock processes and an operation to connect
destination = MockProcess(shape=(1, 2, 3))
source = MockProcess(shape=(1, 2, 3))
op = MockNoChangeOperation()

# connect source to target
# Connect source to target
connections = connect(source.s_out, destination.a_in, ops=[op])

self._test_connections(source, destination, connections)

def test_connecting_without_ops(self) -> None:
"""Tests connecting a source Process to a destination Process
without specifying any operations."""
# create mock processes of the same shape
# Create mock processes of the same shape
shape = (1, 2, 3)
source = MockProcess(shape=shape)
destination = MockProcess(shape=shape)

# connect source to target
connections = connect(source.s_out, destination.a_in)
# Connect source to target
connections = connect(source.s_out, destination.a_in,
connection_class=Dense)

# default connection weights should be the identity matrix
# Default connection weights should be the identity matrix
np.testing.assert_array_equal(connections.weights.get(),
np.eye(int(np.prod(shape))))

Expand All @@ -127,7 +132,7 @@ def test_connecting_different_shapes_without_ops_raises_error(self) -> None:
"""Tests whether an exception is raised when trying to connect two
Processes that have different shapes while not specifying any
operations."""
# create mock processes of different shapes
# Create mock processes of different shapes
source = MockProcess(shape=(1, 2, 3))
destination = MockProcess(shape=(3, 2, 1))

Expand All @@ -141,6 +146,7 @@ def test_empty_operations_list_raises_value_error(self) -> None:

def test_ops_list_containing_invalid_type_raises_type_error(self) -> None:
"""Tests whether the type of all elements in <ops> is validated."""

class NotAnOperation:
pass

Expand All @@ -163,7 +169,7 @@ def test_operation_that_changes_the_shape(self) -> None:
MockProcess(output_shape).a_in,
ops=MockChangeOperation(output_shape=output_shape))

def test_mismatching_op_output_shape_and_dest_shape_raises_error(self)\
def test_mismatching_op_output_shape_and_dest_shape_raises_error(self) \
-> None:
"""Tests whether an error is raised when the output shape of the
last operation does not match the destination shape."""
Expand Down Expand Up @@ -208,6 +214,7 @@ def test_weights_from_multiple_ops_get_multiplied(self) -> None:
class MockNoChangeOpWeights(MockNoChangeOperation):
"""Mock Operation that generates an identity matrix with a given
weight."""

def __init__(self, weight: float) -> None:
super().__init__()
self.weight = weight
Expand All @@ -224,7 +231,8 @@ def _compute_weights(self) -> np.ndarray:
conn = connect(MockProcess(shape).s_out,
MockProcess(shape).a_in,
ops=[MockNoChangeOpWeights(weight=w1),
MockNoChangeOpWeights(weight=w2)])
MockNoChangeOpWeights(weight=w2)],
connection_class=Dense)

computed_weights = conn.weights.get()
expected_weights = np.eye(num_neurons(shape),
Expand All @@ -233,6 +241,42 @@ def _compute_weights(self) -> np.ndarray:

self.assertTrue(np.array_equal(computed_weights, expected_weights))

def test_connection_is_sparse_if_no_connectionclass_is_specified(self) \
-> None:
# Create mock processes and an operation to connect
source = MockProcess(shape=(1, 2, 3))
op = MockNoChangeOperation()
destination = MockProcess(shape=(1, 2, 3))

# Connect source to target
connections = connect(source.s_out, destination.a_in, ops=[op])
self.assertIsInstance(connections, Sparse)

def test_connection_is_sparse_if_sparse_connectionclass_is_specified(self) \
-> None:
# Create mock processes
shape = (1, 2, 3)
source = MockProcess(shape=shape)
op = MockNoChangeOperation()
destination = MockProcess(shape=shape)

# Connect source to target
connections = connect(source.s_out, destination.a_in, ops=[op],
connection_class=Sparse)
self.assertIsInstance(connections, Sparse)

def test_connection_is_sparse_if_dense_connectionclass_is_specified(self) \
-> None:
# Create mock processes and an operation to connect
op = MockNoChangeOperation()
source = MockProcess(shape=(1, 2, 3))
destination = MockProcess(shape=(1, 2, 3))

# Connect source to target
connections = connect(source.s_out, destination.a_in, ops=[op],
connection_class=Dense)
self.assertIsInstance(connections, Dense)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6da459b

Please sign in to comment.