In [None]:
import os
import sys

os.chdir('../')

In [None]:
import argparse
from pathlib import Path
import glob
from functools import partial

import numpy as np
import torch
from scipy.special import softmax
from tqdm.auto import tqdm

In [None]:
class ShapleyValues:
    '''For storing and plotting Shapley values.'''
    def __init__(self, values, std):
        self.values = values
        self.std = std

    def plot(self,
             feature_names=None,
             sort_features=True,
             max_features=np.inf,
             orientation='horizontal',
             error_bars=True,
             color='C0',
             title='Feature Importance',
             title_size=20,
             tick_size=16,
             tick_rotation=None,
             axis_label='',
             label_size=16,
             figsize=(10, 7),
             return_fig=False):
        '''
        Plot Shapley values.

        Args:
          feature_names: list of feature names.
          sort_features: whether to sort features by their Shapley values.
          max_features: number of features to display.
          orientation: horizontal (default) or vertical.
          error_bars: whether to include standard deviation error bars.
          color: bar chart color.
          title: plot title.
          title_size: font size for title.
          tick_size: font size for feature names and numerical values.
          tick_rotation: tick rotation for feature names (vertical plots only).
          label_size: font size for label.
          figsize: figure size (if fig is None).
          return_fig: whether to return matplotlib figure object.
        '''
        return plotting.plot(
            self, feature_names, sort_features, max_features, orientation,
            error_bars, color, title, title_size, tick_size, tick_rotation,
            axis_label, label_size, figsize, return_fig)

def default_min_variance_samples(game):
    '''Determine min_variance_samples.'''
    return 5

def default_variance_batches(num_players, batch_size):
    '''
    Determine variance_batches.

    This value tries to ensure that enough samples are included to make A
    approximation non-singular.
    '''

    return int(np.ceil(10 * num_players / batch_size))

def calculate_result(A, b, total):
    '''Calculate the regression coefficients.'''
    num_players = A.shape[1]
    try:
        if len(b.shape) == 2:
            A_inv_one = np.linalg.solve(A, np.ones((num_players, 1)))
        else:
            A_inv_one = np.linalg.solve(A, np.ones(num_players))
        A_inv_vec = np.linalg.solve(A, b)
        values = (
            A_inv_vec -
            A_inv_one * (np.sum(A_inv_vec, axis=0, keepdims=True) - total)
            / np.sum(A_inv_one))
    except np.linalg.LinAlgError:
        raise ValueError('singular matrix inversion. Consider using larger '
                         'variance_batches')

    return values

def ShapleyRegressionPrecomputed(
                      grand_value,
                      null_value,
                      model_outputs,
                      masks, 
                      num_players,
                      batch_size=512,
                      detect_convergence=True,
                      thresh=0.01,
                      n_samples=None,
                      paired_sampling=True,
                      return_all=False,
                      min_variance_samples=None,
                      variance_batches=None,
                      bar=True,
                      verbose=False):
    # Verify arguments.
    from tqdm.auto import tqdm

    if min_variance_samples is None:
        min_variance_samples = 5
    else:
        assert isinstance(min_variance_samples, int)
        assert min_variance_samples > 1

    if variance_batches is None:
        variance_batches = default_variance_batches(num_players, batch_size)
    else:
        assert isinstance(variance_batches, int)
        assert variance_batches >= 1

    # Possibly force convergence detection.
    if n_samples is None:
        n_samples = 1e20
        if not detect_convergence:
            detect_convergence = True
            if verbose:
                print('Turning convergence detection on')

    if detect_convergence:
        assert 0 < thresh < 1

    # Weighting kernel (probability of each subset size).
    weights = np.arange(1, num_players)
    weights = 1 / (weights * (num_players - weights))
    weights = weights / np.sum(weights)

    # Calculate null and grand coalitions for constraints.
    null = null_value
    grand = grand_value

    # Calculate difference between grand and null coalitions.
    total = grand - null

    # Set up bar.
    n_loops = int(np.ceil(n_samples / batch_size))
    if bar:
        if detect_convergence:
            bar = tqdm(total=1)
        else:
            bar = tqdm(total=n_loops * batch_size)

    # Setup.
    n = 0
    b = 0
    A = 0
    estimate_list = []

    # For variance estimation.
    A_sample_list = []
    b_sample_list = []

    # For tracking progress.
    var = np.nan * np.ones(num_players)
    if return_all:
        N_list = []
        std_list = []
        val_list = []

    # Begin sampling.
    for it in range(n_loops):
        # Sample subsets.
        #print(subsets.shape)
        S=masks[batch_size*it:batch_size*(it+1)]
        game_S=model_outputs[batch_size*it:batch_size*(it+1)]
