Skip to content

Commit

Permalink
WIP: dtype parameter passed
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Apr 7, 2019
1 parent 49048b4 commit 3862458
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 46 deletions.
2 changes: 1 addition & 1 deletion lib/src/matrix/base_matrix.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import 'dart:math' as math;

import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/matrix_norm.dart';
import 'package:ml_linalg/src/matrix/byte_data_storage/data_manager.dart';
import 'package:ml_linalg/src/matrix/data_manager/data_manager.dart';
import 'package:ml_linalg/src/matrix/matrix_validator_mixin.dart';
import 'package:ml_linalg/vector.dart';
import 'package:xrange/zrange.dart';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ abstract class DataManager {
int get rowsNum;
int get columnsNum;
Iterator<Iterable<double>> get dataIterator;
List<Vector> get rowsCache;
List<Vector> get columnsCache;
Vector getColumn(int index, {bool tryCache = true, bool mutable = false});
void setColumn(int columnNum, Iterable<double> columnValues);
Vector getRow(int index, {bool tryCache = true, bool mutable = false});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,71 @@
import 'dart:typed_data';

import 'package:ml_linalg/src/matrix/byte_data_storage/data_manager.dart';
import 'package:ml_linalg/src/matrix/data_manager/data_manager.dart';
import 'package:ml_linalg/src/matrix/float32x4/float32_matrix_iterator.dart';
import 'package:ml_linalg/vector.dart';

class Float32DataManager implements DataManager {
Float32DataManager.from(
class DataManagerImpl implements DataManager {
DataManagerImpl.from(
Iterable<Iterable<double>> source,
int bytesPerElement,
this._dtype,
) :
rowsNum = source.length,
columnsNum = source.first.length,
rowsCache = List<Vector>(source.length),
columnsCache = List<Vector>(source.first.length),
_rowsCache = List<Vector>(source.length),
_columnsCache = List<Vector>(source.first.length),
_data = ByteData(source.length * source.first.length *
bytesPerElement) {
bytesPerElement),
_bytesPerElement = bytesPerElement {
final flattened = _flatten2dimList(source, (i, j) => i * columnsNum + j);
updateAll(0, flattened);
}

Float32DataManager.fromRows(
DataManagerImpl.fromRows(
Iterable<Vector> source,
int bytesPerElement,
this._dtype,
) :
rowsNum = source.length,
columnsNum = source.first.length,
rowsCache = source.toList(growable: false),
columnsCache = List<Vector>(source.first.length),
_rowsCache = source.toList(growable: false),
_columnsCache = List<Vector>(source.first.length),
_data = ByteData(source.length * source.first.length *
bytesPerElement) {
bytesPerElement),
_bytesPerElement = bytesPerElement {
final flattened = _flatten2dimList(source, (i, j) => i * columnsNum + j);
updateAll(0, flattened);
}

Float32DataManager.fromColumns(
DataManagerImpl.fromColumns(
Iterable<Vector> source,
int bytesPerElement,
this._dtype,
) :
rowsNum = source.first.length,
columnsNum = source.length,
rowsCache = List<Vector>(source.first.length),
columnsCache = source.toList(growable: false),
_rowsCache = List<Vector>(source.first.length),
_columnsCache = source.toList(growable: false),
_data = ByteData(source.length * source.first.length *
bytesPerElement) {
bytesPerElement),
_bytesPerElement = bytesPerElement {
final flattened = _flatten2dimList(source, (i, j) => j * columnsNum + i);
updateAll(0, flattened);
}

Float32DataManager.fromFlattened(
DataManagerImpl.fromFlattened(
Iterable<double> source,
int rowsNum,
int colsNum,
int bytesPerElement,
this._dtype,
) :
rowsNum = rowsNum,
columnsNum = colsNum,
rowsCache = List<Vector>(rowsNum),
columnsCache = List<Vector>(colsNum),
_data = ByteData(rowsNum * colsNum * bytesPerElement) {
_rowsCache = List<Vector>(rowsNum),
_columnsCache = List<Vector>(colsNum),
_data = ByteData(rowsNum * colsNum * bytesPerElement),
_bytesPerElement = bytesPerElement {
if (source.length != rowsNum * colsNum) {
throw Exception('Invalid matrix dimension has been provided - '
'$rowsNum x $colsNum, but given a collection of length '
Expand All @@ -72,13 +80,11 @@ class Float32DataManager implements DataManager {
@override
final int rowsNum;

@override
final List<Vector> rowsCache;

@override
final List<Vector> columnsCache;

final List<Vector> _rowsCache;
final List<Vector> _columnsCache;
final int _bytesPerElement;
final ByteData _data;
final Type _dtype;

@override
Iterator<Iterable<double>> get dataIterator =>
Expand All @@ -87,12 +93,12 @@ class Float32DataManager implements DataManager {
//TODO consider a check if the index is inside the _data
@override
Float32List getValues(int index, int length) =>
_data.buffer.asFloat32List(index * Float32List.bytesPerElement, length);
_data.buffer.asFloat32List(index * _bytesPerElement, length);

//TODO consider a check if the index is inside the _data
@override
void update(int idx, double value) =>
_data.setFloat32(idx * Float32List.bytesPerElement, value, Endian.host);
_data.setFloat32(idx * _bytesPerElement, value, Endian.host);

@override
void updateAll(int idx, Iterable<double> values) {
Expand All @@ -102,30 +108,30 @@ class Float32DataManager implements DataManager {
@override
Vector getRow(int index, {bool tryCache = true, bool mutable = false}) {
if (tryCache) {
rowsCache[index] ??= Vector.from(getValues(index * columnsNum,
columnsNum), isMutable: mutable, dtype: Float32x4);
return rowsCache[index];
_rowsCache[index] ??= Vector.from(getValues(index * columnsNum,
columnsNum), isMutable: mutable, dtype: _dtype);
return _rowsCache[index];
} else {
return Vector.from(getValues(index * columnsNum, columnsNum),
isMutable: mutable, dtype: Float32x4);
isMutable: mutable, dtype: _dtype);
}
}

@override
Vector getColumn(int index, {bool tryCache = true, bool mutable = false}) {
if (columnsCache[index] == null || !tryCache) {
if (_columnsCache[index] == null || !tryCache) {
final result = List<double>(rowsNum);
for (int i = 0; i < rowsNum; i++) {
//@TODO: find a more efficient way to get the single value
result[i] = getValues(i * columnsNum + index, 1).first;
}
final column = Vector.from(result, isMutable: mutable, dtype: Float32x4);
final column = Vector.from(result, isMutable: mutable, dtype: _dtype);
if (!tryCache) {
return column;
}
columnsCache[index] = column;
_columnsCache[index] = column;
}
return columnsCache[index];
return _columnsCache[index];
}

@override
Expand All @@ -139,8 +145,8 @@ class Float32DataManager implements DataManager {
'matrix rows number is $rowsNum');
}
// clear rows cache
rowsCache.fillRange(0, rowsNum, null);
columnsCache[columnNum] = columnValues is Vector
_rowsCache.fillRange(0, rowsNum, null);
_columnsCache[columnNum] = columnValues is Vector
? columnValues : Vector.from(columnValues);
final values = columnValues.toList(growable: false);
for (int i = 0, j = 0; i < rowsNum * columnsNum; i++) {
Expand Down
16 changes: 9 additions & 7 deletions lib/src/matrix/float32x4/float32x4_matrix.dart
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import 'dart:core';
import 'dart:typed_data';

import 'package:ml_linalg/src/matrix/byte_data_storage/float32_data_manager.dart';
import 'package:ml_linalg/src/matrix/base_matrix.dart';
import 'package:ml_linalg/src/matrix/data_manager/data_manager_impl.dart';
import 'package:ml_linalg/vector.dart';

class Float32x4Matrix extends BaseMatrix {
Float32x4Matrix.from(Iterable<Iterable<double>> source) :
super(Float32DataManager.from(source, Float32List.bytesPerElement));
super(DataManagerImpl.from(source, Float32List.bytesPerElement,
Float32x4));

Float32x4Matrix.columns(Iterable<Vector> source) :
super(Float32DataManager
.fromColumns(source, Float32List.bytesPerElement));
super(DataManagerImpl
.fromColumns(source, Float32List.bytesPerElement, Float32x4));

Float32x4Matrix.rows(Iterable<Vector> source) :
super(Float32DataManager.fromRows(source, Float32List.bytesPerElement));
super(DataManagerImpl.fromRows(source, Float32List.bytesPerElement,
Float32x4));

Float32x4Matrix.flattened(Iterable<double> source, int rowsNum,
int columnsNum) :
super(Float32DataManager.fromFlattened(source, rowsNum, columnsNum,
Float32List.bytesPerElement));
super(DataManagerImpl.fromFlattened(source, rowsNum, columnsNum,
Float32List.bytesPerElement, Float32x4));

@override
final Type dtype = Float32x4;
Expand Down

0 comments on commit 3862458

Please sign in to comment.