# 0. Configure the notebook for visualization

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

# 1. Sanity check the visualization function

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow.keras.applications import VGG16

from pyment.interpretability import maximize_feature_activation

np.random.seed(42)
tf.random.set_seed(42)

model = VGG16(weights='imagenet', include_top=True)
layer = 'predictions'

labels = {
    'goldfish': 1,
    'vulture': 23,
    'tarantula': 76,
    'zebra': 340
}

keys = list(labels.keys())

for i in range(len(keys)):
    label = keys[i]
    index = labels[keys[i]]

    initial = np.random.uniform(size=(224, 224, 3)).astype(np.float32)
    initial = (initial - 0.5) * 0.25

    img = maximize_feature_activation(model, layer=layer, index=index, initial=initial,
                                      l2_decay=1e-3, blur_every=4, blur_width=1,
                                      norm_threshold=0.05, 
                                      contribution_threshold=0.1)

    img = img - np.amin(img)
    img = img / np.amax(img)

    fig = plt.figure(figsize=(5, 5))
    plt.imshow(img[0])
    plt.suptitle(label)
    plt.axis('off')
    plt.show()


# 2. Run the visualization on the brain age model

In [None]:
from pyment.models import RegressionSFCN
from pyment.interpretability import VolumeViewer

model = RegressionSFCN(weights='brain-age')

layer = 'Regression3DSFCN/block5/conv'
    

initial = np.random.uniform(size=(167, 212, 160)).astype(np.float32)

for i in range(64):
    img = maximize_feature_activation(model, layer=layer, index=0, initial=initial,
                                      l2_decay=1e-3, blur_every=4, blur_width=1,
                                      norm_threshold=0.05, 
                                      contribution_threshold=0.1)

    VolumeViewer(img[0], title=f'Feature {i}')
