In [1]:
#r "/home/jovyan/Notebooks/extensions/Microsoft.DotNet.Interactive.XPlot/Microsoft.DotNet.Interactive.XPlot.dll"
#extend "/home/jovyan/Notebooks/extensions/Microsoft.DotNet.Interactive.XPlot/Microsoft.DotNet.Interactive.XPlot.dll"
#r "nuget:Microsoft.ML.FastTree,1.4.0-preview2"

using Microsoft.DotNet.Interactive.XPlot;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;

In [2]:
private DecisionTreeData DisplayTree(RegressionTreeEnsemble ensemble, in VBuffer<ReadOnlyMemory<char>> featureNames = default)
{
    // just get the first tree, for now
    return DisplayTree(ensemble.Trees.FirstOrDefault(), featureNames);
}

private DecisionTreeData DisplayTree(RegressionTree tree, in VBuffer<ReadOnlyMemory<char>> featureNames = default)
{
    DecisionTreeData treeData = new DecisionTreeData();

    if (tree is object)
    {
        NodeData[] nodes = new NodeData[tree.NumberOfNodes];
        StringBuilder labelBuilder = new StringBuilder();
        for (int node = 0; node < tree.NumberOfNodes; node++)
        {
            labelBuilder.Clear();
            nodes[node] = new NodeData();
            int featureIndex = tree.NumericalSplitFeatureIndexes[node];
            float splitThreshold = tree.NumericalSplitThresholds[node];
            
            ReadOnlyMemory<char> featureName = featureNames.GetItemOrDefault(featureIndex);
            if (!featureName.IsEmpty)
            {
                labelBuilder.Append(featureName);
            }
            else
            {
                labelBuilder.Append('f');
                labelBuilder.Append(featureIndex);
            }
            labelBuilder.Append($" > ");
            labelBuilder.Append(splitThreshold.ToString("n2"));

            nodes[node].Label = labelBuilder.ToString();
        }

        NodeData[] leaves = new NodeData[tree.NumberOfLeaves];
        for (int leaf = 0; leaf < tree.NumberOfLeaves; leaf++)
        {
            leaves[leaf] = new NodeData();
            leaves[leaf].Label = tree.LeafValues[leaf].ToString("n2");
        }

        // hook the nodes up
        NodeData GetNodeData(int child)
        {
            return child >= 0
                ? nodes[child]
                : leaves[~child];
        }

        for (int node = 0; node < tree.NumberOfNodes; node++)
        {
            // the RightChild is the 'greater than' path, so put that first
            nodes[node].Children.Add(GetNodeData(tree.RightChild[node]));
            nodes[node].Children.Add(GetNodeData(tree.LeftChild[node]));
        }

        if (nodes.Length > 0)
        {
            treeData.Root = nodes[0];
        }
    }

    return treeData;
}

In [3]:
MLContext mlContext = new MLContext();
ITransformer model = mlContext.Model.Load("HousingTreeModel.zip", out DataViewSchema inputSchema);

DataViewSchema outputSchema = model.GetOutputSchema(inputSchema);

RegressionPredictionTransformer<FastTreeRegressionModelParameters> predictor = ((TransformerChain<ITransformer>)model).LastTransformer as RegressionPredictionTransformer<FastTreeRegressionModelParameters>;
VBuffer<ReadOnlyMemory<char>> featureNames = default;
outputSchema[predictor.FeatureColumnName].GetSlotNames(ref featureNames);

In [4]:
DisplayTree(predictor.Model.TrainedTreeEnsemble.Trees[3], featureNames)