Skip to content

Commit

Permalink
feat: add a pytorch wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucashervier committed Aug 22, 2023
1 parent a7d5aa6 commit 4788b94
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 4 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
extras_require={
"tests": ["pytest", "pylint"],
"docs": ["mkdocs", "mkdocs-material", "numkdoc"],
"pytorch": ["torch"]
},
packages=find_packages(),
python_requires=">=3.6",
Expand Down
11 changes: 7 additions & 4 deletions xplique/attributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ def __init__(self, model: Callable, batch_size: Optional[int] = 64,
operator: Optional[Callable[[tf.keras.Model, tf.Tensor, tf.Tensor], float]] = None):

if isinstance(model, tf.keras.Model):
model_key = (id(model.input), id(model.output))
if model_key not in BlackBoxExplainer._cache_models:
BlackBoxExplainer._cache_models[model_key] = model
self.model = BlackBoxExplainer._cache_models[model_key]
try:
model_key = (id(model.input), id(model.output))
if model_key not in BlackBoxExplainer._cache_models:
BlackBoxExplainer._cache_models[model_key] = model
self.model = BlackBoxExplainer._cache_models[model_key]
except AttributeError:
self.model = model
else:
self.model = model

Expand Down
5 changes: 5 additions & 0 deletions xplique/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Wrappers available
"""

from .pytorch import TorchWrapper
147 changes: 147 additions & 0 deletions xplique/wrappers/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Module for having a wrapper for PyTorch's model
"""
try:
# use PyTorch functionality
import torch
from torch import nn
from torch import from_numpy
except ImportError as exc:
raise ImportError(
"PyTorch is required to use this feature. Please install PyTorch using 'pip install torch'."
) from exc

import tensorflow as tf
import numpy as np

from ..types import Union, Optional, Tuple, Callable

class TorchWrapper(tf.keras.Model):
"""
A wrapper for PyTorch's model so that they can be used in Xplique framework
for most attributions method
Parameters
----------
torch_model
A nn.Module PyTorch model
device
If we are on GPU or CPU
is_channel_first
A boolean that is true if the torch's model expect a channel dim and if this one come first
"""

def __init__(self, torch_model: nn.Module, device: Union['torch.device', str],
is_channel_first: Optional[bool] = None):
super().__init__()
self.model = torch_model.to(device)
self.device = device
# with torch, the convention for CNN is (N, C, H, W)
if is_channel_first is None:
self.channel_first = self._check_conv_layers()
else:
self.channel_first = is_channel_first
# deactivate all tf.function
tf.config.run_functions_eagerly(True)

# pylint: disable=arguments-differ
@tf.custom_gradient
def call(self, inputs: np.ndarray) -> Tuple[tf.Tensor, Callable]:
"""
A method that allow to call the PyTorch model wrapped on inputs that are
numpy arrays reshaped to match Xplique's explainers expectations (i.e. channel-last)
, get the gradients from PyTorch framework and recast outputs and gradients
as Tensorflow tensors.
Parameters
----------
inputs
Processed inputs as numpy arrays
Returns
-------
outputs
Outputs of the PyTorch model converted to Tensorflow tensor
grad
The function that allow to compute the gradient of the PyTorch model and
broadcast it for Tensorflow
"""
# transform your numpy inputs to torch
torch_inputs = self.transform_np_inputs(inputs).to(self.device)
torch_inputs.requires_grad_(True)

# make predictions
self.model.zero_grad()
outputs = self.model(torch_inputs)
output_tensor = tf.constant(outputs.cpu().detach().numpy())

def grad(upstream):
upstream_tensor = tf.constant(upstream.numpy())
torch.autograd.backward(
outputs.cpu(),
grad_tensors=from_numpy(upstream_tensor.numpy()),
retain_graph=True
)
dx_torch = torch_inputs.grad

dx_np = dx_torch.cpu().detach().numpy()
if self.channel_first:
# (N, C, H, W) -> (N, H, W, C) for explainer
dx_np = np.moveaxis(dx_np, [1, 2, 3], [3, 1, 2])

gradient = tf.constant(dx_np)
return gradient

return output_tensor, grad

def transform_np_inputs(self, np_inputs: np.ndarray):
"""
Methods that transform inputs as expected by the explainer to inputs expected
by your PyTorch model.
Notes
-----
If your model (or your data pipeline) has more specifities, change this function to make
the bridge
Parameters
----------
np_inputs
Inputs as numpy arrays (simpler for conversion between PyTorch
and Tensorflow)
Returns
-------
torch_inputs
Torch Tensor with the format expected by the PyTorch model
"""
if self.channel_first:
# go from channel last to channel first for the torch model
np_inputs = np.moveaxis(np_inputs, [3, 1, 2], [1, 2, 3])
else:
np_inputs = np.asarray(np_inputs)

# convert numpy array to torch tensor
torch_inputs = torch.Tensor(np_inputs)

return torch_inputs

def _check_conv_layers(self):
"""
A method that checks if the PyTorch's model has 2D convolutional layer.
Indeed, convolution with PyTorch expects inputs in the shape (N, C, H, W)
where TF expect (N, H, W, C).
Returns
-------
has_conv_layers
A boolean that says if the PyTorch model has Conv2d layers
"""
has_conv_layers = False

for module in self.model.modules():
if isinstance(module, nn.Conv2d):
has_conv_layers = True
break

return has_conv_layers

0 comments on commit 4788b94

Please sign in to comment.