Skip to content

Latest commit

 

History

History
212 lines (140 loc) · 9.7 KB

model_catalogue.md

File metadata and controls

212 lines (140 loc) · 9.7 KB
CurrentModule = CounterfactualExplanations 

Model Catalogue

While in general it is assumed that users will use this package to explain their pre-trained models, we provide out-of-the-box functionality to train various simple default models. In this tutorial, we will see how these models can be fitted to CounterfactualData.

Available Models

The standard_models_catalogue can be used to inspect the available default models:

standard_models_catalogue
Dict{Symbol, DataType} with 3 entries:
  :Linear       => Linear
  :DeepEnsemble => FluxEnsemble
  :MLP          => FluxModel

The dictionary keys correspond to the model names. In this case, the dictionary values are constructors that can be used called on instances of type CounterfactualData to fit the corresponding model. In most cases, users will find it most convenient to use the fit_model API call instead.

Fitting Models

Models from the standard model catalogue are a core part of the package and thus compatible with all offered counterfactual generators and functionalities.

The all_models_catalogue can be used to inspect all models offered by the package:

all_models_catalogue

However, when using models not included in the standard_models_catalogue, additional caution is advised: they might not be supported by all counterfactual generators or they might not be models native to Julia. Thus, a more thorough reading of their documentation may be necessary to make sure that they are used correctly.

Fitting Flux Models

First, let’s load one of the synthetic datasets. For this, we’ll first need to import the TaijaData.jl package:

n = 500
data = TaijaData.load_multi_class(n)
counterfactual_data = DataPreprocessing.CounterfactualData(data...)

We could use a Deep Ensemble (Lakshminarayanan, Pritzel, and Blundell 2017) as follows:

M = fit_model(counterfactual_data, :DeepEnsemble)

The returned object is an instance of type FluxEnsemble <: AbstractModel and can be used in downstream tasks without further ado. For example, the resulting fit can be visualised using the generic plot() method as:

plts = []
for target in counterfactual_data.y_levels
    plt = plot(M, counterfactual_data; target=target, title="p(y=$(target)|x,θ)")
    plts = [plts..., plt]
end
plot(plts...)

Importing PyTorch models

The package supports generating counterfactuals for any neural network that has been previously defined and trained using PyTorch, regardless of the specific architectural details of the model. To generate counterfactuals for a PyTorch model, save the model inside a .pt file and call the following function:

model_loaded = TaijaInteroperability.pytorch_model_loader(
    "$(pwd())/docs/src/tutorials/miscellaneous",
    "neural_network_class",
    "NeuralNetwork",
    "$(pwd())/docs/src/tutorials/miscellaneous/pretrained_model.pt"
)

The method pytorch_model_loader requires four arguments:

  1. The path to the folder with a .py file where the PyTorch model is defined
  2. The name of the file where the PyTorch model is defined
  3. The name of the class of the PyTorch model
  4. The path to the Pickle file that holds the model weights

In the above case:

  1. The file defining the model is inside $(pwd())/docs/src/tutorials/miscellaneous
  2. The name of the .py file holding the model definition is neural_network_class
  3. The name of the model class is NeuralNetwork
  4. The Pickle file is located at $(pwd())/docs/src/tutorials/miscellaneous/pretrained_model.pt

Though the model file and Pickle file are inside the same directory in this tutorial, this does not necessarily have to be the case.

The reason why the model file and Pickle file have to be provided separately is that the package expects an already trained PyTorch model as input. It is also possible to define new PyTorch models within the package, but since this is not the expected use of our package, special support is not offered for that. A guide for defining Python and PyTorch classes in Julia through PythonCall.jl can be found here.

Once the PyTorch model has been loaded into the package, wrap it inside the PyTorchModel class:

model_pytorch = TaijaInteroperability.PyTorchModel(model_loaded, counterfactual_data.likelihood)

This model can now be passed into the generators like any other.

Please note that the functionality for generating counterfactuals for Python models is only available if your Julia version is 1.8 or above. For Julia 1.7 users, we recommend upgrading the version to 1.8 or 1.9 before loading a PyTorch model into the package.

Importing R torch models

