Skip to content

csluchen/harsanyinet

Repository files navigation

HarsanyiNet

This repository contains the Python implementation for HarsanyiNet, "HarsanyiNet: Computing Accurate Shapley Values in a Single Forward Propagation", ICML 2023.

HarsanyiNet is an interpretable network architecture, which makes inferences on the input sample and simultaneously computes the exact Shapley values of the input variables in a single forward propagation (see papers for details and citations).

Install

HarsanyiNet can be installed in the Python 3 environment:

pip3 install git+https://github.com/csluchen/harsanyinet

The torchtoolbox package also needs to be installed:

pip3 install torchtoolbox

You may also use conda environment

conda create --name harsanyinet python=3.9
conda activate harsanyinet
pip3 install -r requirements.txt

How to use

HarsanyiNet-CNN

To train the model, you can use codes like the following:

  • CIFAR-10 dataset
python train.py
  • MNIST dataset
python train.py --dataset='MNIST' --num_layers=4 --channels=32 --beta=100 --gamma=0.05

or directly access the pre-trained HarsanyiNet in Google Drive. You can download pretrained_model.zip and unzip it into path like ./pretrained_model/{DATASET}/.../model_pths/{DATASET}.pth.

To compute Shapley values using HarsanyiNet in a single forward propagation, use codes like the following:

python shapley.py --save_path='./pretrained_model' --model_path='model_pths/CIFAR10.pth' --num_layers=10 --channels=256 --beta=1000 --gamma=1 

HarsanyiNet-MLP

Datasets

We provide implementation on three different tabular datasets from UCI repository, including

Getting Started

To get started, you can run python utils/tabular/data_preprocess.py to download and preprocess the data. The preprocessed data will be stored as annp.ndarry in data/{DATASET}/. Alternatively, you can directly use utils/data.py to load the dataloader directly, we have already incorporate this step.

To train the model, use the following code:

  • Census dataset
python train_tabular.py
  • Yeast dataset
python train_tabular.py --dataset Yeast --n_attributes 8
  • Commercial (TV News) dataset
python train_tabular.py --dataset Commercial --n_attributes 10

Note:

  • For the Census dataset, we provide the pretrained model under pretrained_model/Census.pth.
  • For the Yeast and Commercial dataset, we do not provide the pretrained models, beacuse both of the datasets don't have official data splits. We randomly split the whole dataset into 80% training data and 20% testing data.

To compute Shapley values using HarsanyiNet in a single forward propagation, use the following code:

  • Census
    • using the provided pretrained_model:
    python shapley_tabular.py --model_path pretrained_model/Census.pth
    
    • if you have trained your own model
     python shapley_tabular.py
    
  • Yeast
   python shapley_tabular.py --dataset Yeast --n_attributes 8
  • Commercial (TV News)
   python shapley_tabular.py --dataset Commercial --n_attributes 10

More details

Comparing Shapley values computed by HarsanyiNet and other methods

To compute the root mean squared error (RMSE) between the Shapley values computed by HarsanyiNet and sampling method, use the following code:

python shapley.py --sampling=True --runs=20000

Note: the larger the number of iterations (runs) of the sampling method, the more accurate the sampling method is and the longer it takes for the code to run.

To compute the RMSE between the Shapley values computed by HarsanyiNet and ground-truth Shapley values, use the following code:

python shapley.py --ground_truth=True

Sample notebooks

For image dataset, we provide a Jupyter notebook for the CIFAR-10 and MNIST dataset for calculating Shapley values via HarsanyiNet under notebooks/CIFAR-10.ipynb and notebooks/MNIST.ipynb, respectively.

For tabular dataset, we provide a Jupyter notebook for the Census dataset for calculating Shapley values via HarsanyiNet under notebooks/Census.ipynb

Citations

@InProceedings{chen23,
  title = {HarsanyiNet: Computing Accurate Shapley Values in a Single Forward Propagation},
  author = {Lu, Chen and Siyu, Lou and Keyan, Zhang and Jin, Huang and Quanshi, Zhang},
  booktitle = {Proceedings of the 40th International Conference on Machine Learning},
  year = {2023}
}

About

An interpretable network to compute the Shapley values in a single forward propagation.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published