Skip to content

Commit

Permalink
Added possibility to create diagonal matrices of different type
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Nov 4, 2019
1 parent b3f936a commit 624f3a5
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 16 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,11 @@
# Changelog

## 12.4.0
- `Matrix`:
- `Matrix.diagonal` constructor added (for creation diagonal matrices)
- `Matrix.scalar` constructor added (for creation scalar matrices)
- `Matrix.identity` constructor added (for creation identity matrices)

## 12.3.0
- `Matrix`:
- `mean` method added
Expand Down
25 changes: 24 additions & 1 deletion lib/matrix.dart
Expand Up @@ -74,7 +74,8 @@ abstract class Matrix implements Iterable<Iterable<double>> {
}
}

/// Creates a diagonal matrix with diagonal elements from [source]
/// Creates a matrix, where elements from [source] are the elements for the
/// matrix main diagonal
factory Matrix.diagonal(List<double> source, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
Expand All @@ -84,6 +85,28 @@ abstract class Matrix implements Iterable<Iterable<double>> {
}
}

/// Creates a matrix of [size] * [size] dimension, where the all main
/// diagonal elements are equal to [scalar]
factory Matrix.scalar(double scalar, int size, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return Float32Matrix.scalar(scalar, size);
default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
}
}

/// Creates a matrix of [size] * [size] dimension, where the all main
/// diagonal elements are equal to 1
factory Matrix.identity(int size, {DType dtype = DType.float32}) {
switch (dtype) {
case DType.float32:
return Float32Matrix.scalar(1.0, size);
default:
throw UnimplementedError('Matrix of type $dtype is not implemented yet');
}
}

/// A data type of [Matrix] elements
DType get dtype;

Expand Down
40 changes: 31 additions & 9 deletions lib/src/matrix/common/data_manager/matrix_data_manager_impl.dart
Expand Up @@ -100,13 +100,24 @@ class MatrixDataManagerImpl implements MatrixDataManager {
_rowsCache = List<Vector>(source.length),
_colsCache = List<Vector>(source.length),
_data = ByteData(source.length * source.length * bytesPerElement) {
for (int i = 0; i < rowsNum; i++) {
for (int j = 0; j < columnsNum; j++) {
final value = i == j ? source[i] : 0.0;
_typedListHelper
.setValue(_data, (i * columnsNum + j) * bytesPerElement, value);
}
}
_updateByteDataForDiagonalMatrix(bytesPerElement, (i) => source[i]);
}

MatrixDataManagerImpl.scalar(
double scalar,
int size,
int bytesPerElement,
this._dtype,
this._typedListHelper,
) :
rowsNum = size,
columnsNum = size,
rowIndices = getZeroBasedIndices(size),
columnIndices = getZeroBasedIndices(size),
_rowsCache = List<Vector>(size),
_colsCache = List<Vector>(size),
_data = ByteData(size * size * bytesPerElement) {
_updateByteDataForDiagonalMatrix(bytesPerElement, (i) => scalar);
}

@override
Expand Down Expand Up @@ -176,17 +187,28 @@ class MatrixDataManagerImpl implements MatrixDataManager {
_typedListHelper.getBufferAsList(_data.buffer).setAll(0, values);

void _updateByteDataBy2dimIterable(Iterable<Iterable<double>> rows,
int accessor(int i, int j), int bytesPerElement) {
int flatten2dIndices(int i, int j), int bytesPerElement) {
var i = 0;
var j = 0;
for (final row in rows) {
for (final value in row) {
_typedListHelper
.setValue(_data, accessor(i, j) * bytesPerElement, value);
.setValue(_data, flatten2dIndices(i, j) * bytesPerElement, value);
j++;
}
i++;
j = 0;
}
}

void _updateByteDataForDiagonalMatrix(int bytesPerElement,
double generateValue(int i)) {
for (int i = 0; i < rowsNum; i++) {
for (int j = 0; j < columnsNum; j++) {
final value = i == j ? generateValue(i) : 0.0;
_typedListHelper
.setValue(_data, (i * columnsNum + j) * bytesPerElement, value);
}
}
}
}
8 changes: 8 additions & 0 deletions lib/src/matrix/float32/float32_matrix.dart
Expand Up @@ -52,6 +52,14 @@ class Float32Matrix extends BaseMatrix {
DType.float32,
Float32ListHelper()));

