In [2]:
""" This module implements the Contrastive Explanation Method in Pytorch.

Paper:  https://arxiv.org/abs/1802.07623
"""

import numpy as np
import torch

class ContrastiveExplanationMethod:
    
    def __init__(
        self,
        classifier,
        mode: str,
        autoencoder = None,
        kappa: float = 0.,
        const: float = 10.,
        beta: float = .1,
        gamma: float = 0.,
        feature_range: tuple = (-1e10, 1e10)
    ):
        """
        Initialise the CEM model.
        
        classifier
            classification model to be explained.
        mode
            for pertinant negatives 'PN' or for pertinant positives 'PP'.
        autoencoder
            optional, autoencoder to be used for regularisation of the
            modifications to the explained samples.
        kappa
            confidence parameter used in the loss functions (eq. 2) and (eq. 4) in
            the original paper.
        const
            initial regularisation coefficient for the attack loss term.
        beta
            regularisation coefficent for the L1 term of the optimisation objective.
        gamma
            regularisation coefficient for the autoencoder term of the optimisation
            objective.
        feature_range
            range over which the features of the perturbed instances should be distributed.
        """
        
        self.explain_model = explain_model
        self.mode = mode
        self.autoencoder = autoencoder
        self.kappa = kappa
        self.const = const
        self.beta = beta
        self.gamma = gamma
        self.feature_range = feature_range
        
        self.delta = torch.zeros(orig_sample.shape)
        self.y = torch.zeros(orig_sample.shape)
    
    def fista(self, orig_sample):
        """Fast Iterative Shrinkage Thresholding Algorithm implementation in pytorch
        
        Paper: https://doi.org/10.1137/080716542
        
        (Eq. 5) and (eq. 6) in https://arxiv.org/abs/1802.07623
        """
        pass
    
    def shrink(self, z):
        """Element-wise shrinkage thresholding function.
        
        (Eq. 7) in https://arxiv.org/abs/1802.07623
        """
        pass
    
    def optimisation_obj(self, orig_sample):
        """
        Optimisation objective for PN (eq. 1) and for PP (eq. 3).
        """
        
        out = (
            self.const * self.loss_fn(orig_sample) +
            self.beta * torch.norm(self.y) +
            torch.norm(self.y) ** 2
        )
        if self.autoencoder is not None:
            if self.mode == "PN":
                out + gamma * torch.norm(orig_sample + self.y - self.autoencoder(orig_sample + self.y), axis=1)
            elif self.mode == "PP":
                out + gamma * torch.norm(self.y - self.autoencoder(self.y))
        return out

    def loss_fn(self, orig_sample):
        """
        Loss term f(x,d) for PN (eq. 2) and for PP (eq. 4).
        
        orig_sample
            the unperturbed original sample, batch size first.
        """
        
        orig_output = self.classifier(orig_sample)
        target_mask = torch.zeros(orig_output.shape)
        target_mask[torch.arange(orig_output.shape[0]), torch.argmax(orig_output, axis=1)] = 1
        nontarget_mask = torch.ones(orig_output.shape) - target_mask
        
        if self.mode == "PN":
            pert_output = torch.max(self.classifier(orig_sample + self.y), axis=1)
            modification_loss = torch.max(
                torch.max(orig_output, axis=1) - torch.max(nontarget_mask * pert_output, axis=1),
                -self.kappa
            )
        elif self.mode == "PP":
            pert_output = self.classifier(self.y)
            modification_loss = torch.max(
                torch.max(nontarget_mask * pert_output, axis=1) - torch.max(target_mask * pert_output, axis=1),
                -self.kappa
            )
        return modification_loss
    