#         print("S", S, S.sum(axis=1))
#         print("game(s)", game_S)
#         print("game(s)-null", game_S-null)


        A_sample = np.matmul(S[:, :, np.newaxis].astype(float),
                             S[:, np.newaxis, :].astype(float))


        b_sample = (S.astype(float).T
                    * (game_S - null)[:, np.newaxis].T).T
        
#         print("b", b_sample)
#         print("variance_batches", variance_batches)

        # Welford's algorithm.
        n += batch_size
        b += np.sum(b_sample - b, axis=0) / n
        A += np.sum(A_sample - A, axis=0) / n

        # Calculate progress.
        values = calculate_result(A, b, total)
        A_sample_list.append(A_sample)
        b_sample_list.append(b_sample)
        if len(A_sample_list) == variance_batches:
            # Aggregate samples for intermediate estimate.
            A_sample = np.concatenate(A_sample_list, axis=0).mean(axis=0)
            b_sample = np.concatenate(b_sample_list, axis=0).mean(axis=0)
            A_sample_list = []
            b_sample_list = []

            # Add new estimate.
            estimate_list.append(calculate_result(A_sample, b_sample, total))

            # Estimate current var.
            # print(len(estimate_list), min_variance_samples)
            if len(estimate_list) >= min_variance_samples:
                var = np.array(estimate_list).var(axis=0)

        # Convergence ratio.
        std = np.sqrt(var * variance_batches / (it + 1))
        ratio = np.max(
            np.max(std, axis=0) / (values.max(axis=0) - values.min(axis=0)))
        # print("std", var)
        # Print progress message.
        if verbose:
            if detect_convergence:
                print(f'StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})')
            else:
                print(f'StdDev Ratio = {ratio:.4f}')

        # Check for convergence.
        if detect_convergence:
            if ratio < thresh:
                if verbose:
                    print('Detected convergence')

                # Skip bar ahead.
                if bar:
                    bar.n = bar.total
                    bar.refresh()
                break

        # Forecast number of iterations required.
        if detect_convergence:
            N_est = (it + 1) * (ratio / thresh) ** 2
            if bar and not np.isnan(N_est):
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()
        elif bar:
            bar.update(batch_size)

        # Save intermediate quantities.
        if return_all:
            val_list.append(values)
            std_list.append(std)
            if detect_convergence:
                N_list.append(N_est)
        
        # print("size", batch_size*it, len(masks))
        if batch_size*(it+1)>=len(masks):
            break
    print(ratio)
    # Return results.
    if return_all:
        # Dictionary for progress tracking.
        iters = (
            (np.arange(it + 1) + 1) * batch_size *
            (1))
        tracking_dict = {
            'values': val_list,
            'std': std_list,
            'iters': iters}
        if detect_convergence:
            tracking_dict['N_est'] = N_list

        return ShapleyValues(values, std), tracking_dict
    else:
        return ShapleyValues(values, std)

