Skip to content


Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?


Feng, Jean, and Noah Simon. March 2022. “Ensembled Sparse‐input Hierarchical Networks for High‐dimensional Datasets.” Statistical Analysis and Data Mining.

Python code for fitting EASIER-nets and reproducing all results from the paper. The python code uses PyTorch.

R code for fitting EASIER-net is available at


Setup a python virtual environment (code runs for python 3.6) with the appropriate packages from requirements.txt.

Simulate data using by following the tutorial notebook or load your own into a npz format with x and y attributes. You may also perform GridSearchCV by following the tutorial.

To fit an EASIER-net, run

python --n-estimators <N_ESTIMATORS> --input-filter-layer <INPUT_FILTER_LAYER> --n-layers <N_LAYERS> --n-hidden <N_HIDDEN> --input-pen <INPUT_PEN> --full-tree-pen <FULL_TREE_PEN> --batch-size <BATCH_SIZE> --num-classes <NUM_CLASSES>  --weight <WEIGHT> --max-iters <MAX_ITERS> --max-prox-iters <MAX_PROX_ITERS> --model-fit-params-file <MODEL_FIT_PARAMS_FILE>


  • N_ESTIMATORS should be size of ensemble; the number of SIER-nets being ensembled.
  • INPUT_FILTER_LAYER is whether to scale the inputs by parameter β
  • N_LAYERS is the number of hidden layers
  • N_HIDDEN is the number of hidden nodes per layer
  • INPUT_PEN specifies $\lambda_1$ in the paper; controls the input sparsity
  • FULL_TREE_PEN specifies $\lambda_2$ in the paper; controls the number of active layers and hidden nodes
  • BATCH_SIZE specifies the size of the mini-batches for Adam
  • NUM_CLASSES should be 0 if doing regression and NUM_CLASSES should be the number of classes if doing binary/multi-classification
  • WEIGHT is a list of weights for the classes
  • MAX_ITERS is the number of epochs to run Adam
  • MAX_PROX_ITERS is the number of epochs to run batch proximal gradient descent
  • MODEL_FIT_PARAMS_FILE is a json file that specifies what the hyperparameters are. If given, this will override the arguments passed in.

To perform cross-validation, one should run separate scripts for each candidate penalty parameter values. Then select the best penalty parameter values using


EASIER-net: Ensemble by Averaging Sparse-Input Hierarchical Networks







No releases published


No packages published