Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2051 from chetan51/issue-2015
Browse files Browse the repository at this point in the history
Add delete segment/synapse functionality to Connections data structure
  • Loading branch information
rhyolight committed Apr 12, 2015
2 parents f2ec172 + 61089b5 commit 8e9674c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 37 deletions.
67 changes: 55 additions & 12 deletions nupic/research/temporal_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,19 @@ def computeFn(self,
'Functional' version of compute.
Returns new state.
@param activeColumns (set) Indices of active columns in `t`
@param prevPredictiveCells (set) Indices of predictive cells in `t-1`
@param prevActiveSegments (set) Indices of active segments in `t-1`
@param prevActiveCells (set) Indices of active cells in `t-1`
@param prevWinnerCells (set) Indices of winner cells in `t-1`
@param connections (Connections) Connectivity of layer
@param learn (bool) Whether or not learning is enabled
@param activeColumns (set) Indices of active columns in `t`
@param prevPredictiveCells (set) Indices of predictive cells in `t-1`
@param prevActiveSegments (set) Indices of active segments in `t-1`
@param prevActiveCells (set) Indices of active cells in `t-1`
@param prevWinnerCells (set) Indices of winner cells in `t-1`
@param connections (Connections) Connectivity of layer
@param learn (bool) Whether or not learning is enabled
@return (tuple) Contains:
`activeCells` (set),
`winnerCells` (set),
`activeSegments` (set),
`predictiveCells` (set)
`activeCells` (set),
`winnerCells` (set),
`activeSegments` (set),
`predictiveCells` (set)
"""
activeCells = set()
winnerCells = set()
Expand Down Expand Up @@ -665,7 +665,7 @@ class Connections(object):
Class to hold data representing the connectivity of a collection of cells.
"""

SynapseData = namedtuple("SyanpseData", ["segment",
SynapseData = namedtuple("SynapseData", ["segment",
"presynapticCell",
"permanence"])

Expand Down Expand Up @@ -727,6 +727,8 @@ def dataForSynapse(self, synapse):
@return (SynapseData) Synapse data
"""
self._validateSynapse(synapse)

return self._synapses[synapse]


Expand Down Expand Up @@ -780,6 +782,23 @@ def createSegment(self, cell):
return segment


def destroySegment(self, segment):
"""
Destroys a segment.
@param segment (int) Segment index
"""
synapses = set(self.synapsesForSegment(segment))
for synapse in synapses:
self.destroySynapse(synapse)

cell = self._segments[segment]
del self._segments[segment]

# Update indexes
self._segmentsForCell[cell].remove(segment)


def createSynapse(self, segment, presynapticCell, permanence):
"""
Creates a new synapse on a segment.
Expand Down Expand Up @@ -809,6 +828,20 @@ def createSynapse(self, segment, presynapticCell, permanence):
return synapse


def destroySynapse(self, synapse):
"""
Destroys a synapse.
@param synapse (int) Synapse index
"""
data = self._synapses[synapse]
del self._synapses[synapse]

# Update indexes
self._synapsesForSegment[data.segment].remove(synapse)
del self._synapsesForPresynapticCell[data.presynapticCell][synapse]


def updateSynapsePermanence(self, synapse, permanence):
"""
Updates the permanence for a synapse.
Expand Down Expand Up @@ -862,6 +895,16 @@ def _validateSegment(self, segment):
raise IndexError("Invalid segment")


def _validateSynapse(self, synapse):
"""
Raises an error if synapse index is invalid.
@param synapse (int) Synapse index
"""
if not synapse in self._synapses:
raise IndexError("Invalid synapse")


@staticmethod
def _validatePermanence(permanence):
"""
Expand Down
96 changes: 71 additions & 25 deletions tests/unit/nupic/research/temporal_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,24 +208,24 @@ def testLearnOnSegments(self):
connections)

# Check segment 0
(_, _, permanence) = connections.dataForSynapse(0)
self.assertAlmostEqual(permanence, 0.7)
synapseData = connections.dataForSynapse(0)
self.assertAlmostEqual(synapseData.permanence, 0.7)

(_, _, permanence) = connections.dataForSynapse(1)
self.assertAlmostEqual(permanence, 0.5)
synapseData = connections.dataForSynapse(1)
self.assertAlmostEqual(synapseData.permanence, 0.5)

(_, _, permanence) = connections.dataForSynapse(2)
self.assertAlmostEqual(permanence, 0.8)
synapseData = connections.dataForSynapse(2)
self.assertAlmostEqual(synapseData.permanence, 0.8)

# Check segment 1
(_, _, permanence) = connections.dataForSynapse(3)
self.assertAlmostEqual(permanence, 0.8)
synapseData = connections.dataForSynapse(3)
self.assertAlmostEqual(synapseData.permanence, 0.8)

self.assertEqual(len(connections.synapsesForSegment(1)), 2)

# Check segment 2
(_, _, permanence) = connections.dataForSynapse(4)
self.assertAlmostEqual(permanence, 0.9)
synapseData = connections.dataForSynapse(4)
self.assertAlmostEqual(synapseData.permanence, 0.9)

self.assertEqual(len(connections.synapsesForSegment(2)), 1)

Expand Down Expand Up @@ -400,14 +400,14 @@ def testAdaptSegment(self):

tm.adaptSegment(0, set([0, 1]), connections)

(_, _, permanence) = connections.dataForSynapse(0)
self.assertAlmostEqual(permanence, 0.7)
synapseData = connections.dataForSynapse(0)
self.assertAlmostEqual(synapseData.permanence, 0.7)

(_, _, permanence) = connections.dataForSynapse(1)
self.assertAlmostEqual(permanence, 0.5)
synapseData = connections.dataForSynapse(1)
self.assertAlmostEqual(synapseData.permanence, 0.5)

(_, _, permanence) = connections.dataForSynapse(2)
self.assertAlmostEqual(permanence, 0.8)
synapseData = connections.dataForSynapse(2)
self.assertAlmostEqual(synapseData.permanence, 0.8)


def testAdaptSegmentToMax(self):
Expand All @@ -418,13 +418,13 @@ def testAdaptSegmentToMax(self):
connections.createSynapse(0, 23, 0.9)

tm.adaptSegment(0, set([0]), connections)
(_, _, permanence) = connections.dataForSynapse(0)
self.assertAlmostEqual(permanence, 1.0)
synapseData = connections.dataForSynapse(0)
self.assertAlmostEqual(synapseData.permanence, 1.0)

# Now permanence should be at max
tm.adaptSegment(0, set([0]), connections)
(_, _, permanence) = connections.dataForSynapse(0)
self.assertAlmostEqual(permanence, 1.0)
synapseData = connections.dataForSynapse(0)
self.assertAlmostEqual(synapseData.permanence, 1.0)


def testAdaptSegmentToMin(self):
Expand All @@ -435,13 +435,13 @@ def testAdaptSegmentToMin(self):
connections.createSynapse(0, 23, 0.1)

tm.adaptSegment(0, set(), connections)
(_, _, permanence) = connections.dataForSynapse(0)
self.assertAlmostEqual(permanence, 0.0)
synapseData = connections.dataForSynapse(0)
self.assertAlmostEqual(synapseData.permanence, 0.0)

# Now permanence should be at min
tm.adaptSegment(0, set(), connections)
(_, _, permanence) = connections.dataForSynapse(0)
self.assertAlmostEqual(permanence, 0.0)
synapseData = connections.dataForSynapse(0)
self.assertAlmostEqual(synapseData.permanence, 0.0)


def testPickCellsToLearnOn(self):
Expand Down Expand Up @@ -602,6 +602,32 @@ def testCreateSegment(self):
self.assertEqual(connections.segmentsForCell(0), set([0, 1]))


def testDestroySegment(self):
connections = self.connections

self.assertEqual(connections.createSegment(0), 0)
self.assertEqual(connections.createSegment(0), 1)
self.assertEqual(connections.createSegment(10), 2)

self.assertEqual(connections.createSynapse(0, 254, 0.1173), 0)
self.assertEqual(connections.createSynapse(0, 477, 0.3253), 1)

connections.destroySegment(0)

args = [0]
self.assertRaises(IndexError, connections.dataForSynapse, *args)
args = [1]
self.assertRaises(IndexError, connections.dataForSynapse, *args)

args = [0]
self.assertRaises(IndexError, connections.synapsesForSegment, *args)

self.assertEqual(connections.synapsesForPresynapticCell(174), {})
self.assertEqual(connections.synapsesForPresynapticCell(254), {})

self.assertEqual(connections.segmentsForCell(0), set([1]))


def testCreateSegmentInvalidCell(self):
connections = self.connections

Expand Down Expand Up @@ -670,14 +696,34 @@ def testCreateSynapseInvalidParams(self):
self.assertRaises(ValueError, connections.createSynapse, *args)


def testDestroySynapse(self):
connections = self.connections

connections.createSegment(0)
self.assertEqual(connections.synapsesForSegment(0), set())

self.assertEqual(connections.createSynapse(0, 254, 0.1173), 0)
self.assertEqual(connections.createSynapse(0, 477, 0.3253), 1)

connections.destroySynapse(0)

args = [0]
self.assertRaises(IndexError, connections.dataForSynapse, *args)

self.assertEqual(connections.synapsesForSegment(0), set([1]))

self.assertEqual(connections.synapsesForPresynapticCell(174), {})
self.assertEqual(connections.synapsesForPresynapticCell(254), {})


def testDataForSynapseInvalidSynapse(self):
connections = self.connections

connections.createSegment(0)
connections.createSynapse(0, 834, 0.1284)

args = [1]
self.assertRaises(KeyError, connections.dataForSynapse, *args)
self.assertRaises(IndexError, connections.dataForSynapse, *args)


def testSynapsesForSegmentInvalidSegment(self):
Expand Down

0 comments on commit 8e9674c

Please sign in to comment.