diff --git a/fern_AxisAligned_regression.png b/fern_AxisAligned_regression.png new file mode 100644 index 0000000..a8a55d6 Binary files /dev/null and b/fern_AxisAligned_regression.png differ diff --git a/fern_Conic_regression.png b/fern_Conic_regression.png new file mode 100644 index 0000000..aac675b Binary files /dev/null and b/fern_Conic_regression.png differ diff --git a/fern_Linear_regression.png b/fern_Linear_regression.png new file mode 100644 index 0000000..f7a208c Binary files /dev/null and b/fern_Linear_regression.png differ diff --git a/fern_Parabola_regression.png b/fern_Parabola_regression.png new file mode 100644 index 0000000..dd15f92 Binary files /dev/null and b/fern_Parabola_regression.png differ diff --git a/randomferns.py b/randomferns.py index 8c055a5..6f0f32f 100644 --- a/randomferns.py +++ b/randomferns.py @@ -15,10 +15,11 @@ def fit( self, points, responses ): self.tests = np.array( self.test_class.generate_all( points, self.depth ) ) if self.regression: self.target_dim = responses.shape[1] - self.data = np.ones( (2**self.depth, self.target_dim), dtype='float64' ) + self.data = np.zeros( (2**self.depth, self.target_dim), dtype='float64' ) bins = self.apply_tests(points) bincount = np.bincount(bins, minlength=self.data.shape[0]) - self.data[self.apply_tests(points)] += responses + for dim in range(self.target_dim): + self.data[:,dim] += np.bincount(bins, weights=responses[:,dim], minlength=self.data.shape[0]) self.data[bincount>0] /= bincount[bincount>0][...,np.newaxis] else: self.n_classes = responses.max() + 1 diff --git a/randomferns_AxisAligned_regression.png b/randomferns_AxisAligned_regression.png new file mode 100644 index 0000000..c063422 Binary files /dev/null and b/randomferns_AxisAligned_regression.png differ diff --git a/randomferns_Conic_regression.png b/randomferns_Conic_regression.png new file mode 100644 index 0000000..470102e Binary files /dev/null and b/randomferns_Conic_regression.png differ diff --git a/randomferns_Linear_regression.png b/randomferns_Linear_regression.png new file mode 100644 index 0000000..08d513c Binary files /dev/null and b/randomferns_Linear_regression.png differ diff --git a/randomferns_Parabola_regression.png b/randomferns_Parabola_regression.png new file mode 100644 index 0000000..0b618cb Binary files /dev/null and b/randomferns_Parabola_regression.png differ