Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Jun 2, 2021
1 parent 90e5f0a commit 15fa45d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SparseDatasetAggregator(columnParams: ColumnParams, chunkSize: Int,
@volatile var threadRowStartIndex = 0L
@volatile var threadInitScoreStartIndex = 0L
@volatile var threadIndexesStartIndex = 0L
@volatile var threadIndptrStartIndex = 0L
@volatile var threadIndptrStartIndex = 1L

def setNumCols(numCols: Int): Unit = {
this.numCols = numCols
Expand All @@ -57,6 +57,10 @@ class SparseDatasetAggregator(columnParams: ColumnParams, chunkSize: Int,
indptrCount: Long): Unit = {
if (synchronized) {
this.synchronized {
// Add extra 0 for start of indptr in parallel case
if (this.indptrCount == 0) {
this.indptrCount += 1
}
innerIncrementCount(rowCount, initScoreCount, indexesCount, indptrCount)
}
} else {
Expand Down Expand Up @@ -86,6 +90,7 @@ class SparseDatasetAggregator(columnParams: ColumnParams, chunkSize: Int,
indexesArray = Some(lightgbmlib.new_intArray(this.indexesCount))
valuesArray = Some(lightgbmlib.new_doubleArray(this.indexesCount))
indptrArray = Some(lightgbmlib.new_intArray(this.indptrCount))
lightgbmlib.intArray_setitem(indptrArray.get, 0, 0)
groupColumnValuesArray = new Array[Row](this.rowCount.toInt)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ object DatasetUtils {
var indptrChunkedArray = new int32ChunkedArray(chunkSize)
val groupColumnValues: ListBuffer[Row] = new ListBuffer[Row]()
var rowCount = 0
indptrChunkedArray.add(0)
if (!useSingleDataset) {
indptrChunkedArray.add(0)
}
while (rowsIter.hasNext) {
rowCount += 1
val row = rowsIter.next()
Expand Down Expand Up @@ -276,6 +278,21 @@ object DatasetUtils {
indexesChunkedArray, valuesChunkedArray, indptrChunkedArray, groupColumnValues)
}

def clearSparseArrays(labelsChunkedArray: floatChunkedArray,
weightChunkedArrayOpt: Option[floatChunkedArray],
initScoreChunkedArrayOpt: Option[doubleChunkedArray],
indexesChunkedArray: int32ChunkedArray,
valuesChunkedArray: doubleChunkedArray,
indptrChunkedArray: int32ChunkedArray): Unit = {
// Clear memory
labelsChunkedArray.delete()
weightChunkedArrayOpt.foreach(_.delete())
initScoreChunkedArrayOpt.foreach(_.delete())
indexesChunkedArray.delete()
valuesChunkedArray.delete()
indptrChunkedArray.delete()
}

def aggregateDenseStreamedData(rowsIter: Iterator[Row], columnParams: ColumnParams,
referenceDataset: Option[LightGBMDataset], schema: StructType,
log: Logger, trainParams: TrainParams): Option[LightGBMDataset] = {
Expand Down

0 comments on commit 15fa45d

Please sign in to comment.