In [None]:
import os
import os.path as osp

import torch

try:
    import pixpnet
except ImportError:
    import sys

    sys.path.append('..')

    import pixpnet
finally:
    from pixpnet.lightning.lightning_data import LitData
    from pixpnet.protonets.evaluate import relevance_ordering_test
    from pixpnet.protonets.utils import load_config_and_best_model
    from pixpnet.protonets.prp.prp import prp_canonized_model

# To Run
Replace `logdir` with the path to your log directory (relative to the notebook). A template format is shown below.

This notebook generates data that can then be visualized in the `interpretability_evaluation.ipynb` notebook. It needs
to be run separately due to the in-place modifications the PRP codebase does to the ProtoPartNN objects.

In [None]:
logdir = '/path/to/logs/protonet/dataset/protonet/timestamp'

In [None]:
config, model = load_config_and_best_model(logdir)
print(config.model.feature_extractor, config.model.feature_layer)

In [None]:
model = model.eval().to('cuda')

In [None]:
prp_model = prp_canonized_model(model.model, config)

In [None]:
data = LitData(config, num_workers=0)
data.setup()

In [None]:
rot_results = {}

# NOTE: Reduce num_samples and/or prop_pixels to speed this up
rot_params = dict(
    num_samples=50,
    normalized=False,
    prop_pixels=1.,
    same_class=True,
    zeros=False,  # if false, use a random image
    seed=4,  # for consistency between methods
)

rot_save_dir = osp.join(logdir, 'rot_data')
os.makedirs(rot_save_dir, exist_ok=True)

save_basename = f'{config.dataset.name}'
for k, v in rot_params.items():
    if isinstance(v, float):
        if v.is_integer():
            v = str(int(v))
        else:
            v = f'{v:.5f}'
    save_basename += f'__{k}-{v}'

for method in [
    'prp',
]:
    if method != 'prp':
        print('You aren\'t running prp...make sure you know that you\'re doing...')

    print(f'Begin method = {method}')

    rot_save_path = osp.join(rot_save_dir, save_basename + f'__method-{method}.pt')

    if osp.exists(rot_save_path):
        print('Load results from', rot_save_path)
        rot_results[method] = torch.load(rot_save_path)
    else:
        cum_sims, cum_sims_agg = relevance_ordering_test(
            model=model.model,
            data=data.test_dataloader(),
            method=method,
            config=config,
            **rot_params,
        )
        rot_results[method] = {
            'cum_sims': cum_sims,
            'cum_sims_agg': cum_sims_agg,
        }

    print('Save to', rot_save_path)
    torch.save(rot_results[method], rot_save_path)