# This is an adaptation of the external PyTorchRelevancePropagation implementation so that we can run it interactively

In [1]:
import argparse
from argparse import Namespace
import time
import pathlib

import torch
from torchvision.models import vgg16, VGG16_Weights

from src.data import get_data_loader
from src.lrp import LRPModel

from projects.per_image_lrp.visualize import plot_relevance_scores

In [2]:
def per_image_lrp(config: argparse.Namespace) -> None:
    """Test function that plots heatmaps for images placed in the input folder.

    Images have to be placed in their corresponding class folders.

    Args:
        config: Argparse namespace object.

    """
    if config.device == "gpu":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    # print(f"Using: {device}\n")

    data_loader = get_data_loader(config)

    model = vgg16(weights=VGG16_Weights.DEFAULT)
    model.to(device)

    lrp_model = LRPModel(model=model, top_k=config.top_k)

    for i, (x, y) in enumerate(data_loader):
        x = x.to(device)
        # y = y.to(device)  # here not used as method is unsupervised.

        t0 = time.time()
        r = lrp_model.forward(x)
        print("{time:.2f} FPS".format(time=(1.0 / (time.time() - t0))))

        plot_relevance_scores(x=x, r=r, name=str(i), config=config)

In [3]:
config = Namespace(
    input_dir="./input/",
    output_dir="./output/",
    batch_size=1,
    device="gpu",
    top_k=0.02,
    resize=0,
)
pathlib.Path(config.output_dir).mkdir(parents=True, exist_ok=True)

In [4]:
per_image_lrp(config=config)
time.sleep(3)
per_image_lrp(config=config)
time.sleep(3)
per_image_lrp(config=config)

3.41 FPS
9.09 FPS
9.17 FPS
9.71 FPS
20.83 FPS
9.52 FPS
21.28 FPS
9.71 FPS
3.20 FPS
9.17 FPS
9.80 FPS
9.80 FPS
21.28 FPS
9.62 FPS
21.74 FPS
9.71 FPS
6.06 FPS
20.41 FPS
9.17 FPS
9.52 FPS
9.52 FPS
9.52 FPS
9.62 FPS
9.80 FPS
