Skip to content

Commit

Permalink
naive bayes
Browse files Browse the repository at this point in the history
decision trees (first pass)
data tables
namespace fixes
  • Loading branch information
Jack Dermody committed Sep 16, 2016
1 parent 5786ecf commit ae96fa2
Show file tree
Hide file tree
Showing 51 changed files with 2,255 additions and 535 deletions.
97 changes: 97 additions & 0 deletions Bayesian/NaiveBayesClassifier.cs
@@ -0,0 +1,97 @@
using BrightWire.Models;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace BrightWire.Bayesian
{
public class NaiveBayesClassifier : IRowProcessor
{
interface IProbabilityProvider
{
double GetProbability(IRow row);
}
class CategoricalColumn : IProbabilityProvider
{
readonly int _columnIndex;
readonly Dictionary<string, double> _probability;
readonly double _nullValue;

public CategoricalColumn(NaiveBayes.CategorialColumn summary, double nullValue = 0)
{
_nullValue = nullValue;
_columnIndex = summary.ColumnIndex;
_probability = summary.Probability.ToDictionary(d => d.Category, d => d.LogProbability);
}

public double GetProbability(IRow row)
{
double ret;
var val = row.GetField<string>(_columnIndex);
if (_probability.TryGetValue(val, out ret))
return ret;
return _nullValue;
}
}
class ContinuousColumn : IProbabilityProvider
{
readonly NaiveBayes.ContinuousGaussianColumn _column;

public ContinuousColumn(NaiveBayes.ContinuousGaussianColumn column)
{
_column = column;
}

public double GetProbability(IRow row)
{
double x = row.GetField<double>(_column.ColumnIndex);
var exponent = Math.Exp(-1 * Math.Pow(x - _column.Mean, 2) / (2 * Math.Pow(_column.Variance, 2)));
return Math.Log(1.0 / Math.Sqrt(2 * Math.PI * _column.Variance) * exponent);
}
}
readonly List<Tuple<string, List<IProbabilityProvider>>> _classProbability = new List<Tuple<string, List<IProbabilityProvider>>>();
readonly List<Tuple<IRow, string>> _resultList = new List<Tuple<IRow, string>>();

public NaiveBayesClassifier(NaiveBayes model)
{
foreach (var cls in model.Class) {
List<IProbabilityProvider> list = new List<IProbabilityProvider>();
foreach (var col in cls.ColumnSummary) {
if (col.Type == NaiveBayes.ColumnType.Categorical)
list.Add(new CategoricalColumn(col as NaiveBayes.CategorialColumn));
else if (col.Type == NaiveBayes.ColumnType.ContinuousGaussian)
list.Add(new ContinuousColumn(col as NaiveBayes.ContinuousGaussianColumn));
}
_classProbability.Add(Tuple.Create(cls.Label, list));
}
}

public IEnumerable<KeyValuePair<string, double>> Classify(IRow row)
{
var ret = new Dictionary<string, double>();
foreach(var cls in _classProbability) {
double score = 0;
foreach (var item in cls.Item2)
score += item.GetProbability(row);
ret.Add(cls.Item1, score);
}
return ret.OrderByDescending(kv => kv.Value);
}

public bool Process(IRow row)
{
var classification = Classify(row).First().Key;
_resultList.Add(Tuple.Create(row, classification));
return true;
}

public void Clear()
{
_resultList.Clear();
}

public IReadOnlyList<Tuple<IRow, string>> Results { get { return _resultList; } }
}
}
78 changes: 78 additions & 0 deletions Bayesian/Training/NaiveBayesTrainer.cs
@@ -0,0 +1,78 @@
using BrightWire.Helper;
using BrightWire.Models;
using BrightWire.TabularData;
using BrightWire.TabularData.Analysis;
using BrightWire.TabularData.Helper;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace BrightWire.Bayesian.Training
{
public static class NaiveBayesTrainer
{
public static NaiveBayes Train(IDataTable table, int classColumnIndex)
{
// analyse the table to get the set of class values
var analysis = new DataTableAnalysis(table, classColumnIndex);
table.Process(analysis);

var classInfo = analysis.ColumnInfo.Single();
if (classInfo.DistinctValues == null)
throw new Exception("Too many class values");

// analyse the data per class
var classBasedFrequency = classInfo.DistinctValues.Select(cv => Tuple.Create<string, IRowProcessor>(cv.ToString(), new FrequencyAnalysis(table, classColumnIndex)));
var frequencyAnalysis = new ClassBasedRowProcessor(classBasedFrequency, classColumnIndex);
table.Process(frequencyAnalysis);

// create the per-class summaries from the frequency table
var classList = new List<NaiveBayes.ClassSummary>();
foreach(var classSummary in frequencyAnalysis.All) {
var classLabel = classSummary.Item1;
var frequency = classSummary.Item2 as FrequencyAnalysis;
var columnList = new List<NaiveBayes.IColumn>();
foreach(var column in frequency.ColumnInfo) {
var continuous = column as NumberCollector;
var categorical = column as FrequencyCollector;
if(categorical != null) {
var total = (double)categorical.Total;
if (total > 0) {
var list = new List<NaiveBayes.CategorialProbability>();
foreach (var item in categorical.Frequency) {
list.Add(new NaiveBayes.CategorialProbability {
Category = item.Key,
LogProbability = Math.Log(item.Value / total)
});
}
columnList.Add(new NaiveBayes.CategorialColumn {
ColumnIndex = categorical.ColumnIndex,
Probability = list
});
}
}else if(continuous != null) {
var variance = continuous.Variance;
if (variance.HasValue) {
var mean = continuous.Mean;
columnList.Add(new NaiveBayes.ContinuousGaussianColumn {
ColumnIndex = continuous.ColumnIndex,
Mean = mean,
Variance = variance.Value
});
}
}
}
classList.Add(new NaiveBayes.ClassSummary {
Label = classLabel,
ColumnSummary = columnList
});
}

return new NaiveBayes {
Class = classList
};
}
}
}
32 changes: 27 additions & 5 deletions BrightWire.Net4.csproj
Expand Up @@ -7,7 +7,7 @@
<ProjectGuid>{FD8AAEF6-2EDB-446C-BB19-5EBEE5CDE982}</ProjectGuid>
<OutputType>Library</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>BrightWire.Net4</RootNamespace>
<RootNamespace>BrightWire</RootNamespace>
<AssemblyName>BrightWire.Net4</AssemblyName>
<TargetFrameworkVersion>v4.6.1</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
Expand Down Expand Up @@ -87,6 +87,8 @@
<ItemGroup>
<Compile Include="Bayesian\HiddenMarkovModel2.cs" />
<Compile Include="Bayesian\HiddenMarkovModel3.cs" />
<Compile Include="Bayesian\NaiveBayesClassifier.cs" />
<Compile Include="Bayesian\Training\NaiveBayesTrainer.cs" />
<Compile Include="Connectionist\Activation\LeakyRELU.cs" />
<Compile Include="Connectionist\Activation\RELU.cs" />
<Compile Include="Connectionist\Activation\Sigmoid.cs" />
Expand All @@ -102,7 +104,9 @@
<Compile Include="Connectionist\Execution\RecurrentExecution.cs" />
<Compile Include="ErrorMetrics\Quadratic.cs" />
<Compile Include="Helper\MiniBatch.cs" />
<Compile Include="Helper\SequenceToSequenceTrainingDataProvider.cs" />
<Compile Include="Helper\TrainingData\DataTableTrainingDataProvider.cs" />
<Compile Include="Helper\TrainingData\SequenceToSequenceTrainingDataProvider.cs" />
<Compile Include="Models\DecisionTree.cs" />
<Compile Include="Models\ExecutionResults\FeedForwardOutput.cs" />
<Compile Include="Models\ExecutionResults\RecurrentOutput.cs" />
<Compile Include="Connectionist\Factory.cs" />
Expand Down Expand Up @@ -136,11 +140,11 @@
<Compile Include="ExtensionMethods.cs" />
<Compile Include="Helper\BigEndianBinaryReader.cs" />
<Compile Include="Helper\BoundMath.cs" />
<Compile Include="Helper\DenseSequentialTrainingDataProvider.cs" />
<Compile Include="Helper\DenseTrainingDataProvider.cs" />
<Compile Include="Helper\TrainingData\DenseSequentialTrainingDataProvider.cs" />
<Compile Include="Helper\TrainingData\DenseTrainingDataProvider.cs" />
<Compile Include="Helper\DisposableMatrixExecutionLine.cs" />
<Compile Include="Helper\SequenceWindowBuilder.cs" />
<Compile Include="Helper\SparseTrainingDataProvider.cs" />
<Compile Include="Helper\TrainingData\SparseTrainingDataProvider.cs" />
<Compile Include="Interface.cs" />
<Compile Include="LinearAlgebra\CpuMatrix.cs" />
<Compile Include="LinearAlgebra\CpuVector.cs" />
Expand All @@ -153,16 +157,33 @@
<Compile Include="Models\FeedForwardNetwork.cs" />
<Compile Include="Models\FloatArray.cs" />
<Compile Include="Models\HMM.cs" />
<Compile Include="Models\NaiveBayes.cs" />
<Compile Include="Models\NetworkLayer.cs" />
<Compile Include="Models\RecurrentLayer.cs" />
<Compile Include="Models\RecurrentNetwork.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="TabularData\Analysis\DataTableAnalysis.cs" />
<Compile Include="TabularData\Analysis\FrequencyAnalysis.cs" />
<Compile Include="TabularData\Analysis\FrequencyCollector.cs" />
<Compile Include="TabularData\Analysis\NumericCollector.cs" />
<Compile Include="TabularData\Analysis\StringCollector.cs" />
<Compile Include="TabularData\DataTable.cs" />
<Compile Include="TabularData\Helper\ClassBasedRowProcessor.cs" />
<Compile Include="TabularData\Helper\ColumnTypeClassifier.cs" />
<Compile Include="TabularData\Helper\CSVDataTableBuilder.cs" />
<Compile Include="TabularData\Helper\DataTableProjector.cs" />
<Compile Include="TabularData\Helper\DataTableRow.cs" />
<Compile Include="TabularData\Helper\DataTableWriter.cs" />
<Compile Include="TabularData\IndexedDataTable.cs" />
<Compile Include="TabularData\MemoryBasedDataTable.cs" />
<Compile Include="TabularData\MutableDataTable.cs" />
<Compile Include="TabularData\NumericTable.cs" />
<Compile Include="TrainingData\Artificial\BinaryIntegers.cs" />
<Compile Include="TrainingData\Artificial\ReberGrammar.cs" />
<Compile Include="TrainingData\Artificial\Xor.cs" />
<Compile Include="TrainingData\MNIST.cs" />
<Compile Include="TreeBased\DecisionTreeClassifier.cs" />
<Compile Include="TreeBased\Training\DecisionTreeTrainer.cs" />
<Compile Include="Unsupervised\Clustering\KMeans.cs" />
</ItemGroup>
<ItemGroup>
Expand All @@ -180,6 +201,7 @@
<Content Include="LinearAlgebra\cuda\kernel.c" />
<Content Include="LinearAlgebra\cuda\kernel.stub.c" />
</ItemGroup>
<ItemGroup />
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.
Expand Down
2 changes: 1 addition & 1 deletion Connectionist/Activation/Softmax.cs
Expand Up @@ -4,7 +4,7 @@
using System.Text;
using System.Threading.Tasks;

