#### Include source package

In [None]:
# switch to the project directory
%cd ..
# working directory should be ../pdi

In [None]:
import sys
import os
module_path = os.path.abspath('src')

if module_path not in sys.path:
    sys.path.append(module_path)

#### Extract file name from INPUT_PATH for creating folders

In [None]:
from pdi.data.config import INPUT_PATH

csv_name = os.path.basename(INPUT_PATH)
csv_name = os.path.splitext(csv_name)[0]
print(csv_name)

#### Load data

In [None]:
from pdi.data.preparation import FeatureSetPreparation
from pdi.data.types import Split

split = Split.TEST
prep = FeatureSetPreparation()
prep._try_load_preprocessed_data([split])

In [None]:
groups = prep.data_to_df_dict(split)

#### Model info

In [None]:
from pdi.constants import PARTICLES_DICT
from pdi.models import AttentionModel
from pdi.data.config import MODEL_NAME

model_dir = MODEL_NAME
model_load_dir = f"models/Proposed/{model_dir}"
model_class = AttentionModel
device = "cuda"

In [None]:
import torch

def import_model(load_path, model_class, device):
    saved_model = torch.load(load_path, map_location=torch.device("cpu"))
    model = model_class(*saved_model["model_args"]).to(device)
    model.load_state_dict(saved_model["state_dict"])
    return model

In [None]:
# wrapper for model, explainers don't allow passing tensors
def predict(input_data):
    new_in = torch.tensor(input_data).to(device)
    return model(new_in).cpu().detach().numpy()

## Model explanation

In [None]:
from pdi.data.data_exploration import explain_model, plot_and_save_beeswarm
from pdi.data.detector_helpers import detector_unmask
from pdi.constants import TARGET_CODES

batch_size = 16 # for bigger number of entries kernel crashes, so here data is split into batches
batches = 50
hide_progress_bars = False

cols = prep.load_columns()

particles_to_explain = [211, 2212, 321]

if not particles_to_explain:
    particles_to_explain = TARGET_CODES
else:
    particles_to_explain = [p for p in particles_to_explain if p in TARGET_CODES]
    
for target_code in particles_to_explain:
    print(PARTICLES_DICT[target_code]
    # import model
    model_name = f"{PARTICLES_DICT[target_code]}.pt"
    load_path = os.path.join(model_load_dir, model_name)
    model = import_model(load_path, model_class, device)
    
    for key, group in groups.items():
        detectors = detector_unmask(key)
        detectors = [d.name for d in detectors]
        label = "_".join(detectors)
        print(label)
        
        result, data_count = explain_model(predict, group, batch_size, batches, hide_progress_bars)
        result.feature_names = cols
    
        save_dir = f"reports/figures/feature_importance/{model_dir}/{csv_name}/{PARTICLES_DICT[target_code]}"
        
        file_name = f"{label}"
        title = f"{PARTICLES_DICT[target_code]}, entries: {data_count}"
        plot_and_save_beeswarm(result, save_dir, file_name, title)