-
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #83 from gyrdym/softmax-unit-tests
DataFrame introduced, ml_linalg 6.0.0 supported, softmax regression unit tests added, optimizer api changed - Vector -> Matrix
- Loading branch information
Showing
114 changed files
with
1,462 additions
and
1,134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
language: dart | ||
dart: | ||
- "2.1.0" | ||
- "2.2.0" | ||
dart_task: | ||
- test: --platform vm | ||
- dartanalyzer: true | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,28 @@ | ||
import 'dart:async'; | ||
import 'dart:typed_data'; | ||
|
||
import 'package:ml_algo/ml_algo.dart'; | ||
|
||
Future main() async { | ||
final data = MLData.fromCsvFile('datasets/pima_indians_diabetes_database.csv', | ||
labelIdx: 8, dtype: Float32x4); | ||
final data = DataFrame.fromCsv('datasets/pima_indians_diabetes_database.csv', | ||
labelIdx: 8, | ||
categoryNameToEncoder: { | ||
'class variable (0 or 1)': CategoricalDataEncoderType.oneHot, | ||
}, | ||
); | ||
|
||
final features = await data.features; | ||
final labels = await data.labels; | ||
|
||
final validator = CrossValidator.kFold(numberOfFolds: 5, dtype: Float32x4); | ||
|
||
// lr=0.0102, randomSeed=134, minWeightsUpdate: 0.000000000001, iterationLimit: 100 => error = 0.3449 | ||
final validator = CrossValidator.kFold(numberOfFolds: 5); | ||
|
||
final logisticRegressor = LinearClassifier.logisticRegressor( | ||
initialLearningRate: 0.0102, | ||
initialLearningRate: 0.00001, | ||
iterationsLimit: 7000, | ||
learningRateType: LearningRateType.constant, | ||
randomSeed: 134); | ||
randomSeed: 150); | ||
|
||
final accuracy = validator.evaluate( | ||
logisticRegressor, features, labels, MetricType.accuracy); | ||
|
||
print('Accuracy is ${(accuracy * 100).toStringAsFixed(2)}%'); | ||
print('Accuracy is ${accuracy.toStringAsFixed(2)}'); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,23 @@ | ||
import 'package:ml_algo/src/predictor.dart'; | ||
import 'package:ml_algo/src/predictor/predictor.dart'; | ||
import 'package:ml_linalg/matrix.dart'; | ||
import 'package:ml_linalg/vector.dart'; | ||
|
||
/// An interface for any classifier (linear, non-linear, parametric, | ||
/// non-parametric, etc.) | ||
abstract class Classifier implements Predictor { | ||
/// A map, where each key is a class label and each value, associated with | ||
/// the key, is a set of weights (coefficients), specific for the class | ||
Map<double, MLVector> get weightsByClasses; | ||
/// A matrix, where each column is a vector of weights, associated with | ||
/// the specific class | ||
Matrix get weightsByClasses; | ||
|
||
/// A collection of encoded class labels. Can be transformed back to original | ||
/// A collection of class labels. Can be transformed back to original | ||
/// labels by a [MLData] instance, that was used previously to encode the | ||
/// labels | ||
Iterable<double> get classLabels; | ||
Matrix get classLabels; | ||
|
||
/// Returns predicted distribution of probabilities for each observation in | ||
/// the passed [features] | ||
MLMatrix predictProbabilities(MLMatrix features); | ||
Matrix predictProbabilities(Matrix features); | ||
|
||
/// Return a collection of predicted class labels for each observation in the | ||
/// passed [features] | ||
MLVector predictClasses(MLMatrix features); | ||
Matrix predictClasses(Matrix features); | ||
} |
This file was deleted.
Oops, something went wrong.
5 changes: 0 additions & 5 deletions
5
lib/src/classifier/labels_processor/labels_processor_factory.dart
This file was deleted.
Oops, something went wrong.
10 changes: 0 additions & 10 deletions
10
lib/src/classifier/labels_processor/labels_processor_factory_impl.dart
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.