Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
182 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
# Changelog | ||
|
||
## 15.2.0 | ||
- Recall metric added | ||
|
||
## 15.1.0 | ||
- MAPE metric: output range squeezed to [0, 1] | ||
|
||
|
62 changes: 62 additions & 0 deletions
62
lib/src/metric/classification/_helpers/divide_true_positive_by.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
import 'package:quiver/iterables.dart'; | ||
|
||
double divideTruePositiveBy( | ||
Vector divider, | ||
Matrix origLabels, | ||
Matrix predictedLabels, | ||
) { | ||
// Let's say we have the following data: | ||
// | ||
// orig labels | predicted labels | ||
// ------------------------------- | ||
// 1 | 1 | ||
// 1 | 0 | ||
// 0 | 1 | ||
// 0 | 0 | ||
// 1 | 1 | ||
//-------------------------------- | ||
// | ||
// in order to count correctly predicted positive labels in matrix notation | ||
// we may multiply predicted labels by 2, and then subtract the two | ||
// matrices from each other: | ||
// | ||
// 1 - (1 * 2) = -1 | ||
// 1 - (0 * 2) = 1 | ||
// 0 - (1 * 2) = -2 | ||
// 0 - (0 * 2) = 0 | ||
// 1 - (1 * 2) = -1 | ||
// | ||
// we see that matrices subtraction in case of original positive label and a | ||
// predicted positive label gives -1, thus we need to count number of elements | ||
// with value equals -1 in the resulting matrix | ||
final difference = origLabels - (predictedLabels * 2); | ||
final truePositiveCounts = difference | ||
.reduceRows( | ||
(counts, row) => counts + row.mapToVector((diff) => diff == -1 | ||
? 1 : 0), | ||
initValue: Vector.zero( | ||
origLabels.columnsNum, | ||
dtype: origLabels.dtype, | ||
)); | ||
final aggregatedScore = (truePositiveCounts / divider).mean(); | ||
|
||
if (aggregatedScore.isFinite) { | ||
return aggregatedScore; | ||
} | ||
|
||
return zip([ | ||
truePositiveCounts, | ||
divider, | ||
]).fold(0, (aggregated, pair) { | ||
final truePositiveCount = pair.first; | ||
final dividerElement = pair.last; | ||
|
||
if (dividerElement != 0) { | ||
return aggregated + truePositiveCount / dividerElement; | ||
} | ||
|
||
return aggregated + (truePositiveCount == 0 ? 1 : 0); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,18 @@ | ||
import 'package:ml_algo/src/helpers/normalize_class_labels.dart'; | ||
import 'package:ml_algo/src/metric/classification/_helpers/divide_true_positive_by.dart'; | ||
import 'package:ml_algo/src/metric/metric.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
import 'package:quiver/iterables.dart'; | ||
|
||
/// TODO: add warning if predicted values are all zeroes | ||
class PrecisionMetric implements Metric { | ||
const PrecisionMetric(); | ||
|
||
@override | ||
/// Accepts [predictedLabels] and [origLabels] with entries with `1` as | ||
/// positive label and `0` as negative one | ||
double getScore(Matrix predictedLabels, Matrix origLabels) { | ||
final allPredictedPositiveCounts = predictedLabels | ||
final predictedTrueCounts = predictedLabels | ||
.reduceRows((counts, row) => counts + row); | ||
|
||
// Let's say we have the following data: | ||
// | ||
// orig labels | predicted labels | ||
// ------------------------------- | ||
// 1 | 1 | ||
// 1 | 0 | ||
// 0 | 1 | ||
// 0 | 0 | ||
// 1 | 1 | ||
//-------------------------------- | ||
// | ||
// in order to count correctly predicted positive labels in matrix notation | ||
// we may multiply predicted labels by 2, and then subtract the two | ||
// matrices from each other: | ||
// | ||
// 1 - (1 * 2) = -1 | ||
// 1 - (0 * 2) = 1 | ||
// 0 - (1 * 2) = -2 | ||
// 0 - (0 * 2) = 0 | ||
// 1 - (1 * 2) = -1 | ||
// | ||
// we see that matrices subtraction in case of original positive label and a | ||
// predicted positive label gives -1, thus we need to count number of elements | ||
// with value equals -1 in the resulting matrix | ||
final difference = origLabels - (predictedLabels * 2); | ||
final correctPositiveCounts = difference | ||
.reduceRows( | ||
(counts, row) => counts + row.mapToVector((diff) => diff == -1 | ||
? 1 : 0), | ||
initValue: Vector.zero( | ||
origLabels.columnsNum, | ||
dtype: origLabels.dtype, | ||
)); | ||
final aggregatedScore = (correctPositiveCounts / allPredictedPositiveCounts) | ||
.mean(); | ||
|
||
if (aggregatedScore.isFinite) { | ||
return aggregatedScore; | ||
} | ||
|
||
return zip([ | ||
correctPositiveCounts, | ||
allPredictedPositiveCounts, | ||
]).fold(0, (aggregated, pair) { | ||
final correctPositiveCount = pair.first; | ||
final allPredictedPositiveCount = pair.last; | ||
|
||
if (allPredictedPositiveCount != 0) { | ||
return aggregated + correctPositiveCount / allPredictedPositiveCount; | ||
} | ||
|
||
return aggregated + (correctPositiveCount == 0 ? 1 : 0); | ||
}); | ||
return divideTruePositiveBy(predictedTrueCounts, origLabels, | ||
predictedLabels); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import 'package:ml_algo/src/metric/classification/_helpers/divide_true_positive_by.dart'; | ||
import 'package:ml_algo/src/metric/metric.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
|
||
class RecallMetric implements Metric { | ||
const RecallMetric(); | ||
|
||
@override | ||
/// Accepts [predictedLabels] and [origLabels] with entries with `1` as | ||
/// positive label and `0` as negative one | ||
double getScore(Matrix predictedLabels, Matrix origLabels) { | ||
final originalTrueCounts = origLabels | ||
.reduceRows((counts, row) => counts + row); | ||
|
||
return divideTruePositiveBy(originalTrueCounts, origLabels, | ||
predictedLabels); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import 'package:ml_algo/src/metric/classification/precision.dart'; | ||
import 'package:ml_algo/src/metric/classification/recall.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('RecallMetric', () { | ||
final origLabels = Matrix.fromList([ | ||
[1, 0, 0], | ||
[1, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 1], | ||
[0, 1, 0], | ||
[1, 0, 0], | ||
[0, 0, 1], | ||
]); | ||
final origLabelsWithZeroColumn = Matrix.fromList([ | ||
[1, 0, 0], | ||
[1, 0, 0], | ||
[0, 1, 0], | ||
[1, 0, 0], | ||
[0, 1, 0], | ||
[1, 0, 0], | ||
[1, 0, 0], | ||
]); | ||
final predictedLabels = Matrix.fromList([ | ||
[0, 1, 0], | ||
[0, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 1], | ||
[0, 1, 0], | ||
[1, 0, 0], | ||
[0, 0, 1], | ||
]); | ||
final predictedLabelsWithZeroColumn = Matrix.fromList([ | ||
[0, 1, 0], | ||
[1, 0, 0], | ||
[0, 1, 0], | ||
[0, 1, 0], | ||
[0, 1, 0], | ||
[1, 0, 0], | ||
[1, 0, 0], | ||
]); | ||
final metric = const RecallMetric(); | ||
|
||
test('should return a correct score', () { | ||
final score = metric.getScore(predictedLabels, origLabels); | ||
|
||
expect(score, closeTo((1 / 3 + 2 / 2 + 2 / 2) / 3, 1e-5)); | ||
}); | ||
|
||
test('should return a correct score if there is at least one column with ' | ||
'all zeroes ', () { | ||
final score = metric.getScore(predictedLabelsWithZeroColumn, origLabels); | ||
|
||
expect(score, closeTo((2 / 3 + 2 / 2 + 0) / 3, 1e-5)); | ||
}); | ||
|
||
test('should return a correct score if there is a zero column in the ' | ||
'original labels', () { | ||
final score = metric.getScore(predictedLabels, origLabelsWithZeroColumn); | ||
|
||
expect(score, closeTo((1 / 5 + 2 / 2 + 0) / 3, 1e-5)); | ||
}); | ||
|
||
test('should return a correct score if both original labels and predicted ' | ||
'labels have zero columns', () { | ||
final score = metric.getScore(predictedLabelsWithZeroColumn, | ||
origLabelsWithZeroColumn); | ||
|
||
expect(score, closeTo((3 / 5 + 2 / 2 + 0) / 3, 1e-5)); | ||
}); | ||
|
||
test('should return 1 if predicted labels are correct', () { | ||
final score = metric.getScore(origLabels, origLabels); | ||
|
||
expect(score, 1); | ||
}); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters