#### Include source package

In [None]:
# switch to the project directory
%cd ..
# working directory should be ../pdi

In [None]:
import sys
import os
module_path = os.path.abspath('src')

if module_path not in sys.path:
    sys.path.append(module_path)

#### Load preprocessed train data

In [None]:
from pdi.data.data_preparation import DataPreparation
from pdi.data.types import Split
from pdi.config import Config
from pdi.constants import PART_NAME_TO_TARGET_CODE, TARGET_CODE_TO_PART_NAME
import json

RUN_DIR_PATH = "results/attention_hyperparameter_tuning/kaon/run_23"
target_code = PART_NAME_TO_TARGET_CODE["kaon"]

with open(f"{RUN_DIR_PATH}/config.json", 'r') as f:
    config_data = json.load(f)
config = Config.from_dict(config_data)

# If you override device for explaining do it such a way
config.training.device = "cpu"

data_prep = DataPreparation(config.data, config.sim_dataset_paths, config.seed)
groups = data_prep.get_prepared_data([Split.TEST])[Split.TEST]
print(groups.keys())

In [None]:
import torch
from pdi.models import build_model

model = build_model(config.model, group_ids=list(groups.keys()))
dirpath = os.path.join(RUN_DIR_PATH, "model_weights")
weights_path = os.path.join(dirpath, "best.pt")
model.load_state_dict(torch.load(weights_path, weights_only=True, map_location=config.training.device))
with open(os.path.join(dirpath, f"metadata.json"), "r") as metadata_file:
    metadata = json.load(metadata_file)
threshold = metadata["threshold"]

In [None]:
# wrapper for model, explainers don't allow passing tensors
def predict(input_data):
    new_in = torch.tensor(input_data).to(config.training.device)
    return model(new_in).cpu().detach().numpy()

## Model explanation

In [None]:
from pdi.data.data_exploration import explain_model, plot_and_save_beeswarm
from pdi.data.constants import COLUMNS_FOR_TRAINING
from pdi.data.group_id_helpers import group_id_to_detectors_available
from pdi.data.types import InputTarget

batch_size = 16 # for bigger number of entries kernel crashes, so here data is split into batches
batches = 50
hide_progress_bars = False

cols = COLUMNS_FOR_TRAINING

for key, input_target_unstandardized in groups.items():
    group_input = input_target_unstandardized[InputTarget.INPUT]
    detectors = group_id_to_detectors_available(key)
    detectors = [d.name for d in detectors]
    label = "_".join(detectors)
    print(label)
    
    result, data_count = explain_model(predict, group_input, batch_size, batches, hide_progress_bars)
    result.feature_names = cols

    save_dir = f"{RUN_DIR_PATH}/feature_importance"
    
    file_name = f"/{label}"
    title = f"{label}: {TARGET_CODE_TO_PART_NAME[target_code]}, entries: {data_count}"
    plot_and_save_beeswarm(result, save_dir, file_name, title)