Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
71 lines (59 sloc) 3.72 KB
from typing import Callable, Optional, List
import numpy as np
from keras import models, layers
from sklearn import cluster, metrics
from ridurre import base_filter_pruning
class KMeansFilterPruning(base_filter_pruning.BasePruning):
def __init__(self,
clustering_factor: float,
model_compile_fn: Callable[[models.Model], None],
model_finetune_fn: Optional[Callable[[models.Model, int, int], None]],
nb_finetune_epochs: int,
maximum_prune_iterations: int = None,
maximum_pruning_percent: float = 0.9,
nb_trained_for_epochs: int = 0):
super().__init__(pruning_factor=clustering_factor,
model_compile_fn=model_compile_fn,
model_finetune_fn=model_finetune_fn,
nb_finetune_epochs=nb_finetune_epochs,
nb_trained_for_epochs=nb_trained_for_epochs,
maximum_prune_iterations=maximum_prune_iterations,
maximum_pruning_percent=maximum_pruning_percent)
def run_pruning_for_conv2d_layer(self, pruning_factor: float, layer: layers.Conv2D, layer_weight_mtx) -> List[int]:
_, _, _, nb_channels = layer_weight_mtx.shape
# Initialize KMeans
nb_of_clusters, _ = self._calculate_number_of_channels_to_keep(pruning_factor, nb_channels)
kmeans = cluster.KMeans(nb_of_clusters, "k-means++")
# Fit with the flattened weight matrix
# (height, width, input_channels, output_channels) -> (output_channels, flattened features)
layer_weight_mtx_reshaped = layer_weight_mtx.transpose(3, 0, 1, 2).reshape(nb_channels, -1)
# Apply some fuzz to the weights, to avoid duplicates
self._apply_fuzz(layer_weight_mtx_reshaped)
kmeans.fit(layer_weight_mtx_reshaped)
# If a cluster has only a single member, then that should not be pruned
# so that point will always be the closest to the cluster center
closest_point_to_cluster_center_indices = metrics.pairwise_distances_argmin(kmeans.cluster_centers_,
layer_weight_mtx_reshaped)
# Compute filter indices which can be pruned
channel_indices = set(np.arange(len(layer_weight_mtx_reshaped)))
channel_indices_to_keep = set(closest_point_to_cluster_center_indices)
channel_indices_to_prune = list(channel_indices.difference(channel_indices_to_keep))
channel_indices_to_keep = list(channel_indices_to_keep)
if len(channel_indices_to_keep) > nb_of_clusters:
print("Number of selected channels for pruning is less than expected")
diff = len(channel_indices_to_keep) - nb_of_clusters
print("Randomly adding {0} channels for pruning".format(diff))
np.random.shuffle(channel_indices_to_keep)
for i in range(diff):
channel_indices_to_prune.append(channel_indices_to_keep.pop(i))
elif len(channel_indices_to_keep) < nb_of_clusters:
print("Number of selected channels for pruning is greater than expected. Leaving too few channels.")
diff = nb_of_clusters - len(channel_indices_to_keep)
print("Discarding {0} pruneable channels".format(diff))
for i in range(diff):
channel_indices_to_keep.append(channel_indices_to_prune.pop(i))
if len(channel_indices_to_keep) != nb_of_clusters:
raise ValueError(
"Number of clusters {0} is not equal with the selected "
"pruneable channels {1}".format(nb_of_clusters, len(channel_indices_to_prune)))
return channel_indices_to_prune
You can’t perform that action at this time.