Skip to content
An implementation of the Deep Neural Decision Forests in PyTorch
Branch: master
Clone or download
jingxil Merge pull request #4 from fubel/master
Adjust code to work with current version of PyTorch (torch==1.1.0)
Latest commit 90f5da0 May 9, 2019
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
data init commit Jan 29, 2018
.gitignore init commit Jan 29, 2018
LICENSE Initial commit Jan 29, 2018
README.md Adjust code to work with current version of PyTorch (torch==1.1.0) May 6, 2019
dataset.py init commit Jan 29, 2018
ndf.py init commit Jan 29, 2018
train.py Adjust code to work with current version of PyTorch (torch==1.1.0) May 6, 2019

README.md

Neural-Decision-Forests

An implementation of the Deep Neural Decision Forests(dNDF) in PyTorch.

Features

  • Two stage optimization as in the original paper Deep Neural Decision Forests (fix the neural network and optimize $\pi$ and then optimize $\Theta$ with the class probability distribution in each leaf node fixed )
  • Jointly training $\pi$ and $\Theta$ proposed by chrischoy in his work Fully Differentiable Deep Neural Decision Forest
  • Shallow Neural Decision Forest (sNDF)
  • Deep Neural Decision Forest (dNDF)

Datasets

MNIST, UCI_Adult, UCI_Letter and UCI_Yeast datasets are available. For datasets other than MNIST, you need to go to corresponding directory and run the get_data.sh script.

Requirements

  • Python 3.x
  • PyTorch >= 1.0.0
  • numpy
  • sklearn

Usage

python train.py --ARG=VALUE

in the case of training the sNDF on MNIST with alternating optimization, the command is like

python train.py -dataset mnist -n_class 10 -gpuid 0 -n_tree 80 -tree_depth 10 -batch_size 1000 -epochs 100

Results

Not spending much time on picking hyperparameters and without bells and whistles, I got the accuracy results(obtained by training $\pi$ and $\Theta$ seperately) as follows:

Dataset sNDF dNDF
MNIST 0.9794 0.9963
UCI_Adult 0.8558 NA
UCI_Letter 0.9507 NA
UCI_Yeast 0.6031 NA

By adding the nonlinearity in the routing function, the accuraries can reach 0.6502 and 0.9753 respectively on the UCI_Yeast and UCI_Letter.

Note

Some people may experience the 'loss is NaN' situation which could be caused by the output probability being zero. Please make sure you have normalized your data and used a large enough tree size and depth. In the case that you want to stick with your tree setting, a workaround could be to clamp the output value.

You can’t perform that action at this time.