# 1. Environment Setup


In [1]:
import json
import torch
import numpy as np
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from pathlib import Path
from types import SimpleNamespace
from importlib import import_module

%matplotlib inline

In [2]:
from senn.models.losses import *
from senn.models.parameterizers import *
from senn.models.conceptizers import *
from senn.models.aggregators import *
from senn.models.senn import SENN, DiSENN

In [3]:
from senn.datasets.dataloaders import get_dataloader
from senn.utils.plot_utils import show_explainations, show_prototypes, plot_lambda_accuracy, get_comparison_plot

In [4]:
from rotation import rotate_resize_and_cast, rotate_image

## Utility functions

In [5]:
def get_config(filename):
    config_path = Path('configs')
    config_file = config_path / filename
    with open(config_file, 'r') as f:
        config = json.load(f)
    
    return SimpleNamespace(**config)

In [6]:
def load_checkpoint(config):
    model_file = Path('results') / config.exp_name / "checkpoints" / "best_model.pt" 
    return torch.load(model_file, config.device)

In [7]:
def accuracy(model, dataloader, config):
    accuracies = []
    model.eval()
    with torch.no_grad():
        for i, (x, labels) in enumerate(dataloader):
            x = x.float().to(config.device)
            labels = labels.long().to(config.device)
            y_pred, (concepts, relevances), _ = model(x)
            accuracies.append((y_pred.argmax(axis=1) == labels).float().mean().item())
    accuracy = np.array(accuracies).mean()
    print(f"Test Mean Accuracy: {accuracy * 100: .3f} %")

## Config

In [8]:
plt.style.use('seaborn-v0_8-paper')

# 2. Basic Comparisons

In [9]:
mnist_config = get_config("mnist_lambda1e-4_seed29.json")
mnist_config.device = "cpu"

In [10]:
_, _, mnist_test_dl = get_dataloader(mnist_config)

In [11]:
conceptizer = ConvConceptizer(**mnist_config.__dict__)
parameterizer = ConvParameterizer(**mnist_config.__dict__)
aggregator = SumAggregator(**mnist_config.__dict__)

mnist_SENN = SENN(conceptizer, parameterizer, aggregator)

In [None]:
mnist_checkpoint = load_checkpoint(mnist_config)
mnist_SENN.load_state_dict(mnist_checkpoint['model_state'])

In [None]:
show_explainations(mnist_SENN, mnist_test_dl, 'mnist')

In [None]:
# Imports
import json
from os import path
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import cv2

from rotation import rotate_resize_and_cast

from senn.utils.plot_utils import (
    create_barplot,
    plot_lambda_accuracy,
    show_explainations
)

plt.style.use('seaborn-v0_8-paper')

def show_explanation_with_rotation(model, image, x_rotation=30, y_rotation=30, z_rotation=30, save_path=None):
    """Shows SENN explanations for original and rotated versions of an image."""
    device = 'cuda:0' if next(model.parameters()).is_cuda else 'cpu'
    model.eval()
    
    # Prepare original image
    original_img = image.unsqueeze(0).float().to(device)
    
    # Create rotated version
    image_np = image.squeeze().cpu().numpy()
    image_bgr = cv2.merge([image_np] * 3)
    _, rotated_img = rotate_resize_and_cast(
        image_bgr, 
        x_rotation=x_rotation,
        y_rotation=y_rotation,
        z_rotation=z_rotation
    )
    rotated_tensor = torch.tensor(rotated_img[:,:,0]).unsqueeze(0).float().to(device)
    
    # Get SENN explanations for both
    with torch.no_grad():
        y_pred_orig, (concepts_orig, relevances_orig), _ = model(original_img)
        y_pred_rot, (concepts_rot, relevances_rot), _ = model(rotated_tensor)
    
    if len(y_pred_orig.size()) > 1:
        y_pred_orig = y_pred_orig.argmax(1)
        y_pred_rot = y_pred_rot.argmax(1)
    
    # Get concept limits for consistent plotting
    concepts = torch.cat([concepts_orig, concepts_rot])
    concepts_min = concepts.min().item()
    concepts_max = concepts.max().item()
    concept_lim = max(abs(concepts_min), abs(concepts_max))
    
    # Create figure
    fig = plt.figure(figsize=(15, 6))
    gridsize = (2, 3)
    
    # Original image and its explanations
    ax1 = plt.subplot2grid(gridsize, (0, 0))
    ax2 = plt.subplot2grid(gridsize, (0, 1))
    ax3 = plt.subplot2grid(gridsize, (0, 2))
    
    # Rotated image and its explanations
    ax4 = plt.subplot2grid(gridsize, (1, 0))
    ax5 = plt.subplot2grid(gridsize, (1, 1))
    ax6 = plt.subplot2grid(gridsize, (1, 2))
    
    # Plot original image and explanations
    ax1.imshow(original_img.squeeze().cpu(), cmap='gray')
    ax1.set_axis_off()
    ax1.set_title(f'Original - Pred: {y_pred_orig.item()}', fontsize=12)
    
    create_barplot(ax2, relevances_orig[0], y_pred_orig[0], x_label='Relevances (θ)')
    ax2.xaxis.set_label_position('top')
    ax2.tick_params(which='major', labelsize=10)
    
    create_barplot(ax3, concepts_orig[0], y_pred_orig[0], 
                  x_lim=concept_lim, x_label='Concepts (h)')
    ax3.xaxis.set_label_position('top')
    ax3.tick_params(which='major', labelsize=10)
    
    # Plot rotated image and explanations
    ax4.imshow(rotated_tensor.squeeze().cpu(), cmap='gray')
    ax4.set_axis_off()
    ax4.set_title(f'Rotated ({x_rotation}°,{y_rotation}°,{z_rotation}°) - Pred: {y_pred_rot.item()}', fontsize=12)
    
    create_barplot(ax5, relevances_rot[0], y_pred_rot[0], x_label='Relevances (θ)')
    ax5.tick_params(which='major', labelsize=10)
    
    create_barplot(ax6, concepts_rot[0], y_pred_rot[0], 
                  x_lim=concept_lim, x_label='Concepts (h)')
    ax6.tick_params(which='major', labelsize=10)
    
    plt.tight_layout()
    plt.show()
    plt.close('all')

In [None]:
print("Moderate Rotations (-45° to 45°)")
print("---------------------------------")
for i in range(10):
    test_batch, _ = next(iter(mnist_test_dl))
    random_idx = np.random.randint(0, len(test_batch))
    sample_image = test_batch[random_idx]
    
    print(f"\nModerate Explanation {i+1}/10")
    show_explanation_with_rotation(
        model=mnist_SENN,
        image=sample_image,
        x_rotation=np.random.randint(-45, 45),
        y_rotation=np.random.randint(-45, 45),
        z_rotation=np.random.randint(-45, 45)
    )

In [None]:
print("\nExtreme Rotations (-90° to 90°)")
print("---------------------------------")
for i in range(10):
    test_batch, _ = next(iter(mnist_test_dl))
    random_idx = np.random.randint(0, len(test_batch))
    sample_image = test_batch[random_idx]
    
    print(f"\nExtreme Explanation {i+1}/10")
    show_explanation_with_rotation(
        model=mnist_SENN,
        image=sample_image,
        x_rotation=np.random.randint(-90, 90),
        y_rotation=np.random.randint(-90, 90),
        z_rotation=np.random.randint(-90, 90)
    )