In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.autograd as autograd
import saliency.core as saliency

class ExplainerBase(object):

    def __init__(self, model_interface, data_interface):

        self.model_interface = model_interface
        self.data_interface = data_interface

    def generate_counterfactuals(self):

        raise NotImplementedError

    def opposing_class_constraint(self):
        
        raise NotImplementedError
        
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim * 2)  # Output has both mean and log variance
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        z_params = self.encoder(x)
        z_params = z_params.view(-1, 2, self.latent_dim)  # Reshape to (batch_size, 2, latent_dim)
        mu = z_params[:, 0, :]
        logvar = z_params[:, 1, :]
        z = self.reparameterize(mu, logvar)
        reconstructed_x = self.decoder(z)
        return reconstructed_x, mu, logvar


        
class QUCE(ExplainerBase):
    def __init__(self, data_interface, model_interface, input_dim, latent_dim):
        super().__init__(data_interface, model_interface)
        self.vae = VAE(input_dim, latent_dim)
        self.vae_optimizer = optim.Adam(self.vae.parameters(), lr=0.1)
        self.latent_dim = latent_dim
    
    def compute_integrated_gradients_uncertainty(self, instances_list, num_steps=1000):
        integrated_gradients = torch.zeros_like(torch.tensor(instances_list[0]))  # Convert to PyTorch tensor

        start_instance_tensor = torch.tensor(instances_list[0], dtype=torch.float32, requires_grad=True)
        end_instance_tensor = torch.tensor(instances_list[1], dtype=torch.float32, requires_grad=True)

            # Linear interpolation between start and end instances
        path_interp = torch.linspace(0, 1, num_steps).view(-1, 1)
        interp_instances = (1 - path_interp) * start_instance_tensor + path_interp * end_instance_tensor

            # Compute model output at each interpolated point
        outputs = self.model_interface(interp_instances)

            # Sum the outputs to make it a scalar (assuming the model returns a tensor)
        outputs_sum = torch.sum(outputs)

            # Compute gradients at each interpolated point
        gradients = torch.autograd.grad(outputs_sum, interp_instances, retain_graph=True)[0]

            # Clip gradients to a reasonable range
            #gradients = torch.clamp(gradients, min=-1.0, max=1.0)
            # Integrate gradients along the path
        integrated_gradients = torch.sum(gradients, dim=0)/len(gradients)*(end_instance_tensor.detach().numpy() - start_instance_tensor.detach().numpy())
        return integrated_gradients
    
    
    def compute_integrated_gradients_interpolation(self, instances_list, num_steps=100):
        integrated_gradients = torch.zeros_like(torch.tensor(instances_list[0]))  # Convert to PyTorch tensor

        for i in range(len(instances_list) - 1):
            start_instance_tensor = torch.tensor(instances_list[i], dtype=torch.float32, requires_grad=True)
            end_instance_tensor = torch.tensor(instances_list[i + 1], dtype=torch.float32, requires_grad=True)

            # Linear interpolation between start and end instances
            path_interp = torch.linspace(0, 1, num_steps).view(-1, 1)
            interp_instances = (1 - path_interp) * start_instance_tensor + path_interp * end_instance_tensor

            # Compute model output at each interpolated point
            outputs = self.model_interface(interp_instances)

            # Sum the outputs to make it a scalar (assuming the model returns a tensor)
            outputs_sum = torch.sum(outputs)

            # Compute gradients at each interpolated point
            gradients = torch.autograd.grad(outputs_sum, interp_instances, retain_graph=True)[0]

            # Clip gradients to a reasonable range
            #gradients = torch.clamp(gradients, min=-1.0, max=1.0)
            # Integrate gradients along the path
            integrated_gradients += torch.sum(gradients, dim=0)/len(gradients)*(end_instance_tensor - start_instance_tensor)

        return integrated_gradients
    
    def compute_blur_integrated_gradients(self, instances_list, num_steps=100, blur_sigma=0.15):
        integrated_gradients = torch.zeros_like(torch.tensor(instances_list[0]))  # Convert to PyTorch tensor

        for i in range(len(instances_list) - 1):
            start_instance_tensor = torch.tensor(instances_list[i], dtype=torch.float32, requires_grad=True)
            end_instance_tensor = torch.tensor(instances_list[i + 1], dtype=torch.float32, requires_grad=True)

            # Linear interpolation between start and end instances
            path_interp = torch.linspace(0, 1, num_steps).view(-1, 1)
            interp_instances = (1 - path_interp) * start_instance_tensor + path_interp * end_instance_tensor

            # Compute model output at each interpolated point
            outputs = self.model_interface(interp_instances)

            # Sum the outputs to make it a scalar (assuming the model returns a tensor)
            outputs_sum = torch.sum(outputs)

            # Compute gradients at each interpolated point
            gradients = torch.autograd.grad(outputs_sum, interp_instances, retain_graph=True)[0]

            # Apply blur to gradients
            blurred_gradients = self.apply_blur(gradients, sigma=blur_sigma)

            # Integrate blurred gradients along the path
            integrated_gradients += torch.sum(blurred_gradients, dim=0) / len(blurred_gradients) * (
                    end_instance_tensor - start_instance_tensor)

        return integrated_gradients
    
    def apply_blur(gradients, sigma=0.15):
        # Apply blur to gradients
        blurred_gradients = torch.nn.functional.gaussian_blur(gradients, kernel_size=1, sigma=sigma)

        return blurred_gradients
    def generate_counterfactuals(self, query_instance, feature_mask=None ,time_constant_index=None, time_constant_diff=None, ts_dist_weight=0.5, reconstruction_weight=1, proba_weight=1, optimizer=None, target_prob_threshold=None, lr=0.01, max_iter=1000):
        query_instance = torch.FloatTensor(query_instance)
        num_features = len(query_instance)
        cf_initialize = query_instance.clone().detach()  # Start with the query instance
        cf_instances_list = []
        integrated_gradients = torch.zeros_like(cf_initialize)
        
        if feature_mask is None:
            feature_mask = torch.ones(num_features)
        else:
            feature_mask = torch.FloatTensor(feature_mask)

        if optimizer == "adam":
            optimizer = torch.optim.Adam([cf_initialize], lr, betas=(0.9, 0.999))  # Adjust beta1 and beta2 as needed
        else:
            optimizer = torch.optim.SGD([cf_initialize], lr, momentum=0.9)  # Adjust the momentum parameter as needed
        
        if target_prob_threshold == None:
            if self.model_interface(query_instance) >= 0.5: 
                target_prob_threshold = 0.05  # Set the desired threshold
            else:
                target_prob_threshold = 0.95
         
        target_class_prob = torch.FloatTensor([target_prob_threshold])

        self.vae.eval()

        for i in range(max_iter):
            cf_initialize.requires_grad = True
            optimizer.zero_grad()
            cf_prob = self.model_interface(cf_initialize)
    
    # Pass cf_initialize through the VAE to get reconstructed_x, mu, and logvar
            reconstructed_x, mu, logvar = self.vae(cf_initialize)
            
    # Calculate the loss
            vae_reconstruction_loss = self.vae_loss(query_instance, reconstructed_x, mu, logvar)
            loss = self.opposing_class_constraint(cf_prob, target_class_prob, query_instance
                                                  , ts_dist_weight, proba_weight, cf_initialize) + reconstruction_weight*vae_reconstruction_loss
            loss.backward()
            
            optimizer.step()
            cf_instances_list.append(cf_initialize.clone().detach().numpy())
            # Convert cf_instances_list to PyTorch tensor
        # Compute the absolute difference vector
        diff_vector = reconstructed_x

    # Run integrated gradients for perturbed_plus
        perturbed_plus = cf_initialize + torch.abs(diff_vector)
        explanation_plus = self.compute_integrated_gradients_uncertainty([cf_initialize, perturbed_plus])

    # Run integrated gradients for perturbed_minus
        perturbed_minus = cf_initialize - torch.abs(diff_vector)
        explanation_minus = self.compute_integrated_gradients_uncertainty([cf_initialize, perturbed_minus])

        integrated_gradients = self.compute_integrated_gradients_interpolation(cf_instances_list)
        return cf_initialize, integrated_gradients, explanation_plus, explanation_minus, cf_instances_list
    
    
    def exepcted_attribution(self, query_instance, p=5, **kwargs):
        """
        Perform iterative counterfactual generation and calculate the mean feature attribution.

        Parameters:
        - query_instance: The input instance for which counterfactuals are generated.
        - p: The number of iterations.
        - **kwargs: Additional arguments to pass to the generate_counterfactuals method.

        Returns:
        - mean_attributions: The mean feature attributions over p iterations.
        """
        attributions_sum = torch.zeros_like(torch.tensor(query_instance), dtype=torch.float32)

        for _ in range(p):
            _, attributions, _, _, _ = self.generate_counterfactuals(query_instance, optimizer='adam', **kwargs)
            attributions_sum += attributions

        mean_attributions = attributions_sum / p

        return mean_attributions
    
    def plot_explanations(self, integrated_gradients, explanation_plus, explanation_minus, feature_names=None):
        num_features = len(integrated_gradients)
        feature_indices = np.arange(num_features)

    # Plot main integrated gradients
        plt.bar(feature_indices, integrated_gradients.detach().numpy(), label='Integrated Path Attribution', color='grey', alpha=1)
        center_positions = feature_indices + integrated_gradients.detach().numpy()
    # Plot uncertainty bars
        plt.bar(
            center_positions,
            explanation_plus.detach().numpy(),
            bottom=integrated_gradients.detach().numpy(),
            color='blue',
            label='Uncertainty (+ε)',
            alpha=1
        )
        plt.bar(
            center_positions,
            explanation_minus.detach().numpy(),
            bottom=integrated_gradients.detach().numpy(),
            color='red',
            label='Uncertainty (-ε)',
            alpha=1
        )
        plt.xticks(ticks=feature_indices, labels=feature_names, rotation='vertical')
        
        plt.legend()
        plt.xlabel('Feature Index')
        plt.ylabel('Attribution Value')
        plt.title('Path-Based Gradients with Uncertainty (ε)')
        plt.show()

    
    def train_vae(self, query_instance, num_epochs):
        self.vae.train()
        for epoch in range(num_epochs):
            instance = torch.FloatTensor(query_instance)
            reconstructed_x, mu, logvar = self.vae(instance)
            loss = self.vae_loss(instance, reconstructed_x, mu, logvar)
            self.vae_optimizer.zero_grad()
            loss.backward()
            self.vae_optimizer.step()

    def vae_loss(self, x, reconstructed_x, mu, logvar):
        # Calculate the VAE loss (reconstruction loss + KL divergence)
        reconstruction_loss = nn.MSELoss()(reconstructed_x, x)  # Change the loss function for non-binary data
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return reconstruction_loss + kl_divergence
    
    def reconstruction_loss(self, x, reconstructed_x):
        # Calculate the VAE loss (reconstruction loss + KL divergence)
        reconstruction_loss = nn.MSELoss()(reconstructed_x, x)  # Change the loss function for non-binary data
        #kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return reconstruction_loss 

    
    def opposing_class_constraint(self, cf_prob, target_class_prob, query_instance, ts_dist_weight, proba_weight, cf_initialize):
    # Calculate a loss to ensure that the generated instance belongs to the opposing class and; 
    # the distance between query instance and the counterfactual is minimal

        cf_prob = F.sigmoid(cf_prob)

        target_class_loss = F.mse_loss(cf_prob, target_class_prob)
   
        distances = np.sum(np.linalg.norm(query_instance.detach().numpy() - cf_initialize.detach().numpy(), axis=1))

        # Define the distance decay rate - latter points in time have a greater weighting
        #alpha = 0.1 + 0.1 * np.arange(len(query_instances.detach().numpy()))

        # Apply the exponential kernel to assign weights
        #weights = np.exp(-alpha * distances)

        # Normalize the weights (optional)
        #normalized_weights = weights / np.sum(weights)

        # Compute the weighted result
        #weighted_distance_list = np.dot(normalized_weights, query_instances.detach().numpy())
        #weighted_distance = np.sum(weighted_distance_list)
        return proba_weight*target_class_loss + ts_dist_weight*distances
        
        