Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainability doc #2901

Merged
merged 9 commits into from Apr 20, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/code/MlNetCookBook.md
Expand Up @@ -578,6 +578,60 @@ var biases = modelParameters.GetBiases();

```

## How do I get a model's weights to look at the global feature importance?
The below snippet shows how to get a model's weights to help determine the feature importance of the model for a linear model.

```csharp
var linearModel = model.LastTransformer.Model;

var weights = linearModel.Weights;
```

The below snipper shows how to get the weights for a fast tree model.

```csharp
var treeModel = model.LastTransformer.Model;

var weights = new VBuffer<float>();
treeModel.GetFeatureWeights(ref weights);
```

## How do I look at the global feature importance?
The below snippet shows how to get a glimpse of the the feature importance. Permutation Feature Importance works by computing the change in the evaluation metrics when each feature is replaced by a random value. In this case, we are investigating the change in the root mean squared error. For more information on permutation feature importance, review the [documentation](https://docs.microsoft.com/en-us/dotnet/machine-learning/how-to-guides/determine-global-feature-importance-in-model).

```csharp
var transformedData = model.Transform(data);

var featureImportance = context.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);

for (int i = 0; i < featureImportance.Count(); i++)
{
Console.WriteLine($"Feature {i}: Difference in RMS - {featureImportance[i].RootMeanSquaredError.Mean}");
}
```

## How do I look at the local feature importance per example?
The below snippet shows how to get feature importance for each example of data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature importance for each example of data [](start = 35, length = 43)

Can you link to the appropriate place in docs for more information for all of these? Maybe we don't actually need to go into major details here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best doc I could find is this one. Is this ok to link to in each of these sections or would there be a doc for each of these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was thinking we could link to the code samples in the repo. But this is a moving target, so let's revisit later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Apologies for the misunderstanding. Was there anything else I missed for the PR? Just making sure no one is waiting for me to make more updates 😄


```csharp
var model = pipeline.Fit(data);
var transformedData = model.Transform(data);

var linearModel = model.LastTransformer;

var featureContributionCalculation = context.Transforms.CalculateFeatureContribution(linearModel, normalize: false);

var featureContributionData = featureContributionCalculation.Fit(transformedData).Transform(transformedData);

var shuffledSubset = context.Data.TakeRows(context.Data.ShuffleRows(featureContributionData), 10);
var scoringEnumerator = context.Data.CreateEnumerable<HousingData>(shuffledSubset, true);

foreach (var row in scoringEnumerator)
{
Console.WriteLine(row);
}
```

## What is normalization and why do I need to care?

In ML.NET we expose a number of [parametric and non-parametric algorithms](https://machinelearningmastery.com/parametric-and-nonparametric-machine-learning-algorithms/).
Expand Down Expand Up @@ -791,6 +845,7 @@ var transformedData = pipeline.Fit(data).Transform(data);
var embeddings = transformedData.GetColumn<float[]>(mlContext, "Embeddings").Take(10).ToArray();
var unigrams = transformedData.GetColumn<float[]>(mlContext, "BagOfWords").Take(10).ToArray();
```

## How do I train using cross-validation?

[Cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) is a useful technique for ML applications. It helps estimate the variance of the model quality from one run to another and also eliminates the need to extract a separate test set for evaluation.
Expand Down Expand Up @@ -841,6 +896,7 @@ var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro);
Console.WriteLine(microAccuracies.Average());

```

## Can I mix and match static and dynamic pipelines?

Yes, we can have both of them in our codebase. The static pipelines are just a statically-typed way to build dynamic pipelines.
Expand Down
Expand Up @@ -250,6 +250,164 @@ private void NormalizationWorkout(string dataPath)
public void Normalization()
=> NormalizationWorkout(GetDataPath("iris.data"));

[Fact]
public void GlobalFeatureImportance()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.FastTree());

var model = pipeline.Fit(data);

var transformedData = model.Transform(data);

var featureImportance = context.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);

for (int i = 0; i < featureImportance.Count(); i++)
{
Console.WriteLine($"Feature {i}: Difference in RMS - {featureImportance[i].RootMeanSquaredError.Mean}");
}
}

[Fact]
public void GetLinearModelWeights()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.Sdca());

var model = pipeline.Fit(data);

var linearModel = model.LastTransformer.Model;

var weights = linearModel.Weights;
}

[Fact]
public void GetFastTreeModelWeights()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.FastTree());

var model = pipeline.Fit(data);

var linearModel = model.LastTransformer.Model;

var weights = new VBuffer<float>();
linearModel.GetFeatureWeights(ref weights);
}

[Fact]
public void FeatureImportanceForEachRow()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("MedianHomeValue", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.FastTree(labelColumnName: "MedianHomeValue"));

var model = pipeline.Fit(data);

var transfomedData = model.Transform(data);

var linearModel = model.LastTransformer;

var featureContributionCalculation = context.Transforms.CalculateFeatureContribution(linearModel, normalize: false);

var featureContributionData = featureContributionCalculation.Fit(transfomedData).Transform(transfomedData);

var shuffledSubset = context.Data.TakeRows(context.Data.ShuffleRows(featureContributionData), 10);
var scoringEnumerator = context.Data.CreateEnumerable<HousingData>(shuffledSubset, true);

foreach (var row in scoringEnumerator)
{
Console.WriteLine(row);
}
}

private IEnumerable<CustomerChurnInfo> GetChurnInfo()
{
var r = new Random(454);
Expand Down Expand Up @@ -626,5 +784,20 @@ private class AdultData
public float Target { get; set; }
}

private class HousingData
{
public float MedianHomeValue { get; set; }
public float CrimesPerCapita { get; set; }
public float PercentResidental { get; set; }
public float PercentNonRetail { get; set; }
public float CharlesRiver { get; set; }
public float NitricOxides { get; set; }
public float RoomsPerDwelling { get; set; }
public float PercentPre40s { get; set; }
public float EmploymentDistance { get; set; }
public float HighwayDistance { get; set; }
public float TaxRate { get; set; }
public float TeacherRatio { get; set; }
}
}
}