In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, accuracy_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from tensorflow.keras.layers import *
from tensorflow.keras import models, Model
import qkeras
from qkeras import *
from sparsepixels.layers import *
from sparsepixels.utils import *

(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

n_val = 10000
X_val = X_train[:n_val]
y_val = y_train[:n_val]
X_train = X_train[n_val:]
y_train = y_train[n_val:]

X_train = np.reshape(X_train, (-1,28,28,1)) / 255.
X_val = np.reshape(X_val, (-1,28,28,1)) / 255.
X_test = np.reshape(X_test, (-1,28,28,1)) / 255.

y_train = to_categorical(y_train, 10)
y_val = to_categorical(y_val, 10)
y_test = to_categorical(y_test, 10)

print("X_train.shape: " + str(X_train.shape))
print("y_train.shape: " + str(y_train.shape))
print("X_val.shape: " + str(X_val.shape))
print("y_val.shape: " + str(y_val.shape))
print("X_test.shape: " + str(X_test.shape))
print("y_test.shape: " + str(y_test.shape))

In [None]:
import os
import random
os.environ['PYTHONHASHSEED'] = str(1221)
random.seed(1221)
tf.random.set_seed(1221)
np.random.seed(1211)

noise_type='uniform'
#noise_type='poisson'
#noise_level=0.42
noise_level=0
inflate_factor=3.5
threshold=0.4
target_size_x=48
target_size_y=target_size_x

x_train_pooled = pool_pad_noise_inflate(X_train, pool_size=3, pool_type='avg', target_size=(target_size_x,target_size_y), noise_type=noise_type, noise_level=0, inflate_factor=1)
x_train_pooled_inflated = pool_pad_noise_inflate(X_train, pool_size=3, pool_type='avg', target_size=(target_size_x,target_size_y), noise_type=noise_type, noise_level=0, inflate_factor=inflate_factor)

x_train = pool_pad_noise_inflate(X_train, pool_size=3, pool_type='avg', target_size=(target_size_x,target_size_y), noise_type=noise_type, noise_level=noise_level, inflate_factor=inflate_factor)
x_val = pool_pad_noise_inflate(X_val, pool_size=3, pool_type='avg', target_size=(target_size_x,target_size_y), noise_type=noise_type, noise_level=noise_level, inflate_factor=inflate_factor)
x_test = pool_pad_noise_inflate(X_test, pool_size=3, pool_type='avg', target_size=(target_size_x,target_size_y), noise_type=noise_type, noise_level=noise_level, inflate_factor=inflate_factor)

x_train[x_train < threshold] = 0.
x_val[x_val < threshold] = 0.
x_test[x_test < threshold] = 0.

for i in range(5):
    plot_sparsemnist(X_train, x_train_pooled, x_train_pooled_inflated, x_train, i, threshold=threshold)

In [None]:
def build_cnn(is_sparse, B=16, I=6, n_max_pixels=None):
    #quantizer = quantized_bits(16, 6, alpha=1)
    #quantized_relu = 'quantized_relu(16, 6)'

    quantizer = quantized_bits(B, I, alpha=1)
    quantized_relu = f'quantized_relu({B}, {I})'

    x_in = keras.Input(shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3]), name='x_in')
    if is_sparse:
        x, keep_mask = InputReduce(n_max_pixels=n_max_pixels, threshold=threshold, name='input_reduce')(x_in)
    else:
        x = x_in

    if is_sparse:
        x = QConv2DSparse(filters=1, kernel_size=7, use_bias=True, name='conv1', padding='same', strides=1,
                          kernel_quantizer=quantizer, bias_quantizer=quantizer)([x, keep_mask])
        x = QActivation(quantized_relu, name='relu1')(x)
        x, keep_mask = AveragePooling2DSparse(4, name='pool1')([x, keep_mask])

        x = QConv2DSparse(filters=3, kernel_size=5, use_bias=True, name='conv2', padding='same', strides=1,
                          kernel_quantizer=quantizer, bias_quantizer=quantizer)([x, keep_mask])
        x = QActivation(quantized_relu, name='relu2')(x)
        x, keep_mask = AveragePooling2DSparse(2, name='pool2')([x, keep_mask])

    else:
        x = QConv2D(filters=1, kernel_size=7, use_bias=True, name='conv1', padding='same', strides=1,
                    kernel_quantizer=quantizer, bias_quantizer=quantizer)(x)
        x = QActivation(quantized_relu, name='relu1')(x)
        x = AveragePooling2D(4, name='pool1')(x)
        
        x = QConv2D(filters=3, kernel_size=5, use_bias=True, name='conv2', padding='same', strides=1,
                    kernel_quantizer=quantizer, bias_quantizer=quantizer)(x)
        x = QActivation(quantized_relu, name='relu2')(x)
        x = AveragePooling2D(2, name='pool2')(x)

    x = Flatten(name='flatten')(x)

    x = QDense(36, kernel_quantizer=quantizer, bias_quantizer=quantizer, name='dense1')(x)
    x = QActivation(quantized_relu, name='relu3')(x)

    x = QDense(10, kernel_quantizer=quantizer, bias_quantizer=quantizer, name='dense2')(x)
    x = Activation('softmax', name='softmax')(x)

    name = 'cnn_sparse'
    if not is_sparse:
        name = 'cnn_full'
    return keras.Model(x_in, x, name=name)

