Skip to content

Commit

Permalink
Cache manager improvemets
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Nov 17, 2019
1 parent 0c584e0 commit f44cf4d
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 121 deletions.
119 changes: 89 additions & 30 deletions lib/matrix.dart
Expand Up @@ -4,11 +4,16 @@ import 'package:ml_linalg/axis.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix_norm.dart';
import 'package:ml_linalg/sort_direction.dart';
import 'package:ml_linalg/src/common/cache_manager/cache_manager_factory.dart';
import 'package:ml_linalg/src/di/dependencies.dart';
import 'package:ml_linalg/src/matrix/data_manager/float64_matrix_data_manager.dart';
import 'package:ml_linalg/src/matrix/matrix_cache_keys.dart';
import 'package:ml_linalg/src/matrix/matrix_impl.dart';
import 'package:ml_linalg/src/matrix/data_manager/float32_matrix_data_manager.dart';
import 'package:ml_linalg/vector.dart';

final _cacheManagerFactory = dependencies.getDependency<CacheManagerFactory>();

/// An algebraic matrix with extended functionality, adapted for data science
/// applications
abstract class Matrix implements Iterable<Iterable<double>> {
Expand Down Expand Up @@ -45,10 +50,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
{DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromList(source));
return MatrixImpl(
Float32MatrixDataManager.fromList(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromList(source));
return MatrixImpl(
Float64MatrixDataManager.fromList(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -87,10 +98,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.fromRows(List<Vector> source, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromRows(source));
return MatrixImpl(
Float32MatrixDataManager.fromRows(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromRows(source));
return MatrixImpl(
Float64MatrixDataManager.fromRows(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -133,10 +150,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
{DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromColumns(source));
return MatrixImpl(
Float32MatrixDataManager.fromColumns(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromColumns(source));
return MatrixImpl(
Float64MatrixDataManager.fromColumns(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -165,10 +188,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.empty({DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromList([]));
return MatrixImpl(
Float32MatrixDataManager.fromList([]),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromList([]));
return MatrixImpl(
Float64MatrixDataManager.fromList([]),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -203,12 +232,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
int columnsNum, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromFlattened(
source, rowsNum, columnsNum));
return MatrixImpl(
Float32MatrixDataManager.fromFlattened(source, rowsNum, columnsNum),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromFlattened(
source, rowsNum, columnsNum));
return MatrixImpl(
Float64MatrixDataManager.fromFlattened(source, rowsNum, columnsNum),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -241,10 +274,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.diagonal(List<double> source, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.diagonal(source));
return MatrixImpl(
Float32MatrixDataManager.diagonal(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.diagonal(source));
return MatrixImpl(
Float64MatrixDataManager.diagonal(source),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -277,10 +316,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.scalar(double scalar, int size, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.scalar(scalar, size));
return MatrixImpl(
Float32MatrixDataManager.scalar(scalar, size),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.scalar(scalar, size));
return MatrixImpl(
Float64MatrixDataManager.scalar(scalar, size),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -313,10 +358,16 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.identity(int size, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.scalar(1.0, size));
return MatrixImpl(
Float32MatrixDataManager.scalar(1.0, size),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.scalar(1.0, size));
return MatrixImpl(
Float64MatrixDataManager.scalar(1.0, size),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
Expand Down Expand Up @@ -344,14 +395,18 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.row(List<double> source, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromRows(
[Vector.fromList(source, dtype: dtype)],
));
return MatrixImpl(
Float32MatrixDataManager
.fromRows([Vector.fromList(source, dtype: dtype)]),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromRows(
[Vector.fromList(source, dtype: dtype)],
));
return MatrixImpl(
Float64MatrixDataManager
.fromRows([Vector.fromList(source, dtype: dtype)]),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError(
Expand Down Expand Up @@ -384,14 +439,18 @@ abstract class Matrix implements Iterable<Iterable<double>> {
factory Matrix.column(List<double> source, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return MatrixImpl(Float32MatrixDataManager.fromColumns((
[Vector.fromList(source, dtype: dtype)]),
));
return MatrixImpl(
Float32MatrixDataManager
.fromColumns(([Vector.fromList(source, dtype: dtype)])),
_cacheManagerFactory.create(matrixCacheKeys),
);

case DType.float64:
return MatrixImpl(Float64MatrixDataManager.fromColumns((
[Vector.fromList(source, dtype: dtype)]),
));
return MatrixImpl(
Float64MatrixDataManager
.fromColumns(([Vector.fromList(source, dtype: dtype)])),
_cacheManagerFactory.create(matrixCacheKeys),
);

default:
throw UnimplementedError(
Expand Down
2 changes: 1 addition & 1 deletion lib/src/common/cache_manager/cache_manager.dart
@@ -1,5 +1,5 @@
abstract class CacheManager {
T retrieveValue<T>(String cachedValueName, T calculateIfAbsent(), {
T retrieveValue<T>(String key, T calculateIfAbsent(), {
bool skipCaching,
});

Expand Down
2 changes: 1 addition & 1 deletion lib/src/common/cache_manager/cache_manager_factory.dart
@@ -1,5 +1,5 @@
import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart';

abstract class CacheManagerFactory {
CacheManager create();
CacheManager create(Set<String> keys);
}
Expand Up @@ -6,5 +6,5 @@ class CacheManagerFactoryImpl implements CacheManagerFactory {
const CacheManagerFactoryImpl();

@override
CacheManager create() => CacheManagerImpl();
CacheManager create(Set<String> keys) => CacheManagerImpl(keys);
}
13 changes: 10 additions & 3 deletions lib/src/common/cache_manager/cache_manager_impl.dart
@@ -1,19 +1,26 @@
import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart';

class CacheManagerImpl implements CacheManager {
CacheManagerImpl(this._keys);

final _cache = <String, dynamic>{};
final Set<String> _keys;

@override
T retrieveValue<T>(String cachedValueName, T Function() calculateIfAbsent, {
T retrieveValue<T>(String key, T Function() calculateIfAbsent, {
bool skipCaching = false,
}) {
var value = _cache[cachedValueName] as T;
if (!_keys.contains(key)) {
throw Exception('Cache key `$key` is not registered');
}

var value = _cache[key] as T;

if (value == null) {
value = calculateIfAbsent();

if (!skipCaching) {
_cache[cachedValueName] = value;
_cache[key] = value;
}
}

Expand Down
11 changes: 11 additions & 0 deletions lib/src/matrix/matrix_cache_keys.dart
@@ -0,0 +1,11 @@
const meansByRowsKey = 'means_by_rows';
const meansByColumnsKey = 'means_by_columns';
const deviationByRowsKey = 'deviation_by_rows';
const deviationByColumnsKey = 'deviation_by_columns';

final matrixCacheKeys = Set<String>.from(<String>[
meansByRowsKey,
meansByColumnsKey,
deviationByRowsKey,
deviationByColumnsKey,
]);
19 changes: 12 additions & 7 deletions lib/src/matrix/matrix_impl.dart
Expand Up @@ -6,19 +6,20 @@ import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/matrix_norm.dart';
import 'package:ml_linalg/sort_direction.dart';
import 'package:ml_linalg/src/common/cache_manager/cache_manager.dart';
import 'package:ml_linalg/src/matrix/data_manager/matrix_data_manager.dart';
import 'package:ml_linalg/src/matrix/matrix_cache_keys.dart';
import 'package:ml_linalg/src/matrix/mixin/matrix_validator_mixin.dart';
import 'package:ml_linalg/vector.dart';
import 'package:quiver/iterables.dart';

class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin
implements Matrix {

MatrixImpl(this._dataManager);
MatrixImpl(this._dataManager, this._cacheManager);

final MatrixDataManager _dataManager;
final Map<Axis, Vector> _meansCache = {};
final Map<Axis, Vector> _deviationCache = {};
final CacheManager _cacheManager;

@override
DType get dtype => _dataManager.dtype;
Expand Down Expand Up @@ -172,10 +173,12 @@ class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin

switch (axis) {
case Axis.columns:
return _meansCache[axis] ??= _mean(columns);
return _cacheManager
.retrieveValue(meansByColumnsKey, () => _mean(columns));

case Axis.rows:
return _meansCache[axis] ??= _mean(rows);
return _cacheManager
.retrieveValue(meansByRowsKey, () => _mean(rows));

default:
throw UnimplementedError('Mean values calculation for axis $axis is not '
Expand All @@ -198,10 +201,12 @@ class MatrixImpl with IterableMixin<Iterable<double>>, MatrixValidatorMixin

switch (axis) {
case Axis.columns:
return _deviationCache[axis] ??= _deviation(rows, means, rowsNum);
return _cacheManager.retrieveValue(deviationByColumnsKey,
() => _deviation(rows, means, rowsNum));

case Axis.rows:
return _deviationCache[axis] ??= _deviation(columns, means, columnsNum);
return _cacheManager.retrieveValue(deviationByRowsKey,
() => _deviation(columns, means, columnsNum));

default:
throw UnimplementedError('Deviation calculation for axis $axis is not '
Expand Down

0 comments on commit f44cf4d

Please sign in to comment.