Skip to content

Commit

Permalink
injector 1.0.9 supported (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Aug 28, 2020
1 parent 59eda7c commit ed1769e
Show file tree
Hide file tree
Showing 21 changed files with 83 additions and 80 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog

## 14.2.6
- `injector` lib 1.0.9 supported

## 14.2.5
- `pubspec`:
- `injector` dependency corrected
Expand Down
Expand Up @@ -45,8 +45,8 @@ LinearOptimizer createLogLikelihoodOptimizer(
.toList();
final points = splits[0].toMatrix(dtype);
final labels = splits[1].toMatrix(dtype);
final optimizerFactory = dependencies.getDependency<LinearOptimizerFactory>();
final costFunctionFactory = dependencies.getDependency<CostFunctionFactory>();
final optimizerFactory = dependencies.get<LinearOptimizerFactory>();
final costFunctionFactory = dependencies.get<CostFunctionFactory>();
final costFunction = costFunctionFactory.createByType(
CostFunctionType.logLikelihood,
linkFunction: linkFunction,
Expand Down
Expand Up @@ -27,6 +27,6 @@ DecisionTreeClassifier createDecisionTreeClassifier(
final treeRootNode = trainer.train(trainData.toMatrix(dtype));

return dependencies
.getDependency<DecisionTreeClassifierFactory>()
.get<DecisionTreeClassifierFactory>()
.create(treeRootNode, targetName, dtype);
}
Expand Up @@ -41,10 +41,10 @@ KnnClassifier createKnnClassifier(
.unique()
.toList(growable: false);

final kernelFactory = dependencies.getDependency<KernelFactory>();
final kernelFactory = dependencies.get<KernelFactory>();
final kernel = kernelFactory.createByType(kernelType);

final solverFactory = dependencies.getDependency<KnnSolverFactory>();
final solverFactory = dependencies.get<KnnSolverFactory>();

final solver = solverFactory.create(
trainFeatures,
Expand All @@ -55,7 +55,7 @@ KnnClassifier createKnnClassifier(
);

final knnClassifierFactory = dependencies
.getDependency<KnnClassifierFactory>();
.get<KnnClassifierFactory>();

return knnClassifierFactory.create(
targetName,
Expand Down
Expand Up @@ -47,7 +47,7 @@ LogisticRegressor createLogisticRegressor(
trainData.toMatrix(dtype).columnsNum - 1);
}

final linkFunction = dependencies.getDependency<LinkFunction>(
final linkFunction = dependencies.get<LinkFunction>(
dependencyName: dTypeToInverseLogitLinkFunctionToken[dtype]);
final optimizer = createLogLikelihoodOptimizer(
trainData,
Expand Down
Expand Up @@ -42,7 +42,7 @@ SoftmaxRegressor createSoftmaxRegressor(

validateTrainData(trainData, targetNames);

final linkFunction = dependencies.getDependency<LinkFunction>(
final linkFunction = dependencies.get<LinkFunction>(
dependencyName: dTypeToSoftmaxLinkFunctionToken[dtype]);

final optimizer = createLogLikelihoodOptimizer(
Expand Down Expand Up @@ -77,7 +77,7 @@ SoftmaxRegressor createSoftmaxRegressor(
: null;

final regressorFactory = dependencies
.getDependency<SoftmaxRegressorFactory>();
.get<SoftmaxRegressorFactory>();

return regressorFactory.create(
coefficientsByClasses,
Expand Down
76 changes: 38 additions & 38 deletions lib/src/di/dependencies.dart
Expand Up @@ -54,101 +54,101 @@ import 'package:ml_algo/src/tree_trainer/tree_trainer_factory_impl.dart';
Injector get dependencies =>
injector ??= Injector()
..registerSingleton<LinearOptimizerFactory>(
(_) => const LinearOptimizerFactoryImpl())
() => const LinearOptimizerFactoryImpl())

..registerSingleton<RandomizerFactory>(
(_) => const RandomizerFactoryImpl())
() => const RandomizerFactoryImpl())

..registerSingleton<LearningRateGeneratorFactory>(
(_) => const LearningRateGeneratorFactoryImpl())
() => const LearningRateGeneratorFactoryImpl())

..registerSingleton<InitialCoefficientsGeneratorFactory>(
(_) => const InitialCoefficientsGeneratorFactoryImpl())
() => const InitialCoefficientsGeneratorFactoryImpl())

..registerDependency<ConvergenceDetectorFactory>(
(_) => const ConvergenceDetectorFactoryImpl())
() => const ConvergenceDetectorFactoryImpl())

..registerSingleton<CostFunctionFactory>(
(_) => const CostFunctionFactoryImpl())
() => const CostFunctionFactoryImpl())

..registerSingleton<LinkFunction>(
(_) => const Float32InverseLogitLinkFunction(),
() => const Float32InverseLogitLinkFunction(),
dependencyName: float32InverseLogitLinkFunctionToken)

..registerSingleton<LinkFunction>(
(_) => const Float64InverseLogitLinkFunction(),
() => const Float64InverseLogitLinkFunction(),
dependencyName: float64InverseLogitLinkFunctionToken)

..registerSingleton<LinkFunction>(
(_) => const Float32SoftmaxLinkFunction(),
() => const Float32SoftmaxLinkFunction(),
dependencyName: float32SoftmaxLinkFunctionToken)

..registerSingleton<LinkFunction>(
(_) => const Float64SoftmaxLinkFunction(),
() => const Float64SoftmaxLinkFunction(),
dependencyName: float64SoftmaxLinkFunctionToken)

..registerSingleton<SplitIndicesProviderFactory>(
(_) => const SplitIndicesProviderFactoryImpl())
() => const SplitIndicesProviderFactoryImpl())

..registerSingleton<SoftmaxRegressorFactory>(
(_) => const SoftmaxRegressorFactoryImpl())
() => const SoftmaxRegressorFactoryImpl())

..registerSingleton<KernelFactory>(
(_) => const KernelFactoryImpl())
() => const KernelFactoryImpl())

..registerDependency<KnnSolverFactory>(
(_) => const KnnSolverFactoryImpl())
() => const KnnSolverFactoryImpl())

..registerSingleton<KnnClassifierFactory>(
(_) => const KnnClassifierFactoryImpl())
() => const KnnClassifierFactoryImpl())

..registerSingleton<KnnRegressorFactory>(
(injector) => KnnRegressorFactoryImpl(
injector.getDependency<KernelFactory>(),
injector.getDependency<KnnSolverFactory>(),
() => KnnRegressorFactoryImpl(
injector.get<KernelFactory>(),
injector.get<KnnSolverFactory>(),
))

..registerSingleton<SequenceElementsDistributionCalculatorFactory>(
(_) => const SequenceElementsDistributionCalculatorFactoryImpl())
() => const SequenceElementsDistributionCalculatorFactoryImpl())

..registerSingleton<NominalTreeSplitterFactory>(
(_) => const NominalTreeSplitterFactoryImpl())
() => const NominalTreeSplitterFactoryImpl())