cnn_full = build_cnn(is_sparse=False)
cnn_full.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics = ['accuracy'])

print(cnn_full.summary())

In [None]:
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=1e-3,
    patience=15,
    mode='min',
    restore_best_weights=True,
)

cnn_sparse_t = build_cnn(is_sparse=True, n_max_pixels=8)
cnn_sparse_t.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics = ['accuracy'])

history = cnn_sparse_t.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=200, batch_size=128, callbacks=[early_stop])

plt.figure(figsize=(15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label='train loss')
axes.plot(history.history['val_loss'], label='val loss')
axes.legend(loc="upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

In [None]:
cnn_sparse_s = build_cnn(is_sparse=True, n_max_pixels=12)
cnn_sparse_s.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics = ['accuracy'])

history = cnn_sparse_s.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=200, batch_size=128, callbacks=[early_stop])

plt.figure(figsize=(15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label='train loss')
axes.plot(history.history['val_loss'], label='val loss')
axes.legend(loc="upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

In [None]:
cnn_sparse_m = build_cnn(is_sparse=True, n_max_pixels=16)
cnn_sparse_m.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics = ['accuracy'])

history = cnn_sparse_m.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=200, batch_size=128, callbacks=[early_stop])

plt.figure(figsize=(15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label='train loss')
axes.plot(history.history['val_loss'], label='val loss')
axes.legend(loc="upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

In [None]:
cnn_sparse_l = build_cnn(is_sparse=True, n_max_pixels=20)
cnn_sparse_l.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics = ['accuracy'])

history = cnn_sparse_l.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=200, batch_size=128, callbacks=[early_stop])

plt.figure(figsize=(15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label='train loss')
axes.plot(history.history['val_loss'], label='val loss')
axes.legend(loc="upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

In [None]:
cnn_full = build_cnn(is_sparse=False)
cnn_full.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics = ['accuracy'])

history = cnn_full.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=200, batch_size=128, callbacks=[early_stop])

plt.figure(figsize=(15,10))
axes = plt.subplot(2, 2, 1)
axes.plot(history.history['loss'], label='train loss')
axes.plot(history.history['val_loss'], label='val loss')
axes.legend(loc="upper right")
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')

