Skip to content

Commit

Permalink
DI logic: conditional dependency registering added (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
gyrdym committed Dec 12, 2020
1 parent 7a20255 commit 318fe3a
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 41 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 15.3.4
- `DI logic`:
- conditional dependency registering added

## 15.3.3
- FUNDING.yml created

Expand Down
22 changes: 11 additions & 11 deletions lib/src/classifier/decision_tree_classifier/_init_module.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_cl
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory.dart';
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory_impl.dart';
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/extensions/injector.dart';
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart';
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory_impl.dart';
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label_factory_factory.dart';
Expand All @@ -25,50 +26,49 @@ void initDecisionTreeModule() {
initCommonModule();

decisionTreeInjector
..clearAll()
..registerSingleton<DistributionCalculatorFactory>(
..registerSingletonIf<DistributionCalculatorFactory>(
() => const DistributionCalculatorFactoryImpl())

..registerSingleton<NominalTreeSplitterFactory>(
..registerSingletonIf<NominalTreeSplitterFactory>(
() => const NominalTreeSplitterFactoryImpl())

..registerSingleton<NumericalTreeSplitterFactory>(
..registerSingletonIf<NumericalTreeSplitterFactory>(
() => const NumericalTreeSplitterFactoryImpl())

..registerSingleton<TreeSplitAssessorFactory>(
..registerSingletonIf<TreeSplitAssessorFactory>(
() => const TreeSplitAssessorFactoryImpl())

..registerSingleton<TreeSplitterFactory>(
..registerSingletonIf<TreeSplitterFactory>(
() => TreeSplitterFactoryImpl(
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
decisionTreeInjector.get<NominalTreeSplitterFactory>(),
decisionTreeInjector.get<NumericalTreeSplitterFactory>(),
))

..registerSingleton<TreeSplitSelectorFactory>(
..registerSingletonIf<TreeSplitSelectorFactory>(
() => TreeSplitSelectorFactoryImpl(
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
decisionTreeInjector.get<TreeSplitterFactory>(),
))

..registerSingleton<TreeLeafDetectorFactory>(
..registerSingletonIf<TreeLeafDetectorFactory>(
() => TreeLeafDetectorFactoryImpl(
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
))

..registerSingleton<TreeLeafLabelFactoryFactory>(
..registerSingletonIf<TreeLeafLabelFactoryFactory>(
() => TreeLeafLabelFactoryFactoryImpl(
decisionTreeInjector
.get<DistributionCalculatorFactory>(),
))

..registerSingleton<TreeTrainerFactory>(
..registerSingletonIf<TreeTrainerFactory>(
() => TreeTrainerFactoryImpl(
decisionTreeInjector.get<TreeLeafDetectorFactory>(),
decisionTreeInjector.get<TreeLeafLabelFactoryFactory>(),
decisionTreeInjector.get<TreeSplitSelectorFactory>(),
))

..registerSingleton<DecisionTreeClassifierFactory>(
..registerSingletonIf<DecisionTreeClassifierFactory>(
() => const DecisionTreeClassifierFactoryImpl());
}
10 changes: 5 additions & 5 deletions lib/src/classifier/knn_classifier/_init_module.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'package:ml_algo/src/classifier/knn_classifier/_injector.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory_impl.dart';
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/extensions/injector.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory_impl.dart';
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
Expand All @@ -13,17 +14,16 @@ void initKnnClassifierModule() {
initCommonModule();

knnClassifierInjector
..clearAll()
..registerSingleton<KernelFactory>(
..registerSingletonIf<KernelFactory>(
() => const KernelFactoryImpl())

..registerDependency<KnnSolverFactory>(
..registerSingletonIf<KnnSolverFactory>(
() => const KnnSolverFactoryImpl())

..registerSingleton<KnnClassifierFactory>(
..registerSingletonIf<KnnClassifierFactory>(
() => const KnnClassifierFactoryImpl())

..registerSingleton<KnnRegressorFactory>(
..registerSingletonIf<KnnRegressorFactory>(
() => KnnRegressorFactoryImpl(
knnClassifierInjector.get<KernelFactory>(),
knnClassifierInjector.get<KnnSolverFactory>(),
Expand Down
6 changes: 3 additions & 3 deletions lib/src/classifier/logistic_regressor/_init_module.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import 'package:ml_algo/src/classifier/logistic_regressor/_injector.dart';
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/extensions/injector.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/link_function/link_function_dependency_tokens.dart';
import 'package:ml_algo/src/link_function/logit/float32_inverse_logit_function.dart';
Expand All @@ -9,12 +10,11 @@ void initLogisticRegressorModule() {
initCommonModule();

logisticRegressorInjector
..clearAll()
..registerSingleton<LinkFunction>(
..registerSingletonIf<LinkFunction>(
() => const Float32InverseLogitLinkFunction(),
dependencyName: float32InverseLogitLinkFunctionToken)

..registerSingleton<LinkFunction>(
..registerSingletonIf<LinkFunction>(
() => const Float64InverseLogitLinkFunction(),
dependencyName: float64InverseLogitLinkFunctionToken);
}
8 changes: 4 additions & 4 deletions lib/src/classifier/softmax_regressor/_init_module.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'package:ml_algo/src/classifier/softmax_regressor/_injector.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory_impl.dart';
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/extensions/injector.dart';
import 'package:ml_algo/src/link_function/link_function.dart';
import 'package:ml_algo/src/link_function/link_function_dependency_tokens.dart';
import 'package:ml_algo/src/link_function/softmax/float32_softmax_link_function.dart';
Expand All @@ -11,15 +12,14 @@ void initSoftmaxRegressorModule() {
initCommonModule();

softmaxRegressorInjector
..clearAll()
..registerSingleton<LinkFunction>(
..registerSingletonIf<LinkFunction>(
() => const Float32SoftmaxLinkFunction(),
dependencyName: float32SoftmaxLinkFunctionToken)

..registerSingleton<LinkFunction>(
..registerSingletonIf<LinkFunction>(
() => const Float64SoftmaxLinkFunction(),
dependencyName: float64SoftmaxLinkFunctionToken)

..registerSingleton<SoftmaxRegressorFactory>(
..registerSingletonIf<SoftmaxRegressorFactory>(
() => const SoftmaxRegressorFactoryImpl());
}
26 changes: 13 additions & 13 deletions lib/src/di/common/init_common_module.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import 'package:ml_algo/src/cost_function/cost_function_factory.dart';
import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart';
import 'package:ml_algo/src/di/dependency_keys.dart';
import 'package:ml_algo/src/di/injector.dart';
import 'package:ml_algo/src/extensions/injector.dart';
import 'package:ml_algo/src/helpers/features_target_split.dart';
import 'package:ml_algo/src/helpers/features_target_split_interface.dart';
import 'package:ml_algo/src/helpers/normalize_class_labels.dart';
Expand Down Expand Up @@ -30,40 +31,39 @@ typedef EncoderFactory = Encoder Function(DataFrame, Iterable<String>);

void initCommonModule() {
injector
..clearAll()
..registerSingleton<EncoderFactory>(
..registerSingletonIf<EncoderFactory>(
() => (DataFrame data, Iterable<String> targetNames) =>
Encoder.oneHot(data, featureNames: targetNames),
dependencyName: oneHotEncoderFactoryKey)

..registerSingleton<RandomizerFactory>(
..registerSingletonIf<RandomizerFactory>(
() => const RandomizerFactoryImpl())

..registerDependency<FeaturesTargetSplit>(
..registerSingletonIf<FeaturesTargetSplit>(
() => featuresTargetSplit)

..registerSingleton<MetricFactory>(
..registerSingletonIf<MetricFactory>(
() => const MetricFactoryImpl())

..registerDependency<NormalizeClassLabels>(
..registerSingletonIf<NormalizeClassLabels>(
() => normalizeClassLabels)

..registerSingleton<LinearOptimizerFactory>(
..registerSingletonIf<LinearOptimizerFactory>(
() => const LinearOptimizerFactoryImpl())

..registerSingleton<LearningRateGeneratorFactory>(
..registerSingletonIf<LearningRateGeneratorFactory>(
() => const LearningRateGeneratorFactoryImpl())

..registerSingleton<InitialCoefficientsGeneratorFactory>(
..registerSingletonIf<InitialCoefficientsGeneratorFactory>(
() => const InitialCoefficientsGeneratorFactoryImpl())

..registerDependency<ConvergenceDetectorFactory>(
..registerSingletonIf<ConvergenceDetectorFactory>(
() => const ConvergenceDetectorFactoryImpl())

..registerSingleton<CostFunctionFactory>(
..registerSingletonIf<CostFunctionFactory>(
() => const CostFunctionFactoryImpl())

..registerSingleton<ModelAssessor<Classifier>>(() =>
..registerSingletonIf<ModelAssessor<Classifier>>(() =>
ClassifierAssessor(
injector.get<MetricFactory>(),
injector.get<EncoderFactory>(
Expand All @@ -72,7 +72,7 @@ void initCommonModule() {
normalizeClassLabels,
))

..registerSingleton<ModelAssessor<Predictor>>(() =>
..registerSingletonIf<ModelAssessor<Predictor>>(() =>
RegressorAssessor(
injector.get<MetricFactory>(),
featuresTargetSplit,
Expand Down
19 changes: 19 additions & 0 deletions lib/src/extensions/injector.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import 'package:injector/injector.dart';

extension InjectorExtension on Injector {
/// Registers a dependency only if it doesn't exist
void registerSingletonIf<T>(Builder<T> builder, {
bool override = false,
String dependencyName = '',
}) {
if (exists<T>(dependencyName: dependencyName)) {
return;
}

registerSingleton<T>(
builder,
override: override,
dependencyName: dependencyName,
);
}
}
8 changes: 4 additions & 4 deletions lib/src/regressor/knn_regressor/_init_module.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/extensions/injector.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
import 'package:ml_algo/src/knn_kernel/kernel_factory_impl.dart';
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
Expand All @@ -11,14 +12,13 @@ void initKnnRegressorModule() {
initCommonModule();

knnRegressorInjector
..clearAll()
..registerSingleton<KernelFactory>(
..registerSingletonIf<KernelFactory>(
() => const KernelFactoryImpl())

..registerDependency<KnnSolverFactory>(
..registerSingletonIf<KnnSolverFactory>(
() => const KnnSolverFactoryImpl())

..registerSingleton<KnnRegressorFactory>(
..registerSingletonIf<KnnRegressorFactory>(
() => KnnRegressorFactoryImpl(
knnRegressorInjector.get<KernelFactory>(),
knnRegressorInjector.get<KnnSolverFactory>(),
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: ml_algo
description: Machine learning algorithms, Machine learning models performance evaluation functionality
version: 15.3.3
version: 15.3.4
homepage: https://github.com/gyrdym/ml_algo

environment:
Expand Down

0 comments on commit 318fe3a

Please sign in to comment.