In [1]:
from hf_support import SmoothGradInterpreter
from transformers import ViTForImageClassification, ViTImageProcessor

# https://huggingface.co/nickmuchi/vit-finetuned-chest-xray-pneumonia
model = ViTForImageClassification.from_pretrained("vit-finetuned-chest-xray-pneumonia/")
processor = ViTImageProcessor.from_pretrained("vit-finetuned-chest-xray-pneumonia/")



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from glob import glob
algo = SmoothGradInterpreter(model)
list_img_path = glob("./chest_xray_val-test/val/NORMAL/*.jpeg") + glob("./chest_xray_val-test/val/PNEUMONIA/*.jpeg")

sg_raw_results = {}

for path in list_img_path:
    exp = algo.interpret(path, processor, visual=False)
    sg_raw_results[path] = exp

100%|██████████| 50/50 [00:01<00:00, 35.80it/s]
100%|██████████| 50/50 [00:01<00:00, 34.35it/s]
100%|██████████| 50/50 [00:01<00:00, 32.63it/s]
100%|██████████| 50/50 [00:01<00:00, 36.30it/s]
100%|██████████| 50/50 [00:01<00:00, 38.43it/s]
100%|██████████| 50/50 [00:01<00:00, 35.45it/s]
100%|██████████| 50/50 [00:01<00:00, 36.72it/s]
100%|██████████| 50/50 [00:01<00:00, 36.04it/s]
100%|██████████| 50/50 [00:01<00:00, 34.23it/s]
100%|██████████| 50/50 [00:01<00:00, 32.50it/s]
100%|██████████| 50/50 [00:01<00:00, 32.14it/s]
100%|██████████| 50/50 [00:01<00:00, 34.65it/s]
100%|██████████| 50/50 [00:01<00:00, 38.24it/s]
100%|██████████| 50/50 [00:01<00:00, 37.88it/s]
100%|██████████| 50/50 [00:01<00:00, 37.90it/s]
100%|██████████| 50/50 [00:01<00:00, 38.78it/s]


In [3]:
from hf_support import IntGradInterpreter
from glob import glob
algo = IntGradInterpreter(model)
list_img_path = glob("./chest_xray_val-test/val/NORMAL/*.jpeg") + glob("./chest_xray_val-test/val/PNEUMONIA/*.jpeg")

ig_raw_results = {}

for path in list_img_path:
    exp = algo.interpret(path, processor, visual=False)
    ig_raw_results[path] = exp

100%|██████████| 50/50 [00:00<00:00, 50.12it/s]
100%|██████████| 50/50 [00:01<00:00, 48.63it/s]
100%|██████████| 50/50 [00:01<00:00, 47.57it/s]
100%|██████████| 50/50 [00:01<00:00, 49.52it/s]
100%|██████████| 50/50 [00:01<00:00, 48.61it/s]
100%|██████████| 50/50 [00:00<00:00, 50.06it/s]
100%|██████████| 50/50 [00:01<00:00, 49.17it/s]
100%|██████████| 50/50 [00:01<00:00, 49.92it/s]
100%|██████████| 50/50 [00:01<00:00, 49.82it/s]
100%|██████████| 50/50 [00:01<00:00, 48.83it/s]
100%|██████████| 50/50 [00:00<00:00, 50.01it/s]
100%|██████████| 50/50 [00:01<00:00, 48.76it/s]
100%|██████████| 50/50 [00:01<00:00, 49.52it/s]
100%|██████████| 50/50 [00:01<00:00, 49.36it/s]
100%|██████████| 50/50 [00:01<00:00, 49.64it/s]
100%|██████████| 50/50 [00:00<00:00, 50.07it/s]


In [4]:
from hf_support import Perturbation

In [6]:
perturbation = Perturbation(model)

eval_sg_raw_results = {}
for path in list_img_path:
    exp = sg_raw_results[path][0].mean((0))
    eval_result = perturbation.evaluate(path, exp, limit_number_generated_samples=100)
    eval_sg_raw_results[path] = eval_result

In [8]:
perturbation = Perturbation(model)

eval_ig_raw_results = {}
for path in list_img_path:
    exp = ig_raw_results[path][0].mean((0))
    eval_result = perturbation.evaluate(path, exp, limit_number_generated_samples=100)
    eval_ig_raw_results[path] = eval_result

In [16]:
from tqdm import tqdm

# MoRF_probas in Perturbation is not correct. It requires to convert to the right one.
def get_eval_results(raw_results, list_paths):
    setting_mscore = 0.0
    setting_lscore = 0.0
    abpc_score = 0.0
    a = raw_results[path]
    for img_path in list_paths:
        b = raw_results[img_path]
        setting_mscore += (b['MoRF_probas'][0] - b['MoRF_probas']).mean()
        setting_lscore += (b['LeRF_probas'][0] - b['LeRF_probas']).mean()
        abpc_score += (b['LeRF_probas'] - b['MoRF_probas']).mean()

    MoRF_score = setting_mscore / len(list_paths)
    LeRF_score = setting_lscore / len(list_paths)
    aubpc_score = abpc_score / len(list_paths)
    return MoRF_score, aubpc_score

In [17]:
get_eval_results(eval_sg_raw_results, list_img_path)

(0.3663406816194765, 0.15824285882990807)

In [18]:
get_eval_results(eval_ig_raw_results, list_img_path)

(0.5070528648793697, 0.42603587871417403)