In [None]:
y_pred_sparse_t_8b = cnn_sparse_t_8b.predict(x_test)
y_pred_sparse_s_8b = cnn_sparse_s_8b.predict(x_test)
y_pred_sparse_m_8b = cnn_sparse_m_8b.predict(x_test)
y_pred_sparse_l_8b = cnn_sparse_l_8b.predict(x_test)
y_pred_full_8b = cnn_full_8b.predict(x_test)
print("acc (sparse cnn-t) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_t_8b, axis=1))))
print("acc (sparse cnn-s) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_s_8b, axis=1))))
print("acc (sparse cnn-m) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_m_8b, axis=1))))
print("acc (sparse cnn-l) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_l_8b, axis=1))))
print("acc (full cnn) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_full_8b, axis=1))))

y_pred_sparse_t_16b = cnn_sparse_t_16b.predict(x_test)
y_pred_sparse_s_16b = cnn_sparse_s_16b.predict(x_test)
y_pred_sparse_m_16b = cnn_sparse_m_16b.predict(x_test)
y_pred_sparse_l_16b = cnn_sparse_l_16b.predict(x_test)
y_pred_full_16b = cnn_full_16b.predict(x_test)
print("acc (sparse cnn-t) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_t_16b, axis=1))))
print("acc (sparse cnn-s) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_s_16b, axis=1))))
print("acc (sparse cnn-m) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_m_16b, axis=1))))
print("acc (sparse cnn-l) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_l_16b, axis=1))))
print("acc (full cnn) = {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_full_16b, axis=1))))

In [None]:
def plot_roc(y_test, y_pred_sparse_s, y_pred_sparse_m, y_pred_sparse_l, y_pred_full, labels):
    plt.figure(figsize=(8, 8))
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    for x, label in enumerate(labels):
        color = colors[x % len(colors)]
        fpr_full, tpr_full, _ = roc_curve(y_test[:, x], y_pred_full[:, x])
        fpr_sparse_s, tpr_sparse_s, _ = roc_curve(y_test[:, x], y_pred_sparse_s[:, x])
        fpr_sparse_m, tpr_sparse_m, _ = roc_curve(y_test[:, x], y_pred_sparse_m[:, x])
        fpr_sparse_l, tpr_sparse_l, _ = roc_curve(y_test[:, x], y_pred_sparse_l[:, x])
        plt.plot(tpr_full, fpr_full, label='{0} ({1:.4f}), full'.format(label, auc(fpr_full, tpr_full)), linestyle='-', lw=1.5, color=color)
        plt.plot(tpr_sparse_s, fpr_sparse_s, label='{0} ({1:.4f}), sparse-s'.format(label, auc(fpr_sparse_s, tpr_sparse_s)), linestyle='--', lw=1.5, color=color)
        plt.plot(tpr_sparse_m, fpr_sparse_m, label='{0} ({1:.4f}), sparse-m'.format(label, auc(fpr_sparse_m, tpr_sparse_m)), linestyle='dotted', lw=1.5, color=color)
        plt.plot(tpr_sparse_l, fpr_sparse_l, label='{0} ({1:.4f}), sparse-l'.format(label, auc(fpr_sparse_l, tpr_sparse_l)), linestyle='-.', lw=1.5, color=color)
    plt.semilogy()
    plt.xlabel("tpr", size=12, loc='right')
    plt.ylabel("fpr", size=12, loc='top')
    plt.xlim(0., 1)
    plt.ylim(0.005, 1)
    plt.legend(loc='best', framealpha=0., prop={'size': 6})

#plot_roc(y_test, y_pred_sparse_s, y_pred_sparse_m, y_pred_sparse_l, y_pred_full, ['0','1','2','3','4','5','6','7','8','9'])

def plot_auc_vs_label(y_test, y_pred_sparse_t, y_pred_sparse_s, y_pred_sparse_m, y_pred_sparse_l, y_pred_full, labels):
    auc_full = []
    auc_sparse_t = []
    auc_sparse_s = []
    auc_sparse_m = []
    auc_sparse_l = []
    n_cls = y_test.shape[1]
    for k in range(n_cls):
        fpr, tpr, _ = roc_curve(y_test[:, k], y_pred_full[:, k])
        auc_full.append(auc(fpr, tpr))

        fpr, tpr, _ = roc_curve(y_test[:, k], y_pred_sparse_t[:, k])
        auc_sparse_t.append(auc(fpr, tpr))

        fpr, tpr, _ = roc_curve(y_test[:, k], y_pred_sparse_s[:, k])
        auc_sparse_s.append(auc(fpr, tpr))

        fpr, tpr, _ = roc_curve(y_test[:, k], y_pred_sparse_m[:, k])
        auc_sparse_m.append(auc(fpr, tpr))

        fpr, tpr, _ = roc_curve(y_test[:, k], y_pred_sparse_l[:, k])
        auc_sparse_l.append(auc(fpr, tpr))


    auc_sparse_t.append(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_t, axis=1)))
    auc_sparse_s.append(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_s, axis=1)))
    auc_sparse_m.append(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_m, axis=1)))
    auc_sparse_l.append(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_sparse_l, axis=1)))
    auc_full.append(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_pred_full, axis=1)))
    x = np.arange(n_cls+1)

    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    c_full, c_t, c_s, c_m, c_l= colors[:5]

    plt.figure(figsize=(10,5))
    marker_size = 200
    alpha = 0.5
    plt.scatter(x, auc_full, s=marker_size, marker="o", color=c_full, label="standard", alpha=alpha)
    plt.scatter(x, auc_sparse_t, s=marker_size, marker="^", color=c_t, label="sparse-t", alpha=alpha)
    plt.scatter(x, auc_sparse_s, s=marker_size, marker="D", color=c_s, label="sparse-s", alpha=alpha)
    plt.scatter(x, auc_sparse_m, s=marker_size, marker="s", color=c_m, label="sparse-m", alpha=alpha)
    plt.scatter(x, auc_sparse_l, s=marker_size, marker="v", color=c_l, label="sparse-l", alpha=alpha)

    plt.axvline(x=n_cls-0.5, color="gray", linestyle="--", linewidth=1)

    plt.xticks(x, labels, fontsize=10)
    plt.ylim(0.7, 1.02)
    plt.grid(axis="y", alpha=0.2)
    plt.legend(fontsize=12)
    plt.tight_layout()

