# View a CNN for CIFAR-10 classification

In [12]:
import sys
from tensorflow import keras
import plotly.graph_objects as go
from ipywidgets import widgets
import numpy as np

In [6]:
from dnnviewer.Grapher import Grapher
from dnnviewer.TestData import TestData
import dnnviewer.layers
import dnnviewer.bridge.tensorflow as tf_bridge
from dnnviewer.bridge.tensorflow_datasets import load_test_data
from dnnviewer.bridge.KerasNetworkExtractor import KerasNetworkExtractor

# Load model

In [9]:
model0 = keras.models.load_model('../dnnviewer-data/models/CIFAR-10_CNN5.h5')
model0.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_0 (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv_1 (Conv2D)              (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0

# Grapher test

In [21]:
test_data = TestData()
load_test_data('cifar10', test_data)
test_data.x_format = tf_bridge.keras_prepare_input(model0.input.dtype.as_numpy_dtype, [None, 32, 32, 3], test_data.x)

In [22]:
fig_widget = go.FigureWidget()
fig_widget.update_layout(margin=dict(l=10, r=10, b=10, t=10))

grapher = Grapher()

# Create all other layers from the Keras Sequential model
extractor = KerasNetworkExtractor(grapher, model0, test_data)
extractor.process()

topn = widgets.IntSlider(
    value=3.0,
    min=1.0,
    max=4.0,
    step=1.0,
    description='Top N:',
    continuous_update=False
)

grapher.plot_layers(fig_widget)
grapher.plot_topn_connections(fig_widget, topn.value, 2, 10)

def set_topn(change):
    with fig_widget.batch_update():
        grapher.plot_topn_connections(fig_widget, topn.value, grapher.layers[2], 10)
      
topn.observe(set_topn, names='value')

fig_widget.update_layout(barmode='overlay')
top_bar = widgets.HBox(children=[topn])
main_widget = widgets.VBox([top_bar, fig_widget])

main_widget

VBox(children=(HBox(children=(IntSlider(value=3, continuous_update=False, description='Top N:', max=4, min=1),…

In [None]:
l = model0.layers[7]
l

In [None]:
len(l.get_weights())