Skip to content

Commit

Permalink
Vector: prod method added, Matrix: sum and prod methods added
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed May 31, 2020
1 parent 13db4d9 commit 3c7f234
Show file tree
Hide file tree
Showing 19 changed files with 366 additions and 39 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 12.13.0
- `Matrix`:
- `Matrix.sum` method added
- `Matrix.prod` method added
- `Vector`:
- `Vector.prod` method added

## 12.12.0
- `Matrix`:
- `Matrix.exp` method added
Expand Down
43 changes: 42 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- [Manhattan norm](#manhattan-norm)
- [Mean value](#mean-value)
- [Sum of all vector elements](#sum-of-all-vector-elements)
- [Product of all vector elements](#product-of-all-vector-elements)
- [Element-wise power](#element-wise-power)
- [Element-wise exp](#element-wise-exp)
- [Dot product](#dot-product-of-two-vectors)
Expand Down Expand Up @@ -57,6 +58,8 @@
- [Getting min value of the matrix](#getting-min-value-of-the-matrix)
- [Matrix element-wise power](#matrix-element-wise-power)
- [Matrix element-wise exp](#matrix-element-wise-exp)
- [Sum of all matrix elements](#sum-of-all-matrix-elements)
- [Product of all matrix elements](#product-of-all-matrix-elements)
- [Matrix indexing and sampling](#matrix-indexing-and-sampling)
- [Add new columns to a matrix](#add-new-columns-to-a-matrix)
- [Matrix serialization/deserialization](#matrix-serializationdeserialization)
Expand Down Expand Up @@ -238,7 +241,17 @@ the difference is significant.
final vector = Vector.fromList([2.0, 3.0, 4.0, 5.0, 6.0]);
final result = vector.sum();
print(result); // 2 + 3 + 4 + 5 + 6 = 20.0 (equivalent to Manhattan norm)
print(result); // 2 + 3 + 4 + 5 + 6 = 20.0
````

#### Product of all vector elements
````Dart
import 'package:ml_linalg/linalg.dart';
final vector = Vector.fromList([2.0, 3.0, 4.0, 5.0, 6.0]);
final result = vector.prod();
print(result); // 2 * 3 * 4 * 5 * 6 = 720
````

### Element-wise power
Expand Down Expand Up @@ -759,6 +772,34 @@ print(matrix1 - matrix2);
// [e ^ 7, e ^ 8, e ^ 9]
````

#### Sum of all matrix elements
````Dart
import 'package:ml_linalg/linalg.dart';
final matrix = Matrix.fromList([
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
]);
final result = matrix.sum();
print(result); // 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0 + 9.0
````

#### Product of all matrix elements
````Dart
import 'package:ml_linalg/linalg.dart';
final matrix = Matrix.fromList([
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
]);
final result = matrix.product();
print(result); // 1.0 * 2.0 * 3.0 * 4.0 * 5.0 * 6.0 * 7.0 * 8.0 * 9.0
````

#### Matrix indexing and sampling
    To access a certain row vector of the matrix one may use `[]` operator:

Expand Down
6 changes: 6 additions & 0 deletions lib/matrix.dart
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ abstract class Matrix implements Iterable<Iterable<double>> {
/// are the elements of this [Matrix]
Matrix exp();

/// Returns the sum of all the matrix elements
double sum();

/// Returns the product of all the matrix elements
double prod();

/// Returns a serializable map
Map<String, dynamic> toJson();
}
24 changes: 16 additions & 8 deletions lib/src/matrix/matrix_cache_keys.dart
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
const meansByRowsKey = 'means_by_rows';
const meansByColumnsKey = 'means_by_columns';
const deviationByRowsKey = 'deviation_by_rows';
const deviationByColumnsKey = 'deviation_by_columns';
const matrixMeansByRowsKey = 'means_by_rows';
const matrixMeansByColumnsKey = 'means_by_columns';
const matrixDeviationByRowsKey = 'deviation_by_rows';
const matrixDeviationByColumnsKey = 'deviation_by_columns';
const matrixPowKey = 'pow';
const matrixExpKey = 'exp';
const matrixSumKey = 'sum';
const matrixProdKey = 'prod';

final matrixCacheKeys = Set<String>.from(<String>[
meansByRowsKey,
meansByColumnsKey,
deviationByRowsKey,
deviationByColumnsKey,
matrixMeansByRowsKey,
matrixMeansByColumnsKey,
matrixDeviationByRowsKey,
matrixDeviationByColumnsKey,
matrixPowKey,
matrixExpKey,
matrixSumKey,
matrixProdKey,
]);
58 changes: 43 additions & 15 deletions lib/src/matrix/matrix_impl.dart
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,20 @@ class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin
final rowsNumber = rowIndices.isEmpty
? rowsNum
: rowIndices.length;

final targetMatrixSource = List<Vector>(rowsNumber);

for (final indexed in enumerate(rowIndices.isEmpty
? count(0).take(rowsNum).map((i) => i.toInt())
: rowIndices)
) {
final targetRowIndex = indexed.index;
final sourceRowIndex = indexed.value;
final sourceRow = getRow(sourceRowIndex);

targetMatrixSource[targetRowIndex] = columnIndices.isEmpty
? sourceRow : sourceRow.sample(columnIndices);
}

return Matrix.fromRows(targetMatrixSource, dtype: dtype);
}

Expand Down Expand Up @@ -181,11 +183,11 @@ class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin
switch (axis) {
case Axis.columns:
return _cacheManager
.retrieveValue(meansByColumnsKey, () => _mean(columns));
.retrieveValue(matrixMeansByColumnsKey, () => _mean(columns));

case Axis.rows:
return _cacheManager
.retrieveValue(meansByRowsKey, () => _mean(rows));
.retrieveValue(matrixMeansByRowsKey, () => _mean(rows));

default:
throw UnimplementedError('Mean values calculation for axis $axis is not '
Expand All @@ -208,11 +210,11 @@ class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin

switch (axis) {
case Axis.columns:
return _cacheManager.retrieveValue(deviationByColumnsKey,
return _cacheManager.retrieveValue(matrixDeviationByColumnsKey,
() => _deviation(rows, means, rowsNum));

case Axis.rows:
return _cacheManager.retrieveValue(deviationByRowsKey,
return _cacheManager.retrieveValue(matrixDeviationByRowsKey,
() => _deviation(columns, means, columnsNum));

default:
Expand Down Expand Up @@ -330,18 +332,44 @@ class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin
}

@override
Matrix pow(num exponent) => _dataManager.areAllRowsCached
? Matrix.fromRows(rows.map(
(row) => row.pow(exponent)).toList(), dtype: dtype)
: Matrix.fromColumns(columns.map(
(column) => column.pow(exponent)).toList(), dtype: dtype);
Matrix pow(num exponent) => _cacheManager.retrieveValue(matrixPowKey,
() => _dataManager.areAllRowsCached
? Matrix.fromRows(rows.map(
(row) => row.pow(exponent)).toList(), dtype: dtype)
: Matrix.fromColumns(columns.map(
(column) => column.pow(exponent)).toList(), dtype: dtype));

@override
Matrix exp() => _cacheManager.retrieveValue(matrixExpKey,
() => _dataManager.areAllRowsCached
? Matrix.fromRows(rows.map(
(row) => row.exp()).toList(), dtype: dtype)
: Matrix.fromColumns(columns.map(
(column) => column.exp()).toList(), dtype: dtype));

@override
double sum() {
if (!hasData) {
return double.nan;
}

return _cacheManager.retrieveValue(matrixSumKey,
() => _dataManager.areAllRowsCached
? rows.fold(0, (result, row) => result + row.sum())
: columns.fold(0, (result, column) => result + column.sum()));
}

@override
Matrix exp() => _dataManager.areAllRowsCached
? Matrix.fromRows(rows.map(
(row) => row.exp()).toList(), dtype: dtype)
: Matrix.fromColumns(columns.map(
(column) => column.exp()).toList(), dtype: dtype);
double prod() {
if (!hasData) {
return double.nan;
}

return _cacheManager.retrieveValue(matrixProdKey,
() => _dataManager.areAllRowsCached
? rows.fold(0, (result, row) => result * row.prod())
: columns.fold(0, (result, column) => result * column.prod()));
}

@override
Map<String, dynamic> toJson() => matrixToJson(this);
Expand Down
38 changes: 35 additions & 3 deletions lib/src/vector/float32x4_vector.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import 'package:ml_linalg/vector.dart';
const _bytesPerElement = Float32List.bytesPerElement;
const _bytesPerSimdElement = Float32x4List.bytesPerElement;
const _bucketSize = Float32x4List.bytesPerElement ~/ Float32List.bytesPerElement;
final _simdOnes = Float32x4.splat(1.0);

class Float32x4Vector with IterableMixin<double> implements Vector {
Float32x4Vector.fromList(List<num> source, this._cacheManager) :
Expand Down Expand Up @@ -273,9 +274,19 @@ class Float32x4Vector with IterableMixin<double> implements Vector {
double dot(Vector vector) => (this * vector).sum();

@override
double sum({bool skipCaching = false}) => _cacheManager.retrieveValue(sumKey,
() => _simdHelper.sumLanes(_innerSimdList.reduce((a, b) => a + b)),
skipCaching: skipCaching);
double sum({bool skipCaching = false}) {
if (isEmpty) {
return double.nan;
}

return _cacheManager.retrieveValue(sumKey,
() => _simdHelper.sumLanes(_innerSimdList.reduce((a, b) => a + b)),
skipCaching: skipCaching);
}

@override
double prod({bool skipCaching = false}) => _cacheManager.retrieveValue(sumKey,
_findProduct, skipCaching: skipCaching);

@override
double distanceTo(Vector other, {
Expand Down Expand Up @@ -335,6 +346,27 @@ class Float32x4Vector with IterableMixin<double> implements Vector {
_findExtrema(double.infinity, _simdHelper.getMinLane,
(a, b) => a.min(b), math.min), skipCaching: skipCaching);

double _findProduct() {
if (length == 0) {
return double.nan;
}

if (_isLastBucketNotFull) {
final fullBucketsList = _innerSimdList.take(_numOfBuckets - 1);
final product = fullBucketsList.isNotEmpty
? fullBucketsList.reduce((result, value) => result * value)
: _simdOnes;

return _simdHelper.simdValueToList(_innerSimdList.last)
.take(length % _bucketSize)
.fold(_simdHelper.multLanes(product),
(result, value) => result * value);
}

return _simdHelper.multLanes(
_innerSimdList.reduce((result, value) => result * value));
}

double _findExtrema(double initialValue,
double getExtremalLane(Float32x4 bucket),
Float32x4 getExtremalBucket(Float32x4 first, Float32x4 second),
Expand Down
38 changes: 35 additions & 3 deletions lib/src/vector/float64x2_vector.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import 'package:ml_linalg/vector.dart';
const _bytesPerElement = Float64List.bytesPerElement;
const _bytesPerSimdElement = Float64x2List.bytesPerElement;
const _bucketSize = Float64x2List.bytesPerElement ~/ Float64List.bytesPerElement;
final _simdOnes = Float64x2.splat(1.0);

class Float64x2Vector with IterableMixin<double> implements Vector {
Float64x2Vector.fromList(List<num> source, this._cacheManager) :
Expand Down Expand Up @@ -275,9 +276,19 @@ class Float64x2Vector with IterableMixin<double> implements Vector {
double dot(Vector vector) => (this * vector).sum();

@override
double sum({bool skipCaching = false}) => _cacheManager.retrieveValue(sumKey,
() => _simdHelper.sumLanes(_innerSimdList.reduce((a, b) => a + b)),
skipCaching: skipCaching);
double sum({bool skipCaching = false}) {
if (isEmpty) {
return double.nan;
}

return _cacheManager.retrieveValue(sumKey,
() => _simdHelper.sumLanes(_innerSimdList.reduce((a, b) => a + b)),
skipCaching: skipCaching);
}

@override
double prod({bool skipCaching = false}) => _cacheManager.retrieveValue(sumKey,
_findProduct, skipCaching: skipCaching);

@override
double distanceTo(Vector other, {
Expand Down Expand Up @@ -337,6 +348,27 @@ class Float64x2Vector with IterableMixin<double> implements Vector {
_findExtrema(double.infinity, _simdHelper.getMinLane,
(a, b) => a.min(b), math.min), skipCaching: skipCaching);

double _findProduct() {
if (length == 0) {
return double.nan;
}

if (_isLastBucketNotFull) {
final fullBucketsList = _innerSimdList.take(_numOfBuckets - 1);
final product = fullBucketsList.isNotEmpty
? fullBucketsList.reduce((result, value) => result * value)
: _simdOnes;

return _simdHelper.simdValueToList(_innerSimdList.last)
.take(length % _bucketSize)
.fold(_simdHelper.multLanes(product),
(result, value) => result * value);
}

return _simdHelper.multLanes(
_innerSimdList.reduce((result, value) => result * value));
}

double _findExtrema(double initialValue,
double getExtremalLane(Float64x2 bucket),
Float64x2 getExtremalBucket(Float64x2 first, Float64x2 second),
Expand Down
9 changes: 4 additions & 5 deletions lib/src/vector/simd_helper/float32x4_helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ class Float32x4Helper implements SimdHelper<Float32x4> {
bool areLanesEqual(Float32x4 a, Float32x4 b) => a.equal(b).signMask == 15;

@override
double sumLanes(Float32x4 a) =>
(a.x.isNaN ? 0.0 : a.x) +
(a.y.isNaN ? 0.0 : a.y) +
(a.z.isNaN ? 0.0 : a.z) +
(a.w.isNaN ? 0.0 : a.w);
double sumLanes(Float32x4 a) => a.x + a.y + a.z + a.w;

@override
double multLanes(Float32x4 a) => a.x * a.y * a.z * a.w;

@override
double sumLanesForHash(Float32x4 a) =>
Expand Down
5 changes: 3 additions & 2 deletions lib/src/vector/simd_helper/float64x2_helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ class Float64x2Helper implements SimdHelper<Float64x2> {
bool areLanesEqual(Float64x2 a, Float64x2 b) => a.x == b.x && a.y == b.y;

@override
double sumLanes(Float64x2 a) =>
(a.x.isNaN ? 0.0 : a.x) + (a.y.isNaN ? 0.0 : a.y);
double sumLanes(Float64x2 a) => a.x + a.y;

@override
double sumLanesForHash(Float64x2 a) =>
(a.x.isNaN || a.x.isInfinite ? 0.0 : a.x) +
(a.y.isNaN || a.y.isInfinite ? 0.0 : a.y);

double multLanes(Float64x2 a) => a.x * a.y;

@override
double getMaxLane(Float64x2 a) => math.max(a.x, a.y);

Expand Down
3 changes: 3 additions & 0 deletions lib/src/vector/simd_helper/simd_helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ abstract class SimdHelper<E> {
/// performs summation of all components of passed simd value [a]
double sumLanes(E a);

/// performs multiplication of all components of passed simd value [a]
double multLanes(E a);

/// Performs summation of the lanes of the given simd value to provide a
/// hashcode. The method handles infinite and NaN values.
double sumLanesForHash(E a);
Expand Down
Loading

0 comments on commit 3c7f234

Please sign in to comment.