plot_auc_vs_label(y_test, y_pred_sparse_t, y_pred_sparse_s, y_pred_sparse_m, y_pred_sparse_l, y_pred_full, ['AUC-0','AUC-1','AUC-2','AUC-3','AUC-4','AUC-5','AUC-6','AUC-7','AUC-8','AUC-9','Accuracy'])

In [None]:
from matplotlib.lines import Line2D

def plot_auc_vs_label_8b16b(
    y_test,
    y_pred_sparse_t_8b, y_pred_sparse_s_8b, y_pred_sparse_m_8b, y_pred_sparse_l_8b, y_pred_full_8b,
    y_pred_sparse_t_16b, y_pred_sparse_s_16b, y_pred_sparse_m_16b, y_pred_sparse_l_16b, y_pred_full_16b,
    labels,
    marker_size=200,
    filled_alpha=0.6,
    hollow_lw=1.8,
    dx=0.1,
    ylim=(0.7, 1.02),
):
    def aucs_plus_acc(y_true, y_pred):
        n_cls = y_true.shape[1]
        vals = []
        for k in range(n_cls):
            fpr, tpr, _ = roc_curve(y_true[:, k], y_pred[:, k])
            vals.append(auc(fpr, tpr))
        vals.append(accuracy_score(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1)))
        return np.array(vals)

    n_cls = y_test.shape[1]
    x = np.arange(n_cls + 1)
    x_left, x_right = x - dx, x + dx

    vals_16b = {
        "standard": aucs_plus_acc(y_test, y_pred_full_16b),
        "sparse-t": aucs_plus_acc(y_test, y_pred_sparse_t_16b),
        "sparse-s": aucs_plus_acc(y_test, y_pred_sparse_s_16b),
        "sparse-m": aucs_plus_acc(y_test, y_pred_sparse_m_16b),
        "sparse-l": aucs_plus_acc(y_test, y_pred_sparse_l_16b),
    }
    vals_8b = {
        "standard": aucs_plus_acc(y_test, y_pred_full_8b),
        "sparse-t": aucs_plus_acc(y_test, y_pred_sparse_t_8b),
        "sparse-s": aucs_plus_acc(y_test, y_pred_sparse_s_8b),
        "sparse-m": aucs_plus_acc(y_test, y_pred_sparse_m_8b),
        "sparse-l": aucs_plus_acc(y_test, y_pred_sparse_l_8b),
    }

    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    c_full, c_t, c_s, c_m, c_l = colors[:5]
    color_map = {"standard": c_full, "sparse-t": c_t, "sparse-s": c_s, "sparse-m": c_m, "sparse-l": c_l}
    marker_map = {"standard": "o", "sparse-t": "^", "sparse-s": "D", "sparse-m": "s", "sparse-l": "v"}

    plt.figure(figsize=(10, 5))
    model_order = ["standard", "sparse-t", "sparse-s", "sparse-m", "sparse-l"]

    handles = []
    for name in model_order:
        col = color_map[name]; mrk = marker_map[name]
        # 16b (right, filled)
        h = plt.scatter(x_right, vals_16b[name], s=marker_size, marker=mrk, color=col,
                        alpha=filled_alpha, zorder=3, label=name)
        handles.append(h)
        # 8b (left, hollow)
        plt.scatter(x_left, vals_8b[name], s=marker_size, marker=mrk, facecolors="none",
                    edgecolors=col, linewidths=hollow_lw, zorder=4)

    plt.axvline(x=n_cls - 0.5, color="gray", linestyle="--", linewidth=1)

    plt.xticks(x, labels, fontsize=10)
    plt.xlim(-0.5 - dx*1.5, n_cls + 0.5 + dx*1.5)
    plt.ylim(*ylim)
    plt.grid(axis="y", alpha=0.2)
    plt.tight_layout()

    first_legend = plt.legend(handles=handles, loc="lower left", fontsize=12, title="Model")
    plt.gca().add_artist(first_legend)
    precision_handles = [
        Line2D([0], [0], marker='o', linestyle='None', markersize=10,
               markerfacecolor='none', markeredgecolor='k', linewidth=0, label='8-bit'),
        Line2D([0], [0], marker='o', linestyle='None', markersize=10,
               markerfacecolor='k', alpha=filled_alpha, markeredgecolor='k', label='16-bit'),
    ]
    plt.legend(handles=precision_handles, loc="lower right", fontsize=12, title="Precision")
    plt.savefig('plots/mnist_performance.png')