Float32Matrix.scalar(double scalar, int size) :
super(MatrixDataManagerImpl.scalar(
scalar,
size,
Float32List.bytesPerElement,
DType.float32,
Float32ListHelper()));

@override
final DType dtype = DType.float32;
}
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_linalg
description: SIMD-based linear algebra (1 operation on 4 float32 values, 1 operation on 2 float64 values)
version: 12.3.0
version: 12.4.0
author: Ilia Gyrdymov <ilgyrd@gmail.com>
homepage: https://github.com/gyrdym/ml_linalg

Expand Down
Expand Up @@ -2,12 +2,12 @@ import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

import '../../dtype_to_class_name_mapping.dart';
import '../../../dtype_to_class_name_mapping.dart';

void matrixDiagonalConstructorTestGroupFactory(DType dtype) =>
group(dtypeToMatrixClassName[dtype], () {
group('diagonal constructor', () {
test('should create a matrix with all zero elements but diagonal '
test('should create a matrix with all zero elements but main diagonal '
'ones', () {
final source = [1.0, 2.0, 3.0, 4.0, 5.0];
final matrix = Matrix.diagonal(source, dtype: dtype);
Expand Down
Expand Up @@ -2,7 +2,7 @@ import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

import '../../dtype_to_class_name_mapping.dart';
import '../../../dtype_to_class_name_mapping.dart';

void matrixEmptyConstructorTestGroupFactory(DType dtype) =>
group(dtypeToMatrixClassName[dtype], () {
Expand Down
@@ -0,0 +1,7 @@
import 'package:ml_linalg/dtype.dart';

import 'matrix_identity_constructor_test_group_factory.dart';

void main() {
matrixIdentityConstructorTestGroupFactory(DType.float32);
}
@@ -0,0 +1,26 @@
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

import '../../../dtype_to_class_name_mapping.dart';

void matrixIdentityConstructorTestGroupFactory(DType dtype) =>
group(dtypeToMatrixClassName[dtype], () {
group('identity constructor', () {
test('should create a matrix with main diagonal elements equal to 1, '
'and the rest elements equal to 0', () {

final matrix = Matrix.identity(7, dtype: dtype);

expect(matrix, equals([
[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1],
]));
});
});
});
@@ -0,0 +1,7 @@
import 'package:ml_linalg/dtype.dart';

import 'matrix_scalar_constructor_test_group_factory.dart';

void main() {
matrixScalarConstructorTestGroupFactory(DType.float32);
}
@@ -0,0 +1,26 @@
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

import '../../../dtype_to_class_name_mapping.dart';

void matrixScalarConstructorTestGroupFactory(DType dtype) =>
group(dtypeToMatrixClassName[dtype], () {
group('scalar constructor', () {
test('should create a matrix with all zero elements except for main '
'diagonal ones - they should be equal to the given scalar value', () {

final matrix = Matrix.scalar(-3.0, 7, dtype: dtype);

expect(matrix, equals([
[-3, 0, 0, 0, 0, 0, 0],
[ 0, -3, 0, 0, 0, 0, 0],
[ 0, 0, -3, 0, 0, 0, 0],
[ 0, 0, 0, -3, 0, 0, 0],
[ 0, 0, 0, 0, -3, 0, 0],
[ 0, 0, 0, 0, 0, -3, 0],
[ 0, 0, 0, 0, 0, 0, -3],
]));
});
});
});
Expand Up @@ -3,7 +3,7 @@ import 'package:ml_linalg/linalg.dart';
import 'package:ml_tech/unit_testing/matchers/iterable_almost_equal_to.dart';
import 'package:test/test.dart';

import '../../dtype_to_class_name_mapping.dart';
import '../../../dtype_to_class_name_mapping.dart';

void matrixDeviationTestGroupFactory(DType dtype) =>
group(dtypeToMatrixClassName[dtype], () {
Expand Down
File renamed without changes.
Expand Up @@ -4,7 +4,7 @@ import 'package:ml_linalg/matrix.dart';
import 'package:ml_tech/unit_testing/matchers/iterable_almost_equal_to.dart';
import 'package:test/test.dart';

import '../../dtype_to_class_name_mapping.dart';
import '../../../dtype_to_class_name_mapping.dart';

void matrixMeanTestGroupFactory(DType dtype) =>
group(dtypeToMatrixClassName[dtype], () {
Expand Down

0 comments on commit 624f3a5

Please sign in to comment.