Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainability doc #2901

Merged
merged 9 commits into from Apr 20, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/code/MlNetCookBook.md
Expand Up @@ -578,6 +578,48 @@ var biases = modelParameters.GetBiases();

```

## How do I look at the global feature importance?
The below snippet shows how to get a glimpse of the the feature importance, or how much each column of data impacts the performance of the model.
Copy link
Contributor

@rogancarr rogancarr Mar 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

column of data [](start = 93, length = 14)

"feature" rather than "column of data". The end features in the model might not be exactly the input columns. #Resolved


```csharp
var transformedData = model.Transform(data);

var featureImportance = context.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);

foreach (var metricsStatistics in featureImportance)
{
Console.WriteLine($"Root Mean Squared - {metricsStatistics.Rms.Mean}");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Console.WriteLine($"Root Mean Squared - {metricsStatistics.Rms.Mean}"); [](start = 4, length = 71)

Explain a bit above about what this is calculating. It's not giving the RMS, but the difference in RMS for each feature if the feature were to be replaced with a random value.

Also, I would print "Feature I: Difference in RMS" rather than just the RMS.

}
```

## How do I get a model's weights to look at the global feature importance?
Copy link
Contributor

@rogancarr rogancarr Mar 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this above PFI, as it's the most naïve way we have to ask this question. #Resolved

The below snippet shows how to get a model's weights to help determine the feature importance of the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The below [](start = 0, length = 9)

Note that for a linear model, the weights are only an approximation. It helps to standardize the variables before the fit, so that they are all on the same scale, and even then, the linear regression solution does not account for correlations between the variables, and therefore this isn't a great measure of explainability.


```csharp
var linearModel = model.LastTransformer.Model;

var weights = new VBuffer<float>();
linearModel.GetFeatureWeights(ref weights);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add for trees as well -- see the functional tests.

```

## How do I look at the feature importance per row?
Copy link
Contributor

@rogancarr rogancarr Mar 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Local feature importance"

"row" => "example" (the language we've shifted to using) #Resolved

The below snippet shows how to get feature importance for each row.

```csharp
var model = pipeline.Fit(data);
var transfomedData = model.Transform(data);

var linearModel = model.LastTransformer;

var featureContributionCalculation = context.Transforms.CalculateFeatureContribution(linearModel, normalize: false);

var featureContributionData = featureContributionCalculation.Fit(transfomedData).Transform(transfomedData);

var shuffledSubset = context.Data.TakeRows(context.Data.ShuffleRows(featureContributionData), 10);

var preview = shuffledSubset.Preview();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manually print the rows of data after casting to an enumerable. You can copy / paste from the FCC Sample.

```

## What is normalization and why do I need to care?

In ML.NET we expose a number of [parametric and non-parametric algorithms](https://machinelearningmastery.com/parametric-and-nonparametric-machine-learning-algorithms/).
Expand Down Expand Up @@ -791,6 +833,7 @@ var transformedData = pipeline.Fit(data).Transform(data);
var embeddings = transformedData.GetColumn<float[]>(mlContext, "Embeddings").Take(10).ToArray();
var unigrams = transformedData.GetColumn<float[]>(mlContext, "BagOfWords").Take(10).ToArray();
```

## How do I train using cross-validation?

[Cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) is a useful technique for ML applications. It helps estimate the variance of the model quality from one run to another and also eliminates the need to extract a separate test set for evaluation.
Expand Down Expand Up @@ -841,6 +884,7 @@ var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro);
Console.WriteLine(microAccuracies.Average());

```

## Can I mix and match static and dynamic pipelines?

Yes, we can have both of them in our codebase. The static pipelines are just a statically-typed way to build dynamic pipelines.
Expand Down
Expand Up @@ -250,6 +250,125 @@ private void NormalizationWorkout(string dataPath)
public void Normalization()
=> NormalizationWorkout(GetDataPath("iris.data"));

[Fact]
public void GlobalFeatureImportance()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.FastTree());

var model = pipeline.Fit(data);

var transformedData = model.Transform(data);

var featureImportance = context.Regression.PermutationFeatureImportance(model.LastTransformer, transformedData);

foreach (var metricsStatistics in featureImportance)
{
Console.WriteLine($"Root Mean Squared - {metricsStatistics.RootMeanSquaredError.Mean}");
}
}

[Fact]
public void GetModelWeights()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.FastTree());

var model = pipeline.Fit(data);

var linearModel = model.LastTransformer.Model;

var weights = new VBuffer<float>();
linearModel.GetFeatureWeights(ref weights);
}

[Fact]
public void FeatureImportanceForEachRow()
{
var dataPath = GetDataPath("housing.txt");

var context = new MLContext();

IDataView data = context.Data.LoadFromTextFile(dataPath, new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("CrimesPerCapita", DataKind.Single, 1),
new TextLoader.Column("PercentResidental", DataKind.Single, 2),
new TextLoader.Column("PercentNonRetail", DataKind.Single, 3),
new TextLoader.Column("CharlesRiver", DataKind.Single, 4),
new TextLoader.Column("NitricOxides", DataKind.Single, 5),
new TextLoader.Column("RoomsPerDwelling", DataKind.Single, 6),
new TextLoader.Column("PercentPre40s", DataKind.Single, 7),
new TextLoader.Column("EmploymentDistance", DataKind.Single, 8),
new TextLoader.Column("HighwayDistance", DataKind.Single, 9),
new TextLoader.Column("TaxRate", DataKind.Single, 10),
new TextLoader.Column("TeacherRatio", DataKind.Single, 11)
},
hasHeader: true);

var pipeline = context.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides",
"RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio")
.Append(context.Regression.Trainers.FastTree());

var model = pipeline.Fit(data);

var transfomedData = model.Transform(data);

var linearModel = model.LastTransformer;

var featureContributionCalculation = context.Transforms.CalculateFeatureContribution(linearModel, normalize: false);

var featureContributionData = featureContributionCalculation.Fit(transfomedData).Transform(transfomedData);

var shuffledSubset = context.Data.TakeRows(context.Data.ShuffleRows(featureContributionData), 10);

var preview = shuffledSubset.Preview();
}

private IEnumerable<CustomerChurnInfo> GetChurnInfo()
{
var r = new Random(454);
Expand Down Expand Up @@ -626,5 +745,20 @@ private class AdultData
public float Target { get; set; }
}

private class HousingData
{
public float MedianHomeValue { get; set; }
public float CrimesPerCapita { get; set; }
public float PercentResidental { get; set; }
public float PercentNonRetail { get; set; }
public float CharlesRiver { get; set; }
public float NitricOxides { get; set; }
public float RoomsPerDwelling { get; set; }
public float PercentPre40s { get; set; }
public float EmploymentDistance { get; set; }
public float HighwayDistance { get; set; }
public float TaxRate { get; set; }
public float TeacherRatio { get; set; }
}
}
}