Skip to content

Commit

Permalink
simplify fromCOO implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Dec 23, 2014
1 parent 10a63a6 commit 4e95e24
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,46 +362,46 @@ object SparseMatrix {
* @return The corresponding `SparseMatrix`
*/
def fromCOO(numRows: Int, numCols: Int, entries: Array[(Int, Int, Double)]): SparseMatrix = {
val numEntries = entries.size
val sortedEntries = entries.sortBy(v => (v._2, v._1))
val colPtrs = new Array[Int](numCols + 1)
var nnz = 0
var lastCol = -1
var lastIndex = -1
sortedEntries.foreach { case (i, j, v) =>
require(i >= 0 && j >= 0, "Negative indices given. Please make sure all indices are " +
s"greater than or equal to zero. i: $i, j: $j, value: $v")
if (v != 0.0) {
while (j != lastCol) {
colPtrs(lastCol + 1) = nnz
lastCol += 1
}
val index = j * numRows + i
if (lastIndex != index) {
nnz += 1
lastIndex = index
}
if (sortedEntries.nonEmpty) {
// Since the entries are sorted by column index, we only need to check the first and the last.
for (col <- Seq(sortedEntries.head._2, sortedEntries.last._2)) {
require(col >= 0 && col < numCols, s"Column index out of range [0, $numCols): $col.")
}
}
while (numCols > lastCol) {
colPtrs(lastCol + 1) = nnz
lastCol += 1
}
val values = new Array[Double](nnz)
val rowIndices = new Array[Int](nnz)
lastIndex = -1
var cnt = -1
sortedEntries.foreach { case (i, j, v) =>
if (v != 0.0) {
val index = j * numRows + i
if (lastIndex != index) {
cnt += 1
lastIndex = index
val colPtrs = new Array[Int](numCols + 1)
val rowIndices = MArrayBuilder.make[Int]
rowIndices.sizeHint(numEntries)
val values = MArrayBuilder.make[Double]
values.sizeHint(numEntries)
var nnz = 0
var prevCol = 0
var prevRow = -1
var prevVal = 0.0
// Append a dummy entry to include the last one at the end of the loop.
(sortedEntries.view :+ (numRows, numCols, 1.0)).foreach { case (i, j, v) =>
if (v != 0) {
if (i == prevRow && j == prevCol) {
prevVal += v
} else {
if (prevVal != 0) {
require(prevRow >= 0 && prevRow < numRows,
s"Row index out of range [0, $numRows): $prevRow.")
nnz += 1
rowIndices += prevRow
values += prevVal
}
prevRow = i
prevVal = v
while (prevCol < j) {
colPtrs(prevCol + 1) = nnz
prevCol += 1
}
}
values(cnt) += v
rowIndices(cnt) = i
}
}
new SparseMatrix(numRows, numCols, colPtrs.toArray, rowIndices, values)
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), values.result())
}

/**
Expand Down

0 comments on commit 4e95e24

Please sign in to comment.