Skip to content

Commit

Permalink
Merge 9348dbc into 05310cd
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed May 6, 2020
2 parents 05310cd + 9348dbc commit b06bcb9
Show file tree
Hide file tree
Showing 148 changed files with 2,053 additions and 1,014 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
language: dart
dart:
- "2.4.1"
- "2.7.2"
script: pub run grinder start
after_success: pub run grinder finish
14 changes: 14 additions & 0 deletions build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
targets:
$default:
builders:
json_serializable:
options:
any_map: false
checked: true
create_to_json: true
disallow_unrecognized_keys: true
explicit_to_json: false
field_rename: none
ignore_unannotated: false
include_if_null: true
nullable: true
2 changes: 0 additions & 2 deletions lib/src/classifier/classifier.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import 'package:ml_dataframe/ml_dataframe.dart';
/// An interface for any classifier (linear, non-linear, parametric,
/// non-parametric, etc.)
abstract class Classifier extends Predictor {
List<String> get classNames;

/// Returns predicted distribution of probabilities for each observation in
/// the passed [testFeatures]
DataFrame predictProbabilities(DataFrame testFeatures);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import 'package:ml_algo/src/di/dependencies.dart';
import 'package:ml_algo/src/helpers/validate_train_data.dart';
import 'package:ml_algo/src/helpers/validate_tree_solver_max_depth.dart';
import 'package:ml_algo/src/helpers/validate_tree_solver_min_samples_count.dart';
import 'package:ml_algo/src/helpers/validate_tree_solver_minimal_error.dart';
import 'package:ml_algo/src/tree_solver/_helpers/create_decision_tree_solver.dart';
import 'package:ml_algo/src/helpers/validate_tree_solver_min_error.dart';
import 'package:ml_algo/src/tree_trainer/_helpers/create_decision_tree_trainer.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';

Expand All @@ -18,20 +18,15 @@ DecisionTreeClassifier createDecisionTreeClassifier(
DType dtype,
) {
validateTrainData(trainData, [targetName]);
validateTreeSolverMinimalError(minError);
validateTreeSolverMinError(minError);
validateTreeSolversMinSamplesCount(minSamplesCount);
validateTreeSolverMaxDepth(maxDepth);

final solver = createDecisionTreeSolver(
trainData,
targetName,
minError,
minSamplesCount,
maxDepth,
dtype,
);
final trainer = createDecisionTreeTrainer(trainData, targetName, minError,
minSamplesCount, maxDepth);
final treeRootNode = trainer.train(trainData.toMatrix(dtype));

return dependencies
.getDependency<DecisionTreeClassifierFactory>()
.create(solver, targetName, dtype);
.create(treeRootNode, targetName, dtype);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import 'dart:convert';

import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart';

DecisionTreeClassifier createDecisionTreeClassifierFromJson(String json) {
if (json.isEmpty) {
throw Exception('Provided JSON object is empty, please provide a proper '
'JSON object');
}

final decodedJson = jsonDecode(json) as Map<String, dynamic>;

return DecisionTreeClassifierImpl.fromJson(decodedJson);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import 'package:ml_algo/src/classifier/classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/_helpers/create_decision_tree_classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/_helper/create_decision_tree_classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/_helper/create_decision_tree_classifier_from_json.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
Expand All @@ -10,21 +12,24 @@ import 'package:ml_linalg/dtype.dart';
/// subsets until the subsets conforming certain stop criteria are found.
///
/// Process of forming such a recursive subsets structure is called
/// decision tree learning. Once a decision tree learned, it may use to classify
/// new samples with the same features, as were used to learn the tree.
abstract class DecisionTreeClassifier implements Classifier, Assessable {
/// decision tree learning. Once a decision tree learned, it may be used to
/// classify new samples with the same features that were used to learn the
/// tree.
abstract class DecisionTreeClassifier implements
Classifier, Assessable, Serializable {
/// Parameters:
///
/// [trainData] A [DataFrame] with observations, that will be used by the
/// [trainData] A [DataFrame] with observations that will be used by the
/// classifier to learn a decision tree. Must contain [targetName] column.
///
/// [targetName] A name of a column in [trainData] that contains class
/// labels
///
/// [minError] A value from range 0..1 (both inclusive). The value denotes a
/// minimal error on a single decision tree node and is used as a stop
/// criteria to avoid farther decision's tree node splitting: if the node is
/// good enough, there is no need to split it and thus it will become a leaf.
/// [minError] A value within the range 0..1 (both inclusive). The value
/// denotes a minimal error on a single decision tree node and is used as a
/// stop criteria to avoid farther decision's tree node splitting: if the
/// node is good enough, there is no need to split it and thus it will become
/// a leaf.
///
/// [minSamplesCount] A minimal number of samples (observations) on the
/// decision's tree node. The value is used as a stop criteria to avoid
Expand All @@ -47,4 +52,7 @@ abstract class DecisionTreeClassifier implements Classifier, Assessable {
maxDepth,
dtype,
);

factory DecisionTreeClassifier.fromJson(String json) =>
createDecisionTreeClassifierFromJson(json);
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import 'package:ml_algo/ml_algo.dart';
import 'package:ml_algo/src/tree_solver/tree_solver.dart';
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
import 'package:ml_linalg/dtype.dart';

abstract class DecisionTreeClassifierFactory {
DecisionTreeClassifier create(
TreeSolver solver,
TreeNode root,
String targetName,
DType dtype,
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart';
import 'package:ml_algo/src/tree_solver/tree_solver.dart';
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
import 'package:ml_linalg/dtype.dart';

class DecisionTreeClassifierFactoryImpl implements
DecisionTreeClassifierFactory {

const DecisionTreeClassifierFactoryImpl();

@override
DecisionTreeClassifier create(
TreeSolver solver,
TreeNode root,
String targetName,
DType dtype,
) => DecisionTreeClassifierImpl(solver, targetName, dtype);
) => DecisionTreeClassifierImpl(
root,
targetName,
dtype,
);
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,65 @@
import 'package:json_annotation/json_annotation.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_json_keys.dart';
import 'package:ml_algo/src/common/dtype_serializer/dtype_to_json.dart';
import 'package:ml_algo/src/common/dtype_serializer/from_dtype_json.dart';
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
import 'package:ml_algo/src/predictor/assessable_predictor_mixin.dart';
import 'package:ml_algo/src/tree_solver/tree_solver.dart';
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label.dart';
import 'package:ml_algo/src/tree_trainer/tree_node/_helper/from_tree_node_json.dart';
import 'package:ml_algo/src/tree_trainer/tree_node/_helper/tree_node_to_json.dart';
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
import 'package:ml_linalg/vector.dart';

class DecisionTreeClassifierImpl with AssessablePredictorMixin
implements DecisionTreeClassifier {
part 'decision_tree_classifier_impl.g.dart';

DecisionTreeClassifierImpl(this._solver, String className, this.dtype)
: classNames = [className];
@JsonSerializable()
class DecisionTreeClassifierImpl
with
AssessablePredictorMixin,
SerializableMixin
implements
DecisionTreeClassifier {

DecisionTreeClassifierImpl(
this.treeRootNode,
this.targetColumnName,
this.dtype,
);

factory DecisionTreeClassifierImpl.fromJson(Map<String, dynamic> json) =>
_$DecisionTreeClassifierImplFromJson(json);

@override
Map<String, dynamic> toJson() => _$DecisionTreeClassifierImplToJson(this);

@override
@JsonKey(
name: dTypeJsonKey,
toJson: dTypeToJson,
fromJson: fromDTypeJson,
)
final DType dtype;

final TreeSolver _solver;
@JsonKey(name: targetColumnNameJsonKey)
final String targetColumnName;

@override
final List<String> classNames;
@JsonKey(
name: treeRootNodeJsonKey,
toJson: treeNodeToJson,
fromJson: fromTreeNodeJson,
)
final TreeNode treeRootNode;

@override
DataFrame predict(DataFrame features) {
final predictedLabels = features
.toMatrix(dtype)
.rows
.map(_solver.getLabelForSample);
.map((sample) => _getLabelForSample(sample, treeRootNode));

if (predictedLabels.isEmpty) {
return DataFrame([<num>[]]);
Expand All @@ -38,27 +72,51 @@ class DecisionTreeClassifierImpl with AssessablePredictorMixin

return DataFrame.fromMatrix(
Matrix.fromColumns([outcomeVector], dtype: dtype),
header: classNames,
header: [
targetColumnName,
],
);
}

@override
DataFrame predictProbabilities(DataFrame features) {
final probabilities = Matrix.fromColumns([
Vector.fromList(
features
.toMatrix(dtype)
.rows
.map(_solver.getLabelForSample)
.map((label) => label.probability)
.toList(growable: false),
dtype: dtype,
),
final sampleVectors = features
.toMatrix(dtype)
.rows;

final probabilities = sampleVectors
.map((sample) => _getLabelForSample(sample, treeRootNode))
.map((label) => label.probability)
.toList(growable: false);

final probabilitiesVector = Vector.fromList(
probabilities,
dtype: dtype,
);

final probabilitiesMatrixColumn = Matrix.fromColumns([
probabilitiesVector,
], dtype: dtype);

return DataFrame.fromMatrix(
probabilities,
header: classNames,
probabilitiesMatrixColumn,
header: [
targetColumnName,
],
);
}

TreeLeafLabel _getLabelForSample(Vector sample, TreeNode node) {
if (node.isLeaf) {
return node.label;
}

for (final childNode in node.children) {
if (childNode.isSamplePassed(sample)) {
return _getLabelForSample(sample, childNode);
}
}

throw Exception('Given sample does not conform any splitting condition');
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
const dTypeJsonKey = 'DT';
const targetColumnNameJsonKey = 'T';
const treeRootNodeJsonKey = 'R';
9 changes: 4 additions & 5 deletions lib/src/classifier/knn_classifier/knn_classifier_impl.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@ import 'package:ml_linalg/vector.dart';

class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {
KnnClassifierImpl(
String targetName,
this._targetColumnName,
this._classLabels,
this._kernel,
this._solver,
this.dtype,
) : classNames = [targetName] {
) {
validateClassLabelList(_classLabels);
}

@override
final List<String> classNames;
final String _targetColumnName;

@override
final DType dtype;
Expand All @@ -47,7 +46,7 @@ class KnnClassifierImpl with AssessablePredictorMixin implements KnnClassifier {

return DataFrame.fromMatrix(
Matrix.fromColumns([outcomesAsVector], dtype: dtype),
header: classNames,
header: [_targetColumnName],
);
}

Expand Down
4 changes: 3 additions & 1 deletion lib/src/classifier/linear_classifier.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import 'package:ml_linalg/matrix.dart';

/// An interface for all types of linear classifiers
abstract class LinearClassifier implements Classifier {
Iterable<String> get classNames;

/// A function that is used for converting learned coefficients into
/// probabilities
LinkFunction get linkFunction;
Expand All @@ -19,4 +21,4 @@ abstract class LinearClassifier implements Classifier {
/// A matrix, where each column is a vector of coefficients, associated with
/// the specific class
Matrix get coefficientsByClasses;
}
}
Loading

0 comments on commit b06bcb9

Please sign in to comment.