!!! warning "Not fully tested" Please note that due to the incompatibility between RCall and PythonCall, it is not feasible to test both PyTorch and RTorch implementations within the same pipeline. While the RTorch implementation has been manually tested, we cannot ensure its consistent functionality as it is inherently susceptible to bugs.

The CounterfactualExplanations package supports generating counterfactuals for neural networks that have been defined and trained using R torch. Regardless of the specific architectural details of the model, you can easily generate counterfactual explanations by following these steps.

Saving the R torch model

First, save your trained R torch model as a .pt file using the torch_save() function provided by the R torch library. This function allows you to serialize the model and save it to a file. For example:

torch_save(model, file = "$(pwd())/docs/src/tutorials/miscellaneous/r_model.pt")

Make sure to specify the correct file path where you want to save the model.

Loading the R torch model

To import the R torch model into the CounterfactualExplanations package, use the rtorch_model_loader() function. This function loads the model from the previously saved .pt file. Here is an example of how to load the R torch model:

model_loaded = TaijaInteroperability.rtorch_model_loader("$(pwd())/docs/src/tutorials/miscellaneous/r_model.pt")

The rtorch_model_loader() function requires only one argument:

  1. model_path: The path to the .pt file that contains the trained R torch model.

Wrapping the R torch model

Once the R torch model has been loaded into the package, wrap it inside the RTorchModel class. This step prepares the model to be used by the counterfactual generators. Here is an example:

model_R = TaijaInteroperability.RTorchModel(model_loaded, counterfactual_data.likelihood)

Generating counterfactuals with the R torch model

Now that the R torch model has been wrapped inside the RTorchModel class, you can pass it into the counterfactual generators as you would with any other model.

Please note that RCall is not fully compatible with PythonCall. Therefore, it is advisable not to import both R torch and PyTorch models within the same Julia session. Additionally, it’s worth mentioning that the R torch integration is still untested in the CounterfactualExplanations package.

Tuning Flux Models

By default, model architectures are very simple. Through optional arguments, users have some control over the neural network architecture and can choose to impose regularization through dropout. Let’s tackle a more challenging dataset: MNIST (LeCun 1998).

data = TaijaData.load_mnist(10000)
counterfactual_data = DataPreprocessing.CounterfactualData(data...)
train_data, test_data = 
    CounterfactualExplanations.DataPreprocessing.train_test_split(counterfactual_data)

In this case, we will use a Multi-Layer Perceptron (MLP) but we will adjust the model and training hyperparameters. Parameters related to training of Flux.jl models are currently stored in a mutable container:

flux_training_params
CounterfactualExplanations.FluxModelParams(:logitbinarycrossentropy, :Adam, 100, 1, false)

In cases like this one, where model training can be expected to take a few moments, it can be useful to activate verbosity, so let’s set the corresponding field value to true. We’ll also impose mini-batch training:

flux_training_params.verbose = true
flux_training_params.batchsize = round(size(train_data.X,2)/10)

To account for the fact that this is a slightly more challenging task, we will use an appropriate number of hidden neurons per layer. We will also activate dropout regularization. To scale networks up further, it is also possible to adjust the number of hidden layers, which we will not do here.

model_params = (
    n_hidden = 32,
    dropout = true
)

The model_params can be supplied to the familiar API call:

M = fit_model(train_data, :MLP; model_params...)
CounterfactualExplanations.Models.Model(Chain(Dense(784 => 32, relu), Dropout(0.25, active=false), Dense(32 => 10)), :classification_multi, Chain(Dense(784 => 32, relu), Dropout(0.25, active=false), Dense(32 => 10)), MLP())

The model performance on our test set can be evaluated as follows:

model_evaluation(M, test_data)
1-element Vector{Float64}:
 0.9185

Finally, let’s restore the default training parameters:

CounterfactualExplanations.reset!(flux_training_params)

References

Lakshminarayanan, Balaji, Alexander Pritzel, and Charles Blundell. 2017. “Simple and Scalable Predictive Uncertainty Estimation Using Deep Ensembles.” Advances in Neural Information Processing Systems 30.

LeCun, Yann. 1998. “The MNIST Database of Handwritten Digits.” http://yann.lecun.com/exdb/mnist/.