diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 7e448b9e3af6f..a6f3343f56b86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -363,7 +363,6 @@ object SparseMatrix { var i = 0 var nnz = 0 var lastCol = -1 - raw.foreach { v => val r = i % numRows val c = (i - r) / numRows @@ -378,7 +377,10 @@ object SparseMatrix { } i += 1 } - sCols.append(sparseA.length) + while (numCols > lastCol){ + sCols.append(sparseA.length) + lastCol += 1 + } new SparseMatrix(numRows, numCols, sCols.toArray, sRows.toArray, sparseA.toArray) } @@ -399,11 +401,11 @@ object SparseMatrix { s"0.0 < d < 1.0. Currently, density: $density") val rand = new XORShiftRandom(seed) val length = numRows * numCols - val rawA = Array.fill(length)(0.0) + val rawA = new Array[Double](length) var nnz = 0 for (i <- 0 until length) { val p = rand.nextDouble() - if (p < density) { + if (p <= density) { rawA.update(i, rand.nextDouble()) nnz += 1 } @@ -439,11 +441,11 @@ object SparseMatrix { s"0.0 < d < 1.0. Currently, density: $density") val rand = new XORShiftRandom(seed) val length = numRows * numCols - val rawA = Array.fill(length)(0.0) + val rawA = new Array[Double](length) var nnz = 0 for (i <- 0 until length) { val p = rand.nextDouble() - if (p < density) { + if (p <= density) { rawA.update(i, rand.nextGaussian()) nnz += 1 } @@ -476,7 +478,7 @@ object SparseMatrix { val values = sVec.values var i = 0 var lastCol = -1 - val colPtrs = new ArrayBuffer[Int](n) + val colPtrs = new ArrayBuffer[Int](n + 1) rows.foreach { r => while (r != lastCol) { colPtrs.append(i) @@ -484,13 +486,16 @@ object SparseMatrix { } i += 1 } - colPtrs.append(n) + while (n > lastCol) { + colPtrs.append(i) + lastCol += 1 + } new SparseMatrix(n, n, colPtrs.toArray, rows, values) case dVec: DenseVector => val values = dVec.values var i = 0 var nnz = 0 - val sVals = values.filter( v => v != 0.0) + val sVals = values.filter(v => v != 0.0) var lastCol = -1 val colPtrs = new ArrayBuffer[Int](n + 1) val sRows = new ArrayBuffer[Int](sVals.length) @@ -687,10 +692,10 @@ object Matrices { * Horizontally concatenate a sequence of matrices. The returned matrix will be in the format * the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in * a dense matrix. - * @param matrices sequence of matrices + * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated */ - private[mllib] def horzCat(matrices: Seq[Matrix]): Matrix = { + def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.size == 1) { return matrices(0) } @@ -744,7 +749,7 @@ object Matrices { * @param matrices sequence of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated */ - private[mllib] def vertCat(matrices: Seq[Matrix]): Matrix = { + def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.size == 1) { return matrices(0) } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java new file mode 100644 index 0000000000000..e938071d5c3fb --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg; + +import static org.junit.Assert.*; +import org.junit.Test; + +import java.io.Serializable; + +public class JavaMatricesSuite implements Serializable { + + @Test + public void randMatrixConstruction() { + Matrix r = Matrices.rand(3, 4, 24); + DenseMatrix dr = DenseMatrix.rand(3, 4, 24); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + + Matrix rn = Matrices.randn(3, 4, 24); + DenseMatrix drn = DenseMatrix.randn(3, 4, 24); + assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); + + Matrix s = Matrices.sprand(3, 4, 0.5, 24); + SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, 24); + assertArrayEquals(s.toArray(), sr.toArray(), 0.0); + + Matrix sn = Matrices.sprandn(3, 4, 0.5, 24); + SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, 24); + assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); + } + + @Test + public void identityMatrixConstruction() { + Matrix r = Matrices.eye(2); + DenseMatrix dr = DenseMatrix.eye(2); + SparseMatrix sr = SparseMatrix.speye(2); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); + assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); + } + + @Test + public void diagonalMatrixConstruction() { + Vector v = Vectors.dense(1.0, 0.0, 2.0); + Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); + + Matrix m = Matrices.diag(v); + Matrix sm = Matrices.diag(sv); + DenseMatrix d = DenseMatrix.diag(v); + DenseMatrix sd = DenseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.diag(v); + SparseMatrix ss = SparseMatrix.diag(sv); + + assertArrayEquals(m.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sd.toArray(), 0.0); + assertArrayEquals(sd.toArray(), s.toArray(), 0.0); + assertArrayEquals(s.toArray(), ss.toArray(), 0.0); + assertArrayEquals(s.values(), ss.values(), 0.0); + assert(s.values().length == 2); + assert(ss.values().length == 2); + assert(s.colPtrs().length == 2); + assert(ss.colPtrs().length == 2); + } + + @Test + public void zerosMatrixConstruction() { + Matrix z = Matrices.zeros(2, 2); + Matrix one = Matrices.ones(2, 2); + DenseMatrix dz = DenseMatrix.zeros(2, 2); + DenseMatrix done = DenseMatrix.ones(2, 2); + + assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + } + + @Test + public void concatenateMatrices() { + int m = 3; + int n = 2; + + SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, 42); + DenseMatrix deMat1 = DenseMatrix.rand(m, n, 42); + Matrix deMat2 = Matrices.eye(3); + Matrix spMat2 = Matrices.speye(3); + Matrix deMat3 = Matrices.eye(2); + Matrix spMat3 = Matrices.speye(2); + + Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); + Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); + Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); + Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); + + assert(deHorz1.numRows() == 3); + assert(deHorz2.numRows() == 3); + assert(deHorz3.numRows() == 3); + assert(spHorz.numRows() == 3); + assert(deHorz1.numCols() == 5); + assert(deHorz2.numCols() == 5); + assert(deHorz3.numCols() == 5); + assert(spHorz.numCols() == 5); + + Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); + Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); + Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); + Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); + + assert(deVert1.numRows() == 5); + assert(deVert2.numRows() == 5); + assert(deVert3.numRows() == 5); + assert(spVert.numRows() == 5); + assert(deVert1.numCols() == 2); + assert(deVert2.numCols() == 2); + assert(deVert3.numCols() == 2); + assert(spVert.numCols() == 2); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 2793e9aaef86d..ef6c4a3974bf3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -132,7 +132,7 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } - test("horzCat, vertCat, eye, speye") { + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 val values = Array(1.0, 2.0, 4.0, 5.0) @@ -147,10 +147,10 @@ class MatricesSuite extends FunSuite { val deMat3 = Matrices.eye(2) val spMat3 = Matrices.speye(2) - val spHorz = Matrices.horzCat(Seq(spMat1, spMat2)) - val deHorz1 = Matrices.horzCat(Seq(deMat1, deMat2)) - val deHorz2 = Matrices.horzCat(Seq(spMat1, deMat2)) - val deHorz3 = Matrices.horzCat(Seq(deMat1, spMat2)) + val spHorz = Matrices.horzcat(Array(spMat1, spMat2)) + val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) + val deHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) + val deHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) assert(deHorz1.numRows === 3) assert(deHorz2.numRows === 3) @@ -179,17 +179,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(1, 4) === 0.0) intercept[IllegalArgumentException] { - Matrices.horzCat(Seq(spMat1, spMat3)) + Matrices.horzcat(Array(spMat1, spMat3)) } intercept[IllegalArgumentException] { - Matrices.horzCat(Seq(deMat1, spMat3)) + Matrices.horzcat(Array(deMat1, spMat3)) } - val spVert = Matrices.vertCat(Seq(spMat1, spMat3)) - val deVert1 = Matrices.vertCat(Seq(deMat1, deMat3)) - val deVert2 = Matrices.vertCat(Seq(spMat1, deMat3)) - val deVert3 = Matrices.vertCat(Seq(deMat1, spMat3)) + val spVert = Matrices.vertcat(Array(spMat1, spMat3)) + val deVert1 = Matrices.vertcat(Array(deMat1, deMat3)) + val deVert2 = Matrices.vertcat(Array(spMat1, deMat3)) + val deVert3 = Matrices.vertcat(Array(deMat1, spMat3)) assert(deVert1.numRows === 5) assert(deVert2.numRows === 5) @@ -214,11 +214,11 @@ class MatricesSuite extends FunSuite { assert(deVert1(4, 1) === 1.0) intercept[IllegalArgumentException] { - Matrices.vertCat(Seq(spMat1, spMat2)) + Matrices.vertcat(Array(spMat1, spMat2)) } intercept[IllegalArgumentException] { - Matrices.vertCat(Seq(deMat1, spMat2)) + Matrices.vertcat(Array(deMat1, spMat2)) } } }