Skip to content

Commit

Permalink
Merge pull request #2980 from rcurtin/dataset-info-resize-fix
Browse files Browse the repository at this point in the history
LoadCSV: only reset the DatasetMapper if the dimensionality is wrong.
  • Loading branch information
zoq committed Jun 20, 2021
2 parents 1313b4e + ae17302 commit a888b56
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@

* Fix Julia model serialization bug (#2970).

* Fix `LoadCSV()` to use pre-populated `DatasetInfo` objects (#2980).

### mlpack 3.4.2
###### 2020-10-26
* Added Mean Absolute Percentage Error.
Expand Down
8 changes: 6 additions & 2 deletions src/mlpack/core/data/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,12 @@ bool Load(const std::string& filename,
* mlpack requires column-major matrices, this should be left at its default
* value of 'true'.
*
* The DatasetMapper object passed to this function will be re-created, so any
* mappings from previous loads will be lost.
* If the given `info` has already been used with a different `data::Load()`
* call where the dataset has the same dimensionality, then the mappings and
* dimension types inside of `info` will be *re-used*. If the given `info` is a
* new `DatasetMapper` object (e.g. its dimensionality is 0), then new mappings
* will be created. If the given `info` has a different dimensionality of data
* than what is present in `filename`, an exception will be thrown.
*
* @param filename Name of file to load.
* @param matrix Matrix to load contents of file into.
Expand Down
30 changes: 27 additions & 3 deletions src/mlpack/core/data/load_csv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,20 @@ class LoadCSV
{
++rows;
}
info = DatasetMapper<MapPolicy>(rows);

// Reset the DatasetInfo object, if needed.
if (info.Dimensionality() == 0)
{
info.SetDimensionality(rows);
}
else if (info.Dimensionality() != rows)
{
std::ostringstream oss;
oss << "data::LoadCSV(): given DatasetInfo has dimensionality "
<< info.Dimensionality() << ", but data has dimensionality "
<< rows;
throw std::invalid_argument(oss.str());
}

// Now, jump back to the beginning of the file.
inFile.clear();
Expand Down Expand Up @@ -179,8 +192,19 @@ class LoadCSV
qi::parse(line.begin(), line.end(),
stringRule[findRowSize] % delimiterRule);

// Now that we know the dimensionality, initialize the DatasetMapper.
info.SetDimensionality(rows);
// Reset the DatasetInfo object, if needed.
if (info.Dimensionality() == 0)
{
info.SetDimensionality(rows);
}
else if (info.Dimensionality() != rows)
{
std::ostringstream oss;
oss << "data::LoadCSV(): given DatasetInfo has dimensionality "
<< info.Dimensionality() << ", but data has dimensionality "
<< rows;
throw std::invalid_argument(oss.str());
}
}

// If we need to do a first pass for the DatasetMapper, do it.
Expand Down

0 comments on commit a888b56

Please sign in to comment.