plot_auc_vs_label_8b16b(
    y_test,
    y_pred_sparse_t_8b, y_pred_sparse_s_8b, y_pred_sparse_m_8b, y_pred_sparse_l_8b, y_pred_full_8b,
    y_pred_sparse_t_16b, y_pred_sparse_s_16b, y_pred_sparse_m_16b, y_pred_sparse_l_16b, y_pred_full_16b,
    ['AUC-0','AUC-1','AUC-2','AUC-3','AUC-4','AUC-5','AUC-6','AUC-7','AUC-8','AUC-9','Accuracy'],
    dx=0.16, ylim=(0.6, 1.02)
)

In [None]:
layer_names = [
    'x_in', 'input_reduce',
    'conv1', 'relu1', 'pool1',
    'conv2', 'relu2', 'pool2',
]

plot_tensors = []
plot_names = []
for name in layer_names:
    layer = cnn_sparse_l_16b.get_layer(name)
    output = layer.output
    if isinstance(output, (list, tuple)):
        plot_tensors.append(output[0])
        plot_names.append(f'{name} (x_reduced)')
        plot_tensors.append(output[1])
        plot_names.append(f'{name} (x_mask)')
    else:
        plot_tensors.append(output)
        plot_names.append(name)

model_cnnpart = models.Model(inputs=cnn_sparse_l_16b.input, outputs=plot_tensors)
ii=10
#ii=21
layers_pred = model_cnnpart.predict(x_test[ii:ii+1])
print(y_test[ii:ii+1])

