diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 4a6cfde0634cf..076aeaf5a987c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -362,46 +362,46 @@ object SparseMatrix { * @return The corresponding `SparseMatrix` */ def fromCOO(numRows: Int, numCols: Int, entries: Array[(Int, Int, Double)]): SparseMatrix = { + val numEntries = entries.size val sortedEntries = entries.sortBy(v => (v._2, v._1)) - val colPtrs = new Array[Int](numCols + 1) - var nnz = 0 - var lastCol = -1 - var lastIndex = -1 - sortedEntries.foreach { case (i, j, v) => - require(i >= 0 && j >= 0, "Negative indices given. Please make sure all indices are " + - s"greater than or equal to zero. i: $i, j: $j, value: $v") - if (v != 0.0) { - while (j != lastCol) { - colPtrs(lastCol + 1) = nnz - lastCol += 1 - } - val index = j * numRows + i - if (lastIndex != index) { - nnz += 1 - lastIndex = index - } + if (sortedEntries.nonEmpty) { + // Since the entries are sorted by column index, we only need to check the first and the last. + for (col <- Seq(sortedEntries.head._2, sortedEntries.last._2)) { + require(col >= 0 && col < numCols, s"Column index out of range [0, $numCols): $col.") } } - while (numCols > lastCol) { - colPtrs(lastCol + 1) = nnz - lastCol += 1 - } - val values = new Array[Double](nnz) - val rowIndices = new Array[Int](nnz) - lastIndex = -1 - var cnt = -1 - sortedEntries.foreach { case (i, j, v) => - if (v != 0.0) { - val index = j * numRows + i - if (lastIndex != index) { - cnt += 1 - lastIndex = index + val colPtrs = new Array[Int](numCols + 1) + val rowIndices = MArrayBuilder.make[Int] + rowIndices.sizeHint(numEntries) + val values = MArrayBuilder.make[Double] + values.sizeHint(numEntries) + var nnz = 0 + var prevCol = 0 + var prevRow = -1 + var prevVal = 0.0 + // Append a dummy entry to include the last one at the end of the loop. + (sortedEntries.view :+ (numRows, numCols, 1.0)).foreach { case (i, j, v) => + if (v != 0) { + if (i == prevRow && j == prevCol) { + prevVal += v + } else { + if (prevVal != 0) { + require(prevRow >= 0 && prevRow < numRows, + s"Row index out of range [0, $numRows): $prevRow.") + nnz += 1 + rowIndices += prevRow + values += prevVal + } + prevRow = i + prevVal = v + while (prevCol < j) { + colPtrs(prevCol + 1) = nnz + prevCol += 1 + } } - values(cnt) += v - rowIndices(cnt) = i } } - new SparseMatrix(numRows, numCols, colPtrs.toArray, rowIndices, values) + new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), values.result()) } /**