Skip to content

Commit

Permalink
unit tests for Standardizer extended
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Oct 10, 2019
1 parent 0d084ed commit b1623cb
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 6 deletions.
17 changes: 16 additions & 1 deletion lib/src/standardizer/standardizer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ class Standardizer implements Pipeable {
}) :
_dtype = dtype,
_mean = fittingData.toMatrix(dtype).mean(),
_deviation = fittingData.toMatrix(dtype).deviation() {
_deviation = Vector.fromList(
// TODO: Consider SIMD-aware mapping
fittingData
.toMatrix(dtype)
.deviation()
.map((el) => el == 0 ? 1 : el)
.toList(),
dtype: dtype,
) {
if (!fittingData.toMatrix(dtype).hasData) {
throw Exception('No data provided');
}
Expand All @@ -21,6 +29,13 @@ class Standardizer implements Pipeable {
@override
DataFrame process(DataFrame input) {
final inputAsMatrix = input.toMatrix(_dtype);

if (inputAsMatrix.columnsNum != _deviation.length) {
throw Exception('Passed dataframe is differ than the one used during '
'creation of the Standardizer: expected columns number - '
'${_deviation.length}, given - ${inputAsMatrix.columnsNum}.');
}

final processedMatrix = inputAsMatrix
.mapRows((row) => (row - _mean) / _deviation);

Expand Down
132 changes: 127 additions & 5 deletions test/standardizer/standardizer_test.dart
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/linalg.dart';
import 'package:ml_preprocessing/ml_preprocessing.dart';
import 'package:ml_tech/unit_testing/matchers/iterable_2d_almost_equal_to.dart';
import 'package:test/test.dart';

void main() {
const dtype = DType.float32;

group('Standardizer', () {
test('should extract deviation and mean values from fitting data and apply '
'them to the same data in order to make the latter look like normally'
Expand All @@ -16,10 +19,10 @@ void main() {
[40, 33, 22, 20],
], headerExists: false);

final standardizer = Standardizer(fittingData, dtype: DType.float32);
final standardizer = Standardizer(fittingData, dtype: dtype);
final processed = standardizer.process(fittingData);

expect(processed.toMatrix(DType.float32), iterable2dAlmostEqualTo([
expect(processed.toMatrix(dtype), iterable2dAlmostEqualTo([
[-1.34164079, -1.28449611, 1.68894093, -0.72760688],
[-0.4472136, 1.25626543, -0.56298031, -0.24253563],
[ 0.4472136, 0.63519039, -0.87653896, 1.69774938],
Expand All @@ -46,10 +49,10 @@ void main() {
[88, -20, 36, 66],
], headerExists: false);

final standardizer = Standardizer(fittingData, dtype: DType.float32);
final standardizer = Standardizer(fittingData, dtype: dtype);
final processed = standardizer.process(testData);

expect(processed.toMatrix(DType.float32), iterable2dAlmostEqualTo([
expect(processed.toMatrix(dtype), iterable2dAlmostEqualTo([
[ 4.91934955, -1.34095748, -0.56298031, -6.54846188],
[ 5.81377674, -4.72863954, -0.106895, -1.69774938],
[-1.34164079, 0.01411534, 1.85997292, 4.12310563],
Expand All @@ -58,12 +61,131 @@ void main() {
]));
});

test('should process a dataframe with only one column', () {
final fittingData = DataFrame(<Iterable<num>>[
[10],
[20],
[30],
[40],
], headerExists: false);

final testData = DataFrame(<Iterable<num>>[
[80],
[90],
[10],
[50],
[88],
], headerExists: false);

final standardizer = Standardizer(fittingData, dtype: dtype);
final processed = standardizer.process(testData);

expect(processed.toMatrix(dtype), iterable2dAlmostEqualTo([
[ 4.91934955],
[ 5.81377674],
[-1.34164079],
[ 2.23606798],
[ 5.6348913 ],
]));
});

test('should process a dataframe with only one row', () {
final fittingData = DataFrame(<Iterable<num>>[
[10, 21, 90, 20],
], headerExists: false);

final testData = DataFrame(<Iterable<num>>[
[80, 20, 11, -100],
[90, -40, 27, 0],
[10, 44, 96, 120],
[50, -99, 73, 10],
[88, -20, 36, 66],
], headerExists: false);

final standardizer = Standardizer(fittingData, dtype: dtype);
final processed = standardizer.process(testData);

expect(processed.toMatrix(dtype), equals([
[70, -1, -79, -120],
[80, -61, -63, -20 ],
[ 0, 23, 6, 100],
[40, -120, -17, -10 ],
[78, -41, -54, 46 ],
]));
});

test('should make deviation of uniform columns equal to 1', () {
final uniformColumn = Matrix.fromList([
[10],
[10],
[10],
[10],
]);

final otherColumns = Matrix.fromList([
[21, 90, 20],
[66, 11, 30],
[55, 0, 70],
[33, 22, 20],
]);

final fittingData = DataFrame.fromMatrix(
Matrix.fromColumns([
...uniformColumn.columns,
...otherColumns.columns,
], dtype: dtype),
);

final testData = DataFrame(<Iterable<num>>[
[80, 20, 11, -100],
[90, -40, 27, 0],
[10, 44, 96, 120],
[50, -99, 73, 10],
[88, -20, 36, 66],
], headerExists: false);

final standardizer = Standardizer(fittingData, dtype: dtype);
final processed = standardizer.process(testData);

expect(processed.toMatrix(dtype), iterable2dAlmostEqualTo([
[ 70, -1.34095748, -0.56298031, -6.54846188],
[ 80, -4.72863954, -0.106895, -1.69774938],
[ 0, 0.01411534, 1.85997292, 4.12310563],
[ 40, -8.05986023, 1.20435028, -1.21267813],
[ 78, -3.59941219, 0.14965299, 1.50372088],
]));
});

test('should throw an exception if one tries to apply standardizer to a '
'dataframe of inappropriate dimension (columns number in the test '
'dataframe should be equal to a number of columns in the fitting '
'dataframe)', () {
final fittingData = DataFrame(<Iterable<num>>[
[10, 21, 90, 20],
[20, 66, 11, 30],
[30, 55, 0, 70],
[40, 33, 22, 20],
], headerExists: false);

final testData = DataFrame(<Iterable<num>>[
[80, 20, 11],
[90, -40, 27],
[10, 44, 96],
[50, -99, 73],
[88, -20, 36],
], headerExists: false);

final standardizer = Standardizer(fittingData, dtype: dtype);

expect(() => standardizer.process(testData), throwsException);
});

test('should throw an exception if one tries to create a standardizer '
'using empty dataframe', () {
final fittingData = DataFrame(<Iterable<num>>[[]], headerExists: false);

expect(
() => Standardizer(fittingData, dtype: DType.float32),
() => Standardizer(fittingData, dtype: dtype),
throwsException,
);
});
Expand Down

0 comments on commit b1623cb

Please sign in to comment.