i = 0
cmap='gray'
cmap='viridis'
while i < len(plot_names):
    name = plot_names[i]

    if "(x_reduced)" in name and i+1 < len(plot_names) and "(x_mask)" in plot_names[i+1]:
        out_r = layers_pred[i] # (1, h, w, c)
        out_m = layers_pred[i+1] # (1, h, w, 1)

        arr_r = out_r[0] # (h, w, c)
        arr_m = out_m[0,...,0] # (h, w)
        h, w, c = arr_r.shape

        #fig, axes = plt.subplots(1, c+1, figsize=((c+1)*3, 3))
        fig, ax = plt.subplots(1, c, figsize=((c)*3, 3))
        #fig.suptitle(name.replace(" (x_reduced)", ""), fontsize=14)

        for ch in range(c):
            #ax = axes[ch]
            ax.imshow(arr_r[..., ch], cmap=cmap)
            ax.set_title(f"ch{ch}")

        #axm = axes[c]
        #axm.imshow(arr_m, cmap=cmap)
        #axm.set_title("mask")

        plt.tight_layout()
        #plt.show()
        plt.savefig(f'plots/mnist_conv_in_{i}.png')

        i += 2
        continue

    out = layers_pred[i]
    arr = out[0]

    if arr.ndim == 2:
        fig, ax = plt.subplots(1, 1, figsize=(3, 3))
        #fig.suptitle(name, fontsize=14)
        ax.imshow(arr, cmap=cmap)
        ax.set_title("ch0")
        plt.tight_layout()
        #plt.show()
        plt.savefig(f'plots/mnist_conv_out_{i}.png')

    elif arr.ndim == 3:
        h, w, c = arr.shape
        fig, axes = plt.subplots(1, c, figsize=(c*3, 3))
        #fig.suptitle(name, fontsize=14)
        for ch in range(c):
            ax = axes[ch] if c>1 else axes
            ax.imshow(arr[..., ch], cmap=cmap)
            ax.set_title(f"ch{ch}")
        plt.tight_layout()
        #plt.show()
        plt.savefig(f'plots/mnist_conv_out__{i}.png')

    i += 1

In [None]:
cnn_full.save_weights('weights/mnist_full-16b.h5')
cnn_sparse_t.save_weights('weights/mnist_sparse_t-16b.h5')
cnn_sparse_s.save_weights('weights/mnist_sparse_s-16b.h5')
cnn_sparse_m.save_weights('weights/mnist_sparse_m-16b.h5')
cnn_sparse_l.save_weights('weights/mnist_sparse_l-16b.h5')

In [None]:
cnn_full_8b = build_cnn(is_sparse=False, B=8, I=0)
cnn_sparse_t_8b = build_cnn(is_sparse=True, B=8, I=0, n_max_pixels=8)
cnn_sparse_s_8b = build_cnn(is_sparse=True, B=8, I=0, n_max_pixels=12)
cnn_sparse_m_8b = build_cnn(is_sparse=True, B=8, I=0, n_max_pixels=16)
cnn_sparse_l_8b = build_cnn(is_sparse=True, B=8, I=0, n_max_pixels=20)

cnn_full_8b.load_weights('weights/mnist_full-8b.h5')
cnn_sparse_t_8b.load_weights('weights/mnist_sparse_t-8b.h5')
cnn_sparse_s_8b.load_weights('weights/mnist_sparse_s-8b.h5')
cnn_sparse_m_8b.load_weights('weights/mnist_sparse_m-8b.h5')
cnn_sparse_l_8b.load_weights('weights/mnist_sparse_l-8b.h5')

cnn_full_16b = build_cnn(is_sparse=False, B=16, I=6)
cnn_sparse_t_16b = build_cnn(is_sparse=True, B=16, I=6, n_max_pixels=8)
cnn_sparse_s_16b = build_cnn(is_sparse=True, B=16, I=6, n_max_pixels=12)
cnn_sparse_m_16b = build_cnn(is_sparse=True, B=16, I=6, n_max_pixels=16)
cnn_sparse_l_16b = build_cnn(is_sparse=True, B=16, I=6, n_max_pixels=20)

cnn_full_16b.load_weights('weights/mnist_full-16b.h5')
cnn_sparse_t_16b.load_weights('weights/mnist_sparse_t-16b.h5')
cnn_sparse_s_16b.load_weights('weights/mnist_sparse_s-16b.h5')
cnn_sparse_m_16b.load_weights('weights/mnist_sparse_m-16b.h5')
cnn_sparse_l_16b.load_weights('weights/mnist_sparse_l-16b.h5')

## hls