namespace BrightWire.Net4.Connectionist.Activation
namespace BrightWire.Connectionist.Activation
{
public class Softmax : IActivationFunction
{
Expand Down
2 changes: 1 addition & 1 deletion Connectionist/Execution/BidirectionalExecution.cs
@@ -1,5 +1,5 @@
using BrightWire.Helper;
using BrightWire.Net4.Models.ExecutionResults;
using BrightWire.Models.ExecutionResults;
using System;
using System.Collections.Generic;
using System.Linq;
Expand Down
2 changes: 1 addition & 1 deletion Connectionist/Execution/RecurrentExecution.cs
@@ -1,5 +1,5 @@
using BrightWire.Helper;
using BrightWire.Net4.Models.ExecutionResults;
using BrightWire.Models.ExecutionResults;
using System;
using System.Collections.Generic;
using System.Linq;
Expand Down
15 changes: 9 additions & 6 deletions Connectionist/Factory.cs
Expand Up @@ -254,30 +254,33 @@ public IBidirectionalRecurrentExecution CreateBidirectional(BidirectionalNetwork
public IFeedForwardTrainingManager CreateFeedForwardManager(
INeuralNetworkTrainer trainer,
string dataFile,
ITrainingDataProvider testData
ITrainingDataProvider testData,
int? autoAdjustOnNoChangeCount = null
)
{
return new FeedForwardManager(trainer, dataFile, testData);
return new FeedForwardManager(trainer, dataFile, testData, autoAdjustOnNoChangeCount);
}

public IRecurrentTrainingManager CreateRecurrentManager(
INeuralNetworkRecurrentBatchTrainer trainer,
string dataFile,
ISequentialTrainingDataProvider testData,
int memorySize
int memorySize,
int? autoAdjustOnNoChangeCount = null
)
{
return new RecurrentManager(trainer, dataFile, testData, memorySize);
return new RecurrentManager(trainer, dataFile, testData, memorySize, autoAdjustOnNoChangeCount);
}

public IBidirectionalRecurrentTrainingManager CreateBidirectionalManager(
INeuralNetworkBidirectionalBatchTrainer trainer,
string dataFile,
ISequentialTrainingDataProvider testData,
int memorySize
int memorySize,
int? autoAdjustOnNoChangeCount = null
)
{
return new BidirectionalManager(_lap, trainer, dataFile, testData, memorySize);
return new BidirectionalManager(_lap, trainer, dataFile, testData, memorySize, autoAdjustOnNoChangeCount);
}
}
}
2 changes: 1 addition & 1 deletion Connectionist/Training/Batch/BatchTrainer.cs
Expand Up @@ -6,7 +6,7 @@
using System.Text;
using System.Threading.Tasks;
using BrightWire.Models;
using BrightWire.Net4.Models.ExecutionResults;
using BrightWire.Models.ExecutionResults;

namespace BrightWire.Connectionist.Training.Batch
{
Expand Down
19 changes: 2 additions & 17 deletions Connectionist/Training/Batch/RecurrentBatchTrainer.cs
Expand Up @@ -131,43 +131,28 @@ public void TrainOnMiniBatch(ISequentialMiniBatch miniBatch, float[] memory, IRe
// backpropagate, accumulating errors across the sequence
using (var updateAccumulator = new UpdateAccumulator(trainingContext)) {
IMatrix curr = null;
var bpt = 0;
while (updateStack.Any()) {
var update = updateStack.Pop();
var isT0 = !updateStack.Any();
var actionStack = update.Item1;

// calculate error
IMatrix previousError = null;
if (curr != null && bpt < context.BackpropagationThroughTime) {
using (var sum = curr.RowSums(1f / curr.ColumnCount))
previousError = sum.ToColumnMatrix();
}
var expectedOutput = update.Item2;
if (expectedOutput != null) {
if (expectedOutput != null)
curr = expectedOutput.Subtract(update.Item3);
if (previousError != null)
curr.AddInPlace(previousError);
}
else if (previousError != null)
curr = previousError;
else
continue;

// backpropagate
beforeBackProp?.Invoke(curr);
while (actionStack.Any()) {
var backpropagationAction = actionStack.Pop();
var shouldCalculateOutput = actionStack.Any() || isT0 || bpt < context.BackpropagationThroughTime;
var shouldCalculateOutput = actionStack.Any() || isT0;
curr = backpropagationAction.Execute(curr, trainingContext, true, updateAccumulator);
}
afterBackProp?.Invoke(curr);

// apply any filters
foreach (var filter in _filter)
filter.AfterBackPropagation(update.Item4, update.Item5, curr);

++bpt;
}

// adjust the initial memory against the error signal
Expand Down
5 changes: 1 addition & 4 deletions Connectionist/Training/Helper/RecurrentContext.cs
Expand Up @@ -13,17 +13,14 @@ public class RecurrentContext : IRecurrentTrainingContext
readonly ILinearAlgebraProvider _lap;
readonly ITrainingContext _trainingContext;
readonly List<INeuralNetworkRecurrentTrainerFilter> _filter = new List<INeuralNetworkRecurrentTrainerFilter>();
readonly int _backpropagationThroughTime;

public RecurrentContext(ILinearAlgebraProvider lap, ITrainingContext trainingContext, int backpropagationThroughTime = 0)
public RecurrentContext(ILinearAlgebraProvider lap, ITrainingContext trainingContext)
{
_lap = lap;
_trainingContext = trainingContext;
_backpropagationThroughTime = backpropagationThroughTime;
}

public ITrainingContext TrainingContext { get { return _trainingContext; } }
public int BackpropagationThroughTime { get { return _backpropagationThroughTime; } }

public void AddFilter(INeuralNetworkRecurrentTrainerFilter filter)
{
Expand Down

0 comments on commit ae96fa2

Please sign in to comment.