In [1]:
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.models import load_model
import numpy as np
from keras.models import Model
from keras.layers import Input, Conv2D
from keras.layers import Add, BatchNormalization, Activation
from sklearn.cluster import DBSCAN
from keras.models import Sequential
from keras.layers import Dense
from CGA.cluster_filters import cluster_filters

In [2]:
model = load_model('../Models/NN/model_mnist_renet50_4_categories.h5')
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 dense (Dense)               (None, 4)                 8196      
                                                                 
Total params: 23,595,908
Trainable params: 8,196
Non-trainable params: 23,587,712
_________________________________________________________________


In [3]:
resnet50_layers = model.layers[0].layers
for i, warstwa in enumerate(resnet50_layers):
    print(f'Warstwa ResNet50 {i + 1}: {warstwa.name}, Typ: {warstwa.__class__.__name__}, Shape: {warstwa.output_shape}')


Warstwa ResNet50 1: input_1, Typ: InputLayer, Shape: [(None, None, None, 3)]
Warstwa ResNet50 2: conv1_pad, Typ: ZeroPadding2D, Shape: (None, None, None, 3)
Warstwa ResNet50 3: conv1_conv, Typ: Conv2D, Shape: (None, None, None, 64)
Warstwa ResNet50 4: conv1_bn, Typ: BatchNormalization, Shape: (None, None, None, 64)
Warstwa ResNet50 5: conv1_relu, Typ: Activation, Shape: (None, None, None, 64)
Warstwa ResNet50 6: pool1_pad, Typ: ZeroPadding2D, Shape: (None, None, None, 64)
Warstwa ResNet50 7: pool1_pool, Typ: MaxPooling2D, Shape: (None, None, None, 64)
Warstwa ResNet50 8: conv2_block1_1_conv, Typ: Conv2D, Shape: (None, None, None, 64)
Warstwa ResNet50 9: conv2_block1_1_bn, Typ: BatchNormalization, Shape: (None, None, None, 64)
Warstwa ResNet50 10: conv2_block1_1_relu, Typ: Activation, Shape: (None, None, None, 64)
Warstwa ResNet50 11: conv2_block1_2_conv, Typ: Conv2D, Shape: (None, None, None, 64)
Warstwa ResNet50 12: conv2_block1_2_bn, Typ: BatchNormalization, Shape: (None, None, None,

In [4]:
resnet_model = model.get_layer('resnet50')

output_layer = 'conv4_block1_2_conv'

layer = resnet_model.get_layer(output_layer)
weights = layer.get_weights()[0]
biases = layer.get_weights()[1] 

In [5]:
weights.shape

(3, 3, 256, 256)

In [6]:
biases.shape

(256,)

In [7]:
def prune_filter(original_model, cut_off_layer_name, indexes):
    layer_to_prune = resnet_model.get_layer(cut_off_layer_name)

    all_layers = original_model.layers

    layer_to_prune_index = all_layers.index(layer_to_prune)

    previous_layer = next((layer for layer in all_layers[:layer_to_prune_index][::-1] if layer.name.endswith("out")), None)
    next_layer = next((layer for layer in all_layers[layer_to_prune_index + 1:] if layer.name.endswith("add")), None)
    
    model_prev = Model(inputs=original_model.input, outputs=previous_layer.output)          
    model_to_prune = Model(inputs=original_model.layers[all_layers.index(previous_layer) + 1].input, outputs=next_layer.output)       
    model_next = Model(inputs=original_model.layers[all_layers.index(next_layer) + 1].input, outputs=original_model.output)
    
    # for layer in model_prev.layers:
    #     layer.trainable = False
    #     
    # for layer in model_next.layers:
    #     layer.trainable = False  

    combined_input = Input(shape=(32, 32, 3))

    output_prev = model_prev(combined_input)

    config = layer_to_prune.get_config()
    config['filters'] = len(indexes)

    new_layer = Conv2D(**config)
    input_shape = layer_to_prune.input_shape
    new_layer.build(input_shape)

    weights = layer_to_prune.get_weights()[0][:, :, :, indexes]
    
    if len(layer_to_prune.get_weights()) > 1:
        biases = layer_to_prune.get_weights()[1][indexes]
        new_weights = [weights, biases]
    else:
        new_weights = [weights]
    new_layer.set_weights(new_weights)
    
    input_tensor = Input(shape=model_prev.output_shape)

    x = input_tensor
    
    for layer in model_to_prune.layers:
        if layer.name == cut_off_layer_name:
            output_prev = new_layer(output_prev)
            output_prev = Conv2D(filters=256, kernel_size=(1, 1), name='adaptation_conv')(output_prev)
        elif isinstance(layer, Conv2D):
            new_layer = Conv2D(
                filters=layer.filters,
                kernel_size=layer.kernel_size,
                strides=layer.strides,
                padding=layer.padding,
                activation=None,
                use_bias=layer.use_bias,
                kernel_initializer=layer.kernel_initializer,
                bias_initializer=layer.bias_initializer
            )
            # new_layer.trainable = False
            x = new_layer(x)
        elif isinstance(layer, BatchNormalization):
            new_layer = BatchNormalization()
            # new_layer.trainable = False
            x = new_layer(x)
        elif isinstance(layer, Activation):
            new_layer = Activation(layer.activation)
            # new_layer.trainable = False
            x = new_layer(x)
        elif isinstance(layer, Add):
            index_conv4_block1_add = model_to_prune.layers.index(layer)
            index_conv4_block1_0_bn = index_conv4_block1_add - 2
            index_conv4_block1_3_bn = index_conv4_block1_add - 1
            output_conv4_block1_0_bn = model_to_prune.layers[index_conv4_block1_0_bn].output
            output_conv4_block1_3_bn = model_to_prune.layers[index_conv4_block1_3_bn].output
            x = Add()([output_conv4_block1_0_bn, output_conv4_block1_3_bn]) 
    
    pruned_model = Model(inputs=model_prev.output, outputs=x)

    new_model = Sequential()
    new_model.add(model_prev)
    new_model.add(pruned_model)
    new_model.add(model_next)
    return new_model


In [8]:
weights.shape

(3, 3, 256, 256)

In [9]:
weights_list = weights.reshape(256, -1)
weights_list.shape

(256, 2304)

In [10]:
# from sklearn.cluster import KMeans
# 
# k = 100  # Liczba klastrów
# knn = KMeans(n_clusters=k)
# knn.fit(weights_list)
# klastry = knn.labels_
# 
# indexes = []
# for i in range(k):
#     indeksy_klastra = np.where(klastry == i)[0]
#     losowy_indeks = np.random.choice(indeksy_klastra)
#     indexes.append(losowy_indeks)
# 
# pruned_model = prune_filter(resnet_model, output_layer, indexes)

In [11]:
dbscan = DBSCAN(eps=0.9, min_samples=1)
dbscan.fit(weights_list)


cluster_indices = np.unique(dbscan.labels_)

indexes = []
for cluster_index in cluster_indices:
    cluster_points = np.where(dbscan.labels_ == cluster_index)[0]
    index = np.random.choice(cluster_points)
    indexes.append(index)
    
pruned_model = prune_filter(resnet_model, output_layer, indexes)



In [12]:
prune_input = Input(shape=(32, 32, 3))

# for layer in model.layers[1:]:
#     layer.trainable = False

new_model = Sequential()
new_model.add(prune_input)
new_model.add(pruned_model)
# new_model.add(model.layers[1])
new_model.add(Dense(10, activation='softmax'))

In [13]:
(train_X, train_y), (test_X, test_y) = mnist.load_data()
new_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
train_X = np.expand_dims(train_X, axis=-1)
train_X = np.repeat(train_X, 3, axis=-1)
train_X = np.pad(train_X, ((0, 0), (2, 2), (2, 2), (0, 0)), mode='constant')
train_y = to_categorical(train_y, 10)
new_model.fit(x = train_X, y= train_y, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x2288b859520>

In [14]:
test_X = np.expand_dims(test_X, axis=-1)
test_X = np.repeat(test_X, 3, axis=-1)
test_X = np.pad(test_X, ((0, 0), (2, 2), (2, 2), (0, 0)), mode='constant')
result = new_model.evaluate(test_X, to_categorical(test_y, 10))



In [15]:
new_model.save('../NN/pruned_DBSCAN_model_mnist_renet50_4_columns_10_epoch.h5')

In [16]:
model_final = load_model('../NN/pruned_DBSCAN_model_mnist_renet50_4_columns_10_epoch.h5')
result = model_final.evaluate(test_X, to_categorical(test_y, 10))
print(result)

[0.15806691348552704, 0.9532999992370605]
