Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confidence Intervals for Permutation Feature Importance #1844

Merged
merged 10 commits into from
Dec 20, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public static void PFI_Regression()
// Compute the permutation metrics using the properly-featurized data.
var transformedData = model.Transform(data);
var permutationMetrics = mlContext.Regression.PermutationFeatureImportance(
linearPredictor, transformedData, label: labelName, features: "Features");
linearPredictor, transformedData, label: labelName, features: "Features", numPermutations: 3);

// Now let's look at which features are most important to the model overall
// First, we have to prepare the data:
Expand All @@ -80,23 +80,27 @@ public static void PFI_Regression()

// Get the feature indices sorted by their impact on R-Squared
var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.RSquared })
.OrderByDescending(feature => Math.Abs(feature.RSquared))
.OrderByDescending(feature => Math.Abs(feature.RSquared.Mean))
rogancarr marked this conversation as resolved.
Show resolved Hide resolved
.Select(feature => feature.index);

// Print out the permutation results, with the model weights, in order of their impact:
// Expected console output:
// Feature Model Weight Change in R - Squared
// RoomsPerDwelling 50.80 -0.3695
// EmploymentDistance -17.79 -0.2238
// TeacherRatio -19.83 -0.1228
// TaxRate -8.60 -0.1042
// NitricOxides -15.95 -0.1025
// HighwayDistance 5.37 -0.09345
// CrimesPerCapita -15.05 -0.05797
// PercentPre40s -4.64 -0.0385
// PercentResidental 3.98 -0.02184
// CharlesRiver 3.38 -0.01487
// PercentNonRetail -1.94 -0.007231
// Feature Model Weight Change in R-Squared 95% Confidence
// RoomsPerDwelling 50.96 -0.4094 0.04344
// EmploymentDistance -17.55 -0.235 0.02501
// TeacherRatio -19.99 -0.1042 0.02287
// NitricOxides -15.75 -0.1017 0.006257
// HighwayDistance 5.44 -0.09583 0.01006
// TaxRate -8.55 -0.08898 0.03211
// CrimesPerCapita -14.97 -0.05299 0.01215
// PercentPre40s -4.64 -0.04206 0.008414
// PercentResidental 4.06 -0.02143 0.008526
// CharlesRiver 3.71 -0.01802 0.004324
// PercentNonRetail -1.91 -0.007466 0.001664
//
// HEY
// DO NOT MERGE UNLESS THIS IS UPDATED
// /HEY
//
// Let's dig into these results a little bit. First, if you look at the weights of the model, they generally correlate
// with the results of PFI, but there are some significant misorderings. For example, "Tax Rate" is weighted lower than
Expand All @@ -110,12 +114,14 @@ public static void PFI_Regression()
// variables in this dataset. The reason why the linear model weights don't reflect the same feature importance as PFI
// is that the solution to the linear model redistributes weights between correlated variables in unpredictable ways, so
// that the weights themselves are no longer a good measure of feature importance.
Console.WriteLine("Feature\tModel Weight\tChange in R-Squared");
Console.WriteLine("Feature\tModel Weight\tChange in R-Squared\t95% Confidence");
var rSquared = permutationMetrics.Select(x => x.RSquared).ToArray(); // Fetch r-squared as an array
foreach (int i in sortedIndices)
{
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i]:G4}");
Console.WriteLine($"{featureNames[i]}\t{weights[i]:0.00}\t{rSquared[i].Mean:G4}\t{1.96 * rSquared[i].StandardDeviation:G4}");
}

throw new NotImplementedException("Haven't completed the documentation!");
}

private static float[] GetLinearModelWeights(LinearRegressionPredictor linearModel)
Expand Down
75 changes: 44 additions & 31 deletions src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@

