Skip to content

Commit

Permalink
Recall metric added (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Sep 11, 2020
1 parent e4f55d0 commit 8b041f2
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 59 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog

## 15.2.0
- Recall metric added

## 15.1.0
- MAPE metric: output range squeezed to [0, 1]

Expand Down
@@ -0,0 +1,62 @@
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';
import 'package:quiver/iterables.dart';

double divideTruePositiveBy(
Vector divider,
Matrix origLabels,
Matrix predictedLabels,
) {
// Let's say we have the following data:
//
// orig labels | predicted labels
// -------------------------------
// 1 | 1
// 1 | 0
// 0 | 1
// 0 | 0
// 1 | 1
//--------------------------------
//
// in order to count correctly predicted positive labels in matrix notation
// we may multiply predicted labels by 2, and then subtract the two
// matrices from each other:
//
// 1 - (1 * 2) = -1
// 1 - (0 * 2) = 1
// 0 - (1 * 2) = -2
// 0 - (0 * 2) = 0
// 1 - (1 * 2) = -1
//
// we see that matrices subtraction in case of original positive label and a
// predicted positive label gives -1, thus we need to count number of elements
// with value equals -1 in the resulting matrix
final difference = origLabels - (predictedLabels * 2);
final truePositiveCounts = difference
.reduceRows(
(counts, row) => counts + row.mapToVector((diff) => diff == -1
? 1 : 0),
initValue: Vector.zero(
origLabels.columnsNum,
dtype: origLabels.dtype,
));
final aggregatedScore = (truePositiveCounts / divider).mean();

if (aggregatedScore.isFinite) {
return aggregatedScore;
}

return zip([
truePositiveCounts,
divider,
]).fold(0, (aggregated, pair) {
final truePositiveCount = pair.first;
final dividerElement = pair.last;

if (dividerElement != 0) {
return aggregated + truePositiveCount / dividerElement;
}

return aggregated + (truePositiveCount == 0 ? 1 : 0);
});
}
62 changes: 4 additions & 58 deletions lib/src/metric/classification/precision.dart
@@ -1,72 +1,18 @@
import 'package:ml_algo/src/helpers/normalize_class_labels.dart';
import 'package:ml_algo/src/metric/classification/_helpers/divide_true_positive_by.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';
import 'package:quiver/iterables.dart';

/// TODO: add warning if predicted values are all zeroes
class PrecisionMetric implements Metric {
const PrecisionMetric();

@override
/// Accepts [predictedLabels] and [origLabels] with entries with `1` as
/// positive label and `0` as negative one
double getScore(Matrix predictedLabels, Matrix origLabels) {
final allPredictedPositiveCounts = predictedLabels
final predictedTrueCounts = predictedLabels
.reduceRows((counts, row) => counts + row);

// Let's say we have the following data:
//
// orig labels | predicted labels
// -------------------------------
// 1 | 1
// 1 | 0
// 0 | 1
// 0 | 0
// 1 | 1
//--------------------------------
//
// in order to count correctly predicted positive labels in matrix notation
// we may multiply predicted labels by 2, and then subtract the two
// matrices from each other:
//
// 1 - (1 * 2) = -1
// 1 - (0 * 2) = 1
// 0 - (1 * 2) = -2
// 0 - (0 * 2) = 0
// 1 - (1 * 2) = -1
//
// we see that matrices subtraction in case of original positive label and a
// predicted positive label gives -1, thus we need to count number of elements
// with value equals -1 in the resulting matrix
final difference = origLabels - (predictedLabels * 2);
final correctPositiveCounts = difference
.reduceRows(
(counts, row) => counts + row.mapToVector((diff) => diff == -1
? 1 : 0),
initValue: Vector.zero(
origLabels.columnsNum,
dtype: origLabels.dtype,
));
final aggregatedScore = (correctPositiveCounts / allPredictedPositiveCounts)
.mean();

if (aggregatedScore.isFinite) {
return aggregatedScore;
}

return zip([
correctPositiveCounts,
allPredictedPositiveCounts,
]).fold(0, (aggregated, pair) {
final correctPositiveCount = pair.first;
final allPredictedPositiveCount = pair.last;

if (allPredictedPositiveCount != 0) {
return aggregated + correctPositiveCount / allPredictedPositiveCount;
}

return aggregated + (correctPositiveCount == 0 ? 1 : 0);
});
return divideTruePositiveBy(predictedTrueCounts, origLabels,
predictedLabels);
}
}
18 changes: 18 additions & 0 deletions lib/src/metric/classification/recall.dart
@@ -0,0 +1,18 @@
import 'package:ml_algo/src/metric/classification/_helpers/divide_true_positive_by.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_linalg/matrix.dart';

class RecallMetric implements Metric {
const RecallMetric();

@override
/// Accepts [predictedLabels] and [origLabels] with entries with `1` as
/// positive label and `0` as negative one
double getScore(Matrix predictedLabels, Matrix origLabels) {
final originalTrueCounts = origLabels
.reduceRows((counts, row) => counts + row);

return divideTruePositiveBy(originalTrueCounts, origLabels,
predictedLabels);
}
}
4 changes: 4 additions & 0 deletions lib/src/metric/metric_factory_impl.dart
@@ -1,5 +1,6 @@
import 'package:ml_algo/src/metric/classification/accuracy.dart';
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_algo/src/metric/metric.dart';
import 'package:ml_algo/src/metric/metric_factory.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
Expand All @@ -24,6 +25,9 @@ class MetricFactoryImpl implements MetricFactory {
case MetricType.precision:
return const PrecisionMetric();

case MetricType.recall:
return const RecallMetric();

default:
throw UnsupportedError('Unsupported metric type $type');
}
Expand Down
5 changes: 5 additions & 0 deletions lib/src/metric/metric_type.dart
Expand Up @@ -22,4 +22,9 @@ enum MetricType {
/// better the prediction's quality is. The metric produces scores within the
/// range [0, 1]
precision,

/// A classification metric. The greater the score produced by the metric, the
/// better the prediction's quality is. The metric produces scores within the
/// range [0, 1]
recall,
}
2 changes: 1 addition & 1 deletion pubspec.yaml
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 15.1.0
version: 15.2.0
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down
80 changes: 80 additions & 0 deletions test/metric/classification/recall_test.dart
@@ -0,0 +1,80 @@
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:test/test.dart';

void main() {
group('RecallMetric', () {
final origLabels = Matrix.fromList([
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
]);
final origLabelsWithZeroColumn = Matrix.fromList([
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
]);
final predictedLabels = Matrix.fromList([
[0, 1, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
]);
final predictedLabelsWithZeroColumn = Matrix.fromList([
[0, 1, 0],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0],
]);
final metric = const RecallMetric();

test('should return a correct score', () {
final score = metric.getScore(predictedLabels, origLabels);

expect(score, closeTo((1 / 3 + 2 / 2 + 2 / 2) / 3, 1e-5));
});

test('should return a correct score if there is at least one column with '
'all zeroes ', () {
final score = metric.getScore(predictedLabelsWithZeroColumn, origLabels);

expect(score, closeTo((2 / 3 + 2 / 2 + 0) / 3, 1e-5));
});

test('should return a correct score if there is a zero column in the '
'original labels', () {
final score = metric.getScore(predictedLabels, origLabelsWithZeroColumn);

expect(score, closeTo((1 / 5 + 2 / 2 + 0) / 3, 1e-5));
});

test('should return a correct score if both original labels and predicted '
'labels have zero columns', () {
final score = metric.getScore(predictedLabelsWithZeroColumn,
origLabelsWithZeroColumn);

expect(score, closeTo((3 / 5 + 2 / 2 + 0) / 3, 1e-5));
});

test('should return 1 if predicted labels are correct', () {
final score = metric.getScore(origLabels, origLabels);

expect(score, 1);
});
});
}
5 changes: 5 additions & 0 deletions test/metric/metric_factory_impl_test.dart
@@ -1,5 +1,6 @@
import 'package:ml_algo/src/metric/classification/accuracy.dart';
import 'package:ml_algo/src/metric/classification/precision.dart';
import 'package:ml_algo/src/metric/classification/recall.dart';
import 'package:ml_algo/src/metric/metric_factory_impl.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/metric/regression/mape.dart';
Expand All @@ -26,6 +27,10 @@ void main() {
expect(factory.createByType(MetricType.precision), isA<PrecisionMetric>());
});

test('should create RecallMetric instance', () {
expect(factory.createByType(MetricType.recall), isA<RecallMetric>());
});

test('should throw an error if null is passed as metric type', () {
expect(() => factory.createByType(null), throwsUnsupportedError);
});
Expand Down

0 comments on commit 8b041f2

Please sign in to comment.