..registerSingleton<NumericalTreeSplitterFactory>(
(_) => const NumericalTreeSplitterFactoryImpl())
() => const NumericalTreeSplitterFactoryImpl())

..registerSingleton<TreeSplitAssessorFactory>(
(_) => const TreeSplitAssessorFactoryImpl())
() => const TreeSplitAssessorFactoryImpl())

..registerSingleton<TreeSplitterFactory>(
(injector) => TreeSplitterFactoryImpl(
injector.getDependency<TreeSplitAssessorFactory>(),
injector.getDependency<NominalTreeSplitterFactory>(),
injector.getDependency<NumericalTreeSplitterFactory>(),
() => TreeSplitterFactoryImpl(
injector.get<TreeSplitAssessorFactory>(),
injector.get<NominalTreeSplitterFactory>(),
injector.get<NumericalTreeSplitterFactory>(),
))

..registerSingleton<TreeSplitSelectorFactory>(
(injector) => TreeSplitSelectorFactoryImpl(
injector.getDependency<TreeSplitAssessorFactory>(),
injector.getDependency<TreeSplitterFactory>(),
() => TreeSplitSelectorFactoryImpl(
injector.get<TreeSplitAssessorFactory>(),
injector.get<TreeSplitterFactory>(),
))

..registerSingleton<TreeLeafDetectorFactory>(
(injector) => TreeLeafDetectorFactoryImpl(
injector.getDependency<TreeSplitAssessorFactory>(),
() => TreeLeafDetectorFactoryImpl(
injector.get<TreeSplitAssessorFactory>(),
))

..registerSingleton<TreeLeafLabelFactoryFactory>(
(injector) => TreeLeafLabelFactoryFactoryImpl(
injector.getDependency<SequenceElementsDistributionCalculatorFactory>(),
() => TreeLeafLabelFactoryFactoryImpl(
injector.get<SequenceElementsDistributionCalculatorFactory>(),
))

..registerSingleton<TreeTrainerFactory>(
(injector) => TreeTrainerFactoryImpl(
injector.getDependency<TreeLeafDetectorFactory>(),
injector.getDependency<TreeLeafLabelFactoryFactory>(),
injector.getDependency<TreeSplitSelectorFactory>(),
() => TreeTrainerFactoryImpl(
injector.get<TreeLeafDetectorFactory>(),
injector.get<TreeLeafLabelFactoryFactory>(),
injector.get<TreeSplitSelectorFactory>(),
))

..registerSingleton<DecisionTreeClassifierFactory>(
(injector) => const DecisionTreeClassifierFactoryImpl());
() => const DecisionTreeClassifierFactoryImpl());
Expand Up @@ -24,11 +24,11 @@ class CoordinateDescentOptimizer implements LinearOptimizer {
_lambda = lambda ?? 0.0,

_initialCoefficientsGenerator = dependencies
.getDependency<InitialCoefficientsGeneratorFactory>()
.get<InitialCoefficientsGeneratorFactory>()
.fromType(initialWeightsType, dtype),

_convergenceDetector = dependencies
.getDependency<ConvergenceDetectorFactory>()
.get<ConvergenceDetectorFactory>()
.create(minCoefficientsUpdate, iterationsLimit),

_costFn = costFunction,
Expand Down
Expand Up @@ -36,19 +36,19 @@ class GradientOptimizer implements LinearOptimizer {
_dtype = dtype,

_initialCoefficientsGenerator = dependencies
.getDependency<InitialCoefficientsGeneratorFactory>()
.get<InitialCoefficientsGeneratorFactory>()
.fromType(initialCoefficientsType, dtype),

_learningRateGenerator = dependencies
.getDependency<LearningRateGeneratorFactory>()
.get<LearningRateGeneratorFactory>()
.fromType(learningRateType),

_convergenceDetector = dependencies
.getDependency<ConvergenceDetectorFactory>()
.get<ConvergenceDetectorFactory>()
.create(minCoefficientsUpdate, iterationLimit),

_randomizer = dependencies
.getDependency<RandomizerFactory>()
.get<RandomizerFactory>()
.create(randomSeed) {
if (batchSize < 1 || batchSize > points.rowsNum) {
throw RangeError.range(batchSize, 1, points.rowsNum, 'Invalid batch size '
Expand Down
4 changes: 2 additions & 2 deletions lib/src/model_selection/cross_validator/cross_validator.dart
Expand Up @@ -38,7 +38,7 @@ abstract class CrossValidator {
DType dtype = DType.float32,
}) {
final dataSplitterFactory = dependencies
.getDependency<SplitIndicesProviderFactory>();
.get<SplitIndicesProviderFactory>();
final dataSplitter = dataSplitterFactory
.createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds);

Expand Down Expand Up @@ -72,7 +72,7 @@ abstract class CrossValidator {
DType dtype = DType.float32,
}) {
final dataSplitterFactory = dependencies
.getDependency<SplitIndicesProviderFactory>();
.get<SplitIndicesProviderFactory>();
final dataSplitter = dataSplitterFactory
.createByType(SplitIndicesProviderType.lpo, p: p);

Expand Down
Expand Up @@ -40,10 +40,10 @@ LinearOptimizer createSquaredCostOptimizer(
final labels = splits[1].toMatrix(dtype);

final optimizerFactory = dependencies
.getDependency<LinearOptimizerFactory>();
.get<LinearOptimizerFactory>();

final costFunctionFactory = dependencies
.getDependency<CostFunctionFactory>();
.get<CostFunctionFactory>();

final costFunction = costFunctionFactory.createByType(
CostFunctionType.leastSquare,
Expand Down
2 changes: 1 addition & 1 deletion lib/src/regressor/knn_regressor/knn_regressor.dart
Expand Up @@ -51,6 +51,6 @@ abstract class KnnRegressor implements Assessable, Predictor {
DType dtype = DType.float32,
}
) => dependencies
.getDependency<KnnRegressorFactory>()
.get<KnnRegressorFactory>()
.create(fittingData, targetName, k, kernel, distance, dtype);
}
Expand Up @@ -35,7 +35,7 @@ TreeTrainer createDecisionTreeTrainer(
),
);

final trainerFactory = dependencies.getDependency<TreeTrainerFactory>();
final trainerFactory = dependencies.get<TreeTrainerFactory>();

return trainerFactory.createByType(
TreeTrainerType.decision,
Expand Down
6 changes: 3 additions & 3 deletions pubspec.yaml
@@ -1,17 +1,17 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 14.2.5
version: 14.2.6
homepage: https://github.com/gyrdym/ml_algo

environment:
sdk: '>=2.7.0 <3.0.0'

dependencies:
injector: 1.0.8
injector: ^1.0.9
json_annotation: ^3.0.1
json_serializable: ^3.3.0
ml_dataframe: ^0.2.0
ml_linalg: ^12.17.1
ml_linalg: ^12.17.3
quiver: ^2.0.2
xrange: ^0.0.8

Expand Down
6 changes: 3 additions & 3 deletions test/classifier/knn_classifier/knn_classifier_test.dart
Expand Up @@ -50,9 +50,9 @@ void main() {
knnClassifierMock);

injector = Injector()
..registerSingleton<KernelFactory>((_) => kernelFactoryMock)
..registerSingleton<KnnSolverFactory>((_) => solverFactoryMock)
..registerSingleton<KnnClassifierFactory>((_) => knnClassifierFactoryMock);
..registerSingleton<KernelFactory>(() => kernelFactoryMock)
..registerSingleton<KnnSolverFactory>(() => solverFactoryMock)
..registerSingleton<KnnClassifierFactory>(() => knnClassifierFactoryMock);
});

tearDown(() {
Expand Down
Expand Up @@ -55,12 +55,12 @@ void main() {

injector = Injector()
..registerSingleton<LinkFunction>(
(_) => linkFunctionMock,
() => linkFunctionMock,
dependencyName: float32InverseLogitLinkFunctionToken)
..registerDependency<CostFunctionFactory>(
(_) => costFunctionFactoryMock)
() => costFunctionFactoryMock)
..registerSingleton<LinearOptimizerFactory>(
(_) => optimizerFactoryMock);
() => optimizerFactoryMock);

when(optimizerMock.findExtrema(
initialCoefficients: anyNamed('initialCoefficients'),
Expand Down
Expand Up @@ -97,13 +97,13 @@ void main() {
softmaxRegressorMock);

injector = Injector()
..registerSingleton<LinkFunction>((_) => linkFunctionMock,
..registerSingleton<LinkFunction>(() => linkFunctionMock,
dependencyName: float32SoftmaxLinkFunctionToken)
..registerDependency<CostFunctionFactory>(
(_) => costFunctionFactoryMock)
..registerSingleton<LinearOptimizerFactory>((_) => optimizerFactoryMock)
() => costFunctionFactoryMock)
..registerSingleton<LinearOptimizerFactory>(() => optimizerFactoryMock)
..registerSingleton<SoftmaxRegressorFactory>(
(_) => softmaxRegressorFactoryMock);
() => softmaxRegressorFactoryMock);

when(optimizerMock.findExtrema(
initialCoefficients: anyNamed('initialCoefficients'),
Expand Down
8 changes: 4 additions & 4 deletions test/linear_optimizer/gradient/gradient_optimizer_test.dart
Expand Up @@ -81,13 +81,13 @@ void main() {

injector = Injector()
..registerDependency<LearningRateGeneratorFactory>(
(_) => learningRateGeneratorFactoryMock)
() => learningRateGeneratorFactoryMock)
..registerDependency<InitialCoefficientsGeneratorFactory>(
(_) => initialWeightsGeneratorFactoryMock)
() => initialWeightsGeneratorFactoryMock)
..registerDependency<ConvergenceDetectorFactory>(
(_) => convergenceDetectorFactoryMock)
() => convergenceDetectorFactoryMock)
..registerDependency<RandomizerFactory>(
(_) => randomizerFactoryMock);
() => randomizerFactoryMock);

when(initialCoefficientsGeneratorMock.generate(argThat(anything)))
.thenReturn(autoGeneratedInitialCoefficients.toVector());
Expand Down
Expand Up @@ -26,7 +26,7 @@ void main() {
dataSplitterFactory = createDataSplitterFactoryMock(dataSplitter);

injector = Injector()
..registerDependency<SplitIndicesProviderFactory>((_) => dataSplitterFactory);
..registerDependency<SplitIndicesProviderFactory>(() => dataSplitterFactory);
});

tearDown(() => injector = null);
Expand Down

0 comments on commit ed1769e

Please sign in to comment.