In [None]:
%load_ext autoreload
%autoreload 2
from cifar import get_label, mapping, coarse_label

In [None]:
import tensorflow as tf

import tensorflow.keras as keras

import pandas as pd
import numpy as np
import gc
import random
import matplotlib.pyplot as plt

device = tf.device("gpu" if len(tf.config.list_physical_devices("GPU"))>0 else "cpu")

In [None]:
!nvidia-smi

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

In [None]:
x_train = tf.cast(tf.reshape(tf.image.rgb_to_grayscale(x_train), (-1, 32, 32, 1)), dtype=tf.float32) / 255.
y_train = y_train

x_test = tf.cast(tf.reshape(tf.image.rgb_to_grayscale(x_test), (-1, 32, 32, 1)), dtype=tf.float32) / 255.
y_test = y_test

In [None]:
i = random.randint(0, 50000)
plt.title(get_label(y_train[i][0]))
plt.imshow(x_train[i][:,:,0], cmap="gray")

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(256, (3,3), use_bias=False, activation = "relu"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(100, activation="softmax")
                           ])
    
model.compile(optimizer=tf.keras.optimizers.Adam(), 
              loss=tf.keras.losses.sparse_categorical_crossentropy, 
              metrics = ["accuracy"])

In [None]:
with device:
    model.fit(x_train, y_train, epochs = 2, batch_size=64)

In [None]:
with device:
    model.evaluate(x_test, y_test)

In [None]:
filters = tf.transpose(model.layers[0].weights[0], [2, 3, 0, 1])[0]

In [None]:
i = random.randint(0, 256)
plt.imshow(filters[i], cmap="gray")

In [None]:
tf.keras.backend.clear_session()
gc.collect()

In [None]:
n = 3000
activation  = tf.reduce_mean(model.layers[0](x_test[:n]), axis=(1,2)).numpy().T

In [None]:
df = pd.DataFrame(data=activation)

In [None]:
plt.plot(df.mean(1).sort_values(ascending=False).values.ravel())
plt.plot([1, df.shape[0]], [0.005, 0.005/df.shape[0]**0.2])
plt.xlim(1, df.shape[0])
plt.yscale("log")
plt.xscale("log")
plt.xlabel("filters rank")
plt.ylabel("mean filter usage")

plt.show()

In [None]:
f = df.mean(1).sort_values(ascending=False)
f

In [None]:
# get labels and fine labels
chosen_set = "flowers" #"people" #"household furniture"

y_test_c, y_test_f = list(
    zip(*map(lambda x: x[0],
    filter(lambda x: x[1][0] == chosen_set, 
            [get_label(y) for y in y_test.ravel()[:n]]))))

c_numbers = np.unique(y_test_c)
f_numbers = np.unique(y_test_f)
print(c_numbers)
print(f_numbers)

y_test_c, y_test_f = list(
    zip(*map(lambda x: x[0],
            [get_label(y) for y in y_test.ravel()[:n]])))

y_test_c = np.array(y_test_c)
y_test_f = np.array(y_test_f)

In [None]:
# consistency check
fig, ax = plt.subplots()

filter_i = f.index[0]

df = pd.DataFrame(data=activation[:, y_test_c.ravel()[:n]==c_numbers[0]])
df.loc[filter_i, :].hist(ax=ax, lw = 5, histtype="step")

df = pd.DataFrame(data=activation[:, np.array(list(map(lambda x: np.isin(x, f_numbers)[0], y_test[:n])))])
df.loc[filter_i, :].hist(ax=ax, lw=5, ls=":", histtype="step")

plt.title("filter {}".format(filter_i))
plt.xlabel("mean filter activation")

plt.show()

In [None]:
def plot_filter(filter_obj = None, filter_idx = 0):
    if filter_obj is None:
        filter_obj = f.index[filter_idx]

    fig, axs = plt.subplots(1, 4, figsize=(21,7))

    df = pd.DataFrame(data=activation)

    bins = np.linspace(0, df.loc[filter_obj, :].max(), 20)

    df.loc[filter_obj, :].hist(ax=axs[0], bins=bins, lw = 5, color="gray", histtype="step", density=True, label="all_data")
    
    y_test_c, y_test_f = list(
    zip(*map(lambda x: x[0],
    filter(lambda x: x[1][0] == "flowers", 
            [get_label(y) for y in y_test.ravel()]))))

    c_numbers = np.unique(y_test_c)
    f_numbers = np.unique(y_test_f)
    
    y_test_c, y_test_f = list(
        zip(*map(lambda x: x[0],
                [get_label(y) for y in y_test.ravel()])))

    y_test_c = np.array(y_test_c)
    y_test_f = np.array(y_test_f)
    
    df = pd.DataFrame(data=activation[:, y_test_c.ravel()[:n]==c_numbers[0]])
    df.loc[filter_obj, :].hist(ax=axs[1], bins=bins, lw = 5, histtype="step", density=True, label=chosen_set)
    
    for finec in f_numbers:
        df = pd.DataFrame(data=activation[:, y_test_f.ravel()[:n]==finec])
        df.loc[filter_obj, :].hist(ax=axs[2], bins=bins, lw=5, ls="-", alpha=0.8, histtype="step", density=True, label=coarse_label[finec])

    y_test_c, y_test_f = list(
    zip(*map(lambda x: x[0],
    filter(lambda x: x[1][0] == "people", 
            [get_label(y) for y in y_test.ravel()]))))

    c_numbers = np.unique(y_test_c)
    f_numbers = np.unique(y_test_f)
    
    y_test_c, y_test_f = list(
    zip(*map(lambda x: x[0],
            [get_label(y) for y in y_test.ravel()])))

    y_test_c = np.array(y_test_c)
    y_test_f = np.array(y_test_f)

    other_idx = 14
    df = pd.DataFrame(data=activation[:, y_test_c.ravel()[:n]==other_idx])
    df.loc[filter_obj, :].hist(ax=axs[1], bins=bins, lw = 5, histtype="step", density=True, label=list(mapping.keys())[other_idx])

    for finec in f_numbers:
        df = pd.DataFrame(data=activation[:, y_test_f.ravel()[:n]==finec])
        df.loc[filter_obj, :].hist(ax=axs[3], bins=bins, lw=5, ls="-", alpha=0.8, histtype="step", density=True, label=coarse_label[finec])

        
    axs[0].set_title("filter {}".format(filter_idx), fontsize=30)
    for ax in axs:
        ax.set_xlabel("mean filter activation", fontsize=25)
        ax.tick_params(labelsize=12, length=5, width=3)
        ax.legend(fontsize=15)
    plt.show()

plot_filter(filter_idx = 44)