diff --git a/lib/matrix.dart b/lib/matrix.dart index 65ac0106..60b60f96 100644 --- a/lib/matrix.dart +++ b/lib/matrix.dart @@ -4,11 +4,16 @@ import 'package:ml_linalg/axis.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix_norm.dart'; import 'package:ml_linalg/sort_direction.dart'; +import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory.dart'; +import 'package:ml_linalg/src/di/dependencies.dart'; import 'package:ml_linalg/src/matrix/data_manager/float64_matrix_data_manager.dart'; +import 'package:ml_linalg/src/matrix/matrix_cache_keys.dart'; import 'package:ml_linalg/src/matrix/matrix_impl.dart'; import 'package:ml_linalg/src/matrix/data_manager/float32_matrix_data_manager.dart'; import 'package:ml_linalg/vector.dart'; +final _cacheManagerFactory = dependencies.getDependency(); + /// An algebraic matrix with extended functionality, adapted for data science /// applications abstract class Matrix implements Iterable> { @@ -45,10 +50,16 @@ abstract class Matrix implements Iterable> { {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromList(source)); + return MatrixImpl( + Float32MatrixDataManager.fromList(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromList(source)); + return MatrixImpl( + Float64MatrixDataManager.fromList(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -87,10 +98,16 @@ abstract class Matrix implements Iterable> { factory Matrix.fromRows(List source, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromRows(source)); + return MatrixImpl( + Float32MatrixDataManager.fromRows(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromRows(source)); + return MatrixImpl( + Float64MatrixDataManager.fromRows(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -133,10 +150,16 @@ abstract class Matrix implements Iterable> { {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromColumns(source)); + return MatrixImpl( + Float32MatrixDataManager.fromColumns(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromColumns(source)); + return MatrixImpl( + Float64MatrixDataManager.fromColumns(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -165,10 +188,16 @@ abstract class Matrix implements Iterable> { factory Matrix.empty({DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromList([])); + return MatrixImpl( + Float32MatrixDataManager.fromList([]), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromList([])); + return MatrixImpl( + Float64MatrixDataManager.fromList([]), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -203,12 +232,16 @@ abstract class Matrix implements Iterable> { int columnsNum, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromFlattened( - source, rowsNum, columnsNum)); + return MatrixImpl( + Float32MatrixDataManager.fromFlattened(source, rowsNum, columnsNum), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromFlattened( - source, rowsNum, columnsNum)); + return MatrixImpl( + Float64MatrixDataManager.fromFlattened(source, rowsNum, columnsNum), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -241,10 +274,16 @@ abstract class Matrix implements Iterable> { factory Matrix.diagonal(List source, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.diagonal(source)); + return MatrixImpl( + Float32MatrixDataManager.diagonal(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.diagonal(source)); + return MatrixImpl( + Float64MatrixDataManager.diagonal(source), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -277,10 +316,16 @@ abstract class Matrix implements Iterable> { factory Matrix.scalar(double scalar, int size, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.scalar(scalar, size)); + return MatrixImpl( + Float32MatrixDataManager.scalar(scalar, size), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.scalar(scalar, size)); + return MatrixImpl( + Float64MatrixDataManager.scalar(scalar, size), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -313,10 +358,16 @@ abstract class Matrix implements Iterable> { factory Matrix.identity(int size, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.scalar(1.0, size)); + return MatrixImpl( + Float32MatrixDataManager.scalar(1.0, size), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.scalar(1.0, size)); + return MatrixImpl( + Float64MatrixDataManager.scalar(1.0, size), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError('Matrix of type $dtype is not implemented yet'); @@ -344,14 +395,18 @@ abstract class Matrix implements Iterable> { factory Matrix.row(List source, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromRows( - [Vector.fromList(source, dtype: dtype)], - )); + return MatrixImpl( + Float32MatrixDataManager + .fromRows([Vector.fromList(source, dtype: dtype)]), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromRows( - [Vector.fromList(source, dtype: dtype)], - )); + return MatrixImpl( + Float64MatrixDataManager + .fromRows([Vector.fromList(source, dtype: dtype)]), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError( @@ -384,14 +439,18 @@ abstract class Matrix implements Iterable> { factory Matrix.column(List source, {DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return MatrixImpl(Float32MatrixDataManager.fromColumns(( - [Vector.fromList(source, dtype: dtype)]), - )); + return MatrixImpl( + Float32MatrixDataManager + .fromColumns(([Vector.fromList(source, dtype: dtype)])), + _cacheManagerFactory.create(matrixCacheKeys), + ); case DType.float64: - return MatrixImpl(Float64MatrixDataManager.fromColumns(( - [Vector.fromList(source, dtype: dtype)]), - )); + return MatrixImpl( + Float64MatrixDataManager + .fromColumns(([Vector.fromList(source, dtype: dtype)])), + _cacheManagerFactory.create(matrixCacheKeys), + ); default: throw UnimplementedError( diff --git a/lib/src/common/cache_manager/cache_manager.dart b/lib/src/common/cache_manager/cache_manager.dart index 56816095..5aad3998 100644 --- a/lib/src/common/cache_manager/cache_manager.dart +++ b/lib/src/common/cache_manager/cache_manager.dart @@ -1,5 +1,5 @@ abstract class CacheManager { - T retrieveValue(String cachedValueName, T calculateIfAbsent(), { + T retrieveValue(String key, T calculateIfAbsent(), { bool skipCaching, }); diff --git a/lib/src/common/cache_manager/cache_manager_factory.dart b/lib/src/common/cache_manager/cache_manager_factory.dart index 36509595..b932d30a 100644 --- a/lib/src/common/cache_manager/cache_manager_factory.dart +++ b/lib/src/common/cache_manager/cache_manager_factory.dart @@ -1,5 +1,5 @@ import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart'; abstract class CacheManagerFactory { - CacheManager create(); + CacheManager create(Set keys); } diff --git a/lib/src/common/cache_manager/cache_manager_factory_impl.dart b/lib/src/common/cache_manager/cache_manager_factory_impl.dart index fe98588a..3c173e19 100644 --- a/lib/src/common/cache_manager/cache_manager_factory_impl.dart +++ b/lib/src/common/cache_manager/cache_manager_factory_impl.dart @@ -6,5 +6,5 @@ class CacheManagerFactoryImpl implements CacheManagerFactory { const CacheManagerFactoryImpl(); @override - CacheManager create() => CacheManagerImpl(); + CacheManager create(Set keys) => CacheManagerImpl(keys); } diff --git a/lib/src/common/cache_manager/cache_manager_impl.dart b/lib/src/common/cache_manager/cache_manager_impl.dart index aa47e914..875e6e35 100644 --- a/lib/src/common/cache_manager/cache_manager_impl.dart +++ b/lib/src/common/cache_manager/cache_manager_impl.dart @@ -1,19 +1,26 @@ import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart'; class CacheManagerImpl implements CacheManager { + CacheManagerImpl(this._keys); + final _cache = {}; + final Set _keys; @override - T retrieveValue(String cachedValueName, T Function() calculateIfAbsent, { + T retrieveValue(String key, T Function() calculateIfAbsent, { bool skipCaching = false, }) { - var value = _cache[cachedValueName] as T; + if (!_keys.contains(key)) { + throw Exception('Cache key `$key` is not registered'); + } + + var value = _cache[key] as T; if (value == null) { value = calculateIfAbsent(); if (!skipCaching) { - _cache[cachedValueName] = value; + _cache[key] = value; } } diff --git a/lib/src/matrix/matrix_cache_keys.dart b/lib/src/matrix/matrix_cache_keys.dart new file mode 100644 index 00000000..87b33edb --- /dev/null +++ b/lib/src/matrix/matrix_cache_keys.dart @@ -0,0 +1,11 @@ +const meansByRowsKey = 'means_by_rows'; +const meansByColumnsKey = 'means_by_columns'; +const deviationByRowsKey = 'deviation_by_rows'; +const deviationByColumnsKey = 'deviation_by_columns'; + +final matrixCacheKeys = Set.from([ + meansByRowsKey, + meansByColumnsKey, + deviationByRowsKey, + deviationByColumnsKey, +]); diff --git a/lib/src/matrix/matrix_impl.dart b/lib/src/matrix/matrix_impl.dart index fbf5c712..127f77e4 100644 --- a/lib/src/matrix/matrix_impl.dart +++ b/lib/src/matrix/matrix_impl.dart @@ -6,7 +6,9 @@ import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/matrix_norm.dart'; import 'package:ml_linalg/sort_direction.dart'; +import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart'; import 'package:ml_linalg/src/matrix/data_manager/matrix_data_manager.dart'; +import 'package:ml_linalg/src/matrix/matrix_cache_keys.dart'; import 'package:ml_linalg/src/matrix/mixin/matrix_validator_mixin.dart'; import 'package:ml_linalg/vector.dart'; import 'package:quiver/iterables.dart'; @@ -14,11 +16,10 @@ import 'package:quiver/iterables.dart'; class MatrixImpl with IterableMixin>, MatrixValidatorMixin implements Matrix { - MatrixImpl(this._dataManager); + MatrixImpl(this._dataManager, this._cacheManager); final MatrixDataManager _dataManager; - final Map _meansCache = {}; - final Map _deviationCache = {}; + final CacheManager _cacheManager; @override DType get dtype => _dataManager.dtype; @@ -172,10 +173,12 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin switch (axis) { case Axis.columns: - return _meansCache[axis] ??= _mean(columns); + return _cacheManager + .retrieveValue(meansByColumnsKey, () => _mean(columns)); case Axis.rows: - return _meansCache[axis] ??= _mean(rows); + return _cacheManager + .retrieveValue(meansByRowsKey, () => _mean(rows)); default: throw UnimplementedError('Mean values calculation for axis $axis is not ' @@ -198,10 +201,12 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin switch (axis) { case Axis.columns: - return _deviationCache[axis] ??= _deviation(rows, means, rowsNum); + return _cacheManager.retrieveValue(deviationByColumnsKey, + () => _deviation(rows, means, rowsNum)); case Axis.rows: - return _deviationCache[axis] ??= _deviation(columns, means, columnsNum); + return _cacheManager.retrieveValue(deviationByRowsKey, + () => _deviation(columns, means, columnsNum)); default: throw UnimplementedError('Deviation calculation for axis $axis is not ' diff --git a/lib/src/vector/float32x4_vector.dart b/lib/src/vector/float32x4_vector.dart index b0342d7d..247171e1 100644 --- a/lib/src/vector/float32x4_vector.dart +++ b/lib/src/vector/float32x4_vector.dart @@ -8,6 +8,7 @@ import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/norm.dart'; import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart'; import 'package:ml_linalg/src/vector/simd_helper/simd_helper.dart'; +import 'package:ml_linalg/src/vector/vector_cache_keys.dart'; import 'package:ml_linalg/vector.dart'; const _bytesPerElement = Float32List.bytesPerElement; @@ -119,7 +120,7 @@ class Float32x4Vector with IterableMixin implements Vector { } @override - int get hashCode => _cacheManager.retrieveValue('hash', () { + int get hashCode => _cacheManager.retrieveValue(hashKey, () { if (isEmpty) { return 0; } @@ -235,7 +236,7 @@ class Float32x4Vector with IterableMixin implements Vector { @override Vector sqrt({bool skipCaching = false}) => - _cacheManager.retrieveValue('sqrt', () { + _cacheManager.retrieveValue(sqrtKey, () { final source = Float32x4List(_numOfBuckets); for (int i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].sqrt(); @@ -251,7 +252,7 @@ class Float32x4Vector with IterableMixin implements Vector { @override Vector abs({bool skipCaching = false}) => - _cacheManager.retrieveValue('abs', () { + _cacheManager.retrieveValue(absKey, () { final source = Float32x4List(_numOfBuckets); for (int i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].abs(); @@ -263,7 +264,7 @@ class Float32x4Vector with IterableMixin implements Vector { double dot(Vector vector) => (this * vector).sum(); @override - double sum({bool skipCaching = false}) => _cacheManager.retrieveValue('sum', + double sum({bool skipCaching = false}) => _cacheManager.retrieveValue(sumKey, () => _simdHelper.sumLanes(_innerSimdList.reduce((a, b) => a + b)), skipCaching: skipCaching); @@ -299,14 +300,14 @@ class Float32x4Vector with IterableMixin implements Vector { if (isEmpty) { throw _emptyVectorException; } - return _cacheManager.retrieveValue('mean', () => sum() / length, + return _cacheManager.retrieveValue(meanKey, () => sum() / length, skipCaching: skipCaching); } @override - double norm([Norm norm = Norm.euclidean, bool skipCaching = false]) => - _cacheManager.retrieveValue('norm_$norm', () { - final power = _getPowerByNormType(norm); + double norm([Norm normType = Norm.euclidean, bool skipCaching = false]) => + _cacheManager.retrieveValue(getCacheKeyForNormByNormType(normType), () { + final power = _getPowerByNormType(normType); if (power == 1) { return abs().sum(); } @@ -315,14 +316,16 @@ class Float32x4Vector with IterableMixin implements Vector { }, skipCaching: skipCaching); @override - double max({bool skipCaching = false}) => _cacheManager.retrieveValue('max', () => - _findExtrema(-double.infinity, _simdHelper.getMaxLane, - (a, b) => a.max(b), math.max), skipCaching: skipCaching); + double max({bool skipCaching = false}) => + _cacheManager.retrieveValue(maxKey, () => + _findExtrema(-double.infinity, _simdHelper.getMaxLane, + (a, b) => a.max(b), math.max), skipCaching: skipCaching); @override - double min({bool skipCaching = false}) => _cacheManager.retrieveValue('min', () => - _findExtrema(double.infinity, _simdHelper.getMinLane, - (a, b) => a.min(b), math.min), skipCaching: skipCaching); + double min({bool skipCaching = false}) => + _cacheManager.retrieveValue(minKey, () => + _findExtrema(double.infinity, _simdHelper.getMinLane, + (a, b) => a.min(b), math.min), skipCaching: skipCaching); double _findExtrema(double initialValue, double getExtremalLane(Float32x4 bucket), @@ -355,7 +358,7 @@ class Float32x4Vector with IterableMixin implements Vector { @override Vector unique({bool skipCaching = false}) => - _cacheManager.retrieveValue('unique', () => Vector.fromList( + _cacheManager.retrieveValue(uniqueKey, () => Vector.fromList( Set.from(this).toList(growable: false), dtype: dtype), skipCaching: skipCaching); @@ -395,18 +398,19 @@ class Float32x4Vector with IterableMixin implements Vector { } final limit = end == null || end > length ? length : end; final collection = _innerTypedList.sublist(start, limit); + return Vector.fromList(collection, dtype: dtype); } @override Vector normalize([Norm normType = Norm.euclidean, bool skipCaching = false]) => - _cacheManager.retrieveValue('normalize_$normType', + _cacheManager.retrieveValue(getCacheKeyForNormalizeByNormType(normType), () => this / norm(normType), skipCaching: skipCaching); @override Vector rescale({bool skipCaching = false}) => - _cacheManager.retrieveValue('rescale', () { + _cacheManager.retrieveValue(rescaleKey, () { final minValue = min(); final maxValue = max(); return (this - minValue) / (maxValue - minValue); diff --git a/lib/src/vector/float64x2_vector.dart b/lib/src/vector/float64x2_vector.dart index 54adf6e4..73e45414 100644 --- a/lib/src/vector/float64x2_vector.dart +++ b/lib/src/vector/float64x2_vector.dart @@ -10,6 +10,7 @@ import 'package:ml_linalg/matrix.dart'; import 'package:ml_linalg/norm.dart'; import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart'; import 'package:ml_linalg/src/vector/simd_helper/simd_helper.dart'; +import 'package:ml_linalg/src/vector/vector_cache_keys.dart'; import 'package:ml_linalg/vector.dart'; const _bytesPerElement = Float64List.bytesPerElement; @@ -121,7 +122,7 @@ class Float64x2Vector with IterableMixin implements Vector { } @override - int get hashCode => _cacheManager.retrieveValue('hash', () { + int get hashCode => _cacheManager.retrieveValue(hashKey, () { if (isEmpty) { return 0; } @@ -237,7 +238,7 @@ class Float64x2Vector with IterableMixin implements Vector { @override Vector sqrt({bool skipCaching = false}) => - _cacheManager.retrieveValue('sqrt', () { + _cacheManager.retrieveValue(sqrtKey, () { final source = Float64x2List(_numOfBuckets); for (int i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].sqrt(); @@ -253,7 +254,7 @@ class Float64x2Vector with IterableMixin implements Vector { @override Vector abs({bool skipCaching = false}) => - _cacheManager.retrieveValue('abs', () { + _cacheManager.retrieveValue(absKey, () { final source = Float64x2List(_numOfBuckets); for (int i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].abs(); @@ -265,7 +266,7 @@ class Float64x2Vector with IterableMixin implements Vector { double dot(Vector vector) => (this * vector).sum(); @override - double sum({bool skipCaching = false}) => _cacheManager.retrieveValue('sum', + double sum({bool skipCaching = false}) => _cacheManager.retrieveValue(sumKey, () => _simdHelper.sumLanes(_innerSimdList.reduce((a, b) => a + b)), skipCaching: skipCaching); @@ -301,14 +302,14 @@ class Float64x2Vector with IterableMixin implements Vector { if (isEmpty) { throw _emptyVectorException; } - return _cacheManager.retrieveValue('mean', () => sum() / length, + return _cacheManager.retrieveValue(meanKey, () => sum() / length, skipCaching: skipCaching); } @override - double norm([Norm norm = Norm.euclidean, bool skipCaching = false]) => - _cacheManager.retrieveValue('norm_$norm', () { - final power = _getPowerByNormType(norm); + double norm([Norm normType = Norm.euclidean, bool skipCaching = false]) => + _cacheManager.retrieveValue(getCacheKeyForNormByNormType(normType), () { + final power = _getPowerByNormType(normType); if (power == 1) { return abs().sum(); } @@ -317,14 +318,16 @@ class Float64x2Vector with IterableMixin implements Vector { }, skipCaching: skipCaching); @override - double max({bool skipCaching = false}) => _cacheManager.retrieveValue('max', () => - _findExtrema(-double.infinity, _simdHelper.getMaxLane, - (a, b) => a.max(b), math.max), skipCaching: skipCaching); + double max({bool skipCaching = false}) => + _cacheManager.retrieveValue(maxKey, () => + _findExtrema(-double.infinity, _simdHelper.getMaxLane, + (a, b) => a.max(b), math.max), skipCaching: skipCaching); @override - double min({bool skipCaching = false}) => _cacheManager.retrieveValue('min', () => - _findExtrema(double.infinity, _simdHelper.getMinLane, - (a, b) => a.min(b), math.min), skipCaching: skipCaching); + double min({bool skipCaching = false}) => + _cacheManager.retrieveValue(minKey, () => + _findExtrema(double.infinity, _simdHelper.getMinLane, + (a, b) => a.min(b), math.min), skipCaching: skipCaching); double _findExtrema(double initialValue, double getExtremalLane(Float64x2 bucket), @@ -357,7 +360,7 @@ class Float64x2Vector with IterableMixin implements Vector { @override Vector unique({bool skipCaching = false}) => - _cacheManager.retrieveValue('unique', () => Vector.fromList( + _cacheManager.retrieveValue(uniqueKey, () => Vector.fromList( Set.from(this).toList(growable: false), dtype: dtype), skipCaching: skipCaching); @@ -369,33 +372,6 @@ class Float64x2Vector with IterableMixin implements Vector { dtype: dtype); } - @override - void fastForEach(void iteratorFn(T element, bool isLast, List remains)) { - if (isEmpty) { - return; - } - - final remains = _isLastBucketNotFull - ? _simdHelper.simdValueToList(_innerSimdList[_numOfBuckets - 1]) - : []; - - final _fullBucketsNumber = _isLastBucketNotFull - ? _numOfBuckets - 1 - : _numOfBuckets; - - int counter = 0; - - _innerSimdList - .take(_fullBucketsNumber) - .forEach( - (element) => iteratorFn( - element as T, - ++counter == _fullBucketsNumber, - remains, - ), - ); - } - @override double operator [](int index) { if (isEmpty) { @@ -424,18 +400,19 @@ class Float64x2Vector with IterableMixin implements Vector { } final limit = end == null || end > length ? length : end; final collection = _innerTypedList.sublist(start, limit); + return Vector.fromList(collection, dtype: dtype); } @override Vector normalize([Norm normType = Norm.euclidean, bool skipCaching = false]) => - _cacheManager.retrieveValue('normalize_$normType', + _cacheManager.retrieveValue(getCacheKeyForNormalizeByNormType(normType), () => this / norm(normType), skipCaching: skipCaching); @override Vector rescale({bool skipCaching = false}) => - _cacheManager.retrieveValue('rescale', () { + _cacheManager.retrieveValue(rescaleKey, () { final minValue = min(); final maxValue = max(); return (this - minValue) / (maxValue - minValue); diff --git a/lib/src/vector/vector_cache_keys.dart b/lib/src/vector/vector_cache_keys.dart new file mode 100644 index 00000000..eee172f9 --- /dev/null +++ b/lib/src/vector/vector_cache_keys.dart @@ -0,0 +1,57 @@ +import 'package:ml_linalg/norm.dart'; + +const hashKey = 'hash'; +const sqrtKey = 'sqrt'; +const absKey = 'abs'; +const sumKey = 'sum'; +const meanKey = 'mean'; +const euclideanNormKey = 'euclidean_norm'; +const manhattanNormKey = 'manhattan_norm'; +const maxKey = 'max'; +const minKey = 'min'; +const uniqueKey = 'unique'; +const euclideanNormalizeKey = 'euclidean_normalize'; +const manhattanNormalizeKey = 'manhattan_normalize'; +const rescaleKey = 'rescale'; + +final vectorCacheKeys = Set.from([ + hashKey, + sqrtKey, + absKey, + sumKey, + meanKey, + euclideanNormKey, + manhattanNormKey, + maxKey, + minKey, + uniqueKey, + euclideanNormalizeKey, + manhattanNormalizeKey, + rescaleKey, +]); + +String getCacheKeyForNormByNormType(Norm normType) { + switch (normType) { + case Norm.euclidean: + return euclideanNormKey; + + case Norm.manhattan: + return manhattanNormKey; + + default: + throw UnsupportedError('Unsupported norm type `$normType`'); + } +} + +String getCacheKeyForNormalizeByNormType(Norm normType) { + switch (normType) { + case Norm.euclidean: + return euclideanNormalizeKey; + + case Norm.manhattan: + return manhattanNormalizeKey; + + default: + throw UnsupportedError('Unsupported norm type `$normType`'); + } +} diff --git a/lib/vector.dart b/lib/vector.dart index 31bafdf6..cb5d8dab 100644 --- a/lib/vector.dart +++ b/lib/vector.dart @@ -8,6 +8,7 @@ import 'package:ml_linalg/src/di/dependencies.dart'; import 'package:ml_linalg/src/vector/float32x4_vector.dart'; import 'package:ml_linalg/src/vector/float64x2_vector.dart'; import 'package:ml_linalg/src/vector/simd_helper/simd_helper_factory.dart'; +import 'package:ml_linalg/src/vector/vector_cache_keys.dart'; final _cacheManagerFactory = dependencies.getDependency(); final _simdHelperFactory = dependencies.getDependency(); @@ -42,12 +43,18 @@ abstract class Vector implements Iterable { }) { switch (dtype) { case DType.float32: - return Float32x4Vector.fromList(source, _cacheManagerFactory.create(), - _simdHelperFactory.createByDType(dtype)); + return Float32x4Vector.fromList( + source, + _cacheManagerFactory.create(vectorCacheKeys), + _simdHelperFactory.createByDType(dtype), + ); case DType.float64: - return Float64x2Vector.fromList(source, _cacheManagerFactory.create(), - _simdHelperFactory.createByDType(dtype)); + return Float64x2Vector.fromList( + source, + _cacheManagerFactory.create(vectorCacheKeys), + _simdHelperFactory.createByDType(dtype), + ); default: throw UnimplementedError('Vector of $dtype type is not implemented yet'); @@ -87,7 +94,7 @@ abstract class Vector implements Iterable { return Float32x4Vector.fromSimdList( source as Float32x4List, actualLength, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), ); @@ -95,7 +102,7 @@ abstract class Vector implements Iterable { return Float64x2Vector.fromSimdList( source as Float64x2List, actualLength, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), ); @@ -129,7 +136,7 @@ abstract class Vector implements Iterable { return Float32x4Vector.filled( length, value, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), ); @@ -137,7 +144,7 @@ abstract class Vector implements Iterable { return Float64x2Vector.filled( length, value, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), ); @@ -170,14 +177,14 @@ abstract class Vector implements Iterable { case DType.float32: return Float32x4Vector.zero( length, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), ); case DType.float64: return Float64x2Vector.zero( length, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), ); @@ -219,7 +226,7 @@ abstract class Vector implements Iterable { return Float32x4Vector.randomFilled( length, seed, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), max: max, min: min, @@ -229,7 +236,7 @@ abstract class Vector implements Iterable { return Float64x2Vector.randomFilled( length, seed, - _cacheManagerFactory.create(), + _cacheManagerFactory.create(vectorCacheKeys), _simdHelperFactory.createByDType(dtype), max: max, min: min, @@ -260,12 +267,16 @@ abstract class Vector implements Iterable { factory Vector.empty({DType dtype = DType.float32}) { switch (dtype) { case DType.float32: - return Float32x4Vector.empty(_cacheManagerFactory.create(), - _simdHelperFactory.createByDType(dtype)); + return Float32x4Vector.empty( + _cacheManagerFactory.create(vectorCacheKeys), + _simdHelperFactory.createByDType(dtype), + ); case DType.float64: - return Float64x2Vector.empty(_cacheManagerFactory.create(), - _simdHelperFactory.createByDType(dtype)); + return Float64x2Vector.empty( + _cacheManagerFactory.create(vectorCacheKeys), + _simdHelperFactory.createByDType(dtype), + ); default: throw UnimplementedError('Vector of $dtype type is not implemented yet'); diff --git a/test/common/cache_manager/cache_manager_factory_impl_test.dart b/test/common/cache_manager/cache_manager_factory_impl_test.dart index 53682950..d88d5aae 100644 --- a/test/common/cache_manager/cache_manager_factory_impl_test.dart +++ b/test/common/cache_manager/cache_manager_factory_impl_test.dart @@ -5,8 +5,10 @@ import 'package:test/test.dart'; void main() { group('CacheManagerFactoryImpl', () { test('should create a CacheManagerImpl instance', () { + final keys = Set.from(['key_1', 'key_2', 'key_3']); + final cacheManagerFactory = const CacheManagerFactoryImpl(); - final manager = cacheManagerFactory.create(); + final manager = cacheManagerFactory.create(keys); expect(manager, isA()); });