diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml index 5fc4cd28..f2598ccc 100644 --- a/.github/workflows/ci_pipeline.yml +++ b/.github/workflows/ci_pipeline.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest container: - image: google/dart:latest + image: google/dart:beta steps: - uses: actions/checkout@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 4582bd1b..77330d8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 13.0.0-nullsafety.0 +- null-safety supported + ## 12.17.10 - `xrange` v1.0.0 supported diff --git a/benchmark/matrix/float32/matrix_initializing/float32_matrix_diagonal.dart b/benchmark/matrix/float32/matrix_initializing/float32_matrix_diagonal.dart index 6d436028..af82a10f 100644 --- a/benchmark/matrix/float32/matrix_initializing/float32_matrix_diagonal.dart +++ b/benchmark/matrix/float32/matrix_initializing/float32_matrix_diagonal.dart @@ -11,7 +11,7 @@ class Float32MatrixDiagonalBenchmark extends BenchmarkBase { Float32MatrixDiagonalBenchmark() : super('Matrix initialization (diagonal)'); - List _source; + late List _source; static void main() { Float32MatrixDiagonalBenchmark().report(); diff --git a/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_columns.dart b/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_columns.dart index 5f0a9f8d..c21ffe6e 100644 --- a/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_columns.dart +++ b/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_columns.dart @@ -12,7 +12,7 @@ class Float32MatrixFromColumnsBenchmark extends BenchmarkBase { Float32MatrixFromColumnsBenchmark() : super('Matrix initialization (fromColumns)'); - List _source; + late List _source; static void main() { Float32MatrixFromColumnsBenchmark().report(); diff --git a/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_flattened.dart b/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_flattened.dart index d5e263b4..92976ef9 100644 --- a/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_flattened.dart +++ b/benchmark/matrix/float32/matrix_initializing/float32_matrix_from_flattened.dart @@ -12,7 +12,7 @@ class Float32MatrixFromFlattenedBenchmark extends BenchmarkBase { Float32MatrixFromFlattenedBenchmark() : super('Matrix initialization (fromFlattenedList)'); - List _source; + late List _source; static void main() { Float32MatrixFromFlattenedBenchmark().report(); diff --git a/benchmark/vector/baseline/regular_lists_addition.dart b/benchmark/vector/baseline/regular_lists_addition.dart index cf1a4815..72cbb69a 100644 --- a/benchmark/vector/baseline/regular_lists_addition.dart +++ b/benchmark/vector/baseline/regular_lists_addition.dart @@ -10,9 +10,9 @@ class RegularListsAdditionBenchmark extends BenchmarkBase { RegularListsAdditionBenchmark() : super('Regular lists addition; ' '$amountOfElements elements'); - List list1; - List list2; - final result = List(amountOfElements); + late List list1; + late List list2; + final result = List.filled(amountOfElements, 0.0); static void main() { RegularListsAdditionBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_abs.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_abs.dart index f0f5534a..00e9bff3 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_abs.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_abs.dart @@ -10,7 +10,7 @@ class Float32x4VectorAbsBenchmark extends BenchmarkBase { Float32x4VectorAbsBenchmark() : super('Vector `abs` method; $amountOfElements elements'); - Vector vector; + Vector? vector; static void main() { Float32x4VectorAbsBenchmark().report(); @@ -18,7 +18,7 @@ class Float32x4VectorAbsBenchmark extends BenchmarkBase { @override void run() { - vector.abs(skipCaching: true); + vector!.abs(skipCaching: true); } @override diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_equality_operator.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_equality_operator.dart index a344a59c..893c1a24 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_equality_operator.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_equality_operator.dart @@ -11,8 +11,8 @@ class Float32x4VectorEqualityOperatorBenchmark extends BenchmarkBase { : super('Vector `==` operator, operands: vector, vector; ' '$amountOfElements elements'); - Vector vector1; - Vector vector2; + Vector? vector1; + Vector? vector2; static void main() { Float32x4VectorEqualityOperatorBenchmark().report(); @@ -32,7 +32,7 @@ class Float32x4VectorEqualityOperatorBenchmark extends BenchmarkBase { max: 1000, dtype: DType.float32, ); - vector2 = Vector.fromList(vector1.toList()); + vector2 = Vector.fromList(vector1!.toList()); } void tearDown() { diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_exp.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_exp.dart index 34ad48a6..78e7546b 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_exp.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_exp.dart @@ -10,7 +10,7 @@ class Float32x4VectorExpBenchmark extends BenchmarkBase { Float32x4VectorExpBenchmark() : super('Vector `exp` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorExpBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_hash_code.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_hash_code.dart index caf4c298..e50c0e2b 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_hash_code.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_hash_code.dart @@ -10,7 +10,7 @@ class Float32x4VectorHashCodeBenchmark extends BenchmarkBase { Float32x4VectorHashCodeBenchmark() : super('Vector `hashCode`; $amountOfElements elements'); - Vector vector; + Vector? vector; static void main() { Float32x4VectorHashCodeBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_indexed_access_operator.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_indexed_access_operator.dart index 97d7c85a..bd86ee7f 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_indexed_access_operator.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_indexed_access_operator.dart @@ -10,7 +10,7 @@ class Float32x4VectorRandomAccessBenchmark extends BenchmarkBase { Float32x4VectorRandomAccessBenchmark() : super('Vector `[]` operator; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorRandomAccessBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_max.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_max.dart index 20d4602a..47094f38 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_max.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_max.dart @@ -10,7 +10,7 @@ class Float32x4VectorMaxValueBenchmark extends BenchmarkBase { Float32x4VectorMaxValueBenchmark() : super('Vector `max` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorMaxValueBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_min.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_min.dart index 036bca47..cc55a23e 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_min.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_min.dart @@ -10,7 +10,7 @@ class Float32x4VectorMinValueBenchmark extends BenchmarkBase { Float32x4VectorMinValueBenchmark() : super('Vector `min` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorMinValueBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_norm.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_norm.dart index c7972b93..72838e04 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_norm.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_norm.dart @@ -11,7 +11,7 @@ class Float32x4VectorNormBenchmark extends BenchmarkBase { Float32x4VectorNormBenchmark() : super('Vector `norm` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorNormBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_pow.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_pow.dart index f5d87e0d..b3969cde 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_pow.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_pow.dart @@ -10,7 +10,7 @@ class Float32x4VectorPowBenchmark extends BenchmarkBase { Float32x4VectorPowBenchmark() : super('Vector `pow` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorPowBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_addition.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_addition.dart index 8eb86bd1..83f09121 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_addition.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_addition.dart @@ -11,7 +11,7 @@ class Float32x4VectorAndScalarAdditionBenchmark extends BenchmarkBase { : super('Vector `+` operator, operands: vector, scalar; ' '$amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorAndScalarAdditionBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_division.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_division.dart index 0c1a7a32..33d63242 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_division.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_division.dart @@ -11,7 +11,7 @@ class Float32x4VectorAndScalarDivisionBenchmark extends BenchmarkBase { : super('Vector `/` operator, operands: vector, scalar; ' '$amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorAndScalarDivisionBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_multiplication.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_multiplication.dart index a7100c18..5ebfe833 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_multiplication.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_multiplication.dart @@ -11,7 +11,7 @@ class Float32x4VectorAndScalarMultiplicationBenchmark extends BenchmarkBase { : super('Vector `*` operator, operands: vector, scalar; ' '$amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorAndScalarMultiplicationBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_subtraction.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_subtraction.dart index 18caa6fe..9c79f2e3 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_subtraction.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_scalar_subtraction.dart @@ -11,7 +11,7 @@ class Float32x4VectorAndScalarSubtractionBenchmark extends BenchmarkBase { : super('Vector `-` operator, operands: vector, scalar; ' '$amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorAndScalarSubtractionBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_sqrt.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_sqrt.dart index 13ebe2f9..e1b740f6 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_sqrt.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_sqrt.dart @@ -10,7 +10,7 @@ class Float32x4VectorSqrtBenchmark extends BenchmarkBase { Float32x4VectorSqrtBenchmark() : super('Vector `sqrt` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorSqrtBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_sum.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_sum.dart index 6c407594..4756ae21 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_sum.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_sum.dart @@ -10,7 +10,7 @@ class Float32x4VectorSumBenchmark extends BenchmarkBase { Float32x4VectorSumBenchmark() : super('Vector `sum` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorSumBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_unique.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_unique.dart index 15ae964b..07c0f914 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_unique.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_unique.dart @@ -10,7 +10,7 @@ class Float32x4VectorUniqueBenchmark extends BenchmarkBase { Float32x4VectorUniqueBenchmark() : super('Vector `unique` method; $amountOfElements elements'); - Vector vector; + late Vector vector; static void main() { Float32x4VectorUniqueBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_vector_addition.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_vector_addition.dart index b520916c..7612a30a 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_vector_addition.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_vector_addition.dart @@ -11,9 +11,9 @@ class Float32x4VectorAndVectorAdditionBenchmark extends BenchmarkBase { : super('Vector `+` operator, operands: vector, vector; ' '$amountOfElements elements'); - Vector vector1; - Vector vector2; - Vector vector3; + late Vector vector1; + late Vector vector2; + late Vector vector3; static void main() { Float32x4VectorAndVectorAdditionBenchmark().report(); diff --git a/benchmark/vector/float32/vector_operations/float32x4_vector_vector_multiplication.dart b/benchmark/vector/float32/vector_operations/float32x4_vector_vector_multiplication.dart index f3b18610..628f59b7 100644 --- a/benchmark/vector/float32/vector_operations/float32x4_vector_vector_multiplication.dart +++ b/benchmark/vector/float32/vector_operations/float32x4_vector_vector_multiplication.dart @@ -11,8 +11,8 @@ class Float32x4VectorAndVectorMultiplicationBenchmark extends BenchmarkBase { : super('Vector `*` operator, operands: vector, vector; ' '$amountOfElements elements'); - Vector vector1; - Vector vector2; + late Vector vector1; + late Vector vector2; static void main() { Float32x4VectorAndVectorMultiplicationBenchmark().report(); diff --git a/lib/matrix.dart b/lib/matrix.dart index 898eddd6..c0ce2b29 100644 --- a/lib/matrix.dart +++ b/lib/matrix.dart @@ -4,8 +4,7 @@ import 'package:ml_linalg/axis.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix_norm.dart'; import 'package:ml_linalg/sort_direction.dart'; -import 'package:ml_linalg/src/di/dependencies.dart'; -import 'package:ml_linalg/src/matrix/matrix_factory.dart'; +import 'package:ml_linalg/src/matrix/helper/create_matrix.dart'; import 'package:ml_linalg/src/matrix/serialization/from_matrix_json.dart'; import 'package:ml_linalg/vector.dart'; @@ -37,9 +36,10 @@ abstract class Matrix implements Iterable> { /// (1.0, 2.0, 3.0, 4.0, 5.0) /// (6.0, 7.0, 8.0, 9.0, 0.0) /// ``` - factory Matrix.fromList(List> source, { - DType dtype = DType.float32, - }) => dependencies.get().fromList(dtype, source); + factory Matrix.fromList( + List> source, { + DType dtype = DType.float32, + }) => createMatrixFactory().fromList(dtype, source); /// Creates a matrix with predefined row vectors /// @@ -66,8 +66,10 @@ abstract class Matrix implements Iterable> { /// (1.0, 2.0, 3.0, 4.0, 5.0) /// (6.0, 7.0, 8.0, 9.0, 0.0) /// ``` - factory Matrix.fromRows(List source, {DType dtype = DType.float32}) => - dependencies.get().fromRows(dtype, source); + factory Matrix.fromRows( + List source, { + DType dtype = DType.float32 + }) => createMatrixFactory().fromRows(dtype, source); /// Creates a matrix with predefined column vectors /// @@ -97,9 +99,10 @@ abstract class Matrix implements Iterable> { /// (4.0, 9.0) /// (5.0, 0.0) /// ``` - factory Matrix.fromColumns(List source, { - DType dtype = DType.float32, - }) => dependencies.get().fromColumns(dtype, source); + factory Matrix.fromColumns( + List source, { + DType dtype = DType.float32, + }) => createMatrixFactory().fromColumns(dtype, source); /// Creates a matrix of shape 0 x 0 (no rows, no columns) /// @@ -121,7 +124,7 @@ abstract class Matrix implements Iterable> { /// Matrix 0 x 0 /// ``` factory Matrix.empty({DType dtype = DType.float32}) => - dependencies.get().empty(dtype); + createMatrixFactory().empty(dtype); /// Creates a matrix from flattened list of length equal to /// [rowsNum] * [columnsNum] @@ -147,11 +150,17 @@ abstract class Matrix implements Iterable> { /// (1.0, 2.0, 3.0, 4.0, 5.0) /// (6.0, 7.0, 8.0, 9.0, 0.0) /// ``` - factory Matrix.fromFlattenedList(List source, int rowsNum, - int columnsNum, {DType dtype = DType.float32}) => - dependencies - .get() - .fromFlattenedList(dtype, source, rowsNum, columnsNum); + factory Matrix.fromFlattenedList( + List source, + int rowsNum, + int columnsNum, { + DType dtype = DType.float32, + }) => createMatrixFactory().fromFlattenedList( + dtype, + source, + rowsNum, + columnsNum, + ); /// Creates a matrix, where elements from [source] are the elements for the /// matrix main diagonal, the rest of the elements are zero @@ -176,8 +185,10 @@ abstract class Matrix implements Iterable> { /// (0.0, 0.0, 0.0, 4.0, 0.0) /// (0.0, 0.0, 0.0, 0.0, 5.0) /// ``` - factory Matrix.diagonal(List source, {DType dtype = DType.float32}) => - dependencies.get().diagonal(dtype, source); + factory Matrix.diagonal( + List source, { + DType dtype = DType.float32, + }) => createMatrixFactory().diagonal(dtype, source); /// Creates a matrix of [size] * [size] dimension, where all the main /// diagonal elements are equal to [scalar], the rest of the elements are 0 @@ -202,9 +213,15 @@ abstract class Matrix implements Iterable> { /// (0.0, 0.0, 0.0, 3.0, 0.0) /// (0.0, 0.0, 0.0, 0.0, 3.0) /// ``` - factory Matrix.scalar(double scalar, int size, { - DType dtype = DType.float32, - }) => dependencies.get().scalar(dtype, scalar, size); + factory Matrix.scalar( + double scalar, + int size, { + DType dtype = DType.float32, + }) => createMatrixFactory().scalar( + dtype, + scalar, + size, + ); /// Creates a matrix of [size] * [size] dimension, where all the main /// diagonal elements are equal to 1, the rest of the elements are 0 @@ -229,8 +246,10 @@ abstract class Matrix implements Iterable> { /// (0.0, 0.0, 0.0, 1.0, 0.0) /// (0.0, 0.0, 0.0, 0.0, 1.0) /// ``` - factory Matrix.identity(int size, {DType dtype = DType.float32}) => - dependencies.get().identity(dtype, size); + factory Matrix.identity( + int size, { + DType dtype = DType.float32, + }) => createMatrixFactory().identity(dtype, size); /// Creates a matrix, consisting of just one row (aka `Row matrix`) /// @@ -250,11 +269,13 @@ abstract class Matrix implements Iterable> { /// Matrix 1 x 5: /// (1.0, 2.0, 3.0, 4.0, 5.0) /// ``` - factory Matrix.row(List source, {DType dtype = DType.float32}) => - dependencies.get().row(dtype, source); + factory Matrix.row( + List source, { + DType dtype = DType.float32, + }) => createMatrixFactory().row(dtype, source); /// Returns a restored matrix from a serializable map - factory Matrix.fromJson(Map json) => fromMatrixJson(json); + factory Matrix.fromJson(Map json) => fromMatrixJson(json)!; /// Creates a matrix, consisting of just one column (aka `Column matrix`) /// @@ -278,8 +299,10 @@ abstract class Matrix implements Iterable> { /// (4.0) /// (5.0) /// ``` - factory Matrix.column(List source, {DType dtype = DType.float32}) => - dependencies.get().column(dtype, source); + factory Matrix.column( + List source, { + DType dtype = DType.float32 + }) => createMatrixFactory().column(dtype, source); /// A data type of [Matrix] elements DType get dtype; @@ -359,11 +382,11 @@ abstract class Matrix implements Iterable> { /// Reduces all the matrix columns to only column, using [combiner] function Vector reduceColumns(Vector Function(Vector combine, Vector vector) combiner, - {Vector initValue}); + {Vector? initValue}); /// Reduces all the matrix rows to only row, using [combiner] function Vector reduceRows(Vector Function(Vector combine, Vector vector) combiner, - {Vector initValue}); + {Vector? initValue}); /// Performs element-wise mapping of this [Matrix] to a new one via passed /// [mapper] function @@ -448,8 +471,8 @@ abstract class Matrix implements Iterable> { Vector deviation([Axis axis = Axis.columns]); /// Returns a new matrix with sorted elements from this [Matrix] - Matrix sort(double Function(Vector vector) selectSortValue, [Axis axis = Axis.rows, - SortDirection sortDir = SortDirection.asc]); + Matrix sort(double Function(Vector vector) selectSortValue, [ + Axis axis = Axis.rows, SortDirection sortDir = SortDirection.asc]); /// Raise all the elements of the matrix to the power [exponent] and returns /// a new [Matrix] with these elements. Avoid raising a matrix to a float diff --git a/lib/src/common/cache_manager/cache_manager_impl.dart b/lib/src/common/cache_manager/cache_manager_impl.dart index 875e6e35..36874f18 100644 --- a/lib/src/common/cache_manager/cache_manager_impl.dart +++ b/lib/src/common/cache_manager/cache_manager_impl.dart @@ -13,18 +13,14 @@ class CacheManagerImpl implements CacheManager { if (!_keys.contains(key)) { throw Exception('Cache key `$key` is not registered'); } - - var value = _cache[key] as T; - - if (value == null) { - value = calculateIfAbsent(); - if (!skipCaching) { - _cache[key] = value; - } + if (skipCaching) { + return calculateIfAbsent(); } + + _cache[key] ??= calculateIfAbsent(); - return value; + return _cache[key]; } @override diff --git a/lib/src/common/dtype_serializer/dtype_to_json.dart b/lib/src/common/dtype_serializer/dtype_to_json.dart index 22e225b4..f322d16b 100644 --- a/lib/src/common/dtype_serializer/dtype_to_json.dart +++ b/lib/src/common/dtype_serializer/dtype_to_json.dart @@ -2,7 +2,7 @@ import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/src/common/dtype_serializer/dtype_encoded_values.dart'; /// Encodes [dtype] to a json-serializable value -String dTypeToJson(DType dtype) { +String? dTypeToJson(DType? dtype) { switch (dtype) { case DType.float32: return dTypeFloat32EncodedValue; diff --git a/lib/src/common/dtype_serializer/from_dtype_json.dart b/lib/src/common/dtype_serializer/from_dtype_json.dart index 028eb41a..ad42ca72 100644 --- a/lib/src/common/dtype_serializer/from_dtype_json.dart +++ b/lib/src/common/dtype_serializer/from_dtype_json.dart @@ -2,7 +2,7 @@ import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/src/common/dtype_serializer/dtype_encoded_values.dart'; /// Restores the original [DType] value by given encoded string -DType fromDTypeJson(String json) { +DType? fromDTypeJson(String? json) { switch (json) { case dTypeFloat32EncodedValue: return DType.float32; diff --git a/lib/src/di/dependencies.dart b/lib/src/di/dependencies.dart deleted file mode 100644 index 674c45e3..00000000 --- a/lib/src/di/dependencies.dart +++ /dev/null @@ -1,16 +0,0 @@ -import 'package:injector/injector.dart'; -import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory.dart'; -import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory_impl.dart'; -import 'package:ml_linalg/src/di/injector.dart'; -import 'package:ml_linalg/src/matrix/matrix_factory.dart'; -import 'package:ml_linalg/src/matrix/matrix_factory_impl.dart'; - -Injector get dependencies => injector ??= Injector() - ..registerSingleton( - () => const CacheManagerFactoryImpl()) - - ..registerSingleton(() { - final cacheManagerFactory = injector.get(); - - return MatrixFactoryImpl(cacheManagerFactory); - }); diff --git a/lib/src/di/injector.dart b/lib/src/di/injector.dart deleted file mode 100644 index d2a981c8..00000000 --- a/lib/src/di/injector.dart +++ /dev/null @@ -1,3 +0,0 @@ -import 'package:injector/injector.dart'; - -Injector injector; diff --git a/lib/src/matrix/data_manager/float32_matrix_data_manager.dart b/lib/src/matrix/data_manager/float32_matrix_data_manager.dart index 06b029e4..acc3f0b2 100644 --- a/lib/src/matrix/data_manager/float32_matrix_data_manager.dart +++ b/lib/src/matrix/data_manager/float32_matrix_data_manager.dart @@ -16,8 +16,8 @@ class Float32MatrixDataManager implements MatrixDataManager { columnsNum = getLengthOfFirstOrZero(source), rowIndices = getZeroBasedIndices(get2dIterableLength(source)), columnIndices = getZeroBasedIndices(getLengthOfFirstOrZero(source)), - _rowsCache = List(source.length), - _colsCache = List(getLengthOfFirstOrZero(source)), + _rowsCache = List.filled(source.length, null), + _colsCache = List.filled(getLengthOfFirstOrZero(source), null), _data = ByteData(source.length * getLengthOfFirstOrZero(source) * _bytesPerElement), areAllRowsCached = false, @@ -45,7 +45,7 @@ class Float32MatrixDataManager implements MatrixDataManager { rowIndices = getZeroBasedIndices(get2dIterableLength(source)), columnIndices = getZeroBasedIndices(getLengthOfFirstOrZero(source)), _rowsCache = [...source], - _colsCache = List(getLengthOfFirstOrZero(source)), + _colsCache = List.filled(getLengthOfFirstOrZero(source), null), _data = ByteData(source.length * getLengthOfFirstOrZero(source) * _bytesPerElement), areAllRowsCached = true, @@ -80,7 +80,7 @@ class Float32MatrixDataManager implements MatrixDataManager { columnsNum = get2dIterableLength(source), rowIndices = getZeroBasedIndices(getLengthOfFirstOrZero(source)), columnIndices = getZeroBasedIndices(get2dIterableLength(source)), - _rowsCache = List(getLengthOfFirstOrZero(source)), + _rowsCache = List.filled(getLengthOfFirstOrZero(source), null), _colsCache = [...source], _data = ByteData(source.length * getLengthOfFirstOrZero(source) * _bytesPerElement), @@ -117,8 +117,8 @@ class Float32MatrixDataManager implements MatrixDataManager { columnsNum = colsNum, rowIndices = getZeroBasedIndices(rowsNum), columnIndices = getZeroBasedIndices(colsNum), - _rowsCache = List(rowsNum), - _colsCache = List(colsNum), + _rowsCache = List.filled(rowsNum, null), + _colsCache = List.filled(colsNum, null), _data = Float32List.fromList(source).buffer.asByteData(), areAllRowsCached = false, areAllColumnsCached = false @@ -135,8 +135,8 @@ class Float32MatrixDataManager implements MatrixDataManager { columnsNum = source.length, rowIndices = getZeroBasedIndices(source.length), columnIndices = getZeroBasedIndices(source.length), - _rowsCache = List(source.length), - _colsCache = List(source.length), + _rowsCache = List.filled(source.length, null), + _colsCache = List.filled(source.length, null), _data = ByteData(source.length * source.length * _bytesPerElement), areAllRowsCached = false, areAllColumnsCached = false @@ -152,8 +152,8 @@ class Float32MatrixDataManager implements MatrixDataManager { columnsNum = size, rowIndices = getZeroBasedIndices(size), columnIndices = getZeroBasedIndices(size), - _rowsCache = List(size), - _colsCache = List(size), + _rowsCache = List.filled(size, null), + _colsCache = List.filled(size, null), _data = ByteData(size * size * _bytesPerElement), areAllRowsCached = false, areAllColumnsCached = false @@ -185,8 +185,8 @@ class Float32MatrixDataManager implements MatrixDataManager { @override final bool areAllColumnsCached; - final List _rowsCache; - final List _colsCache; + final List _rowsCache; + final List _colsCache; final ByteData _data; @override @@ -212,7 +212,7 @@ class Float32MatrixDataManager implements MatrixDataManager { indexFrom * _bytesPerElement, columnsNum); _rowsCache[index] ??= Vector.fromList(values, dtype: dtype); - return _rowsCache[index]; + return _rowsCache[index]!; } @override @@ -220,14 +220,15 @@ class Float32MatrixDataManager implements MatrixDataManager { if (!hasData) { throw Exception('Matrix is empty'); } + if (_colsCache[index] == null) { - final result = List(rowsNum); - for (var i = 0; i < result.length; i++) { - result[i] = _data.getFloat32( - (i * columnsNum + index) * _bytesPerElement, Endian.host); - } + final result = List.generate(rowsNum, + (i) => _data.getFloat32( + (i * columnsNum + index) * _bytesPerElement, Endian.host)); + _colsCache[index] = Vector.fromList(result, dtype: dtype); } - return _colsCache[index]; + + return _colsCache[index]!; } } diff --git a/lib/src/matrix/data_manager/float64_matrix_data_manager.dart b/lib/src/matrix/data_manager/float64_matrix_data_manager.dart index 81ad56d2..afae46bb 100644 --- a/lib/src/matrix/data_manager/float64_matrix_data_manager.dart +++ b/lib/src/matrix/data_manager/float64_matrix_data_manager.dart @@ -18,8 +18,8 @@ class Float64MatrixDataManager implements MatrixDataManager { columnsNum = getLengthOfFirstOrZero(source), rowIndices = getZeroBasedIndices(get2dIterableLength(source)), columnIndices = getZeroBasedIndices(getLengthOfFirstOrZero(source)), - _rowsCache = List(source.length), - _colsCache = List(getLengthOfFirstOrZero(source)), + _rowsCache = List.filled(source.length, null), + _colsCache = List.filled(getLengthOfFirstOrZero(source), null), _data = ByteData(source.length * getLengthOfFirstOrZero(source) * _bytesPerElement), areAllRowsCached = false, @@ -47,7 +47,7 @@ class Float64MatrixDataManager implements MatrixDataManager { rowIndices = getZeroBasedIndices(get2dIterableLength(source)), columnIndices = getZeroBasedIndices(getLengthOfFirstOrZero(source)), _rowsCache = [...source], - _colsCache = List(getLengthOfFirstOrZero(source)), + _colsCache = List.filled(getLengthOfFirstOrZero(source), null), _data = ByteData(source.length * getLengthOfFirstOrZero(source) * _bytesPerElement), areAllRowsCached = true, @@ -82,7 +82,7 @@ class Float64MatrixDataManager implements MatrixDataManager { columnsNum = get2dIterableLength(source), rowIndices = getZeroBasedIndices(getLengthOfFirstOrZero(source)), columnIndices = getZeroBasedIndices(get2dIterableLength(source)), - _rowsCache = List(getLengthOfFirstOrZero(source)), + _rowsCache = List.filled(getLengthOfFirstOrZero(source), null), _colsCache = [...source], _data = ByteData(source.length * getLengthOfFirstOrZero(source) * _bytesPerElement), @@ -119,8 +119,8 @@ class Float64MatrixDataManager implements MatrixDataManager { columnsNum = colsNum, rowIndices = getZeroBasedIndices(rowsNum), columnIndices = getZeroBasedIndices(colsNum), - _rowsCache = List(rowsNum), - _colsCache = List(colsNum), + _rowsCache = List.filled(rowsNum, null), + _colsCache = List.filled(colsNum, null), _data = Float64List.fromList(source).buffer.asByteData(), areAllRowsCached = false, areAllColumnsCached = false @@ -137,8 +137,8 @@ class Float64MatrixDataManager implements MatrixDataManager { columnsNum = source.length, rowIndices = getZeroBasedIndices(source.length), columnIndices = getZeroBasedIndices(source.length), - _rowsCache = List(source.length), - _colsCache = List(source.length), + _rowsCache = List.filled(source.length, null), + _colsCache = List.filled(source.length, null), _data = ByteData(source.length * source.length * _bytesPerElement), areAllRowsCached = false, areAllColumnsCached = false @@ -154,8 +154,8 @@ class Float64MatrixDataManager implements MatrixDataManager { columnsNum = size, rowIndices = getZeroBasedIndices(size), columnIndices = getZeroBasedIndices(size), - _rowsCache = List(size), - _colsCache = List(size), + _rowsCache = List.filled(size, null), + _colsCache = List.filled(size, null), _data = ByteData(size * size * _bytesPerElement), areAllRowsCached = false, areAllColumnsCached = false @@ -187,8 +187,8 @@ class Float64MatrixDataManager implements MatrixDataManager { @override final bool areAllColumnsCached; - final List _rowsCache; - final List _colsCache; + final List _rowsCache; + final List _colsCache; final ByteData _data; @override @@ -214,7 +214,7 @@ class Float64MatrixDataManager implements MatrixDataManager { indexFrom * _bytesPerElement, columnsNum); _rowsCache[index] ??= Vector.fromList(values, dtype: dtype); - return _rowsCache[index]; + return _rowsCache[index]!; } @override @@ -222,14 +222,15 @@ class Float64MatrixDataManager implements MatrixDataManager { if (!hasData) { throw Exception('Matrix is empty'); } + if (_colsCache[index] == null) { - final result = List(rowsNum); - for (var i = 0; i < result.length; i++) { - result[i] = _data.getFloat64( - (i * columnsNum + index) * _bytesPerElement, Endian.host); - } + final result = List.generate(rowsNum, + (i) => _data.getFloat64( + (i * columnsNum + index) * _bytesPerElement, Endian.host)); + _colsCache[index] = Vector.fromList(result, dtype: dtype); } - return _colsCache[index]; + + return _colsCache[index]!; } } diff --git a/lib/src/matrix/helper/create_matrix.dart b/lib/src/matrix/helper/create_matrix.dart new file mode 100644 index 00000000..bf5093a4 --- /dev/null +++ b/lib/src/matrix/helper/create_matrix.dart @@ -0,0 +1,9 @@ +import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory_impl.dart'; +import 'package:ml_linalg/src/matrix/matrix_factory.dart'; +import 'package:ml_linalg/src/matrix/matrix_factory_impl.dart'; + +MatrixFactory createMatrixFactory() { + const cacheManagerFactory = CacheManagerFactoryImpl(); + + return MatrixFactoryImpl(cacheManagerFactory); +} diff --git a/lib/src/matrix/helper/get_zero_based_indices.dart b/lib/src/matrix/helper/get_zero_based_indices.dart index 253dec83..76386282 100644 --- a/lib/src/matrix/helper/get_zero_based_indices.dart +++ b/lib/src/matrix/helper/get_zero_based_indices.dart @@ -1,6 +1,6 @@ -import 'package:xrange/xrange.dart'; +import 'package:quiver/iterables.dart'; Iterable getZeroBasedIndices(int maxIndex) => maxIndex == 0 ? [] - : integers(0, maxIndex); + : count(0).take(maxIndex).map((value) => value.toInt()); diff --git a/lib/src/matrix/iterator/float32_matrix_iterator.dart b/lib/src/matrix/iterator/float32_matrix_iterator.dart index cc39067c..e088ef48 100644 --- a/lib/src/matrix/iterator/float32_matrix_iterator.dart +++ b/lib/src/matrix/iterator/float32_matrix_iterator.dart @@ -9,7 +9,7 @@ class Float32MatrixIterator implements Iterator> { final int _rowsNum; final int _colsNum; - Float32List _current; + late Float32List _current; int _currentRow = 0; @override @@ -17,14 +17,14 @@ class Float32MatrixIterator implements Iterator> { @override bool moveNext() { - final startIdx = _currentRow * _colsNum; - if (_currentRow >= _rowsNum) { - _current = null; - } else { + final hasNext = _currentRow < _rowsNum; + + if (hasNext) { _current = _data.buffer - .asFloat32List(startIdx * _bytesPerElement, _colsNum); + .asFloat32List(_currentRow * _colsNum * _bytesPerElement, _colsNum); + _currentRow++; } - _currentRow++; - return _current != null; + + return hasNext; } } diff --git a/lib/src/matrix/iterator/float64_matrix_iterator.dart b/lib/src/matrix/iterator/float64_matrix_iterator.dart index c6a5e1fb..228749b5 100644 --- a/lib/src/matrix/iterator/float64_matrix_iterator.dart +++ b/lib/src/matrix/iterator/float64_matrix_iterator.dart @@ -11,7 +11,7 @@ class Float64MatrixIterator implements Iterator> { final int _rowsNum; final int _colsNum; - Float64List _current; + late Float64List _current; int _currentRow = 0; @override @@ -19,14 +19,14 @@ class Float64MatrixIterator implements Iterator> { @override bool moveNext() { - final startIdx = _currentRow * _colsNum; - if (_currentRow >= _rowsNum) { - _current = null; - } else { + final hasNext = _currentRow < _rowsNum; + + if (hasNext) { _current = _data.buffer - .asFloat64List(startIdx * _bytesPerElement, _colsNum); + .asFloat64List(_currentRow * _colsNum * _bytesPerElement, _colsNum); + _currentRow++; } - _currentRow++; - return _current != null; + + return hasNext; } } diff --git a/lib/src/matrix/matrix_impl.dart b/lib/src/matrix/matrix_impl.dart index 6a975de5..34e42403 100644 --- a/lib/src/matrix/matrix_impl.dart +++ b/lib/src/matrix/matrix_impl.dart @@ -16,10 +16,17 @@ import 'package:ml_linalg/src/matrix/serialization/matrix_to_json.dart'; import 'package:ml_linalg/vector.dart'; import 'package:quiver/iterables.dart'; -class MatrixImpl with IterableMixin>, MatrixValidatorMixin - implements Matrix { - - MatrixImpl(this._dataManager, this._cacheManager); +class MatrixImpl + with + IterableMixin>, + MatrixValidatorMixin + implements + Matrix { + + MatrixImpl( + this._dataManager, + this._cacheManager, + ); final MatrixDataManager _dataManager; final CacheManager _cacheManager; @@ -60,12 +67,14 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin Matrix operator -(Object value) { if (value is Matrix) { return _matrixSub(value); - } else if (value is num) { + } + + if (value is num) { return _matrixScalarSub(value.toDouble()); - } else { - throw UnsupportedError( - 'Cannot subtract a ${value.runtimeType} from a ${runtimeType}'); } + + throw UnsupportedError( + 'Cannot subtract a ${value.runtimeType} from a ${runtimeType}'); } /// Mathematical matrix multiplication @@ -78,14 +87,18 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin Matrix operator *(Object value) { if (value is Vector) { return _matrixVectorMul(value); - } else if (value is Matrix) { + } + + if (value is Matrix) { return _matrixMul(value); - } else if (value is num) { + } + + if (value is num) { return _matrixScalarMul(value.toDouble()); - } else { - throw UnsupportedError( - 'Cannot multiple a ${runtimeType} and a ${value.runtimeType}'); } + + throw UnsupportedError( + 'Cannot multiple a ${runtimeType} and a ${value.runtimeType}'); } /// Performs division of the matrix by vector, matrix or scalar @@ -93,14 +106,18 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin Matrix operator /(Object value) { if (value is Vector) { return _matrixByVectorDiv(value); - } else if (value is Matrix) { + } + + if (value is Matrix) { return _matrixByMatrixDiv(value); - } else if (value is num) { + } + + if (value is num) { return _matrixByScalarDiv(value.toDouble()); - } else { - throw UnsupportedError( - 'Cannot divide a ${runtimeType} by a ${value.runtimeType}'); } + + throw UnsupportedError( + 'Cannot divide a ${runtimeType} by a ${value.runtimeType}'); } @override @@ -109,6 +126,7 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin @override Matrix transpose() { final source = List.generate(rowsNum, getRow); + return Matrix.fromColumns(source, dtype: dtype); } @@ -123,37 +141,43 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin Iterable rowIndices = const [], Iterable columnIndices = const [], }) { - final rowsNumber = rowIndices.isEmpty - ? rowsNum - : rowIndices.length; - final targetMatrixSource = List(rowsNumber); - - for (final indexed in enumerate(rowIndices.isEmpty + final indices = rowIndices.isEmpty ? count(0).take(rowsNum).map((i) => i.toInt()) - : rowIndices) - ) { - final targetRowIndex = indexed.index; - final sourceRowIndex = indexed.value; - final sourceRow = getRow(sourceRowIndex); - - targetMatrixSource[targetRowIndex] = columnIndices.isEmpty - ? sourceRow : sourceRow.sample(columnIndices); - } + : rowIndices; + final targetMatrixSource = indices.map((index) { + final sourceRow = getRow(index); + + return columnIndices.isEmpty + ? sourceRow + : sourceRow.sample(columnIndices); + }).toList(growable: false); return Matrix.fromRows(targetMatrixSource, dtype: dtype); } @override Vector reduceColumns( - Vector Function(Vector combine, Vector vector) combiner, - {Vector initValue}) => - _reduce(combiner, columnsNum, getColumn, initValue: initValue); + Vector Function(Vector combine, Vector vector) combiner, + { + Vector? initValue, + }) => _reduce( + combiner, + columnsNum, + getColumn, + initValue: initValue, + ); @override Vector reduceRows( - Vector Function(Vector combine, Vector vector) combiner, - {Vector initValue}) => - _reduce(combiner, rowsNum, getRow, initValue: initValue); + Vector Function(Vector combine, Vector vector) combiner, + { + Vector? initValue + }) => _reduce( + combiner, + rowsNum, + getRow, + initValue: initValue, + ); @override Matrix mapElements(double Function(double element) mapper) => @@ -173,12 +197,14 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin @override Matrix uniqueRows() { - // TODO: consider using Set instead of List final checked = []; for (final i in _dataManager.rowIndices) { final row = getRow(i); - if (!checked.contains(row)) checked.add(row); + + if (!checked.contains(row)) { + checked.add(row); + } } return Matrix.fromRows(checked, dtype: dtype); @@ -244,7 +270,9 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin Vector toVector() { if (columnsNum == 1) { return getColumn(0); - } else if (rowsNum == 1) { + } + + if (rowsNum == 1) { return getRow(0); } @@ -285,7 +313,7 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin switch (norm) { case MatrixNorm.frobenius: // ignore: deprecated_member_use_from_same_package - return math.sqrt(reduceRows((sum, row) => sum + row.toIntegerPower(2)) + return math.sqrt(reduceRows((sum, row) => sum + row.pow(2)) .sum()); default: throw UnsupportedError('Unsupported matrix norm type - $norm'); @@ -293,15 +321,23 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin } @override - Matrix insertColumns(int targetIdx, List columns) { - final newColumns = List(columnsNum + columns.length) - ..setRange(targetIdx, targetIdx + columns.length, columns); - var i = 0; + Matrix insertColumns(int targetIndex, List columns) { + final columnsIterator = columns + .iterator; + final indices = count(0) + .take(columnsNum + columns.length) + .map((i) => i.toInt()); + final newColumns = indices.map((index) { + if (index < targetIndex) { + return getColumn(index); + } - for (final column in this.columns) { - if (i == targetIdx) i += columns.length; - newColumns[i++] = column; - } + if (index < (targetIndex + columns.length)) { + return (columnsIterator..moveNext()).current; + } + + return getColumn(index - columns.length); + }).toList(growable: false); return Matrix.fromColumns(newColumns, dtype: dtype); } @@ -333,16 +369,22 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin } @override - Iterable get rows => _dataManager.rowIndices.map(getRow); + Iterable get rows => _dataManager + .rowIndices + .map(getRow); @override - Iterable get columns => _dataManager.columnIndices.map(getColumn); + Iterable get columns => _dataManager + .columnIndices + .map(getColumn); @override - Iterable get rowIndices => _dataManager.rowIndices; + Iterable get rowIndices => _dataManager + .rowIndices; @override - Iterable get columnIndices => _dataManager.columnIndices; + Iterable get columnIndices => _dataManager + .columnIndices; @override Matrix fastMap(T Function(T element) mapper) { @@ -355,20 +397,24 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin @override Matrix pow(num exponent) => _dataManager.areAllRowsCached ? Matrix.fromRows(rows.map( - (row) => row.pow(exponent)).toList(), dtype: dtype) + (row) => row + .pow(exponent)).toList(), dtype: dtype) : Matrix.fromColumns(columns.map( - (column) => column.pow(exponent)).toList(), dtype: dtype); + (column) => column + .pow(exponent)).toList(), dtype: dtype); @override Matrix exp({bool skipCaching = false}) => - _cacheManager.retrieveValue(matrixLogKey, + _cacheManager.retrieveValue(matrixExpKey, () => _dataManager.areAllRowsCached ? Matrix.fromRows(rows.map( (row) => row.exp( - skipCaching: skipCaching)).toList(), dtype: dtype) + skipCaching: skipCaching, + )).toList(), dtype: dtype) : Matrix.fromColumns(columns.map( (column) => column.exp( - skipCaching: skipCaching)).toList(), dtype: dtype), + skipCaching: skipCaching, + )).toList(), dtype: dtype), skipCaching: skipCaching); @override @@ -377,10 +423,12 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin () => _dataManager.areAllRowsCached ? Matrix.fromRows(rows.map( (row) => row.log( - skipCaching: skipCaching)).toList(), dtype: dtype) + skipCaching: skipCaching, + )).toList(), dtype: dtype) : Matrix.fromColumns(columns.map( (column) => column.log( - skipCaching: skipCaching)).toList(), dtype: dtype), + skipCaching: skipCaching, + )).toList(), dtype: dtype), skipCaching: skipCaching); @override @@ -419,15 +467,13 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin } @override - Map toJson() => matrixToJson(this); + Map toJson() => matrixToJson(this)!; double _findExtrema(double Function(Vector vector) callback) { - var i = 0; - final minValues = List(rowsNum); - - for (final row in rows) { - minValues[i++] = callback(row); - } + final rowIterator = rows.iterator; + final minValues = List + .generate(rowsNum, + (i) => callback((rowIterator..moveNext()).current)); return callback(Vector.fromList(minValues, dtype: dtype)); } @@ -436,7 +482,9 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin Vector Function(Vector combine, Vector vector) combiner, int length, Vector Function(int index) getVector, - {Vector initValue}) { + { + Vector? initValue + }) { var reduced = initValue ?? getVector(0); final startIndex = initValue != null ? 0 : 1; @@ -534,6 +582,7 @@ class MatrixImpl with IterableMixin>, MatrixValidatorMixin // TODO: use then `fastMap` to accelerate computations final elementGenFn = (int i) => operation(scalar, getRow(i)); final source = List.generate(rowsNum, elementGenFn); + return Matrix.fromRows(source, dtype: dtype); } } diff --git a/lib/src/matrix/serialization/from_matrix_json.dart b/lib/src/matrix/serialization/from_matrix_json.dart index 3c18f4c0..17e03408 100644 --- a/lib/src/matrix/serialization/from_matrix_json.dart +++ b/lib/src/matrix/serialization/from_matrix_json.dart @@ -3,12 +3,12 @@ import 'package:ml_linalg/src/common/dtype_serializer/from_dtype_json.dart'; import 'package:ml_linalg/src/matrix/serialization/matrix_json_keys.dart'; /// Restores a matrix instance from given [json] -Matrix fromMatrixJson(Map json) { +Matrix? fromMatrixJson(Map? json) { if (json == null) { return null; } - final matrixSource = json[matrixDataJsonKey] as List; + final matrixSource = json[matrixDataJsonKey] as List?; if (matrixSource == null) { throw Exception('Provided json is missing `$matrixDataJsonKey` field'); @@ -21,7 +21,7 @@ Matrix fromMatrixJson(Map json) { .toList(growable: false)) .toList(growable: false); - final encodedDType = json[matrixDTypeJsonKey] as String; + final encodedDType = json[matrixDTypeJsonKey] as String?; if (encodedDType == null) { throw Exception('Provided json is missing `$matrixDTypeJsonKey` field'); @@ -29,5 +29,5 @@ Matrix fromMatrixJson(Map json) { final dType = fromDTypeJson(encodedDType); - return Matrix.fromList(double2dList, dtype: dType); + return Matrix.fromList(double2dList, dtype: dType!); } diff --git a/lib/src/matrix/serialization/matrix_to_json.dart b/lib/src/matrix/serialization/matrix_to_json.dart index 653fdcb2..7d6280b2 100644 --- a/lib/src/matrix/serialization/matrix_to_json.dart +++ b/lib/src/matrix/serialization/matrix_to_json.dart @@ -3,7 +3,7 @@ import 'package:ml_linalg/src/common/dtype_serializer/dtype_to_json.dart'; import 'package:ml_linalg/src/matrix/serialization/matrix_json_keys.dart'; /// Encodes a [matrix] to a json-serializable map -Map matrixToJson(Matrix matrix) { +Map? matrixToJson(Matrix? matrix) { if (matrix == null) { return null; } diff --git a/lib/src/vector/float32x4_vector.dart b/lib/src/vector/float32x4_vector.dart index da57a37b..6be965e0 100644 --- a/lib/src/vector/float32x4_vector.dart +++ b/lib/src/vector/float32x4_vector.dart @@ -17,23 +17,35 @@ const _bytesPerSimdElement = Float32x4List.bytesPerElement; const _bucketSize = Float32x4List.bytesPerElement ~/ Float32List.bytesPerElement; final _simdOnes = Float32x4.splat(1.0); -class Float32x4Vector with IterableMixin implements Vector { +class Float32x4Vector + with + IterableMixin + implements + Vector { + Float32x4Vector.fromList(List source, this._cacheManager) : length = source.length, _numOfBuckets = _getNumOfBuckets(source.length, _bucketSize), _buffer = ByteData(_getNumOfBuckets(source.length, _bucketSize) * _bytesPerSimdElement).buffer { + for (var i = 0; i < length; i++) { _buffer.asByteData().setFloat32(_bytesPerElement * i, source[i].toDouble(), Endian.host); } } - Float32x4Vector.randomFilled(this.length, int seed, this._cacheManager, { - num min = 0, num max = 1}) : + Float32x4Vector.randomFilled( + this.length, + int seed, + this._cacheManager, { + num min = 0, + num max = 1, + }) : _numOfBuckets = _getNumOfBuckets(length, _bucketSize), _buffer = ByteData(_getNumOfBuckets(length, _bucketSize) * _bytesPerSimdElement).buffer { + if (min >= max) { throw ArgumentError.value(min, 'Argument `min` should be less than `max`'); } @@ -47,10 +59,14 @@ class Float32x4Vector with IterableMixin implements Vector { } } - Float32x4Vector.filled(this.length, num value, this._cacheManager) : + Float32x4Vector.filled( + this.length, + num value, + this._cacheManager) : _numOfBuckets = _getNumOfBuckets(length, _bucketSize), _buffer = ByteData(_getNumOfBuckets(length, _bucketSize) * _bytesPerSimdElement).buffer { + for (var i = 0; i < length; i++) { _buffer.asByteData().setFloat32(_bytesPerElement * i, value.toDouble(), Endian.host); @@ -62,10 +78,13 @@ class Float32x4Vector with IterableMixin implements Vector { _buffer = ByteData(_getNumOfBuckets(length, _bucketSize) * _bytesPerSimdElement).buffer; - Float32x4Vector.fromSimdList(Float32x4List data, this.length, + Float32x4Vector.fromSimdList( + Float32x4List data, + this.length, this._cacheManager) : _numOfBuckets = _getNumOfBuckets(length, _bucketSize), _buffer = data.buffer { + _cachedInnerSimdList = data; } @@ -90,11 +109,11 @@ class Float32x4Vector with IterableMixin implements Vector { Float32x4List get _innerSimdList => _cachedInnerSimdList ??= _buffer.asFloat32x4List(); - Float32x4List _cachedInnerSimdList; + Float32x4List? _cachedInnerSimdList; List get _innerTypedList => _cachedInnerTypedList ??= _buffer.asFloat32List(0, length); - Float32List _cachedInnerTypedList; + Float32List? _cachedInnerTypedList; bool get _isLastBucketNotFull => length % _bucketSize > 0; @@ -132,22 +151,32 @@ class Float32x4Vector with IterableMixin implements Vector { @override Vector operator +(Object value) { if (value is Vector || value is Matrix) { - final other = (value is Matrix ? value.toVector() : value) as Float32x4Vector; + final other = (value is Matrix + ? value.toVector() + : value) as Float32x4Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] + other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); - } else if (value is num) { + } + + if (value is num) { final arg = Float32x4.splat(value.toDouble()); final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] + arg; } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -157,22 +186,31 @@ class Float32x4Vector with IterableMixin implements Vector { @override Vector operator -(Object value) { if (value is Vector || value is Matrix) { - final other = (value is Matrix ? value.toVector() : value) as Float32x4Vector; + final other = (value is Matrix + ? value.toVector() + : value) as Float32x4Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] - other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); + } - } else if (value is num) { + if (value is num) { final arg = Float32x4.splat(value.toDouble()); final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] - arg; } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -183,13 +221,17 @@ class Float32x4Vector with IterableMixin implements Vector { Vector operator *(Object value) { if (value is Vector) { final other = value as Float32x4Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] * other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -199,9 +241,11 @@ class Float32x4Vector with IterableMixin implements Vector { if (value is num) { final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].scale(value.toDouble()); } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -212,20 +256,28 @@ class Float32x4Vector with IterableMixin implements Vector { Vector operator /(Object value) { if (value is Vector) { final other = value as Float32x4Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] / other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); - } else if (value is num) { + } + + if (value is num) { final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].scale(1 / value); } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -236,9 +288,11 @@ class Float32x4Vector with IterableMixin implements Vector { Vector sqrt({bool skipCaching = false}) => _cacheManager.retrieveValue(vectorSqrtKey, () { final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].sqrt(); } + return Vector.fromSimdList(source, length, dtype: dtype); }, skipCaching: skipCaching); @@ -271,10 +325,6 @@ class Float32x4Vector with IterableMixin implements Vector { return Vector.fromList(source, dtype: dtype); }, skipCaching: skipCaching); - @override - @deprecated - Vector toIntegerPower(int power) => pow(power); - @override Vector abs({bool skipCaching = false}) => _cacheManager.retrieveValue(vectorAbsKey, () { @@ -311,10 +361,13 @@ class Float32x4Vector with IterableMixin implements Vector { switch (distance) { case Distance.euclidean: return (this - other).norm(Norm.euclidean); + case Distance.manhattan: return (this - other).norm(Norm.manhattan); + case Distance.cosine: return 1 - getCosine(other); + default: throw UnimplementedError('Unimplemented distance type - $distance'); } @@ -324,10 +377,12 @@ class Float32x4Vector with IterableMixin implements Vector { double getCosine(Vector other) { final cosine = (dot(other) / norm(Norm.euclidean) / other.norm(Norm.euclidean)); + if (cosine.isInfinite || cosine.isNaN) { throw Exception('It is impossible to find cosine of an angle of two ' 'vectors if at least one of the vectors is zero-vector'); } + return cosine; } @@ -336,6 +391,7 @@ class Float32x4Vector with IterableMixin implements Vector { if (isEmpty) { throw _emptyVectorException; } + return _cacheManager.retrieveValue(vectorMeanKey, () => sum() / length, skipCaching: skipCaching); } @@ -344,9 +400,11 @@ class Float32x4Vector with IterableMixin implements Vector { double norm([Norm normType = Norm.euclidean, bool skipCaching = false]) => _cacheManager.retrieveValue(getCacheKeyForNormByNormType(normType), () { final power = _getPowerByNormType(normType); + if (power == 1) { return abs().sum(); } + return math.pow(pow(power).sum(), 1 / power) as double; }, skipCaching: skipCaching); @@ -391,31 +449,41 @@ class Float32x4Vector with IterableMixin implements Vector { if (_isLastBucketNotFull) { var extrema = initialValue; final fullBucketsList = _innerSimdList.take(_numOfBuckets - 1); + if (fullBucketsList.isNotEmpty) { extrema = getExtremalLane(fullBucketsList.reduce(getExtremalBucket)); } + return _simdHelper.simdValueToList(_innerSimdList.last) .take(length % _bucketSize) .fold(extrema, getExtremalValue); - } else { - return getExtremalLane(_innerSimdList.reduce(getExtremalBucket)); } + + return getExtremalLane(_innerSimdList.reduce(getExtremalBucket)); } @override Vector sample(Iterable indices) { - final list = Float32List(indices.length); var i = 0; + final list = Float32List(indices.length); + for (final idx in indices) { list[i++] = this[idx]; } + return Vector.fromList(list, dtype: dtype); } @override Vector unique({bool skipCaching = false}) => - _cacheManager.retrieveValue(vectorUniqueKey, () => Vector.fromList( - Set.from(this).toList(growable: false), dtype: dtype), + _cacheManager.retrieveValue( + vectorUniqueKey, + () => Vector.fromList( + Set + .from(this) + .toList(growable: false), + dtype: dtype, + ), skipCaching: skipCaching); @override @@ -435,27 +503,32 @@ class Float32x4Vector with IterableMixin implements Vector { if (isEmpty) { throw _emptyVectorException; } + if (index >= length) { throw RangeError.index(index, this); } + return _innerTypedList[index]; } @override - Vector subvector(int start, [int end]) { + Vector subvector(int start, [int? end]) { if (start < 0) { throw RangeError.range(start, 0, length - 1, '`start` cannot' ' be negative'); } + if (end != null && start >= end) { throw RangeError.range(start, 0, length - 1, '`start` cannot be greater than or equal to `end`'); } + if (start >= length) { throw RangeError.range(start, 0, length - 1, '`start` cannot be greater than or equal to the vector' 'length'); } + final limit = end == null || end > length ? length : end; final collection = _innerTypedList.sublist(start, limit); @@ -473,11 +546,12 @@ class Float32x4Vector with IterableMixin implements Vector { _cacheManager.retrieveValue(vectorRescaleKey, () { final minValue = min(); final maxValue = max(); + return (this - minValue) / (maxValue - minValue); }, skipCaching: skipCaching); @override - Map toJson() => vectorToJson(this); + Map toJson() => vectorToJson(this)!; /// Returns exponent depending on vector norm type (for Euclidean norm - 2, /// Manhattan - 1) @@ -485,10 +559,12 @@ class Float32x4Vector with IterableMixin implements Vector { switch (norm) { case Norm.euclidean: return 2; + case Norm.manhattan: return 1; + default: - throw UnsupportedError('Unsupported norm type!'); + throw UnsupportedError('Unsupported norm type $norm'); } } @@ -498,17 +574,21 @@ class Float32x4Vector with IterableMixin implements Vector { Vector _elementWiseIntegerPow(int exponent) { final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _simdToIntPow(_innerSimdList[i], exponent); } + return Vector.fromSimdList(source, length, dtype: dtype); } Vector _elementWiseFloatPow(double exponent) { final source = Float32x4List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _simdToFloatPow(_innerSimdList[i], exponent); } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -538,8 +618,10 @@ class Float32x4Vector with IterableMixin implements Vector { 'vector length is not allowed: vector length: $length, matrix ' 'row number: ${matrix.rowsNum}'); } + final source = List.generate( matrix.columnsNum, (int i) => dot(matrix.getColumn(i))); + return Vector.fromList(source, dtype: dtype); } diff --git a/lib/src/vector/float64x2_vector.dart b/lib/src/vector/float64x2_vector.dart index 27a000e9..fb63f3fb 100644 --- a/lib/src/vector/float64x2_vector.dart +++ b/lib/src/vector/float64x2_vector.dart @@ -19,23 +19,35 @@ const _bytesPerSimdElement = Float64x2List.bytesPerElement; const _bucketSize = Float64x2List.bytesPerElement ~/ Float64List.bytesPerElement; final _simdOnes = Float64x2.splat(1.0); -class Float64x2Vector with IterableMixin implements Vector { +class Float64x2Vector + with + IterableMixin + implements + Vector { + Float64x2Vector.fromList(List source, this._cacheManager) : length = source.length, _numOfBuckets = _getNumOfBuckets(source.length, _bucketSize), _buffer = ByteData(_getNumOfBuckets(source.length, _bucketSize) * _bytesPerSimdElement).buffer { + for (var i = 0; i < length; i++) { _buffer.asByteData().setFloat64(_bytesPerElement * i, source[i].toDouble(), Endian.host); } } - Float64x2Vector.randomFilled(this.length, int seed, this._cacheManager, { - num min = 0, num max = 1}) : + Float64x2Vector.randomFilled( + this.length, + int seed, + this._cacheManager, { + num min = 0, + num max = 1, + }) : _numOfBuckets = _getNumOfBuckets(length, _bucketSize), _buffer = ByteData(_getNumOfBuckets(length, _bucketSize) * _bytesPerSimdElement).buffer { + if (min >= max) { throw ArgumentError.value(min, 'Argument `min` should be less than `max`'); } @@ -49,10 +61,14 @@ class Float64x2Vector with IterableMixin implements Vector { } } - Float64x2Vector.filled(this.length, num value, this._cacheManager) : + Float64x2Vector.filled( + this.length, + num value, + this._cacheManager) : _numOfBuckets = _getNumOfBuckets(length, _bucketSize), _buffer = ByteData(_getNumOfBuckets(length, _bucketSize) * _bytesPerSimdElement).buffer { + for (var i = 0; i < length; i++) { _buffer.asByteData().setFloat64(_bytesPerElement * i, value.toDouble(), Endian.host); @@ -64,10 +80,13 @@ class Float64x2Vector with IterableMixin implements Vector { _buffer = ByteData(_getNumOfBuckets(length, _bucketSize) * _bytesPerSimdElement).buffer; - Float64x2Vector.fromSimdList(Float64x2List data, this.length, + Float64x2Vector.fromSimdList( + Float64x2List data, + this.length, this._cacheManager) : _numOfBuckets = _getNumOfBuckets(length, _bucketSize), _buffer = data.buffer { + _cachedInnerSimdList = data; } @@ -92,11 +111,11 @@ class Float64x2Vector with IterableMixin implements Vector { Float64x2List get _innerSimdList => _cachedInnerSimdList ??= _buffer.asFloat64x2List(); - Float64x2List _cachedInnerSimdList; + Float64x2List? _cachedInnerSimdList; List get _innerTypedList => _cachedInnerTypedList ??= _buffer.asFloat64List(0, length); - Float64List _cachedInnerTypedList; + Float64List? _cachedInnerTypedList; bool get _isLastBucketNotFull => length % _bucketSize > 0; @@ -134,22 +153,32 @@ class Float64x2Vector with IterableMixin implements Vector { @override Vector operator +(Object value) { if (value is Vector || value is Matrix) { - final other = (value is Matrix ? value.toVector() : value) as Float64x2Vector; + final other = (value is Matrix + ? value.toVector() + : value) as Float64x2Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] + other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); - } else if (value is num) { + } + + if (value is num) { final arg = Float64x2.splat(value.toDouble()); final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] + arg; } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -159,22 +188,31 @@ class Float64x2Vector with IterableMixin implements Vector { @override Vector operator -(Object value) { if (value is Vector || value is Matrix) { - final other = (value is Matrix ? value.toVector() : value) as Float64x2Vector; + final other = (value is Matrix + ? value.toVector() + : value) as Float64x2Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] - other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); + } - } else if (value is num) { + if (value is num) { final arg = Float64x2.splat(value.toDouble()); final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] - arg; } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -185,13 +223,17 @@ class Float64x2Vector with IterableMixin implements Vector { Vector operator *(Object value) { if (value is Vector) { final other = value as Float64x2Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] * other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -201,9 +243,11 @@ class Float64x2Vector with IterableMixin implements Vector { if (value is num) { final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].scale(value.toDouble()); } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -214,20 +258,28 @@ class Float64x2Vector with IterableMixin implements Vector { Vector operator /(Object value) { if (value is Vector) { final other = value as Float64x2Vector; + if (other.length != length) { throw _mismatchLengthError; } + final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i] / other._innerSimdList[i]; } + return Vector.fromSimdList(source, length, dtype: dtype); - } else if (value is num) { + } + + if (value is num) { final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].scale(1 / value); } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -238,9 +290,11 @@ class Float64x2Vector with IterableMixin implements Vector { Vector sqrt({bool skipCaching = false}) => _cacheManager.retrieveValue(vectorSqrtKey, () { final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _innerSimdList[i].sqrt(); } + return Vector.fromSimdList(source, length, dtype: dtype); }, skipCaching: skipCaching); @@ -273,10 +327,6 @@ class Float64x2Vector with IterableMixin implements Vector { return Vector.fromList(source, dtype: dtype); }, skipCaching: skipCaching); - @override - @deprecated - Vector toIntegerPower(int power) => pow(power); - @override Vector abs({bool skipCaching = false}) => _cacheManager.retrieveValue(vectorAbsKey, () { @@ -313,10 +363,13 @@ class Float64x2Vector with IterableMixin implements Vector { switch (distance) { case Distance.euclidean: return (this - other).norm(Norm.euclidean); + case Distance.manhattan: return (this - other).norm(Norm.manhattan); + case Distance.cosine: return 1 - getCosine(other); + default: throw UnimplementedError('Unimplemented distance type - $distance'); } @@ -326,10 +379,12 @@ class Float64x2Vector with IterableMixin implements Vector { double getCosine(Vector other) { final cosine = (dot(other) / norm(Norm.euclidean) / other.norm(Norm.euclidean)); + if (cosine.isInfinite || cosine.isNaN) { throw Exception('It is impossible to find cosine of an angle of two ' 'vectors if at least one of the vectors is zero-vector'); } + return cosine; } @@ -338,6 +393,7 @@ class Float64x2Vector with IterableMixin implements Vector { if (isEmpty) { throw _emptyVectorException; } + return _cacheManager.retrieveValue(vectorMeanKey, () => sum() / length, skipCaching: skipCaching); } @@ -346,9 +402,11 @@ class Float64x2Vector with IterableMixin implements Vector { double norm([Norm normType = Norm.euclidean, bool skipCaching = false]) => _cacheManager.retrieveValue(getCacheKeyForNormByNormType(normType), () { final power = _getPowerByNormType(normType); + if (power == 1) { return abs().sum(); } + return math.pow(pow(power).sum(), 1 / power) as double; }, skipCaching: skipCaching); @@ -393,31 +451,41 @@ class Float64x2Vector with IterableMixin implements Vector { if (_isLastBucketNotFull) { var extrema = initialValue; final fullBucketsList = _innerSimdList.take(_numOfBuckets - 1); + if (fullBucketsList.isNotEmpty) { extrema = getExtremalLane(fullBucketsList.reduce(getExtremalBucket)); } + return _simdHelper.simdValueToList(_innerSimdList.last) .take(length % _bucketSize) .fold(extrema, getExtremalValue); - } else { - return getExtremalLane(_innerSimdList.reduce(getExtremalBucket)); } + + return getExtremalLane(_innerSimdList.reduce(getExtremalBucket)); } @override Vector sample(Iterable indices) { - final list = Float64List(indices.length); var i = 0; + final list = Float64List(indices.length); + for (final idx in indices) { list[i++] = this[idx]; } + return Vector.fromList(list, dtype: dtype); } @override Vector unique({bool skipCaching = false}) => - _cacheManager.retrieveValue(vectorUniqueKey, () => Vector.fromList( - Set.from(this).toList(growable: false), dtype: dtype), + _cacheManager.retrieveValue( + vectorUniqueKey, + () => Vector.fromList( + Set + .from(this) + .toList(growable: false), + dtype: dtype, + ), skipCaching: skipCaching); @override @@ -437,27 +505,32 @@ class Float64x2Vector with IterableMixin implements Vector { if (isEmpty) { throw _emptyVectorException; } + if (index >= length) { throw RangeError.index(index, this); } + return _innerTypedList[index]; } @override - Vector subvector(int start, [int end]) { + Vector subvector(int start, [int? end]) { if (start < 0) { throw RangeError.range(start, 0, length - 1, '`start` cannot' ' be negative'); } + if (end != null && start >= end) { throw RangeError.range(start, 0, length - 1, '`start` cannot be greater than or equal to `end`'); } + if (start >= length) { throw RangeError.range(start, 0, length - 1, '`start` cannot be greater than or equal to the vector' 'length'); } + final limit = end == null || end > length ? length : end; final collection = _innerTypedList.sublist(start, limit); @@ -475,11 +548,12 @@ class Float64x2Vector with IterableMixin implements Vector { _cacheManager.retrieveValue(vectorRescaleKey, () { final minValue = min(); final maxValue = max(); + return (this - minValue) / (maxValue - minValue); }, skipCaching: skipCaching); @override - Map toJson() => vectorToJson(this); + Map toJson() => vectorToJson(this)!; /// Returns exponent depending on vector norm type (for Euclidean norm - 2, /// Manhattan - 1) @@ -487,10 +561,12 @@ class Float64x2Vector with IterableMixin implements Vector { switch (norm) { case Norm.euclidean: return 2; + case Norm.manhattan: return 1; + default: - throw UnsupportedError('Unsupported norm type!'); + throw UnsupportedError('Unsupported norm type $norm'); } } @@ -500,17 +576,21 @@ class Float64x2Vector with IterableMixin implements Vector { Vector _elementWiseIntegerPow(int exponent) { final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _simdToIntPow(_innerSimdList[i], exponent); } + return Vector.fromSimdList(source, length, dtype: dtype); } Vector _elementWiseFloatPow(double exponent) { final source = Float64x2List(_numOfBuckets); + for (var i = 0; i < _numOfBuckets; i++) { source[i] = _simdToFloatPow(_innerSimdList[i], exponent); } + return Vector.fromSimdList(source, length, dtype: dtype); } @@ -540,8 +620,10 @@ class Float64x2Vector with IterableMixin implements Vector { 'vector length is not allowed: vector length: $length, matrix ' 'row number: ${matrix.rowsNum}'); } + final source = List.generate( matrix.columnsNum, (int i) => dot(matrix.getColumn(i))); + return Vector.fromList(source, dtype: dtype); } diff --git a/lib/src/vector/serialization/from_vector_json.dart b/lib/src/vector/serialization/from_vector_json.dart index 95c8af18..2c203d33 100644 --- a/lib/src/vector/serialization/from_vector_json.dart +++ b/lib/src/vector/serialization/from_vector_json.dart @@ -3,7 +3,7 @@ import 'package:ml_linalg/src/common/dtype_serializer/dtype_encoded_values.dart' import 'package:ml_linalg/src/vector/vector_json_keys.dart'; /// Restores a vector instance from the given [json] -Vector fromVectorJson(Map json) { +Vector? fromVectorJson(Map? json) { if (json == null) { return null; } @@ -12,7 +12,7 @@ Vector fromVectorJson(Map json) { .map((dynamic value) => double.parse(value.toString())) .toList(growable: false); - switch(json[vectorDTypeJsonKey] as String) { + switch(json[vectorDTypeJsonKey] as String?) { case dTypeFloat32EncodedValue: return Vector.fromList(source, dtype: DType.float32); diff --git a/lib/src/vector/serialization/vector_to_json.dart b/lib/src/vector/serialization/vector_to_json.dart index 9807326e..24e3fff0 100644 --- a/lib/src/vector/serialization/vector_to_json.dart +++ b/lib/src/vector/serialization/vector_to_json.dart @@ -3,7 +3,7 @@ import 'package:ml_linalg/src/vector/vector_json_keys.dart'; import 'package:ml_linalg/vector.dart'; /// Returns a json-serializable map for the [vector] -Map vectorToJson(Vector vector) { +Map? vectorToJson(Vector? vector) { if (vector == null) { return null; } diff --git a/lib/vector.dart b/lib/vector.dart index 150dba1c..32e98fff 100644 --- a/lib/vector.dart +++ b/lib/vector.dart @@ -3,15 +3,12 @@ import 'dart:typed_data'; import 'package:ml_linalg/distance.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/norm.dart'; -import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory.dart'; -import 'package:ml_linalg/src/di/dependencies.dart'; +import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory_impl.dart'; import 'package:ml_linalg/src/vector/float32x4_vector.dart'; import 'package:ml_linalg/src/vector/float64x2_vector.dart'; import 'package:ml_linalg/src/vector/serialization/from_vector_json.dart'; import 'package:ml_linalg/src/vector/vector_cache_keys.dart'; -final _cacheManagerFactory = dependencies.get(); - /// An algebraic vector with SIMD (single instruction, multiple data) /// architecture support and extended functionality, adapted for data science /// applications @@ -44,13 +41,15 @@ abstract class Vector implements Iterable { case DType.float32: return Float32x4Vector.fromList( source, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); case DType.float64: return Float64x2Vector.fromList( source, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); default: @@ -91,14 +90,16 @@ abstract class Vector implements Iterable { return Float32x4Vector.fromSimdList( source as Float32x4List, actualLength, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); case DType.float64: return Float64x2Vector.fromSimdList( source as Float64x2List, actualLength, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); default: @@ -131,14 +132,16 @@ abstract class Vector implements Iterable { return Float32x4Vector.filled( length, value, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); case DType.float64: return Float64x2Vector.filled( length, value, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); default: @@ -170,13 +173,15 @@ abstract class Vector implements Iterable { case DType.float32: return Float32x4Vector.zero( length, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); case DType.float64: return Float64x2Vector.zero( length, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); default: @@ -217,7 +222,8 @@ abstract class Vector implements Iterable { return Float32x4Vector.randomFilled( length, seed, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), max: max, min: min, ); @@ -226,7 +232,8 @@ abstract class Vector implements Iterable { return Float64x2Vector.randomFilled( length, seed, - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), max: max, min: min, ); @@ -257,12 +264,14 @@ abstract class Vector implements Iterable { switch (dtype) { case DType.float32: return Float32x4Vector.empty( - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); case DType.float64: return Float64x2Vector.empty( - _cacheManagerFactory.create(vectorCacheKeys), + CacheManagerFactoryImpl() + .create(vectorCacheKeys), ); default: @@ -270,7 +279,7 @@ abstract class Vector implements Iterable { } } - factory Vector.fromJson(Map json) => fromVectorJson(json); + factory Vector.fromJson(Map json) => fromVectorJson(json)!; /// Denotes a data type, used for representation of the vector's elements DType get dtype; @@ -298,12 +307,6 @@ abstract class Vector implements Iterable { /// divided by [scalar] Vector scalarDiv(num scalar); - /// Creates a new [Vector] containing elements of this [Vector] raised to - /// the integer [power] - /// Deprecated, use [pow] instead - @deprecated - Vector toIntegerPower(int power); - /// Creates a new [Vector] composed of elements of this [Vector] raised to /// the [exponent]. Avoid raising a vector to a float power, since it is /// a slow operation @@ -395,7 +398,7 @@ abstract class Vector implements Iterable { /// Returns a new vector composed of values whose indices are within the range /// [start] (inclusive) - [end] (exclusive) - Vector subvector(int start, [int end]); + Vector subvector(int start, [int? end]); /// Returns a json-serializable map Map toJson(); diff --git a/pubspec.yaml b/pubspec.yaml index fe51324b..13cbd96d 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,21 +1,19 @@ name: ml_linalg description: SIMD-based linear algebra and statistics for data science -version: 12.17.10 +version: 13.0.0-nullsafety.0 homepage: https://github.com/gyrdym/ml_linalg environment: - sdk: '>=2.7.0 <3.0.0' + sdk: '>=2.12.0-0 <3.0.0' dependencies: - injector: ^1.0.9 - quiver: '>=2.0.0 <3.0.0' - xrange: ^1.0.0 + quiver: ^3.0.0-nullsafety.3 dev_dependencies: - benchmark_harness: '>=1.0.0 <2.0.0' + benchmark_harness: ^2.0.0-nullsafety.0 build_runner: ^1.10.11 build_test: ^1.3.2 - mockito: ^3.0.0 - pedantic: ^1.9.2 - test: ^1.15.7 + mockito: ^5.0.0-nullsafety.5 + pedantic: ^1.10.0-nullsafety.3 + test: ^1.16.0-nullsafety.17 test_coverage: ^0.5.0 diff --git a/test/integration_test/matrix/methods/insert_columns/insert_columns_test_group_factory.dart b/test/integration_test/matrix/methods/insert_columns/insert_columns_test_group_factory.dart index f0b9a213..129237c2 100644 --- a/test/integration_test/matrix/methods/insert_columns/insert_columns_test_group_factory.dart +++ b/test/integration_test/matrix/methods/insert_columns/insert_columns_test_group_factory.dart @@ -22,10 +22,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect(newMatrix.getColumn(1), equals(newCol)); - expect(newMatrix.getRow(0), equals([4.0, 100.0, 8.0, 12.0, 16.0, 34.0])); - expect(newMatrix.getRow(4), equals([112.0, 500.0, 10.0, 34.0, 2.0, 10.0])); expect(newMatrix, equals([ [4.0, 100.0, 8.0, 12.0, 16.0, 34.0], [20.0, 200.0, 24.0, 28.0, 32.0, 23.0], @@ -50,8 +46,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect(newMatrix.getColumn(0), equals(newCol)); expect(newMatrix, equals([ [100.0, 4.0, 8.0, 12.0, 16.0, 34.0], [200.0, 20.0, 24.0, 28.0, 32.0, 23.0], @@ -75,8 +69,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect(newMatrix.getColumn(4), equals(newCol)); expect(newMatrix, equals([ [4.0, 8.0, 12.0, 16.0, 100.0, 34.0], [20.0, 24.0, 28.0, 32.0, 200.0, 23.0], @@ -101,8 +93,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect(newMatrix.getColumn(5), equals(newCol)); expect(newMatrix, equals([ [4.0, 8.0, 12.0, 16.0, 34.0, 100.0], [20.0, 24.0, 28.0, 32.0, 23.0, 200.0], @@ -137,10 +127,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - expect(newMatrix.getRow(0), - equals([4.0, 100.0, -100.0, 8.0, 12.0, 16.0, 34.0])); - expect(newMatrix.getRow(4), - equals([112.0, 500.0, -500.0, 10.0, 34.0, 2.0, 10.0])); expect(newMatrix, equals([ [4.0, 100.0, -100.0, 8.0, 12.0, 16.0, 34.0], [20.0, 200.0, -200.0, 24.0, 28.0, 32.0, 23.0], @@ -169,12 +155,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect([ - newMatrix.getColumn(0), - newMatrix.getColumn(1), - ], equals(newCols)); - expect(newMatrix, equals([ [100.0, -100.0, 4.0, 8.0, 12.0, 16.0, 34.0], [200.0, -200.0, 20.0, 24.0, 28.0, 32.0, 23.0], @@ -199,8 +179,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect(newMatrix.getColumn(4), equals(newCol)); expect(newMatrix, equals([ [4.0, 8.0, 12.0, 16.0, 100.0, 34.0], [20.0, 24.0, 28.0, 32.0, 200.0, 23.0], @@ -228,12 +206,6 @@ void matrixInsertColumnsTestGroupFactory(DType dtype) => expect(matrix.dtype, dtype); expect(newMatrix.dtype, dtype); - - expect([ - newMatrix.getColumn(5), - newMatrix.getColumn(6), - ], equals(newCols)); - expect(newMatrix, equals([ [4.0, 8.0, 12.0, 16.0, 34.0, 100.0, -100.0], [20.0, 24.0, 28.0, 32.0, 23.0, 200.0, -200.0], diff --git a/test/integration_test/vector/methods/to_integer_power/to_integer_power_test.dart b/test/integration_test/vector/methods/to_integer_power/to_integer_power_test.dart deleted file mode 100644 index a43dd772..00000000 --- a/test/integration_test/vector/methods/to_integer_power/to_integer_power_test.dart +++ /dev/null @@ -1,8 +0,0 @@ -import 'package:ml_linalg/dtype.dart'; - -import 'to_integer_power_test_group_factory.dart'; - -void main() { - vectorToIntegerPowerTestGroupFactory(DType.float32); - vectorToIntegerPowerTestGroupFactory(DType.float64); -} diff --git a/test/integration_test/vector/methods/to_integer_power/to_integer_power_test_group_factory.dart b/test/integration_test/vector/methods/to_integer_power/to_integer_power_test_group_factory.dart deleted file mode 100644 index 2de4e819..00000000 --- a/test/integration_test/vector/methods/to_integer_power/to_integer_power_test_group_factory.dart +++ /dev/null @@ -1,22 +0,0 @@ -import 'package:ml_linalg/dtype.dart'; -import 'package:ml_linalg/vector.dart'; -import 'package:test/test.dart'; - -import '../../../../dtype_to_title.dart'; - -void vectorToIntegerPowerTestGroupFactory(DType dtype) => - group(dtypeToVectorTestTitle[dtype], () { - group('toIntegerPower method', () { - test('should raise vector elements to the integer power', () { - final vector = Vector.fromList([1.0, 2.0, 3.0, 4.0, 5.0], - dtype: dtype); - // ignore: deprecated_member_use_from_same_package - final result = vector.toIntegerPower(3); - - expect(result, isNot(same(vector))); - expect(result.length, equals(5)); - expect(result, equals([1.0, 8.0, 27.0, 64.0, 125.0])); - expect(result.dtype, dtype); - }); - }); - }); diff --git a/test/unit_test/matrix/matrix_iterator/matrix_iterator_test_group_factory.dart b/test/unit_test/matrix/matrix_iterator/matrix_iterator_test_group_factory.dart index 51ef1341..eae15100 100644 --- a/test/unit_test/matrix/matrix_iterator/matrix_iterator_test_group_factory.dart +++ b/test/unit_test/matrix/matrix_iterator/matrix_iterator_test_group_factory.dart @@ -19,11 +19,12 @@ void matrixIteratorTestGroupFactory(DType dtype, final source = createSource([1.0, 2.0, 3.0, 10.0, 22.0, 31.0, 8.3, 3.4, 34.5]); - test('should initilize with `current` property equals `null`', () { + test('should throw an error if one accesses `current` property before ' + '`moveNext` call', () { final data = createByteData(source); final iterator = createIterator(data, 3, 3); - expect(iterator.current, isNull); + expect(() => iterator.current, throwsA(isA())); }); test('should return the next value on every `moveNext` method call (9 ' @@ -40,9 +41,19 @@ void matrixIteratorTestGroupFactory(DType dtype, iterator.moveNext(); expect(iterator.current, iterableAlmostEqualTo([8.3, 3.4, 34.5])); + }); - iterator.moveNext(); - expect(iterator.current, isNull); + test('should contain the last successful value in `current` field if ' + '`moveNext` returns `false`', () { + final data = createByteData(source); + + final iterator = createIterator(data, 3, 3) + ..moveNext() + ..moveNext() + ..moveNext() + ..moveNext(); + + expect(iterator.current, iterableAlmostEqualTo([8.3, 3.4, 34.5])); }); test('should return the next value on every `moveNext` method call (9 ' @@ -78,10 +89,6 @@ void matrixIteratorTestGroupFactory(DType dtype, iterator.current, iterableAlmostEqualTo( [1.0, 2.0, 3.0, 10.0, 22.0, 31.0, 8.3, 3.4, 34.5])); - - iterator.moveNext(); - - expect(iterator.current, isNull); }); test('should return a proper boolean indicator after each `moveNext` ' diff --git a/test/unit_test/matrix/serialization/from_matrix_json_test.dart b/test/unit_test/matrix/serialization/from_matrix_json_test.dart index 2c0a99dc..80a07bc8 100644 --- a/test/unit_test/matrix/serialization/from_matrix_json_test.dart +++ b/test/unit_test/matrix/serialization/from_matrix_json_test.dart @@ -14,7 +14,7 @@ void main() { [33, -987, 90, 732], ]; - final dataWithNull = >[ + final dataWithNull = >[ [123.0009863, null, 11.777209, 90003.112], [-93.5678, 12, null, -10e2], ]; @@ -66,13 +66,13 @@ void main() { }); test('should restore a float32 matrix instance from json', () { - final matrix = fromMatrixJson(validFloat32Json); + final matrix = fromMatrixJson(validFloat32Json)!; expect(matrix.dtype, DType.float32); expect(matrix, iterable2dAlmostEqualTo(data)); }); test('should restore a float64 matrix instance from json', () { - final matrix = fromMatrixJson(validFloat64Json); + final matrix = fromMatrixJson(validFloat64Json)!; expect(matrix.dtype, DType.float64); expect(matrix, iterable2dAlmostEqualTo(data)); }); diff --git a/tool/generate_class_from_template.dart b/tool/generate_class_from_template.dart index 5af0d428..824490f2 100644 --- a/tool/generate_class_from_template.dart +++ b/tool/generate_class_from_template.dart @@ -12,7 +12,7 @@ Future generateClassFromTemplate( String targetFileName, String templateFileName, { - Map mapping, + Map? mapping, String comment = '/* This file is auto generated, do not change it manually */\n\n', } ) async {