Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
decision trees (first pass) data tables namespace fixes
- Loading branch information
Jack Dermody
committed
Sep 16, 2016
1 parent
5786ecf
commit ae96fa2
Showing
51 changed files
with
2,255 additions
and
535 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
using BrightWire.Models; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading.Tasks; | ||
|
||
namespace BrightWire.Bayesian | ||
{ | ||
public class NaiveBayesClassifier : IRowProcessor | ||
{ | ||
interface IProbabilityProvider | ||
{ | ||
double GetProbability(IRow row); | ||
} | ||
class CategoricalColumn : IProbabilityProvider | ||
{ | ||
readonly int _columnIndex; | ||
readonly Dictionary<string, double> _probability; | ||
readonly double _nullValue; | ||
|
||
public CategoricalColumn(NaiveBayes.CategorialColumn summary, double nullValue = 0) | ||
{ | ||
_nullValue = nullValue; | ||
_columnIndex = summary.ColumnIndex; | ||
_probability = summary.Probability.ToDictionary(d => d.Category, d => d.LogProbability); | ||
} | ||
|
||
public double GetProbability(IRow row) | ||
{ | ||
double ret; | ||
var val = row.GetField<string>(_columnIndex); | ||
if (_probability.TryGetValue(val, out ret)) | ||
return ret; | ||
return _nullValue; | ||
} | ||
} | ||
class ContinuousColumn : IProbabilityProvider | ||
{ | ||
readonly NaiveBayes.ContinuousGaussianColumn _column; | ||
|
||
public ContinuousColumn(NaiveBayes.ContinuousGaussianColumn column) | ||
{ | ||
_column = column; | ||
} | ||
|
||
public double GetProbability(IRow row) | ||
{ | ||
double x = row.GetField<double>(_column.ColumnIndex); | ||
var exponent = Math.Exp(-1 * Math.Pow(x - _column.Mean, 2) / (2 * Math.Pow(_column.Variance, 2))); | ||
return Math.Log(1.0 / Math.Sqrt(2 * Math.PI * _column.Variance) * exponent); | ||
} | ||
} | ||
readonly List<Tuple<string, List<IProbabilityProvider>>> _classProbability = new List<Tuple<string, List<IProbabilityProvider>>>(); | ||
readonly List<Tuple<IRow, string>> _resultList = new List<Tuple<IRow, string>>(); | ||
|
||
public NaiveBayesClassifier(NaiveBayes model) | ||
{ | ||
foreach (var cls in model.Class) { | ||
List<IProbabilityProvider> list = new List<IProbabilityProvider>(); | ||
foreach (var col in cls.ColumnSummary) { | ||
if (col.Type == NaiveBayes.ColumnType.Categorical) | ||
list.Add(new CategoricalColumn(col as NaiveBayes.CategorialColumn)); | ||
else if (col.Type == NaiveBayes.ColumnType.ContinuousGaussian) | ||
list.Add(new ContinuousColumn(col as NaiveBayes.ContinuousGaussianColumn)); | ||
} | ||
_classProbability.Add(Tuple.Create(cls.Label, list)); | ||
} | ||
} | ||
|
||
public IEnumerable<KeyValuePair<string, double>> Classify(IRow row) | ||
{ | ||
var ret = new Dictionary<string, double>(); | ||
foreach(var cls in _classProbability) { | ||
double score = 0; | ||
foreach (var item in cls.Item2) | ||
score += item.GetProbability(row); | ||
ret.Add(cls.Item1, score); | ||
} | ||
return ret.OrderByDescending(kv => kv.Value); | ||
} | ||
|
||
public bool Process(IRow row) | ||
{ | ||
var classification = Classify(row).First().Key; | ||
_resultList.Add(Tuple.Create(row, classification)); | ||
return true; | ||
} | ||
|
||
public void Clear() | ||
{ | ||
_resultList.Clear(); | ||
} | ||
|
||
public IReadOnlyList<Tuple<IRow, string>> Results { get { return _resultList; } } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
using BrightWire.Helper; | ||
using BrightWire.Models; | ||
using BrightWire.TabularData; | ||
using BrightWire.TabularData.Analysis; | ||
using BrightWire.TabularData.Helper; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading.Tasks; | ||
|
||
namespace BrightWire.Bayesian.Training | ||
{ | ||
public static class NaiveBayesTrainer | ||
{ | ||
public static NaiveBayes Train(IDataTable table, int classColumnIndex) | ||
{ | ||
// analyse the table to get the set of class values | ||
var analysis = new DataTableAnalysis(table, classColumnIndex); | ||
table.Process(analysis); | ||
|
||
var classInfo = analysis.ColumnInfo.Single(); | ||
if (classInfo.DistinctValues == null) | ||
throw new Exception("Too many class values"); | ||
|
||
// analyse the data per class | ||
var classBasedFrequency = classInfo.DistinctValues.Select(cv => Tuple.Create<string, IRowProcessor>(cv.ToString(), new FrequencyAnalysis(table, classColumnIndex))); | ||
var frequencyAnalysis = new ClassBasedRowProcessor(classBasedFrequency, classColumnIndex); | ||
table.Process(frequencyAnalysis); | ||
|
||
// create the per-class summaries from the frequency table | ||
var classList = new List<NaiveBayes.ClassSummary>(); | ||
foreach(var classSummary in frequencyAnalysis.All) { | ||
var classLabel = classSummary.Item1; | ||
var frequency = classSummary.Item2 as FrequencyAnalysis; | ||
var columnList = new List<NaiveBayes.IColumn>(); | ||
foreach(var column in frequency.ColumnInfo) { | ||
var continuous = column as NumberCollector; | ||
var categorical = column as FrequencyCollector; | ||
if(categorical != null) { | ||
var total = (double)categorical.Total; | ||
if (total > 0) { | ||
var list = new List<NaiveBayes.CategorialProbability>(); | ||
foreach (var item in categorical.Frequency) { | ||
list.Add(new NaiveBayes.CategorialProbability { | ||
Category = item.Key, | ||
LogProbability = Math.Log(item.Value / total) | ||
}); | ||
} | ||
columnList.Add(new NaiveBayes.CategorialColumn { | ||
ColumnIndex = categorical.ColumnIndex, | ||
Probability = list | ||
}); | ||
} | ||
}else if(continuous != null) { | ||
var variance = continuous.Variance; | ||
if (variance.HasValue) { | ||
var mean = continuous.Mean; | ||
columnList.Add(new NaiveBayes.ContinuousGaussianColumn { | ||
ColumnIndex = continuous.ColumnIndex, | ||
Mean = mean, | ||
Variance = variance.Value | ||
}); | ||
} | ||
} | ||
} | ||
classList.Add(new NaiveBayes.ClassSummary { | ||
Label = classLabel, | ||
ColumnSummary = columnList | ||
}); | ||
} | ||
|
||
return new NaiveBayes { | ||
Class = classList | ||
}; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.