In [1]:
import os
os.chdir("..")

In [2]:
import quantus
import torch
import numpy as np
from src import Trainable
from src.sverl import Shapley

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODELS = {
    'ppo': {
        'breakout': {
            'policy': [
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppo-ps-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppo-ps-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppo-ps-final-3\\400.pt'),
            ],
            'value': [
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppo-vs-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppo-vs-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppo-vs-final-3\\400.pt'),
            ]
        },
        'pong': {
            'policy': [
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppo-ps-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppo-ps-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppo-ps-final-3\\400.pt'),
            ],
            'value': [
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppo-vs-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppo-vs-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppo-vs-final-3\\400.pt'),
            ]
        },
        'invaders': {
            'policy': [
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppo-ps-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppo-ps-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppo-ps-final-3\\400.pt'),
            ],
            'value': [
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppo-vs-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppo-vs-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppo-vs-final-3\\400.pt'),
            ]
        }
    },
    'ppg': {
        'breakout': {
            'policy': [
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppg-ps-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppg-ps-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppg-ps-final-3\\400.pt'),
            ],
            'value': [
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppg-vs-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppg-vs-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\breakout-ppg-vs-final-3\\400.pt'),
            ]
        },
        'pong': {
            'policy': [
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppg-ps-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppg-ps-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppg-ps-final-3\\400.pt'),
            ],
            'value': [
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppg-vs-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppg-vs-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\pong-ppg-vs-final-3\\400.pt'),
            ]
        },
        'invaders': {
            'policy': [
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppg-ps-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppg-ps-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppg-ps-final-3\\400.pt'),
            ],
            'value': [
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppg-vs-final-1\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppg-vs-final-2\\400.pt'),
                lambda : Trainable.load_checkpoint('checkpoints\\invaders-ppg-vs-final-3\\400.pt'),
            ]
        }
    }
}

In [4]:
MASKS = [-0.5, 'mean', 0.5]

In [5]:
AGENT: str = 'ppo'
ENV: str = 'breakout'
TYPE: str = 'policy'
INDEX: int = 0

In [6]:
def apply_metric(metric, model, x_batch, y_batch, explain_func, argmax=False, **_):
    if argmax:
        y_batch = y_batch.argmax(dim=-1)
    return metric(
        model=model.target,
        x_batch=x_batch.cpu().numpy(),
        y_batch=y_batch.cpu().numpy(),
        device=model.device,
        explain_func=explain_func,
    )[0]

def apply_metric_argmax(metric, model, x_batch, y_batch, explain_func, **_):
    return apply_metric(metric, model, x_batch, y_batch, explain_func, argmax=True, **_)

def apply_metric_special(metric, model, x_batch, y_batch, explain_func, mask):
    return metric(
        model=model.target,
        x_batch=x_batch.cpu().numpy(),
        y_batch=y_batch.argmax(dim=-1).cpu().numpy(),
        device=model.device,
        explain_func=explain_func,
        mask=mask,
    )[0]

