From ae96fa2dca3c73a1955c3b8741c8d766ec57349a Mon Sep 17 00:00:00 2001 From: Jack Dermody Date: Fri, 16 Sep 2016 19:19:29 +1000 Subject: [PATCH] naive bayes decision trees (first pass) data tables namespace fixes --- Bayesian/NaiveBayesClassifier.cs | 97 +++++++ Bayesian/Training/NaiveBayesTrainer.cs | 78 ++++++ BrightWire.Net4.csproj | 32 ++- Connectionist/Activation/Softmax.cs | 2 +- .../Execution/BidirectionalExecution.cs | 2 +- Connectionist/Execution/RecurrentExecution.cs | 2 +- Connectionist/Factory.cs | 15 +- Connectionist/Training/Batch/BatchTrainer.cs | 2 +- .../Training/Batch/RecurrentBatchTrainer.cs | 19 +- .../Training/Helper/RecurrentContext.cs | 5 +- .../Layer/Recurrent/SimpleRecurrent.cs | 1 + .../Training/Manager/BidirectionalManager.cs | 8 +- .../Training/Manager/FeedForwardManager.cs | 22 +- .../Training/Manager/RecurrentManager.cs | 9 +- .../Training/Manager/RecurrentManagerBase.cs | 21 +- .../Training/WeightInitialisation/Xavier.cs | 1 + Helper/MiniBatch.cs | 2 +- .../SequenceToSequenceTrainingDataProvider.cs | 69 ----- .../DataTableTrainingDataProvider.cs | 72 +++++ .../DenseSequentialTrainingDataProvider.cs | 2 +- .../DenseTrainingDataProvider.cs | 4 +- .../SequenceToSequenceTrainingDataProvider.cs | 98 +++++++ .../SparseTrainingDataProvider.cs | 4 +- Interface.cs | 76 ++++- LICENSE | 222 ++------------- LinearAlgebra/CpuVector.cs | 26 ++ LinearAlgebra/GpuVector.cs | 25 ++ Models/DecisionTree.cs | 72 +++++ Models/ExecutionResults/FeedForwardOutput.cs | 2 +- Models/ExecutionResults/RecurrentOutput.cs | 2 +- Models/NaiveBayes.cs | 74 +++++ TabularData/Analysis/DataTableAnalysis.cs | 44 +++ TabularData/Analysis/FrequencyAnalysis.cs | 37 +++ TabularData/Analysis/FrequencyCollector.cs | 58 ++++ TabularData/Analysis/NumericCollector.cs | 110 ++++++++ TabularData/Analysis/StringCollector.cs | 56 ++++ TabularData/DataTable.cs | 262 +++++------------- TabularData/Helper/CSVDataTableBuilder.cs | 237 ++++++++++++++++ TabularData/Helper/ClassBasedRowProcessor.cs | 41 +++ TabularData/Helper/ColumnTypeClassifier.cs | 48 ++++ TabularData/Helper/DataTableProjector.cs | 66 +++++ TabularData/Helper/DataTableRow.cs | 46 +++ TabularData/Helper/DataTableWriter.cs | 69 +++++ TabularData/IndexedDataTable.cs | 92 ++++++ TabularData/MemoryBasedDataTable.cs | 52 ++++ TabularData/MutableDataTable.cs | 168 +++++++++++ TabularData/NumericTable.cs | 2 +- TrainingData/Artificial/ReberGrammar.cs | 14 +- TreeBased/DecisionTreeClassifier.cs | 62 +++++ TreeBased/Training/DecisionTreeTrainer.cs | 258 +++++++++++++++++ Unsupervised/Clustering/KMeans.cs | 2 +- 51 files changed, 2255 insertions(+), 535 deletions(-) create mode 100644 Bayesian/NaiveBayesClassifier.cs create mode 100644 Bayesian/Training/NaiveBayesTrainer.cs delete mode 100644 Helper/SequenceToSequenceTrainingDataProvider.cs create mode 100644 Helper/TrainingData/DataTableTrainingDataProvider.cs rename Helper/{ => TrainingData}/DenseSequentialTrainingDataProvider.cs (98%) rename Helper/{ => TrainingData}/DenseTrainingDataProvider.cs (95%) create mode 100644 Helper/TrainingData/SequenceToSequenceTrainingDataProvider.cs rename Helper/{ => TrainingData}/SparseTrainingDataProvider.cs (97%) create mode 100644 Models/DecisionTree.cs create mode 100644 Models/NaiveBayes.cs create mode 100644 TabularData/Analysis/DataTableAnalysis.cs create mode 100644 TabularData/Analysis/FrequencyAnalysis.cs create mode 100644 TabularData/Analysis/FrequencyCollector.cs create mode 100644 TabularData/Analysis/NumericCollector.cs create mode 100644 TabularData/Analysis/StringCollector.cs create mode 100644 TabularData/Helper/CSVDataTableBuilder.cs create mode 100644 TabularData/Helper/ClassBasedRowProcessor.cs create mode 100644 TabularData/Helper/ColumnTypeClassifier.cs create mode 100644 TabularData/Helper/DataTableProjector.cs create mode 100644 TabularData/Helper/DataTableRow.cs create mode 100644 TabularData/Helper/DataTableWriter.cs create mode 100644 TabularData/IndexedDataTable.cs create mode 100644 TabularData/MemoryBasedDataTable.cs create mode 100644 TabularData/MutableDataTable.cs create mode 100644 TreeBased/DecisionTreeClassifier.cs create mode 100644 TreeBased/Training/DecisionTreeTrainer.cs diff --git a/Bayesian/NaiveBayesClassifier.cs b/Bayesian/NaiveBayesClassifier.cs new file mode 100644 index 00000000..2234e49d --- /dev/null +++ b/Bayesian/NaiveBayesClassifier.cs @@ -0,0 +1,97 @@ +using BrightWire.Models; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace BrightWire.Bayesian +{ + public class NaiveBayesClassifier : IRowProcessor + { + interface IProbabilityProvider + { + double GetProbability(IRow row); + } + class CategoricalColumn : IProbabilityProvider + { + readonly int _columnIndex; + readonly Dictionary _probability; + readonly double _nullValue; + + public CategoricalColumn(NaiveBayes.CategorialColumn summary, double nullValue = 0) + { + _nullValue = nullValue; + _columnIndex = summary.ColumnIndex; + _probability = summary.Probability.ToDictionary(d => d.Category, d => d.LogProbability); + } + + public double GetProbability(IRow row) + { + double ret; + var val = row.GetField(_columnIndex); + if (_probability.TryGetValue(val, out ret)) + return ret; + return _nullValue; + } + } + class ContinuousColumn : IProbabilityProvider + { + readonly NaiveBayes.ContinuousGaussianColumn _column; + + public ContinuousColumn(NaiveBayes.ContinuousGaussianColumn column) + { + _column = column; + } + + public double GetProbability(IRow row) + { + double x = row.GetField(_column.ColumnIndex); + var exponent = Math.Exp(-1 * Math.Pow(x - _column.Mean, 2) / (2 * Math.Pow(_column.Variance, 2))); + return Math.Log(1.0 / Math.Sqrt(2 * Math.PI * _column.Variance) * exponent); + } + } + readonly List>> _classProbability = new List>>(); + readonly List> _resultList = new List>(); + + public NaiveBayesClassifier(NaiveBayes model) + { + foreach (var cls in model.Class) { + List list = new List(); + foreach (var col in cls.ColumnSummary) { + if (col.Type == NaiveBayes.ColumnType.Categorical) + list.Add(new CategoricalColumn(col as NaiveBayes.CategorialColumn)); + else if (col.Type == NaiveBayes.ColumnType.ContinuousGaussian) + list.Add(new ContinuousColumn(col as NaiveBayes.ContinuousGaussianColumn)); + } + _classProbability.Add(Tuple.Create(cls.Label, list)); + } + } + + public IEnumerable> Classify(IRow row) + { + var ret = new Dictionary(); + foreach(var cls in _classProbability) { + double score = 0; + foreach (var item in cls.Item2) + score += item.GetProbability(row); + ret.Add(cls.Item1, score); + } + return ret.OrderByDescending(kv => kv.Value); + } + + public bool Process(IRow row) + { + var classification = Classify(row).First().Key; + _resultList.Add(Tuple.Create(row, classification)); + return true; + } + + public void Clear() + { + _resultList.Clear(); + } + + public IReadOnlyList> Results { get { return _resultList; } } + } +} diff --git a/Bayesian/Training/NaiveBayesTrainer.cs b/Bayesian/Training/NaiveBayesTrainer.cs new file mode 100644 index 00000000..f5f76c56 --- /dev/null +++ b/Bayesian/Training/NaiveBayesTrainer.cs @@ -0,0 +1,78 @@ +using BrightWire.Helper; +using BrightWire.Models; +using BrightWire.TabularData; +using BrightWire.TabularData.Analysis; +using BrightWire.TabularData.Helper; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace BrightWire.Bayesian.Training +{ + public static class NaiveBayesTrainer + { + public static NaiveBayes Train(IDataTable table, int classColumnIndex) + { + // analyse the table to get the set of class values + var analysis = new DataTableAnalysis(table, classColumnIndex); + table.Process(analysis); + + var classInfo = analysis.ColumnInfo.Single(); + if (classInfo.DistinctValues == null) + throw new Exception("Too many class values"); + + // analyse the data per class + var classBasedFrequency = classInfo.DistinctValues.Select(cv => Tuple.Create(cv.ToString(), new FrequencyAnalysis(table, classColumnIndex))); + var frequencyAnalysis = new ClassBasedRowProcessor(classBasedFrequency, classColumnIndex); + table.Process(frequencyAnalysis); + + // create the per-class summaries from the frequency table + var classList = new List(); + foreach(var classSummary in frequencyAnalysis.All) { + var classLabel = classSummary.Item1; + var frequency = classSummary.Item2 as FrequencyAnalysis; + var columnList = new List(); + foreach(var column in frequency.ColumnInfo) { + var continuous = column as NumberCollector; + var categorical = column as FrequencyCollector; + if(categorical != null) { + var total = (double)categorical.Total; + if (total > 0) { + var list = new List(); + foreach (var item in categorical.Frequency) { + list.Add(new NaiveBayes.CategorialProbability { + Category = item.Key, + LogProbability = Math.Log(item.Value / total) + }); + } + columnList.Add(new NaiveBayes.CategorialColumn { + ColumnIndex = categorical.ColumnIndex, + Probability = list + }); + } + }else if(continuous != null) { + var variance = continuous.Variance; + if (variance.HasValue) { + var mean = continuous.Mean; + columnList.Add(new NaiveBayes.ContinuousGaussianColumn { + ColumnIndex = continuous.ColumnIndex, + Mean = mean, + Variance = variance.Value + }); + } + } + } + classList.Add(new NaiveBayes.ClassSummary { + Label = classLabel, + ColumnSummary = columnList + }); + } + + return new NaiveBayes { + Class = classList + }; + } + } +} diff --git a/BrightWire.Net4.csproj b/BrightWire.Net4.csproj index 7ad84936..9f9892fe 100644 --- a/BrightWire.Net4.csproj +++ b/BrightWire.Net4.csproj @@ -7,7 +7,7 @@ {FD8AAEF6-2EDB-446C-BB19-5EBEE5CDE982} Library Properties - BrightWire.Net4 + BrightWire BrightWire.Net4 v4.6.1 512 @@ -87,6 +87,8 @@ + + @@ -102,7 +104,9 @@ - + + + @@ -136,11 +140,11 @@ - - + + - + @@ -153,16 +157,33 @@ + + + + + + + + + + + + + + + + + @@ -180,6 +201,7 @@ +