# Pruning

In [8]:
from tensorflow.keras import applications as apps # type: ignore
import tensorflow as tf # type: ignore
tf.config.set_visible_devices([], 'GPU')

In [9]:
def save_model(model, filename):
    model.save(f"{filename}.keras")

In [10]:
def weight_pruning(w: tf.Variable, k: float) -> tf.Variable:
    k = tf.cast(
        tf.round(
            tf.cast(tf.size(w), tf.float32) * k
        ), dtype=tf.int32
    )

    w_reshaped = tf.reshape(w, [-1])

    _, indices = tf.nn.top_k(
        tf.negative(tf.abs(w_reshaped)),
        k=k
    )

    mask = tf.tensor_scatter_nd_update(
        tf.ones_like(w_reshaped, dtype=tf.float32),
        tf.reshape(indices, [-1, 1]),
        tf.zeros([k], dtype=tf.float32)
    )

    return w.assign(tf.reshape(w_reshaped * mask, tf.shape(w)))

In [11]:
def unit_pruning(w: tf.Variable, k: float) -> tf.Variable:
    norm = tf.norm(w, axis=0)

    num_cols = tf.cast(tf.shape(w)[1], dtype=tf.float32)
    k = tf.cast(tf.round(num_cols * k), dtype=tf.int32)

    _, col_indices = tf.nn.top_k(
        tf.negative(norm),
        k=k,
        sorted=True
    )

    row_indices = tf.range(tf.shape(w)[0])
    row_indices, col_indices = tf.meshgrid(row_indices, col_indices, indexing='ij')

    indices = tf.stack([tf.reshape(row_indices, [-1]), tf.reshape(col_indices, [-1])], axis=1)

    # Ensure update size matches indices
    num_updates = tf.shape(indices)[0]
    updates = tf.zeros([num_updates], dtype=tf.float32)

    return w.assign(
        tf.tensor_scatter_nd_update(
            w,
            indices,
            updates
        )
    )

In [12]:
model = apps.VGG16(weights='imagenet', include_top=True)
model.trainable = False
save_model(model, 'original_model')

for layer in model.layers:
    if isinstance(layer, tf.keras.layers.Conv2D):
        layer.trainable = True  # Enable training for weight modification
        unit_pruning(layer.kernel, 0.5)
        layer.trainable = False  # Re-freeze the layer after pruning

save_model(model, 'unit_pruned_model')

2025-02-10 00:27:17.682051: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: INVALID_ARGUMENT: Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,3,3,64] updates: [54]


InvalidArgumentError: {{function_node __wrapped__TensorScatterUpdate_device_/job:localhost/replica:0/task:0/device:CPU:0}} Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,3,3,64] updates: [54] [Op:TensorScatterUpdate]