In [7]:
metrics = {
    'Faithfulness': {
        # 'FaithfulnessCorrelation': (
        #     lambda model, x_batch, y_batch, device, explain_func, mask : 
        #     quantus.FaithfulnessCorrelation(
        #         nr_runs=1000,
        #         subset_size=84,
        #         perturb_baseline=mask,
        #         perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
        #         similarity_func=quantus.similarity_func.cosine,
        #         abs=True,
        #         normalise=True,
        #         aggregate_func=np.mean,
        #         return_aggregate=True,
        #         disable_warnings=True,
        #     )(model=model,x_batch=x_batch,y_batch=y_batch,device=device,explain_func=explain_func),
        #     apply_metric_special
        # ),
        'PixelFlipping': (
            lambda model, x_batch, y_batch, device, explain_func, mask : 
            quantus.PixelFlipping(
                perturb_baseline=mask,
                abs=True,
                normalise=True,
                aggregate_func=np.mean,
                return_aggregate=True,
                disable_warnings=True,
            )(model=model,x_batch=x_batch,y_batch=y_batch,device=device,explain_func=explain_func),
            apply_metric_special
        ),
        'Selectivity': (
            lambda model, x_batch, y_batch, device, explain_func, mask :
            quantus.Selectivity(
                perturb_baseline=mask,
                abs=True,
                normalise=True,
                aggregate_func=np.mean,
                return_aggregate=True,
                disable_warnings=True,
            )(model=model,x_batch=x_batch,y_batch=y_batch,device=device,explain_func=explain_func),
            apply_metric_special
        )
    },
    # 'Robustness': {
    #     'AvgSensitivity': (
    #         quantus.AvgSensitivity(
    #             nr_samples=10,
    #             lower_bound=0.2,
    #             norm_numerator=quantus.norm_func.fro_norm,
    #             norm_denominator=quantus.norm_func.fro_norm,
    #             perturb_func=quantus.perturb_func.uniform_noise,
    #             similarity_func=quantus.similarity_func.difference,
    #             abs=True,
    #             normalise=True,
    #             aggregate_func=np.mean,
    #             return_aggregate=True,
    #             disable_warnings=True,
    #         ),
    #         apply_metric
    #     ),
    #     'MaxSensitivity': (
    #         quantus.MaxSensitivity(
    #             nr_samples=10,
    #             lower_bound=0.2,
    #             norm_numerator=quantus.norm_func.fro_norm,
    #             norm_denominator=quantus.norm_func.fro_norm,
    #             perturb_func=quantus.perturb_func.uniform_noise,
    #             similarity_func=quantus.similarity_func.difference,
    #             abs=True,
    #             normalise=True,
    #             aggregate_func=np.mean,
    #             return_aggregate=True,
    #             disable_warnings=True,
    #         ),
    #         apply_metric
    #     ),
    #     'LocalLipschitzEstimate': (
    #         quantus.LocalLipschitzEstimate(
    #             abs=True,
    #             normalise=True,
    #             aggregate_func=np.mean,
    #             return_aggregate=True,
    #             disable_warnings=True,
    #         ),
    #         apply_metric
    #     )
    # }
}

In [None]:
results = {}
for env in ['breakout', 'pong', 'invaders']:
    x_models: list[Shapley] = MODELS['ppo'][env]['policy']
    x_model: Shapley = x_models[0]()

    x_batch = [*x_model.state_sampler.sample(5000, batch_size=5000)][0].to(x_model.device)[:1]

    for agent in ['ppo', 'ppg']:
        for type in ['policy', 'value']:
            for index in range(3):
                
                models: list[Shapley] = MODELS[agent][env][type]
                model: Shapley = models[index]()

                model.target.eval()

                with torch.no_grad():
                    y_batch = model.target(x_batch)

                get_shapley_values = lambda inputs, **_ : model.infer(torch.Tensor(inputs).to(model.device)).sum(dim=-1).cpu().numpy()

                for metric_type in metrics:
                    for metric in metrics[metric_type]:
                        metric_method, apply = metrics[metric_type][metric]
                        value = apply(metric_method, model, x_batch, y_batch, get_shapley_values, mask=MASKS[index])
                        results[(agent, env, type, index, metric_type, metric)] = value
                        print((agent, env, type, index, metric_type, metric), value)


  logger.warn(
  logger.warn(


In [None]:
print(len(results))

216


In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame(
    [(a, e, t, m_idx, m_type, m_name, val) for (a, e, t, m_idx, m_type, m_name), val in results.items()],
    columns=['agent', 'environment', 'explanation_type', 'method_index', 'metric_type', 'metric_name', 'value']
)

In [None]:
df

Unnamed: 0,agent,environment,explanation_type,method_index,metric_type,metric_name,value
0,ppo,breakout,policy,0,Faithfulness,FaithfulnessCorrelation,0.437206
1,ppo,breakout,policy,0,Faithfulness,PixelFlipping,0.222820
2,ppo,breakout,policy,0,Faithfulness,Selectivity,0.284400
3,ppo,breakout,policy,0,Robustness,AvgSensitivity,0.246340
4,ppo,breakout,policy,0,Robustness,MaxSensitivity,0.255735
...,...,...,...,...,...,...,...
211,ppg,invaders,value,2,Faithfulness,PixelFlipping,1.000000
212,ppg,invaders,value,2,Faithfulness,Selectivity,1.000000
213,ppg,invaders,value,2,Robustness,AvgSensitivity,0.741953
214,ppg,invaders,value,2,Robustness,MaxSensitivity,0.745698


In [None]:
df.to_csv('fastsverl_metrics2.csv', index=False)