Skip to content

csinva/transformation-importance

master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Official code for using / reproducing TRIM from the paper Transformation Importance with Applications to Cosmology (ICLR 2020 Workshop). This code shows examples and provides useful wrappers for calculating importance in a transformed feature space.

This repo is actively maintained. For any questions please file an issue.

trim

examples/documentation

  • dependencies: depends on the pip-installable acd package
  • examples: different folders (e.g. ex_cosmology, ex_fake_news, ex_mnist, ex_urban_sound contain examples for using TRIM in different settings)
  • src: the core code is in the trim folder, containing wrappers and code for different transformations
  • requirements: tested with python 3.7 and pytorch > 1.0
Attribution to different scales in cosmological images Fake news attribution to different topics
Attribution to different NMF components in MNIST classification Attribution to different frequencies in audio classification

sample usage

import torch
import torch.nn as nn
from trim import TrimModel
from functools import partial

# setup a trim model
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1)) # orig model
transform = partial(torch.rfft, signal_ndim=1, onesided=False) # fft
inv_transform = partial(torch.irfft, signal_ndim=1, onesided=False) # inverse fft
model_trim = TrimModel(model=model, inv_transform=inv_transform) # trim model

# get a data point
x = torch.randn(1, 10)
s = transform(x)

# can now use any attribution method on the trim model
# get (input_x_gradient) attribution in the fft space
s.requires_grad = True
model_trim(s).backward()
input_x_gradient = s.grad * s
  • see notebooks for more detailed usage

related work

  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

reference

  • feel free to use/share this code openly
  • if you find this code useful for your research, please cite the following:
@article{singh2020transformation,
    title={Transformation Importance with Applications to Cosmology},
    author={Singh, Chandan and Ha, Wooseok and Lanusse, Francois, and Boehm, Vanessa, and Liu, Jia and Yu, Bin},
    journal={arXiv preprint arXiv:2003.01926},
    year={2020},
    url={https://arxiv.org/abs/2003.01926},
}