diff --git a/lib/src/matrix/base_matrix.dart b/lib/src/matrix/base_matrix.dart index 48157191..fad2ef08 100644 --- a/lib/src/matrix/base_matrix.dart +++ b/lib/src/matrix/base_matrix.dart @@ -3,7 +3,7 @@ import 'dart:math' as math; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/matrix_norm.dart'; -import 'package:ml_linalg/src/matrix/byte_data_storage/data_manager.dart'; +import 'package:ml_linalg/src/matrix/data_manager/data_manager.dart'; import 'package:ml_linalg/src/matrix/matrix_validator_mixin.dart'; import 'package:ml_linalg/vector.dart'; import 'package:xrange/zrange.dart'; diff --git a/lib/src/matrix/byte_data_storage/data_manager.dart b/lib/src/matrix/data_manager/data_manager.dart similarity index 89% rename from lib/src/matrix/byte_data_storage/data_manager.dart rename to lib/src/matrix/data_manager/data_manager.dart index 281b3be0..99e25a52 100644 --- a/lib/src/matrix/byte_data_storage/data_manager.dart +++ b/lib/src/matrix/data_manager/data_manager.dart @@ -4,8 +4,6 @@ abstract class DataManager { int get rowsNum; int get columnsNum; Iterator> get dataIterator; - List get rowsCache; - List get columnsCache; Vector getColumn(int index, {bool tryCache = true, bool mutable = false}); void setColumn(int columnNum, Iterable columnValues); Vector getRow(int index, {bool tryCache = true, bool mutable = false}); diff --git a/lib/src/matrix/byte_data_storage/float32_data_manager.dart b/lib/src/matrix/data_manager/data_manager_impl.dart similarity index 68% rename from lib/src/matrix/byte_data_storage/float32_data_manager.dart rename to lib/src/matrix/data_manager/data_manager_impl.dart index a44d80b2..316f2696 100644 --- a/lib/src/matrix/byte_data_storage/float32_data_manager.dart +++ b/lib/src/matrix/data_manager/data_manager_impl.dart @@ -1,63 +1,71 @@ import 'dart:typed_data'; -import 'package:ml_linalg/src/matrix/byte_data_storage/data_manager.dart'; +import 'package:ml_linalg/src/matrix/data_manager/data_manager.dart'; import 'package:ml_linalg/src/matrix/float32x4/float32_matrix_iterator.dart'; import 'package:ml_linalg/vector.dart'; -class Float32DataManager implements DataManager { - Float32DataManager.from( +class DataManagerImpl implements DataManager { + DataManagerImpl.from( Iterable> source, int bytesPerElement, + this._dtype, ) : rowsNum = source.length, columnsNum = source.first.length, - rowsCache = List(source.length), - columnsCache = List(source.first.length), + _rowsCache = List(source.length), + _columnsCache = List(source.first.length), _data = ByteData(source.length * source.first.length * - bytesPerElement) { + bytesPerElement), + _bytesPerElement = bytesPerElement { final flattened = _flatten2dimList(source, (i, j) => i * columnsNum + j); updateAll(0, flattened); } - Float32DataManager.fromRows( + DataManagerImpl.fromRows( Iterable source, int bytesPerElement, + this._dtype, ) : rowsNum = source.length, columnsNum = source.first.length, - rowsCache = source.toList(growable: false), - columnsCache = List(source.first.length), + _rowsCache = source.toList(growable: false), + _columnsCache = List(source.first.length), _data = ByteData(source.length * source.first.length * - bytesPerElement) { + bytesPerElement), + _bytesPerElement = bytesPerElement { final flattened = _flatten2dimList(source, (i, j) => i * columnsNum + j); updateAll(0, flattened); } - Float32DataManager.fromColumns( + DataManagerImpl.fromColumns( Iterable source, int bytesPerElement, + this._dtype, ) : rowsNum = source.first.length, columnsNum = source.length, - rowsCache = List(source.first.length), - columnsCache = source.toList(growable: false), + _rowsCache = List(source.first.length), + _columnsCache = source.toList(growable: false), _data = ByteData(source.length * source.first.length * - bytesPerElement) { + bytesPerElement), + _bytesPerElement = bytesPerElement { final flattened = _flatten2dimList(source, (i, j) => j * columnsNum + i); updateAll(0, flattened); } - Float32DataManager.fromFlattened( + DataManagerImpl.fromFlattened( Iterable source, int rowsNum, int colsNum, int bytesPerElement, + this._dtype, ) : rowsNum = rowsNum, columnsNum = colsNum, - rowsCache = List(rowsNum), - columnsCache = List(colsNum), - _data = ByteData(rowsNum * colsNum * bytesPerElement) { + _rowsCache = List(rowsNum), + _columnsCache = List(colsNum), + _data = ByteData(rowsNum * colsNum * bytesPerElement), + _bytesPerElement = bytesPerElement { if (source.length != rowsNum * colsNum) { throw Exception('Invalid matrix dimension has been provided - ' '$rowsNum x $colsNum, but given a collection of length ' @@ -72,13 +80,11 @@ class Float32DataManager implements DataManager { @override final int rowsNum; - @override - final List rowsCache; - - @override - final List columnsCache; - + final List _rowsCache; + final List _columnsCache; + final int _bytesPerElement; final ByteData _data; + final Type _dtype; @override Iterator> get dataIterator => @@ -87,12 +93,12 @@ class Float32DataManager implements DataManager { //TODO consider a check if the index is inside the _data @override Float32List getValues(int index, int length) => - _data.buffer.asFloat32List(index * Float32List.bytesPerElement, length); + _data.buffer.asFloat32List(index * _bytesPerElement, length); //TODO consider a check if the index is inside the _data @override void update(int idx, double value) => - _data.setFloat32(idx * Float32List.bytesPerElement, value, Endian.host); + _data.setFloat32(idx * _bytesPerElement, value, Endian.host); @override void updateAll(int idx, Iterable values) { @@ -102,30 +108,30 @@ class Float32DataManager implements DataManager { @override Vector getRow(int index, {bool tryCache = true, bool mutable = false}) { if (tryCache) { - rowsCache[index] ??= Vector.from(getValues(index * columnsNum, - columnsNum), isMutable: mutable, dtype: Float32x4); - return rowsCache[index]; + _rowsCache[index] ??= Vector.from(getValues(index * columnsNum, + columnsNum), isMutable: mutable, dtype: _dtype); + return _rowsCache[index]; } else { return Vector.from(getValues(index * columnsNum, columnsNum), - isMutable: mutable, dtype: Float32x4); + isMutable: mutable, dtype: _dtype); } } @override Vector getColumn(int index, {bool tryCache = true, bool mutable = false}) { - if (columnsCache[index] == null || !tryCache) { + if (_columnsCache[index] == null || !tryCache) { final result = List(rowsNum); for (int i = 0; i < rowsNum; i++) { //@TODO: find a more efficient way to get the single value result[i] = getValues(i * columnsNum + index, 1).first; } - final column = Vector.from(result, isMutable: mutable, dtype: Float32x4); + final column = Vector.from(result, isMutable: mutable, dtype: _dtype); if (!tryCache) { return column; } - columnsCache[index] = column; + _columnsCache[index] = column; } - return columnsCache[index]; + return _columnsCache[index]; } @override @@ -139,8 +145,8 @@ class Float32DataManager implements DataManager { 'matrix rows number is $rowsNum'); } // clear rows cache - rowsCache.fillRange(0, rowsNum, null); - columnsCache[columnNum] = columnValues is Vector + _rowsCache.fillRange(0, rowsNum, null); + _columnsCache[columnNum] = columnValues is Vector ? columnValues : Vector.from(columnValues); final values = columnValues.toList(growable: false); for (int i = 0, j = 0; i < rowsNum * columnsNum; i++) { diff --git a/lib/src/matrix/float32x4/float32x4_matrix.dart b/lib/src/matrix/float32x4/float32x4_matrix.dart index 31f7b572..984e2b93 100644 --- a/lib/src/matrix/float32x4/float32x4_matrix.dart +++ b/lib/src/matrix/float32x4/float32x4_matrix.dart @@ -1,25 +1,27 @@ import 'dart:core'; import 'dart:typed_data'; -import 'package:ml_linalg/src/matrix/byte_data_storage/float32_data_manager.dart'; import 'package:ml_linalg/src/matrix/base_matrix.dart'; +import 'package:ml_linalg/src/matrix/data_manager/data_manager_impl.dart'; import 'package:ml_linalg/vector.dart'; class Float32x4Matrix extends BaseMatrix { Float32x4Matrix.from(Iterable> source) : - super(Float32DataManager.from(source, Float32List.bytesPerElement)); + super(DataManagerImpl.from(source, Float32List.bytesPerElement, + Float32x4)); Float32x4Matrix.columns(Iterable source) : - super(Float32DataManager - .fromColumns(source, Float32List.bytesPerElement)); + super(DataManagerImpl + .fromColumns(source, Float32List.bytesPerElement, Float32x4)); Float32x4Matrix.rows(Iterable source) : - super(Float32DataManager.fromRows(source, Float32List.bytesPerElement)); + super(DataManagerImpl.fromRows(source, Float32List.bytesPerElement, + Float32x4)); Float32x4Matrix.flattened(Iterable source, int rowsNum, int columnsNum) : - super(Float32DataManager.fromFlattened(source, rowsNum, columnsNum, - Float32List.bytesPerElement)); + super(DataManagerImpl.fromFlattened(source, rowsNum, columnsNum, + Float32List.bytesPerElement, Float32x4)); @override final Type dtype = Float32x4;