### Import Packages and Dependencies

In [None]:
import os
import torch
import glob
import numpy as np
import fnmatch
import sys
from pathlib import Path
import plotly.graph_objects as go
import numpy as np

os.chdir(Path(os.getcwd()).parents[0])
sys.path.append(os.getcwd())

from src.data.modelnet40_datamodule import ModelNet40DataModule
from src.data.shapenet_datamodule import ShapeNetDataModule
from src.data.coma_datamodule import CoMADataModule

### Import Datasets

In [None]:
print(
    sorted(
        glob.glob(os.getcwd() + "/data/saliency_maps/point_cloud/*"),
        key=os.path.getmtime,
    )
)

file = sorted(
    glob.glob(os.getcwd() + "/data/saliency_maps/point_cloud/*"),
    key=os.path.getmtime,
)[
    0
]  # selects dataset
array = np.load(file)
data = [array["arr_0"], array["arr_1"], array["arr_2"]]


if fnmatch.fnmatch(file, "*modelnet40*"):
    datamodule = ModelNet40DataModule(
        data_dir=os.getcwd() + "/data/datasets/", batch_size=20
    )
    classes = [
        "airplane",
        "bathtub",
        "bed",
        "bench",
        "bookshelf",
        "bottle",
        "bowl",
        "car",
        "chair",
        "cone",
        "cup",
        "curtain",
        "desk",
        "door",
        "dresser",
        "flower_pot",
        "glass_box",
        "guitar",
        "keyboard",
        "lamp",
        "laptop",
        "mantel",
        "monitor",
        "night_stand",
        "person",
        "piano",
        "plant",
        "radio",
        "range_hood",
        "sink",
        "sofa",
        "stairs",
        "stool",
        "table",
        "tent",
        "toilet",
        "tv_stand",
        "vase",
        "wardrobe",
        "xbox",
    ]

if fnmatch.fnmatch(file, "*shapenet*"):
    datamodule = ShapeNetDataModule(
        data_dir=os.getcwd() + "/data/datasets/", batch_size=20
    )
    classes = [
        "Airplane",
        "Bag",
        "Cap",
        "Car",
        "Chair",
        "Earphone",
        "Guitar",
        "Knife",
        "Lamp",
        "Laptop",
        "Motorbike",
        "Mug",
        "Pistol",
        "Rocket",
        "Skateboard",
        "Table",
    ]

if fnmatch.fnmatch(file, "*coma*"):
    datamodule = CoMADataModule(data_dir=os.getcwd() + "/data/datasets/", batch_size=20)
    classes = [
        "bareteeth",
        "cheeks_in",
        "eyebrow",
        "high_smile",
        "lips_back",
        "lips_up",
        "mouth_down",
        "mouth_extreme",
        "mouth_middle",
        "mouth_open",
        "mouth_side",
        "mouth_up",
    ]


dataloader = datamodule.dataloader()

with torch.no_grad():
    x_batch, y_batch = next(iter(dataloader))

In [None]:
methods = [
    "Occlusion",
    "LIME",
    "Kernel SHAP",
    "Saliency",
    "Input x Gradient",
    "Guided Backprob",
    "IG",
    "EG",
    "Deeplift",
    "Deeplift SHAP",
    "LRP",
    "Raw Attention",
    "Rollout Attention",
    "LRP Attention",
]
models = ["PointNet", "DGCNN", "Pointcloud Transformer"]
n = 0  # select observation
model = 2  # select model
img = x_batch[n].detach().numpy()
img[1], img[2] = img[2], img[1].copy()

titles = ["Original Class: " + str(classes[int(y_batch[n])]).title()] + methods

### X-Axis traversal plot for Continuity metric

In [None]:
# PC Continuity x_transverse transformation
img2 = img.T + np.array([-1.5, 0, 0], dtype="float32")  # position
img2 = img2.T

for i in range(img2.shape[1]):
    if np.any(img2[:, i] > 1):
        img2[:, i] = 0
    if np.any(img2[:, i] < -1):
        img2[:, i] = 0


