### Import Packages and Dependencies

In [None]:
import os
import torch
import glob
import numpy as np
import fnmatch
import sys
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[0])
sys.path.append(os.getcwd())

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from src.data.vesselmnist3d_datamodule import VesselMNSIT3DDataModule
from src.data.organmnist3d_datamodule import OrganMNSIT3DDataModule
from src.data.adrenalmnist3d_datamodule import AdrenalMNSIT3DDataModule

### Import Datasets

In [None]:
print(
    sorted(
        glob.glob(os.getcwd() + "/data/saliency_maps/volume/*"),
        key=os.path.getmtime,
    )
)

file = sorted(
    glob.glob(os.getcwd() + "/data/saliency_maps/volume/*"),
    key=os.path.getmtime,
)[
    0
]  # selects always the newsest
array = np.load(file)
data = [array["arr_0"], array["arr_1"], array["arr_2"]]


if fnmatch.fnmatch(file, "*organ*"):
    datamodule = OrganMNSIT3DDataModule(
        data_dir=os.getcwd() + "/data/datasets/", batch_size=20
    )
    classes = [
        "liver",
        "kidney-right",
        "kidney-left",
        "femur-right",
        "femur-left",
        "bladder",
        "heart",
        "lung-right",
        "lung-left",
        "spleen",
        "pancreas",
    ]

if fnmatch.fnmatch(file, "*vessel*"):
    datamodule = VesselMNSIT3DDataModule(
        data_dir=os.getcwd() + "/data/datasets/", batch_size=20
    )
    classes = ["vessel", "aneurysm"]

if fnmatch.fnmatch(file, "*adrenal*"):
    datamodule = AdrenalMNSIT3DDataModule(
        data_dir=os.getcwd() + "/data/datasets/", batch_size=20
    )
    classes = ["normal", "adrenal mass"]

dataloader = datamodule.dataloader()

with torch.no_grad():
    x_batch, y_batch = next(iter(dataloader))

In [None]:
methods = [
    "Occlusion",
    "LIME (Mask)",
    "Kernel SHAP (Mask)",
    "Saliency",
    "Input x Gradient",
    "Guided Backprob",
    "GradCAM",
    "ScoreCAM",
    "GradCAM++",
    "IG",
    "EG",
    "Deeplift",
    "Deeplift SHAP",
    "LRP",
    "Raw Attention",
    "Rollout Attention",
    "LRP Attention",
]
models = ["3DResNet18", "3DEfficientNetb0", "Simple3DFormer"]
n = 1
model = 2
img = x_batch[n].detach().numpy().transpose(1, 2, 3, 0)

methods.append("Original Class: " + str(classes[int(y_batch[n])]).title())

### Saliency Maps for all XAI Methods

In [None]:
cmap = [[0, "white"], [0.5, "red"], [1, "red"]]
titles = methods if model == 2 else methods[0:14] + [methods[17]]


def NormalizeData(data):
    return (data - np.min(data)) / ((np.max(data) - np.min(data)) + 0.00000000001)


X, Y, Z = np.mgrid[0:28:28j, 0:28:28j, 0:28:28j]
colorbar = dict(
    tickfont=dict(family="Helvetica", size=18), outlinewidth=0, thickness=20, len=0.8
)

fig = make_subplots(
    rows=3,
    cols=7,
    specs=[
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            None,
            None,
            None,
        ],
    ]
    if model == 2
    else [
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            None,
            None,
            None,
            None,
            None,
            None,
        ],
    ],
    subplot_titles=titles,
    vertical_spacing=0.05,
)

for i in range(7):
    fig.add_trace(
        go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=NormalizeData(np.abs(data[model][n, i, :, :, :, :]).flatten()),
            isomin=0.1,
            isomax=1.0,
            opacity=0.1,  # needs to be small to see through all surfaces
            surface_count=21,  # needs to be a large number for good volume rendering
            colorscale="viridis",
            colorbar=colorbar,
        ),
        row=1,
        col=i + 1,
    )

for i in range(7):
    fig.add_trace(
        go.Volume(
            x=X.flatten(),
            y=Y.flatten(),
            z=Z.flatten(),
            value=NormalizeData(np.abs(data[model][n, i + 7, :, :, :, :]).flatten()),
            isomin=0.1,
            isomax=1.0,
            opacity=0.1,  # needs to be small to see through all surfaces
            surface_count=21,  # needs to be a large number for good volume rendering
            colorscale="viridis",
            colorbar=colorbar,
        ),
        row=2,
        col=i + 1,
    )

if model == 2:
    for i in range(3):
        fig.add_trace(
            go.Volume(
                x=X.flatten(),
                y=Y.flatten(),
                z=Z.flatten(),
                value=NormalizeData(
                    np.abs(data[model][n, i + 14, :, :, :, :]).flatten()
                ),
                isomin=0.2,
                isomax=1.0,
                opacity=0.1,  # needs to be small to see through all surfaces
                surface_count=21,  # needs to be a large number for good volume rendering
                colorscale="viridis",
                showscale=False,
            ),
            row=3,
            col=i + 1,
        )

fig.add_trace(
    go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=NormalizeData(img[:, :, :, 0].flatten()),
        isomin=0.02,
        isomax=1.0,
        opacity=0.95,  # needs to be small to see through all surfaces
        surface_count=21,  # needs to be a large number for good volume rendering
        showscale=False,
        colorscale=cmap,
    ),
    row=3,
    col=4 if model == 2 else 1,
)

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
)

fig.update_annotations(font=dict(family="Helvetica", size=22))

fig.update_layout(
    title=dict(
        text="<b>3D Attribution and Attention for " + models[model] + " Model</b>",
        font=dict(family="Helvetica", size=28),
        x=0.03,
    ),
    height=1200,
    width=2500,
    font=dict(
        family="Helvetica",
        color="#000000",
    ),
)


# fig.write_image("data/figures/3D_"+ datamodule.__name__ + "_"+ str(model) +"_Importance.png", scale=2)
fig.show()

### Gif Animation

In [None]:
x_eye = -1.25
y_eye = 2
z_eye = 0.5


def rotate_z(x, y, z, theta):
    w = x + 1j * y
    return np.real(np.exp(1j * theta) * w), np.imag(np.exp(1j * theta) * w), z


for t in np.arange(0, 6.26, 0.2):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)

    fig.update_scenes(camera_eye=dict(x=xe, y=ye, z=ze))

    fig.write_image("data/figures/gif/frame_" + str(t) + "_.png", scale=1)

In [None]:
from PIL import Image

imgs = (
    Image.open(f)
    for f in sorted(glob.glob("data/figures/gif/frame_*"), key=os.path.getmtime)
)
img = next(imgs)  # extract first image from iterator
img.save(
    fp="data/figures/gif/3D_" + str(model) + "_md.gif",
    format="GIF",
    append_images=imgs,
    save_all=True,
    duration=120,
    loop=0,
)

In [None]:
for i in glob.glob("data/figures/gif/frame_*"):
    os.remove(i)