Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Added split method to dataset, and special method __len__

  • Loading branch information...
commit c0af486c489524358029793c9a90493572eb8f44 1 parent 9cdf62d
gasagna authored November 26, 2011

Showing 1 changed file with 42 additions and 0 deletions. Show diff stats Hide diff stats

  1. 42  nn.py
42  nn.py
@@ -78,6 +78,48 @@ def __init__ ( self, inputs, targets ):
78 78
         # length of the dataset, number of samples
79 79
         self.n_samples = self.inputs.shape[0]
80 80
         
  81
+    def split( self, fractions=[0.5, 0.5]):
  82
+        """Split randomly the dataset into smaller dataset.
  83
+        
  84
+        Parameters
  85
+        ----------
  86
+        fraction: list of floats, default = [0.5, 0.5]
  87
+            the dataset is split into ``len(fraction)`` smaller
  88
+            dataset, and the ``i``-th dataset has a size
  89
+            which is ``fraction[i]`` of the original dataset.
  90
+            Note that ``sum(fraction)`` can also be smaller than one
  91
+            but not greater.
  92
+            
  93
+        Returns
  94
+        -------
  95
+        subsets: list of :py:class:`nn.Dataset`
  96
+            a list of the subsets of the original datasets
  97
+        """
  98
+        
  99
+
  100
+
  101
+        if sum(fractions) > 1.0 or sum(fractions) <= 0:
  102
+            raise ValueError( "the sum of fractions argument should be between 0 and 1" )
  103
+        
  104
+        # random indices
  105
+        idx = np.arange(self.n_samples)
  106
+        np.random.shuffle(idx)
  107
+        
  108
+        # insert zero
  109
+        fractions.insert(0, 0)
  110
+        
  111
+        # gte limits of the subsets
  112
+        limits = (np.cumsum(fractions)*self.n_samples ).astype(np.int32)
  113
+                
  114
+        subsets = []
  115
+        # create output dataset
  116
+        for i in range(len(fractions)-1):
  117
+            subsets.append( Dataset(self.inputs[idx[limits[i]:limits[i+1]]], self.targets[idx[limits[i]:limits[i+1]]]) )
  118
+        
  119
+        return subsets
  120
+
  121
+    def __len__(self):
  122
+        return len( self.inputs )
81 123
 
82 124
 class MultiLayerPerceptron( ):
83 125
     """A Multi Layer Perceptron feed-forward neural network.

0 notes on commit c0af486

Please sign in to comment.
Something went wrong with that request. Please try again.