In [None]:
def build_cnn_sparse_forhls(cnn_sparse):
    x_in = keras.Input(shape=cnn_sparse.input_shape[1:], name="x_in")
    x = x_in
    for layer in cnn_sparse.layers:
        if isinstance(layer, keras.layers.InputLayer):
            continue
        if isinstance(layer, InputReduce):
            continue
        if layer.name.startswith("mask_pool"):
            continue

        if isinstance(layer, QConv2DSparse):
            cfg = layer.conv.get_config()
            cfg["use_bias"] = True
            cfg["name"] = layer.name
            cfg["bias_quantizer"] = layer._bias_quant_cfg

            conv_full = QConv2D.from_config(cfg)
            x = conv_full(x)

            kernel_w = layer.conv.get_weights()[0]
            bias_w = keras.backend.get_value(layer.bias)
            conv_full.set_weights([kernel_w, bias_w])
        elif isinstance(layer, AveragePooling2DSparse):
            x = layer.avg_pool(x)
        else:
            x = layer(x)

    return keras.Model(x_in, x, name='cnn_sparse_forhls')

cnn_sparse_t_forhls = build_cnn_sparse_forhls(cnn_sparse_t)
cnn_sparse_s_forhls = build_cnn_sparse_forhls(cnn_sparse_s)
cnn_sparse_m_forhls = build_cnn_sparse_forhls(cnn_sparse_m)
cnn_sparse_l_forhls = build_cnn_sparse_forhls(cnn_sparse_l)
#cnn_sparse_s_forhls.summary()

In [None]:
import hls4ml
def write_sparse_hls(cnn_sparse_forhls, name):
    config = hls4ml.utils.config_from_keras_model(cnn_sparse_forhls, granularity='name', backend='Vitis')
    #config['LayerName']['x_in']['Precision'] = 'ap_ufixed<8,1>'
    #config

    cnn_sparse_hls = hls4ml.converters.convert_from_keras_model(
        cnn_sparse_forhls,
        hls_config=config,
        project_name='myhls',
        backend='Vitis',
        output_dir=f'hls_proj/sparsemnist/model-16b/{name}',
        part='xcu250-figd2104-2L-e',
        io_type='io_parallel',
    )

    #cnn_sparse_hls.compile()
    cnn_sparse_hls.write()

write_sparse_hls(cnn_sparse_t_forhls, 'hls_sparse_t')
write_sparse_hls(cnn_sparse_s_forhls, 'hls_sparse_s')
write_sparse_hls(cnn_sparse_m_forhls, 'hls_sparse_m')
write_sparse_hls(cnn_sparse_l_forhls, 'hls_sparse_l')

In [None]:
import hls4ml
config = hls4ml.utils.config_from_keras_model(cnn_full, granularity='name', backend='Vivado')
#config['LayerName']['x_in']['Precision'] = 'ap_ufixed<8,1>'
#config['LayerName']['conv1']['ParallelizationFactor'] = 100
#config['LayerName']['conv2']['ParallelizationFactor'] = 50
#config

cnn_full_hls = hls4ml.converters.convert_from_keras_model(
    cnn_full,
    hls_config=config,
    project_name='myhls',
    backend='Vivado',
    output_dir='hls_proj/sparsemnist/model-16b/hls_full',
    part='xcu250-figd2104-2L-e',
    io_type='io_stream',
)

#cnn_full_hls.compile()
cnn_full_hls.write()

## test bench

In [None]:
n_tb = 100
x_tb = x_test[:n_tb]
y_tb_t = y_pred_sparse_t[:n_tb]
y_tb_s = y_pred_sparse_s[:n_tb]
y_tb_m = y_pred_sparse_m[:n_tb]
y_tb_l = y_pred_sparse_l[:n_tb]
y_tb_full = y_pred_full[:n_tb]

bit = "16b"

# inputs
with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_t/tb_data/tb_input_features.dat", "w") as f:
    x_tb_flat = x_tb.reshape(n_tb, -1)
    for row in x_tb_flat:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_s/tb_data/tb_input_features.dat", "w") as f:
    x_tb_flat = x_tb.reshape(n_tb, -1)
    for row in x_tb_flat:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_m/tb_data/tb_input_features.dat", "w") as f:
    x_tb_flat = x_tb.reshape(n_tb, -1)
    for row in x_tb_flat:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_l/tb_data/tb_input_features.dat", "w") as f:
    x_tb_flat = x_tb.reshape(n_tb, -1)
    for row in x_tb_flat:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_full/tb_data/tb_input_features.dat", "w") as f:
    x_tb_flat = x_tb.reshape(n_tb, -1)
    for row in x_tb_flat:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