fig = go.Figure(
    data=[
        go.Scatter3d(
            x=img2[0],
            y=img2[1],
            z=img2[2],
            mode="markers",
            marker_size=4,
        )
    ]
)

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
    aspectmode="cube",
    camera=dict(eye=dict(x=0, y=0.3, z=1.5))
    if fnmatch.fnmatch(file, "*shapenet*")
    else dict(eye=dict(x=1.25, y=1.25, z=1.25)),
    xaxis=dict(range=[-1, 1], visible=False),
    yaxis=dict(range=[-1, 1], visible=False),
    zaxis=dict(range=[-1, 1], visible=False),
)

fig.update_layout(
    height=500,
    width=500,
)

# fig.write_image("data/figures/continuity_6.png", scale=2)
fig.show()

### KMeans Clustering Visualization for Point Cloud Data

In [None]:
# PC IROF Kmeans clustering/segmentation
import plotly.graph_objects as go
import numpy as np
from sklearn.cluster import KMeans

model = KMeans(n_clusters=16)


fig = go.Figure(
    data=[
        go.Scatter3d(
            x=img[0],
            y=img[1],
            z=img[2],
            mode="markers",
            marker_size=4,
            marker_color=model.fit(np.moveaxis(img, 0, 1)).predict(
                np.moveaxis(img, 0, 1)
            ),
        )
    ]
)

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
    aspectmode="cube",
    camera=dict(eye=dict(x=-0.1, y=0.1, z=1.5))
    if fnmatch.fnmatch(file, "*shapenet*")
    else dict(eye=dict(x=1.25, y=1.25, z=1.25)),
    xaxis=dict(range=[-1, 1], visible=False),
    yaxis=dict(range=[-1, 1], visible=False),
    zaxis=dict(range=[-1, 1], visible=False),
)

fig.update_layout(
    height=500,
    width=500,
)

# fig.write_image("data/figures/kmeans_cluster_pc.png", scale=2)
fig.show()

### Saliency Maps for all XAI Methods

In [None]:
def NormalizeData(data):
    return (data - np.min(data)) / ((np.max(data) - np.min(data)) + 0.00000000001)


# Sinlge Map
X, Y, Z = np.mgrid[-1:1:100j, -1:1:100j, -1:1:100j]
xai = 13

cmap = [
    [0, "rgba(255,255,255, 0.8)"],
    [1 / 10, "rgb(31,120,180)"],
    [1 / 2, "rgb(227,26,28)"],
    [1, "rgb(227,26,28)"],
]

cmap_2 = [
    [0, "rgba(255,255,255, 0.8)"],
    [1 / 3, "rgb(31,120,180)"],
    [1 / 1.3, "rgb(227,26,28)"],
    [1, "rgb(227,26,28)"],
]

fig = go.Figure(
    data=go.Scatter3d(
        x=img[0],
        y=img[1],
        z=img[2],
        mode="markers",
        marker=dict(
            size=6,
            color=NormalizeData(np.abs(data[model][n, xai, :, :])).flatten(),
            colorscale=cmap,
            colorbar=dict(
                tickfont=dict(family="Helvetica", size=18),
                tickvals=[0, 1 / 10, 1 / 2, 1],
                outlinewidth=0,
                thickness=20,
                len=0.8,
            ),
        ),
    )
)

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
    aspectmode="cube",
    camera=dict(eye=dict(x=-0.1, y=0.1, z=1.5))
    if fnmatch.fnmatch(file, "*shapenet*")
    else dict(eye=dict(x=1.25, y=1.25, z=1.25)),
    xaxis=dict(range=[-1, 1], visible=False),
    yaxis=dict(range=[-1, 1], visible=False),
    zaxis=dict(range=[-1, 1], visible=False),
)

fig.update_layout(
    height=1000,
    width=1000,
)

# fig.write_image("data/figures/aa_single_image.png", scale=2)

In [None]:
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots

# Multiple Maps
cmap = [
    [0, "rgba(255,255,255, 0.8)"],
    [1 / 10, "rgb(31,120,180)"],
    [1 / 2, "rgb(227,26,28)"],
    [1, "rgb(227,26,28)"],
]

cmap_2 = [
    [0, "rgba(255,255,255, 0.8)"],
    [1 / 3, "rgb(31,120,180)"],
    [1 / 1.3, "rgb(227,26,28)"],
    [1, "rgb(227,26,28)"],
]


def NormalizeData(data):
    return (data - np.min(data)) / ((np.max(data) - np.min(data)) + 0.00000000001)


