From 61089b58c9ec6becabd4a8852d011a2a45cc1576 Mon Sep 17 00:00:00 2001 From: Chetan Surpur Date: Sat, 11 Apr 2015 13:27:15 -0700 Subject: [PATCH] Add delete segment/synapse functionality to Connections data structure --- nupic/research/temporal_memory.py | 67 ++++++++++--- .../nupic/research/temporal_memory_test.py | 96 ++++++++++++++----- 2 files changed, 126 insertions(+), 37 deletions(-) diff --git a/nupic/research/temporal_memory.py b/nupic/research/temporal_memory.py index cd833e2b28..992b28dec9 100644 --- a/nupic/research/temporal_memory.py +++ b/nupic/research/temporal_memory.py @@ -132,19 +132,19 @@ def computeFn(self, 'Functional' version of compute. Returns new state. - @param activeColumns (set) Indices of active columns in `t` - @param prevPredictiveCells (set) Indices of predictive cells in `t-1` - @param prevActiveSegments (set) Indices of active segments in `t-1` - @param prevActiveCells (set) Indices of active cells in `t-1` - @param prevWinnerCells (set) Indices of winner cells in `t-1` - @param connections (Connections) Connectivity of layer - @param learn (bool) Whether or not learning is enabled + @param activeColumns (set) Indices of active columns in `t` + @param prevPredictiveCells (set) Indices of predictive cells in `t-1` + @param prevActiveSegments (set) Indices of active segments in `t-1` + @param prevActiveCells (set) Indices of active cells in `t-1` + @param prevWinnerCells (set) Indices of winner cells in `t-1` + @param connections (Connections) Connectivity of layer + @param learn (bool) Whether or not learning is enabled @return (tuple) Contains: - `activeCells` (set), - `winnerCells` (set), - `activeSegments` (set), - `predictiveCells` (set) + `activeCells` (set), + `winnerCells` (set), + `activeSegments` (set), + `predictiveCells` (set) """ activeCells = set() winnerCells = set() @@ -665,7 +665,7 @@ class Connections(object): Class to hold data representing the connectivity of a collection of cells. """ - SynapseData = namedtuple("SyanpseData", ["segment", + SynapseData = namedtuple("SynapseData", ["segment", "presynapticCell", "permanence"]) @@ -727,6 +727,8 @@ def dataForSynapse(self, synapse): @return (SynapseData) Synapse data """ + self._validateSynapse(synapse) + return self._synapses[synapse] @@ -780,6 +782,23 @@ def createSegment(self, cell): return segment + def destroySegment(self, segment): + """ + Destroys a segment. + + @param segment (int) Segment index + """ + synapses = set(self.synapsesForSegment(segment)) + for synapse in synapses: + self.destroySynapse(synapse) + + cell = self._segments[segment] + del self._segments[segment] + + # Update indexes + self._segmentsForCell[cell].remove(segment) + + def createSynapse(self, segment, presynapticCell, permanence): """ Creates a new synapse on a segment. @@ -809,6 +828,20 @@ def createSynapse(self, segment, presynapticCell, permanence): return synapse + def destroySynapse(self, synapse): + """ + Destroys a synapse. + + @param synapse (int) Synapse index + """ + data = self._synapses[synapse] + del self._synapses[synapse] + + # Update indexes + self._synapsesForSegment[data.segment].remove(synapse) + del self._synapsesForPresynapticCell[data.presynapticCell][synapse] + + def updateSynapsePermanence(self, synapse, permanence): """ Updates the permanence for a synapse. @@ -862,6 +895,16 @@ def _validateSegment(self, segment): raise IndexError("Invalid segment") + def _validateSynapse(self, synapse): + """ + Raises an error if synapse index is invalid. + + @param synapse (int) Synapse index + """ + if not synapse in self._synapses: + raise IndexError("Invalid synapse") + + @staticmethod def _validatePermanence(permanence): """ diff --git a/tests/unit/nupic/research/temporal_memory_test.py b/tests/unit/nupic/research/temporal_memory_test.py index d69c890990..fc8c99bee3 100755 --- a/tests/unit/nupic/research/temporal_memory_test.py +++ b/tests/unit/nupic/research/temporal_memory_test.py @@ -208,24 +208,24 @@ def testLearnOnSegments(self): connections) # Check segment 0 - (_, _, permanence) = connections.dataForSynapse(0) - self.assertAlmostEqual(permanence, 0.7) + synapseData = connections.dataForSynapse(0) + self.assertAlmostEqual(synapseData.permanence, 0.7) - (_, _, permanence) = connections.dataForSynapse(1) - self.assertAlmostEqual(permanence, 0.5) + synapseData = connections.dataForSynapse(1) + self.assertAlmostEqual(synapseData.permanence, 0.5) - (_, _, permanence) = connections.dataForSynapse(2) - self.assertAlmostEqual(permanence, 0.8) + synapseData = connections.dataForSynapse(2) + self.assertAlmostEqual(synapseData.permanence, 0.8) # Check segment 1 - (_, _, permanence) = connections.dataForSynapse(3) - self.assertAlmostEqual(permanence, 0.8) + synapseData = connections.dataForSynapse(3) + self.assertAlmostEqual(synapseData.permanence, 0.8) self.assertEqual(len(connections.synapsesForSegment(1)), 2) # Check segment 2 - (_, _, permanence) = connections.dataForSynapse(4) - self.assertAlmostEqual(permanence, 0.9) + synapseData = connections.dataForSynapse(4) + self.assertAlmostEqual(synapseData.permanence, 0.9) self.assertEqual(len(connections.synapsesForSegment(2)), 1) @@ -400,14 +400,14 @@ def testAdaptSegment(self): tm.adaptSegment(0, set([0, 1]), connections) - (_, _, permanence) = connections.dataForSynapse(0) - self.assertAlmostEqual(permanence, 0.7) + synapseData = connections.dataForSynapse(0) + self.assertAlmostEqual(synapseData.permanence, 0.7) - (_, _, permanence) = connections.dataForSynapse(1) - self.assertAlmostEqual(permanence, 0.5) + synapseData = connections.dataForSynapse(1) + self.assertAlmostEqual(synapseData.permanence, 0.5) - (_, _, permanence) = connections.dataForSynapse(2) - self.assertAlmostEqual(permanence, 0.8) + synapseData = connections.dataForSynapse(2) + self.assertAlmostEqual(synapseData.permanence, 0.8) def testAdaptSegmentToMax(self): @@ -418,13 +418,13 @@ def testAdaptSegmentToMax(self): connections.createSynapse(0, 23, 0.9) tm.adaptSegment(0, set([0]), connections) - (_, _, permanence) = connections.dataForSynapse(0) - self.assertAlmostEqual(permanence, 1.0) + synapseData = connections.dataForSynapse(0) + self.assertAlmostEqual(synapseData.permanence, 1.0) # Now permanence should be at max tm.adaptSegment(0, set([0]), connections) - (_, _, permanence) = connections.dataForSynapse(0) - self.assertAlmostEqual(permanence, 1.0) + synapseData = connections.dataForSynapse(0) + self.assertAlmostEqual(synapseData.permanence, 1.0) def testAdaptSegmentToMin(self): @@ -435,13 +435,13 @@ def testAdaptSegmentToMin(self): connections.createSynapse(0, 23, 0.1) tm.adaptSegment(0, set(), connections) - (_, _, permanence) = connections.dataForSynapse(0) - self.assertAlmostEqual(permanence, 0.0) + synapseData = connections.dataForSynapse(0) + self.assertAlmostEqual(synapseData.permanence, 0.0) # Now permanence should be at min tm.adaptSegment(0, set(), connections) - (_, _, permanence) = connections.dataForSynapse(0) - self.assertAlmostEqual(permanence, 0.0) + synapseData = connections.dataForSynapse(0) + self.assertAlmostEqual(synapseData.permanence, 0.0) def testPickCellsToLearnOn(self): @@ -602,6 +602,32 @@ def testCreateSegment(self): self.assertEqual(connections.segmentsForCell(0), set([0, 1])) + def testDestroySegment(self): + connections = self.connections + + self.assertEqual(connections.createSegment(0), 0) + self.assertEqual(connections.createSegment(0), 1) + self.assertEqual(connections.createSegment(10), 2) + + self.assertEqual(connections.createSynapse(0, 254, 0.1173), 0) + self.assertEqual(connections.createSynapse(0, 477, 0.3253), 1) + + connections.destroySegment(0) + + args = [0] + self.assertRaises(IndexError, connections.dataForSynapse, *args) + args = [1] + self.assertRaises(IndexError, connections.dataForSynapse, *args) + + args = [0] + self.assertRaises(IndexError, connections.synapsesForSegment, *args) + + self.assertEqual(connections.synapsesForPresynapticCell(174), {}) + self.assertEqual(connections.synapsesForPresynapticCell(254), {}) + + self.assertEqual(connections.segmentsForCell(0), set([1])) + + def testCreateSegmentInvalidCell(self): connections = self.connections @@ -670,6 +696,26 @@ def testCreateSynapseInvalidParams(self): self.assertRaises(ValueError, connections.createSynapse, *args) + def testDestroySynapse(self): + connections = self.connections + + connections.createSegment(0) + self.assertEqual(connections.synapsesForSegment(0), set()) + + self.assertEqual(connections.createSynapse(0, 254, 0.1173), 0) + self.assertEqual(connections.createSynapse(0, 477, 0.3253), 1) + + connections.destroySynapse(0) + + args = [0] + self.assertRaises(IndexError, connections.dataForSynapse, *args) + + self.assertEqual(connections.synapsesForSegment(0), set([1])) + + self.assertEqual(connections.synapsesForPresynapticCell(174), {}) + self.assertEqual(connections.synapsesForPresynapticCell(254), {}) + + def testDataForSynapseInvalidSynapse(self): connections = self.connections @@ -677,7 +723,7 @@ def testDataForSynapseInvalidSynapse(self): connections.createSynapse(0, 834, 0.1284) args = [1] - self.assertRaises(KeyError, connections.dataForSynapse, *args) + self.assertRaises(IndexError, connections.dataForSynapse, *args) def testSynapsesForSegmentInvalidSegment(self):