In [1]:
# default_exp interface

In [2]:
from ipynb_path import *

In [3]:
# export
from counterfactual.import_essentials import *

In [4]:
# export
class Clamp(torch.autograd.Function):
    """
    Clamp parameter to [0, 1]
    code from: https://discuss.pytorch.org/t/regarding-clamped-learnable-parameter/58474/4
    """
    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0, max=1)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()

class ExplainerBase(ABC):
    @abstractmethod
    def generate_cf(self, x: torch.Tensor, *args, **kargs):
        """generate cf explanation

        Args:
            x (torch.Tensor): input instance
        """
        raise NotImplementedError

In [5]:
# export
class LocalExplainerBase(nn.Module, ExplainerBase):
    def __init__(self,
                 x: torch.Tensor,
                 model: nn.Module):
        super().__init__()
        self.model = model
        self.model.freeze()
        self.x = x
        # self.clamp = Clamp()

    def forward(self):
        raise NotImplementedError

    def configure_optimizers(self):
        return torch.optim.Adam([self.cf], lr=0.001)

In [6]:
# export
class GlobalExplainerBase(ExplainerBase):
    pass

In [7]:
# export
class ABCBaseModule(ABC):
    @abstractmethod
    def model_forward(self, *x):
        raise NotImplementedError

    @abstractmethod
    def forward(self, *x):
        raise NotImplementedError

    @abstractmethod
    def predict(self, *x):
        raise NotImplementedError