namespace Microsoft.ML.Transforms
{
internal static class PermutationFeatureImportance<TResult>
internal static class PermutationFeatureImportance<TMetric, TResult> where TResult : MetricsStatisticsBase<TMetric>, new()
{
public static ImmutableArray<TResult>
GetImportanceMetricsMatrix(
IHostEnvironment env,
IPredictionTransformer<IPredictor> model,
IDataView data,
Func<IDataView, TResult> evaluationFunc,
Func<TResult, TResult, TResult> deltaFunc,
Func<IDataView, TMetric> evaluationFunc,
Func<TMetric, TMetric, TMetric> deltaFunc,
string features,
int numPermutations,
bool useFeatureWeightFilter = false,
int? topExamples = null)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register(nameof(PermutationFeatureImportance<TResult>));
var host = env.Register(nameof(PermutationFeatureImportance<TMetric, TResult>));
host.CheckValue(model, nameof(model));
host.CheckValue(data, nameof(data));
host.CheckNonEmpty(features, nameof(features));
Expand Down Expand Up @@ -168,7 +169,7 @@ internal static class PermutationFeatureImportance<TResult>
// Now iterate through all the working slots, do permutation and calc the delta of metrics.
int processedCnt = 0;
int nextFeatureIndex = 0;
int shuffleSeed = host.Rand.Next();
var shuffleRand = RandomUtils.Create(host.Rand.Next());
using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup"))
{
pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
Expand All @@ -178,27 +179,9 @@ internal static class PermutationFeatureImportance<TResult>
if (processedCnt < workingFeatureIndices.Count - 1)
nextFeatureIndex = workingFeatureIndices[processedCnt + 1];

// Used for pre-caching the next feature
int nextValuesIndex = 0;

Utils.Shuffle<float>(RandomUtils.Create(shuffleSeed), featureValuesBuffer);

Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
(src, dst, state) =>
{
src.Features.CopyTo(ref dst.Features);
VBufferUtils.ApplyAt(ref dst.Features, workingIndx,
(int ii, ref float d) =>
d = featureValuesBuffer[state.SampleIndex++]);

if (processedCnt < workingFeatureIndices.Count - 1)
{
// This is the reason I need PermuterState in LambdaTransform.CreateMap.
nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex);
if (nextValuesIndex < valuesRowCount - 1)
nextValuesIndex++;
}
};

SchemaDefinition input = SchemaDefinition.Create(typeof(FeaturesBuffer));
Contracts.Assert(input.Count == 1);
input[0].ColumnName = features;
Expand All @@ -208,15 +191,45 @@ internal static class PermutationFeatureImportance<TResult>
output[0].ColumnName = features;
output[0].ColumnType = featuresColumn.Type;

IDataView viewPermuted = LambdaTransform.CreateMap(
host, data, permuter, null, input, output);
if (valuesRowCount == topExamples)
viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeArguments() { Count = valuesRowCount }, viewPermuted);
// Perform multiple permutations for one feature to build a confidence interval
var metricsDeltaForFeature = new TResult();
for (int permutationIteration = 0; permutationIteration < numPermutations; permutationIteration++)
{
Utils.Shuffle<float>(shuffleRand, featureValuesBuffer);

var metrics = evaluationFunc(model.Transform(viewPermuted));
Action<FeaturesBuffer, FeaturesBuffer, PermuterState> permuter =
(src, dst, state) =>
{
src.Features.CopyTo(ref dst.Features);
VBufferUtils.ApplyAt(ref dst.Features, workingIndx,
(int ii, ref float d) =>
d = featureValuesBuffer[state.SampleIndex++]);

// Is it time to pre-cache the next feature?
if (permutationIteration == numPermutations - 1 &&
processedCnt < workingFeatureIndices.Count - 1)
{
// Fill out the featureValueBuffer for the next feature while updating the current feature
// This is the reason I need PermuterState in LambdaTransform.CreateMap.
nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex);
if (nextValuesIndex < valuesRowCount - 1)
nextValuesIndex++;
}
};

IDataView viewPermuted = LambdaTransform.CreateMap(
host, data, permuter, null, input, output);
if (valuesRowCount == topExamples)
viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeArguments() { Count = valuesRowCount }, viewPermuted);

var metrics = evaluationFunc(model.Transform(viewPermuted));

var delta = deltaFunc(metrics, baselineMetrics);
metricsDeltaForFeature.Add(delta);
}

var delta = deltaFunc(metrics, baselineMetrics);
metricsDelta.Add(delta);
// Add the metrics delta to the list
metricsDelta.Add(metricsDeltaForFeature);

// Swap values for next iteration of permutation.
Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length);
Expand Down
Loading