# Imports

In [None]:
import os
import sys

sys.path.append(os.path.dirname(os.path.abspath(os.path.abspath("../"))))
sys.path.append(os.path.dirname(os.path.abspath(os.path.abspath("../"))))

from src.explainibility.visualization import dislay_all_explainibility, display_sae_features, display_grad_cam_explanattions
from src.explainibility.sae_explainibility import explain_model_with_sae, sae_statistics, get_minimal_tree_from_sae_model
from src.model_architecture.cnn_clasifier.cnn_clasifier import CnnKneeClassifier
import torch
from src.model_training.training_helpers.knee_datasets import KneeScans3DDataset
import torchio as tio
from pathlib import Path

# Model and Dataset

In [None]:
model = CnnKneeClassifier(num_classes=3, input_channels=1)
model.load_state_dict(
    torch.load(
        "/home/mikic202/semestr_9/knee_scaner/models/basic_clasifier_model_1766343254.9682245.pth"
    )
)

dataset_transform = tio.transforms.Compose(
    [
        tio.transforms.Resize((64, 64, 64)),
    ]
)

dataset = KneeScans3DDataset(
    datset_filepath="/media/mikic202/Nowy1/uczelnia/semestr_9/SIWY/datasets/kneemri",
    transform=dataset_transform,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL = model.to(device)
LAYER_TO_EXPLAIN = model.last_feature
LAYER_SIZE = 64 * 16 * 16 * 16

DATASET = dataset
EXAMPLE = DATASET[0][0]
EXAMPLE_CLASS = DATASET[0][1]

# Basic Gradinet Methods

In [None]:
dislay_all_explainibility(MODEL, EXAMPLE, EXAMPLE_CLASS, device)

In [None]:
dislay_all_explainibility(MODEL, EXAMPLE, EXAMPLE_CLASS, device, in_slices=True)

# GradCam

In [None]:
display_grad_cam_explanattions(MODEL, LAYER_TO_EXPLAIN, EXAMPLE, EXAMPLE_CLASS, device)

In [None]:
display_grad_cam_explanattions(MODEL, LAYER_TO_EXPLAIN, EXAMPLE, EXAMPLE_CLASS, device, in_slices=True)

# SAE

In [None]:
sae_model = explain_model_with_sae(
        MODEL, DATASET, LAYER_TO_EXPLAIN, LAYER_SIZE, hidden_size=8 * 4096, max_number_of_hidden_features=4 * 4096, num_of_epochs=15, learning_rate=0.007
    )

sae_features, feature_popularity_order_per_class, feature_counts_per_class = sae_statistics(sae_model, MODEL, LAYER_TO_EXPLAIN, DATASET)

torch.jit.save(torch.jit.script(sae_model), str(Path.home() / "sae_explainibility_model.pt"))


In [None]:
generated_trees = get_minimal_tree_from_sae_model(sae_model, MODEL, LAYER_TO_EXPLAIN, DATASET)
print(generated_trees)

In [None]:
print("Most popular features per class:")
for class_label, feature_order in feature_popularity_order_per_class.items():
    print(f"Class {class_label}: Features {feature_order[:10]}")

In [None]:
display_sae_features(sae_features, Path(""))