In [None]:
def read_eval_results(path):
    file_set=set([p for p in glob.glob(str(Path(path)/"*.pt")) if p.split('/')[-1]!="shapley_output.pt"])
    
    path_grand_null=str(Path(path)/"grand_null.pt")
    file_set.remove(path_grand_null)
    
    file_list=sorted(list(file_set), key=lambda x: x.split("_")[-2])
    begin_idx=int(file_list[0].split('_')[-2])
    end_idx=int(file_list[0].split('_')[-1].replace('.pt',''))
    step_size=end_idx-begin_idx
    
    idx=begin_idx
    
    path_eval_list=[]
    while True:
        path_eval=str(Path(path)/f"mask_eval_{idx}_{idx+step_size}.pt")
        if path_eval in file_set:
            file_set.remove(path_eval)
            path_eval_list.append(path_eval)
        else:
            break
        idx+=step_size
            
    assert len(file_set)==0
    
    grand_null=torch.load(path_grand_null)
    eval_list=[torch.load(path_eval) for path_eval in path_eval_list]
    
    grand_logits=grand_null["logits"][0]
    grand_masks=grand_null["masks"][0]
    null_logits=grand_null["logits"][1]
    nulll_masks=grand_null["masks"][1]
    
    eval_logits=np.concatenate([eval_value["logits"] for eval_value in eval_list], axis=0)
    eval_masks=np.concatenate([eval_value["masks"] for eval_value in eval_list], axis=0)
    
    return {
        "grand": {"logits": grand_logits, "masks": grand_masks},
        "null": {"logits": null_logits, "masks": nulll_masks},
        "subsets": {"logits": eval_logits, "masks": eval_masks},
    }

In [None]:
def get_args():
     """Parse the command line arguments."""
    parser = argparse.ArgumentParser(description='Process some inputs.')
    
    # Argument for batch size without default
    parser.add_argument('--batch_size', type=int, required=True,
                        help='The batch size for processing.')
    
    # Argument for input path without default
    parser.add_argument('--input_path', type=str, required=True,
                        help='Path to the input directory.')
    
    # Argument for normalization function
    parser.add_argument('--normalize_function', type=str, choices=['softmax'], required=True,
                        help='The normalization function to be used. Options: softmax.')
    
    parser.add_argument('--num_players', type=int, required=True,
                        help='The number of players')    

    return parser.parse_args()


if __name__ == "__main__":
    
    args = get_args()
    
    # Accessing the arguments
    batch_size = args.batch_size
    input_path = args.input_path
    if args.normalize_function == "softmax":
        normalize_function = softmax
    else:
        raise ValueError("Unsupported normalization function")
    num_players = args.num_players

        
    sample_list=glob.glob(str(Path(input_path)/"[0-9]*"))
    
    pbar=tqdm(sample_list)
    
    for sample_path in pbar:
        eval_results=read_eval_results(path=sample_path)

        grand_value=eval_results["grand"]["logits"]
        if len(grand_value.shape)==1:
            grand_value=partial(normalize_function, axis=0)(grand_value)
        else:
            raise RuntimeError(f"Not supported grand shape {grand_value.shape}")

        null_value=eval_results["null"]["logits"]
        if len(null_value.shape)==1:
            null_value=partial(normalize_function, axis=0)(null_value)
        else:
            raise RuntimeError(f"Not supported null shape {null_value.shape}")       

        subsets_output=eval_results["subsets"]["logits"]
        if len(subsets_output.shape)==1:
            subsets_output=partial(normalize_function, axis=1)(subsets_output)    
        else:
            raise RuntimeError(f"Not supported subset outputs shape {subsets_output.shape}")        
                               
    
        subsets=eval_results["subsets"]["masks"]
        
        assert subsets_output.shape[1]==grand_value.shape[0]==null_value.shape[0]
        assert subsets.shape[1]==num_players        

        _,tracking_dict=ShapleyRegressionPrecomputed(grand_value=grand_value,
                      null_value=null_value,
                      model_outputs=subsets_output,
                      masks=subsets,
                      batch_size=batch_size,
                      num_players=num_players,
                      variance_batches=2,
                      min_variance_samples=2,
                      return_all=True,
                      bar=False
                      )  

        torch.save(obj=tracking_dict, f=str(Path(sample_path)/"shapley_output.pt"))
        
        pbar.set_postfix(ratio,refresh=True)

In [None]:
path_parent="logs/vitbase_imagenette_surrogate_eval_test/extract_output/test/*"
print(len(glob.glob(path_parent)))
iter_prev=None
for idx, path in enumerate(glob.glob(path_parent)):
    loaded=torch.load(path+'/shapley_output.pt')
    
    if iter_prev is not None and iter_prev!=loaded["iters"][-1]:
        print(idx, iter_prev, loaded["iters"])
    
    if iter_prev is None:
        print(loaded["iters"])
    iter_prev=loaded["iters"][-1]
    

In [None]:
len(loaded["values"])

In [None]:
loaded.keys()

In [None]:
loaded["iters"]