In [None]:
import os
import os.path as osp

import torch

try:
    import pixpnet
except ImportError:
    import sys

    sys.path.append('..')

    import pixpnet
finally:
    from pixpnet.data import get_datasets
    from pixpnet.lightning.lightning_data import LitData
    from pixpnet.protonets.evaluate import consistency
    from pixpnet.protonets.evaluate import stability
    from pixpnet.protonets.evaluate import relevance_ordering_test
    from pixpnet.protonets.utils import load_config_and_best_model
    from pixpnet.utils_torch import unravel_index

# To Run
Replace `logdir` with the path to your log directory (relative to the notebook). A template format is shown below.

In [None]:
logdir = '/path/to/logs/protonet/dataset/protonet/timestamp'

In [None]:
config, model = load_config_and_best_model(logdir)
print(config.model.feature_extractor, config.model.feature_layer)

In [None]:
model = model.eval().to('cuda')

In [None]:
data_parts = LitData(
    config,
    yield_img_id=True,
    yield_orig_shape=True,
    part_annotations=True,
    num_workers=0,  # due to notebook multiprocessing issues
)
data_parts.setup()

## Choose one of these options to run

In [None]:
pixel_space_method = 'bbox'
# pixel_space_method = 'upsample'

In [None]:
consistency_score_soft, consistency_score_hard = consistency(
    model=model.model,
    data=data_parts.test_dataloader(),
    config=config,
    parts=data_parts.test.df_parts,
    metadata=data_parts.test.data,
    method=pixel_space_method,
)
print(f'consistency_score_soft={consistency_score_soft}\n'
      f'consistency_score_hard={consistency_score_hard}')

In [None]:
stability_score_soft, stability_score = stability(
    model=model.model,
    data=data_parts.test_dataloader(),
    config=config,
    parts=data_parts.test.df_parts,
    metadata=data_parts.test.data,
    method=pixel_space_method,
)
print(f'stability_score_soft={stability_score_soft}\n'
      f'stability_score     ={stability_score}')

In [None]:
data = LitData(config, num_workers=0)
data.setup()

In [None]:
rot_viz_dir = osp.join(logdir, 'rot_data_viz')

In [None]:
rot_results = {}

# NOTE: Reduce num_samples and/or prop_pixels to speed this up
rot_params = dict(
    num_samples=50,
    normalized=False,
    prop_pixels=1.,
    same_class=True,
    zeros=False,  # if false, use a random image
    seed=4,  # for consistency between methods
)

rot_save_dir = osp.join(logdir, 'rot_data')
os.makedirs(rot_save_dir, exist_ok=True)

save_basename = f'{config.dataset.name}'
for k, v in rot_params.items():
    if isinstance(v, float):
        if v.is_integer():
            v = str(int(v))
        else:
            v = f'{v:.5f}'
    save_basename += f'__{k}-{v}'

for method in [
    'rf',
    'upscale',
    'random',
    'prp',
]:
    print(f'Begin method = {method}')

    rot_save_path = osp.join(rot_save_dir, save_basename + f'__method-{method}.pt')
    print(rot_save_path)

    if osp.exists(rot_save_path):
        print('Load results from', rot_save_path)
        rot_results[method] = torch.load(rot_save_path)
    elif method == 'prp':
        print(f'WARNING: prp result is missing from {rot_save_path}! Ensure you '
              f'have gathered it with the companion interpretability_evaluation_prp.ipynb notebook.')
    else:
        cum_sims, cum_sims_agg = relevance_ordering_test(
            model=model.model,
            data=data.test_dataloader(),
            method=method,
            savedir_for_viz=rot_viz_dir,
            **rot_params,
        )
        rot_results[method] = {
            'cum_sims': cum_sims,
            'cum_sims_agg': cum_sims_agg,
        }
        print('Save to', rot_save_path)
        torch.save(rot_results[method], rot_save_path)

In [None]:
% matplotlib inline

In [None]:
from matplotlib import rc
import seaborn as sns
import numpy as np
import pandas as pd

sns.set(
    context='talk',
    style='ticks',
    font_scale=0.5,
)

rc('font', **{
    'family': 'serif',
    'sans-serif': ['Times']
})
rc('text', usetex=True)

In [None]:
from pixpnet.data import get_metadata

metadata = get_metadata(config)

In [None]:
from pixpnet.symbolic.models import compute_rf_data

_, rf_data = compute_rf_data(config.model.feature_extractor,
                             metadata.input_size, metadata.input_size,
                             num_classes=1)
rf_layer = rf_data[config.model.feature_layer]
rf_hcc_lens = [len(hcc) for hcc in rf_layer.flat]
im_size = metadata.input_size ** 2

mean_rf_pct = np.mean(rf_hcc_lens) / im_size * 100

In [None]:
dfs = []

for method, scores in rot_results.items():
    cum_sims_agg = scores['cum_sims_agg']
    pcts = np.arange(len(cum_sims_agg)) / (metadata.input_size ** 2) * 100

    # area above the baseline similarity value
    # [0] is the baseline similarity value (all random) - used as min
    # [-1] is the original similarity value - used as max
    # it is possible for area to be negative or positive
    area = np.trapz(x=pcts / 100, y=(cum_sims_agg - cum_sims_agg[0]) / (cum_sims_agg[-1] - cum_sims_agg[0]))

    # pct at which we hit/surpass the original similarity score
    past_orig_sim_idx = np.where(cum_sims_agg >= cum_sims_agg[-1])[0][0]
    pct_orig_sim_point = pcts[past_orig_sim_idx]

    print(f'{method} area = {area} | pct@orig. sim = {pct_orig_sim_point:.3f}%')

    df_m = pd.DataFrame({
        'Mean Prototype Similarity': cum_sims_agg,
        '\% Pixels Added Back': pcts,
        'Method': method,
    })
    dfs.append(df_m)

if dfs:
    df = pd.concat(dfs, ignore_index=True)

    df['Method'] = df['Method'].replace({
        'rf': 'Ours',
        'upscale': 'Upsample',
        'random': 'Random',
        'prp': 'PRP',
    })

    hue_order = ['Random', 'Ours', 'Upsample']
    if ('PRP' == df['Method']).any():
        hue_order.append('PRP')
    g = sns.relplot(
        data=df,
        x='\% Pixels Added Back',
        y='Mean Prototype Similarity',
        hue='Method',
        hue_order=hue_order,
        kind='line',
        aspect=1.4,
    )
    ylim = g.ax.get_ylim()
    g.ax.plot([mean_rf_pct, mean_rf_pct], [ylim[0], ylim[1]], 'k--', label='Mean Receptive Field of Layer')
    g.ax.set_ylim(ylim)

    # Uncomment to save figure
    # !mkdir -p ../ figures
    # save_basename = f'../figures/rot_{config.model.feature_extractor}_{config.model.feature_layer}'
    # g.savefig(f'{save_basename}.pdf')