X, Y, Z = np.mgrid[-1:1:100j, -1:1:100j, -1:1:100j]
colorbar = dict(
    tickfont=dict(family="Helvetica", size=18),
    tickvals=[0, 1 / 10, 1 / 2, 1],
    outlinewidth=0,
    thickness=20,
    len=0.8,
)

fig = make_subplots(
    rows=3 if model == 2 else 2,
    cols=6,
    specs=[
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            None,
            None,
            None,
        ],
    ]
    if model == 2
    else [
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
    ],
    subplot_titles=titles,
    vertical_spacing=0.05,
)


for i in range(1, 6, 1):
    fig.add_trace(
        go.Scatter3d(
            x=img[0],
            y=img[1],
            z=img[2],
            mode="markers",
            showlegend=False,
            marker=dict(
                size=2.5,
                color=NormalizeData(np.abs(data[model][n, i - 1, :, :])).flatten(),
                colorscale=cmap,
            ),
        ),
        row=1,
        col=i + 1,
    )

for i in range(6):
    fig.add_trace(
        go.Scatter3d(
            x=img[0],
            y=img[1],
            z=img[2],
            mode="markers",
            showlegend=False,
            marker=dict(
                size=3,
                color=NormalizeData(np.abs(data[model][n, i + 5, :, :])).flatten(),
                colorscale=cmap,
                colorbar=colorbar if i == 1 else None,
            ),
        ),
        row=2,
        col=i + 1,
    )

if model == 2:
    for i in range(3):
        fig.add_trace(
            go.Scatter3d(
                x=img[0],
                y=img[1],
                z=img[2],
                mode="markers",
                showlegend=False,
                marker=dict(
                    size=3,
                    color=NormalizeData(np.abs(data[model][n, i + 11, :, :])).flatten(),
                    colorscale=cmap_2,
                ),
            ),
            row=3,
            col=i + 1,
        )

fig.add_trace(
    go.Scatter3d(
        x=img[0],
        y=img[1],
        z=img[2],
        mode="markers",
        showlegend=False,
        marker=dict(
            size=3,
            color="black",
        ),
    ),
    row=1,
    col=1,
)

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
    aspectmode="cube",
    camera=dict(eye=dict(x=-0.1, y=0.1, z=1.5))
    if fnmatch.fnmatch(file, "*shapenet*")
    else dict(eye=dict(x=1.25, y=1.25, z=1.25)),
    xaxis=dict(range=[-1, 1], visible=False),
    yaxis=dict(range=[-1, 1], visible=False),
    zaxis=dict(range=[-1, 1], visible=False),
)

fig.update_annotations(font=dict(family="Helvetica", size=22))

fig.update_layout(
    title=dict(
        text="<b>3D Attribution and Attention for " + models[model] + " Model</b>",
        font=dict(family="Helvetica", size=28),
        x=0.03,
    ),
    height=1200 if model == 2 else 750,
    width=2000 if model == 2 else 2000,
    font=dict(
        family="Helvetica",
        color="#000000",
    ),
)


# fig.write_image("data/figures/3DPC_"+ datamodule.__name__ + "_"+ str(model) +"_Importance.png", scale=2)
fig.show()

## Gif Animation

In [None]:
x_eye = 1
y_eye = 1
z_eye = 0.7


def rotate_z(x, y, z, theta):
    w = x + 1j * y
    return np.real(np.exp(1j * theta) * w), np.imag(np.exp(1j * theta) * w), z


for t in np.arange(0, 6.26, 0.2):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)

    fig.update_scenes(camera_eye=dict(x=xe, y=ye, z=ze))

    fig.write_image("data/figures/gif/frame_" + str(t) + "_.png", scale=1)

In [None]:
from PIL import Image

imgs = (
    Image.open(f)
    for f in sorted(glob.glob("data/figures/gif/frame_*"), key=os.path.getmtime)
)
img = next(imgs)  # extract first image from iterator
img.save(
    fp="data/figures/gif/3DPC_" + str(model) + "_md.gif",
    format="GIF",
    append_images=imgs,
    save_all=True,
    duration=120,
    loop=0,
)

In [None]:
for i in glob.glob("data/figures/gif/frame_*"):
    os.remove(i)