# predictions
with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_t/tb_data/tb_output_predictions.dat", "w") as f:
    for row in y_tb_t:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_s/tb_data/tb_output_predictions.dat", "w") as f:
    for row in y_tb_s:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_m/tb_data/tb_output_predictions.dat", "w") as f:
    for row in y_tb_m:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_sparse_l/tb_data/tb_output_predictions.dat", "w") as f:
    for row in y_tb_l:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

with open("hls_proj/sparsemnist/model-"+bit+"/hls_full/tb_data/tb_output_predictions.dat", "w") as f:
    for row in y_tb_full:
        f.write(" ".join(str(v) for v in row))
        f.write("\n")

## info leak at reduction

In [None]:
plot_sparsemnist(X_train, x_train_pooled, x_train_pooled_inflated, x_train, 4, threshold=threshold)

In [None]:
def plot_sparsemnist_p(x_original, x_modified1, x_modified2, x_modified3, n_example, threshold, figname=None):
    img1 = x_original[n_example+1011]
    img2 = x_modified1[n_example+1011]
    img3 = x_modified2[n_example+1011]
    img4 = x_modified3[n_example+1011]
    img5 = np.where(img4 > threshold, img4, 0)

    fontsize=18
    fig, axes = plt.subplots(1, 2, figsize=(10,5))
    axes[0].imshow(img1)
    #axes[0].set_title('[0] original', fontsize=fontsize)
    #axes[1].imshow(img2)
    #axes[1].set_title('[1] pooled+padded', fontsize=fontsize)
    axes[1].imshow(img3)
    #axes[2].set_title('[2] inflated', fontsize=fontsize)
    #axes[3].imshow(img4)
    #axes[3].set_title('[3] noised', fontsize=fontsize)
    #axes[4].imshow(img5)
    #axes[4].set_title(f'[4] noised (threshold>{threshold})', fontsize=fontsize)
    plt.tight_layout()
    #plt.show()
    if figname is not None:
     plt.savefig(f'plots/{figname}.png')

plot_sparsemnist_p(X_train, x_train_pooled, x_train_pooled_inflated, x_train, 4, threshold=threshold, figname='mnist_input')

In [None]:
def plot_sparsemnist_leak(x_original, x_modified1, x_modified2, x_modified3, n_example, threshold, figname=None):
    img1 = x_original[n_example + 1011]
    img2 = x_modified1[n_example + 1011]
    img3 = x_modified2[n_example + 1011]
    img4 = x_modified3[n_example + 1011]
    img5 = np.where(img4 > threshold, img4, 0)

    img5_2d = img5.squeeze()

    def keep_first_k_nonzeros(img2d, k):
        flat = img2d.ravel(order='C')
        nz_idx = np.flatnonzero(flat)
        k = min(k, nz_idx.size)
        if k == 0:
            return np.zeros_like(img2d)
        keep = np.zeros_like(flat, dtype=bool)
        keep[nz_idx[:k]] = True
        keep = keep.reshape(img2d.shape)
        return np.where(keep, img2d, 0)

    imgs = [
        #img5_2d,
        keep_first_k_nonzeros(img5_2d, 4),
        keep_first_k_nonzeros(img5_2d, 8),
        keep_first_k_nonzeros(img5_2d, 12),
        keep_first_k_nonzeros(img5_2d, 16),
    ]

    fontsize = 18
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    vmax = float(img5_2d.max()) if img5_2d.max() > 0 else 1.0

    for ax, im in zip(axes, imgs):
        ax.imshow(im, vmin=0, vmax=vmax)
        #ax.set_title(title, fontsize=fontsize)
        ax.axis('off')

    plt.tight_layout()
    if figname is not None:
        os.makedirs("plots", exist_ok=True)
        plt.savefig(f"plots/{figname}.png")

plot_sparsemnist_leak(X_train, x_train_pooled, x_train_pooled_inflated, x_train, 4, threshold=threshold, figname='mnist_leak')