<a href="https://colab.research.google.com/github/lenishu/IPA_using_Densenet/blob/main/testing_v_7pruning_0_gradient_updated_weights_densenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import csv
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Input, BatchNormalization, ReLU, Conv2D, Dense, MaxPool2D, AvgPool2D, GlobalAvgPool2D, Concatenate
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10, cifar100
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import Model
import math
import time
from datetime import datetime

# Mount Google Drive (this will work in Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    GOOGLE_DRIVE_MOUNTED = True
    GOOGLE_DRIVE_BASE_PATH = "/content/drive/My Drive/DenseNet_Pruning_Verification"
    # Create base directory in Google Drive if it doesn't exist
    os.makedirs(GOOGLE_DRIVE_BASE_PATH, exist_ok=True)
    print(f"Google Drive mounted. Files will be saved to {GOOGLE_DRIVE_BASE_PATH}")
except ImportError:
    GOOGLE_DRIVE_MOUNTED = False
    GOOGLE_DRIVE_BASE_PATH = None
    print("Not running in Colab or Google Drive not available. Files will be saved locally.")

# Configuration
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 256
LEARNING_RATE = 0.1
EPOCHS_PER_RUN = 1
prune_percentage= 0.00013514851

# Function to get output directory path (either on Google Drive or locally)
def get_output_dir(local_dir, create=True):
    """
    Get the output directory path (either on Google Drive or locally)

    Parameters:
    local_dir -- Local directory path
    create -- Whether to create the directory if it doesn't exist

    Returns:
    Full path to the directory
    """
    if GOOGLE_DRIVE_MOUNTED:
        # Extract the base name from the local path if it's a full path
        if os.path.isabs(local_dir):
            base_name = os.path.basename(local_dir)
        else:
            base_name = local_dir

        # Construct the full path on Google Drive
        full_path = os.path.join(GOOGLE_DRIVE_BASE_PATH, base_name)
    else:
        full_path = local_dir

    # Create the directory if requested
    if create:
        os.makedirs(full_path, exist_ok=True)
        print(f"Created directory: {full_path}")

    return full_path

# Function to load and preprocess datasets
def load_dataset(dataset_name):
    if dataset_name == 'mnist':
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        # Reshape to add channel dimension and resize to 32x32
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        # Repeat the channel to match the 3-channel input expected by DenseNet
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'fashion_mnist':
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        num_classes = 10
    elif dataset_name == 'cifar100':
        (x_train, y_train), (x_test, y_test) = cifar100.load_data()
        num_classes = 100
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # Normalize data
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Convert class vectors to binary class matrices
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    return (x_train, y_train), (x_test, y_test), num_classes

# Define DenseNet-121 architecture
def bn_rl_conv(x, filters, kernel_size):
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
    return x

def dense_block(tensor, k, reps):
    for _ in range(reps):
        x = bn_rl_conv(tensor, filters=4 * k, kernel_size=1)
        x = bn_rl_conv(x, filters=k, kernel_size=3)
        tensor = Concatenate()([tensor, x])
    return tensor

def transition_layer(x, theta):
    f = int(tf.keras.backend.int_shape(x)[-1] * theta)
    x = bn_rl_conv(x, filters=f, kernel_size=1)
    x = AvgPool2D(pool_size=2, strides=2, padding='same')(x)
    return x

def create_densenet121(input_shape, num_classes, k=32, theta=0.5):
    # DenseNet-121 has repetitions [6, 12, 24, 16]
    repetitions = [6, 12, 24, 16]

    inputs = Input(shape=input_shape)
    x = Conv2D(2 * k, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPool2D(3, strides=2, padding='same')(x)

    for reps in repetitions:
        x = dense_block(x, k, reps)
        x = transition_layer(x, theta)

    x = GlobalAvgPool2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)


# Enhanced MaskWeightsCallback with tracking of weight values, gradients, and updates
# Here's an updated version of the EnhancedMaskWeightsCallback class that captures
# the initial random weights before any training happens

# Here's the complete EnhancedMaskWeightsCallback class with all required methods


# Here's the updated EnhancedMaskWeightsCallback class with renumbered batches
# The main changes are in the save_all_flat_pruned_data_txt method
# Here's the enhanced EnhancedMaskWeightsCallback class with complete weight tracking
# We'll capture the weights before pruning, after pruning, gradients, and after updates

class CompleteWeightTrackingCallback(tf.keras.callbacks.Callback):
    def __init__(self, masks, layer_indices, pruned_indices_per_layer, all_pruned_flat_indices, output_dir):
        super().__init__()
        self.masks = masks
        self.layer_indices = layer_indices
        self.pruned_indices_per_layer = pruned_indices_per_layer
        self.all_pruned_flat_indices = all_pruned_flat_indices
        self.output_dir = output_dir
        self.batch_count = 0
        self.max_batches_to_verify = 4  # Track batches 0, 1, 2, 3 (will be displayed as 1, 2, 3, 4)

        # Store values for each batch
        self.pre_pruning_weights = None  # Will store weights before any pruning
        self.initial_weights = None      # Will store weights after pruning but before any training
        self.weight_values = {}          # Will store weights at the beginning of each batch
        self.gradient_values = {}        # Will store gradients for each batch
        self.updated_weight_values = {}  # Will store weights after update but before re-pruning

        # Create the weights directory
        weights_dir = os.path.join(output_dir, "weights_tracking")
        os.makedirs(weights_dir, exist_ok=True)
        self.weights_dir = weights_dir

    def on_train_begin(self, logs=None):
        # Capture the weights after pruning but before any training
        self.initial_weights = []
        for layer_idx in self.layer_indices:
            layer = self.model.layers[layer_idx]
            weights = layer.get_weights()[0]
            self.initial_weights.append(weights.copy())

        # Save the initial weights data
        self.save_initial_weights_data()

    def on_batch_begin(self, batch, logs=None):
        if self.batch_count < self.max_batches_to_verify:
            # Get all the weight tensors before updates
            weights_before = []
            for layer_idx in self.layer_indices:
                layer = self.model.layers[layer_idx]
                weights_before.append(layer.get_weights()[0])

            # Store them for this batch
            self.weight_values[self.batch_count] = weights_before

    def on_batch_end(self, batch, logs=None):
        if self.batch_count < self.max_batches_to_verify:
            # Get updated weights after the batch
            updated_weights = []
            for layer_idx in self.layer_indices:
                layer = self.model.layers[layer_idx]
                weights = layer.get_weights()[0]
                updated_weights.append(weights.copy())  # Copy before masking

            # Store the updated weights
            self.updated_weight_values[self.batch_count] = updated_weights

            # Try to extract gradients - depends on model setup
            try:
                gradients = []
                # Calculate approximate gradients based on weight changes
                # (updated_weights - original_weights) / learning_rate
                for i, layer_idx in enumerate(self.layer_indices):
                    grad = (updated_weights[i] - self.weight_values[self.batch_count][i]) / LEARNING_RATE
                    gradients.append(grad)

                self.gradient_values[self.batch_count] = gradients
            except:
                print(f"Couldn't extract gradients for batch {self.batch_count}")
                self.gradient_values[self.batch_count] = [None] * len(self.layer_indices)

            # Apply masks after data collection
            for i, layer_idx in enumerate(self.layer_indices):
                layer = self.model.layers[layer_idx]
                weights = layer.get_weights()
                weights[0] = weights[0] * self.masks[i]  # Apply mask
                layer.set_weights(weights)

            # Increment batch counter
            self.batch_count += 1

    def set_pre_pruning_weights(self, weights):
        """
        Set the pre-pruning weights (called from the pruning function).
        """
        self.pre_pruning_weights = weights

    def save_initial_weights_data(self):
        """
        Save the initial weights data before and after pruning.
        """
        filename = "initial_weights_data.txt"
        filepath = os.path.join(self.weights_dir, filename)

        with open(filepath, 'w') as f:
            f.write("Initial Weights Data Before and After Pruning\n")
            f.write("=" * 80 + "\n\n")

            if self.pre_pruning_weights is None:
                f.write("Pre-pruning weights not captured. Please make sure to set them using set_pre_pruning_weights().\n\n")
            else:
                f.write("Flat Index | Pre-Pruning Value | After Pruning Value\n")
                f.write("-" * 70 + "\n")

                # Process each pruned index
                params_so_far = 0
                for i, layer_idx in enumerate(self.layer_indices):
                    layer = self.model.layers[layer_idx]
                    layer_name = layer.name

                    # Get the weights for this layer
                    pre_weights = self.pre_pruning_weights[i]
                    post_weights = self.initial_weights[i]

                    # Flatten the weights for this layer
                    flat_pre_weights = pre_weights.flatten()
                    flat_post_weights = post_weights.flatten()

                    # Write pruned indices for this layer
                    f.write(f"\nLayer {i} ({layer_name}):\n")

                    # Get pruned indices for this layer
                    layer_pruned_indices = self.pruned_indices_per_layer.get(i, [])

                    for idx in layer_pruned_indices:
                        flat_idx = params_so_far + idx
                        pre_val = flat_pre_weights[idx]
                        post_val = flat_post_weights[idx]
                        f.write(f"{flat_idx:10d} | {pre_val:15.10f} | {post_val:15.10f}\n")

                    # Update params counter
                    params_so_far += pre_weights.size

        print(f"Saved initial weights data to {filepath}")

    def save_all_data_to_csv(self):
        """
        Save all tracked data to a comprehensive CSV file for easier analysis.
        """
        csv_path = os.path.join(self.weights_dir, "pruned_weights_tracking.csv")

        with open(csv_path, 'w', newline='') as csv_file:
            writer = csv.writer(csv_file)

            # Header with all states we're tracking
            writer.writerow([
                'batch', 'flat_index', 'layer_index',
                'pre_pruning_value', 'initial_value',
                'weight_before_batch', 'gradient', 'weight_after_update'
            ])

            # Pre-pruning and initial weights (batch 0)
            if self.initial_weights is not None:
                params_so_far = 0
                for i, layer_idx in enumerate(self.layer_indices):
                    post_weights = self.initial_weights[i]
                    flat_post_weights = post_weights.flatten()

                    # Get pruned indices for this layer
                    layer_pruned_indices = self.pruned_indices_per_layer.get(i, [])

                    for idx in layer_pruned_indices:
                        flat_idx = params_so_far + idx
                        init_weight = flat_post_weights[idx]

                        # Get pre-pruning weight if available
                        pre_weight = 'N/A'
                        if self.pre_pruning_weights is not None and i < len(self.pre_pruning_weights):
                            pre_weight = self.pre_pruning_weights[i].flatten()[idx]

                        # Use 0 as batch number to indicate initial state
                        writer.writerow([0, flat_idx, i, pre_weight, init_weight, init_weight, 'N/A', 'N/A'])

                    params_so_far += post_weights.size

            # Process each batch (using batch+1 for the renumbering)
            for batch in range(min(self.batch_count, self.max_batches_to_verify)):
                # Skip if we don't have data for this batch
                if batch not in self.weight_values:
                    continue

                params_so_far = 0

                # Process each layer
                for i, layer_idx in enumerate(self.layer_indices):
                    # Get all the data
                    original_weights = self.weight_values.get(batch, [])[i] if batch in self.weight_values and i < len(self.weight_values[batch]) else None
                    gradients = self.gradient_values.get(batch, [])[i] if batch in self.gradient_values and i < len(self.gradient_values[batch]) else None
                    updated_weights = self.updated_weight_values.get(batch, [])[i] if batch in self.updated_weight_values and i < len(self.updated_weight_values[batch]) else None

                    if original_weights is None or updated_weights is None:
                        continue

                    # Flatten all arrays
                    flat_orig_weights = original_weights.flatten()
                    flat_updated_weights = updated_weights.flatten()
                    flat_gradients = gradients.flatten() if gradients is not None else None

                    # Get pruned indices for this layer
                    layer_pruned_indices = self.pruned_indices_per_layer.get(i, [])

                    for idx in layer_pruned_indices:
                        flat_idx = params_so_far + idx
                        orig_weight = flat_orig_weights[idx]
                        updated_weight = flat_updated_weights[idx]

                        # Get pre-pruning and initial weights
                        pre_weight = 'N/A'
                        if self.pre_pruning_weights is not None and i < len(self.pre_pruning_weights):
                            pre_weight = self.pre_pruning_weights[i].flatten()[idx]

                        init_weight = 'N/A'
                        if self.initial_weights is not None and i < len(self.initial_weights):
                            init_weight = self.initial_weights[i].flatten()[idx]

                        gradient = 'N/A'
                        if flat_gradients is not None:
                            gradient = flat_gradients[idx]

                        # Use batch+1 for display (batch 0 becomes batch 1, etc.)
                        writer.writerow([
                            batch+1, flat_idx, i,
                            pre_weight, init_weight,
                            orig_weight, gradient, updated_weight
                        ])

                    # Update params counter
                    params_so_far += original_weights.size

        print(f"Saved all pruned weights tracking data to {csv_path}")

    def save_all_flat_pruned_data_txt(self):
        """
        Save all batches' pruned data using only `all_pruned_flat_indices` to a single .txt file.
        Includes pre-pruning weights, initial weights, batch weights, gradients, and updated weights.
        """
        txt_path = os.path.join(self.weights_dir, "all_flat_pruned_weights.txt")

        with open(txt_path, 'w') as f:
            f.write("Complete Flat Pruned Indices Tracking\n")
            f.write("=" * 120 + "\n\n")

            # First, add a section for pre-pruning and initial weights
            f.write("SECTION 1: INITIAL WEIGHTS (BEFORE AND AFTER PRUNING)\n")
            f.write("-" * 120 + "\n")

            if self.pre_pruning_weights is not None:
                f.write("Flat Index | Pre-Pruning Value (Original Random) | After Pruning Value\n")
                f.write("-" * 80 + "\n")
            else:
                f.write("Flat Index | After Pruning Value\n")
                f.write("-" * 50 + "\n")

            # Check if we have any pruned indices
            if not self.all_pruned_flat_indices:
                f.write("No pruned indices found. Please check if prune_percentage > 0.\n")
                print(f"Warning: No pruned indices found. Please check if prune_percentage > 0.")
                return

            # Print initial weights for each pruned index
            params_so_far = 0
            for i, layer_idx in enumerate(self.layer_indices):
                if self.initial_weights is None or i >= len(self.initial_weights):
                    continue

                layer = self.model.layers[layer_idx]
                layer_name = layer.name
                init_weights = self.initial_weights[i]
                flat_init_weights = init_weights.flatten()
                size = init_weights.size

                # Get indices in this layer
                layer_indices = []
                for flat_idx in self.all_pruned_flat_indices:
                    if params_so_far <= flat_idx < params_so_far + size:
                        local_idx = flat_idx - params_so_far
                        layer_indices.append((flat_idx, local_idx))

                if layer_indices:
                    f.write(f"\nLayer {i} ({layer_name}):\n")
                    for flat_idx, local_idx in layer_indices:
                        init_val = flat_init_weights[local_idx]

                        if self.pre_pruning_weights is not None and i < len(self.pre_pruning_weights):
                            pre_val = self.pre_pruning_weights[i].flatten()[local_idx]
                            f.write(f"{flat_idx:10d} | {pre_val:30.10f} | {init_val:15.10f}\n")
                        else:
                            f.write(f"{flat_idx:10d} | {init_val:15.10f}\n")

                params_so_far += size

            # Now add a section for batch training tracking
            f.write("\n\nSECTION 2: TRAINING BATCH TRACKING\n")
            f.write("-" * 120 + "\n")
            f.write("Flat Index | Batch | Weight Before | Gradient | Weight After Update\n")
            f.write("-" * 100 + "\n")

            for batch in range(min(self.batch_count, self.max_batches_to_verify)):
                # Skip if we don't have data for this batch
                if batch not in self.weight_values:
                    continue

                # Use batch+1 for display (batch 0 becomes batch 1, etc.)
                f.write(f"\nBatch {batch+1}:\n")
                f.write("-" * 50 + "\n")

                for flat_idx in self.all_pruned_flat_indices:
                    params_so_far = 0
                    for i, layer_idx in enumerate(self.layer_indices):
                        # Skip if we don't have data for this layer
                        if i >= len(self.weight_values.get(batch, [])):
                            continue

                        orig_weights = self.weight_values.get(batch, [])[i]
                        grads = self.gradient_values.get(batch, [])[i]
                        updated_weights = self.updated_weight_values.get(batch, [])[i]

                        if orig_weights is None or updated_weights is None:
                            continue

                        size = orig_weights.size
                        if params_so_far <= flat_idx < params_so_far + size:
                            local_idx = flat_idx - params_so_far
                            orig_val = orig_weights.flatten()[local_idx]

                            # Handle gradient values properly
                            if grads is not None:
                                try:
                                    grad_val = grads.flatten()[local_idx]
                                    grad_str = f"{grad_val:15.10f}"
                                except:
                                    grad_str = "N/A"
                            else:
                                grad_str = "N/A"

                            updated_val = updated_weights.flatten()[local_idx]

                            # Use batch+1 for display (batch 0 becomes batch 1, etc.)
                            f.write(f"{flat_idx:10d} | {batch+1:5d} | {orig_val:15.10f} | {grad_str:15} | {updated_val:15.10f}\n")
                            break  # Move to next flat index
                        params_so_far += size

        print(f"Saved all pruned flat weight tracking to {txt_path}")

# Update the parameter_mask_pruning_with_tracking function to capture pre-pruning weights

def parameter_mask_pruning_with_tracking(model, prune_percentage, seed=None, output_dir="pruning_verification"):
    """
    Randomly select a percentage of weights across the entire network, set them to zero,
    and track their values during training.

    Parameters:
    model -- The Keras model to prune
    prune_percentage -- Percentage of weights to prune (0-100)
    seed -- Random seed for reproducibility
    output_dir -- Directory to save verification files

    Returns:
    The pruned model with a custom weight mask and tracking callback
    """
    # Get the output directory (either on Google Drive or locally)
    output_dir = get_output_dir(output_dir)

    # Set random seed if provided
    if seed is not None:
        np.random.seed(seed)

    # Get all trainable weights in the model
    all_weights = []
    all_shapes = []
    all_layer_indices = []
    layer_info = []  # Store layer name and shape for reference

    for i, layer in enumerate(model.layers):
        if isinstance(layer, (Conv2D, Dense)) and len(layer.weights) > 0:
            # Only consider weight matrices, not biases
            weight = layer.get_weights()[0]
            all_weights.append(weight)
            all_shapes.append(weight.shape)
            all_layer_indices.append(i)
            layer_info.append(f"Layer {i} ({layer.name}): {weight.shape}")

    # Save a copy of the pre-pruning weights
    pre_pruning_weights = [w.copy() for w in all_weights]

    # Count total parameters
    total_params = sum(w.size for w in all_weights)

    # Test mode for tiny pruning percentage
    if prune_percentage < 0.0002:  # Using a threshold slightly higher than the exact percentage
        num_to_prune = 10
        print(f"Test mode: Pruning exactly 10 parameters (approx. {prune_percentage:.10f}% of {total_params:,} parameters)")
    else:
        num_to_prune = int(total_params * prune_percentage / 100)
        print(f"Pruning {num_to_prune:,} parameters ({prune_percentage:.2f}% of {total_params:,} parameters)")

    print(f"Total trainable parameters: {total_params:,}")
    print(f"Parameters to prune: {num_to_prune:,} ({prune_percentage}%)")

    # Create masks for each layer (initially all ones)
    masks = [np.ones_like(w) for w in all_weights]

    # Dictionary to track which indices were set to zero in each layer
    pruned_indices_per_layer = {}

    # Store flat indices for verification
    all_pruned_flat_indices = []

    # If pruning percentage is not 0, create masks and apply them
    if prune_percentage > 0:
        # Randomly select indices to prune across all parameters
        flat_indices = np.random.choice(total_params, size=num_to_prune, replace=False)
        all_pruned_flat_indices = sorted(flat_indices.tolist())

        # Map flat indices back to layer, row, col indices
        params_so_far = 0
        for i, weight in enumerate(all_weights):
            size = weight.size
            # Get indices that fall within this layer
            indices_in_layer = flat_indices[(flat_indices >= params_so_far) &
                                          (flat_indices < params_so_far + size)] - params_so_far

            # Store these indices for verification
            pruned_indices_per_layer[i] = sorted(indices_in_layer.tolist())

            # Flatten the mask, set the selected indices to zero, then reshape back
            flat_mask = masks[i].flatten()
            flat_mask[indices_in_layer] = 0
            masks[i] = flat_mask.reshape(all_shapes[i])

            params_so_far += size

        # Apply masks to each layer's weights
        for i, layer_idx in enumerate(all_layer_indices):
            layer = model.layers[layer_idx]
            weights = layer.get_weights()
            weights[0] = weights[0] * masks[i]  # Apply mask to weights
            layer.set_weights(weights)

    # Write pruned indices to file
    with open(os.path.join(output_dir, "prune-indices_batch-0.txt"), 'w') as f:
        f.write(f"Total parameters: {total_params}\n")
        f.write(f"Total pruned: {num_to_prune}\n\n")
        f.write("Layer information:\n")
        for info in layer_info:
            f.write(f"{info}\n")
        f.write("\nPruned indices per layer:\n")
        for layer_idx, indices in pruned_indices_per_layer.items():
            f.write(f"Layer index {layer_idx}: {len(indices)} pruned indices\n")
            if len(indices) <= 20:  # Only print all indices if there are few
                f.write(f"  {indices}\n")
            else:
                f.write(f"  First 10: {indices[:10]}\n")
                f.write(f"  Last 10: {indices[-10:]}\n")

        f.write("\nAll pruned flat indices:\n")
        # Print all pruned indices
        f.write(f"{all_pruned_flat_indices}\n")

    # Also save the full list to a separate CSV file for easier processing
    csv_path = os.path.join(output_dir, "pruned_indices.csv")
    with open(csv_path, 'w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(['flat_index'])  # Header
        for idx in all_pruned_flat_indices:
            writer.writerow([idx])

    # Attach the enhanced mask callback to the model for weight tracking
    mask_callback = CompleteWeightTrackingCallback(
        masks,
        all_layer_indices,
        pruned_indices_per_layer,
        all_pruned_flat_indices,
        output_dir
    )

    # Set the pre-pruning weights in the callback
    mask_callback.set_pre_pruning_weights(pre_pruning_weights)

    model.mask_callback = mask_callback

    # Reset random seed
    if seed is not None:
        np.random.seed(None)

    return model

#=======================================================================================================================


# Load dataset
(x_train, y_train), (x_test, y_test), num_classes = load_dataset('cifar10')

# Create model
model = create_densenet121(input_shape=(32, 32, 3), num_classes=num_classes)

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE, momentum=0.9),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Apply pruning with enhanced tracking
model = parameter_mask_pruning_with_tracking(model, prune_percentage=prune_percentage, seed=42)

# Custom callback to save all tracking data at the end of training
class SaveTrackingDataCallback(tf.keras.callbacks.Callback):
    def on_train_end(self, logs=None):
        if hasattr(model, 'mask_callback'):
            model.mask_callback.save_all_data_to_csv()
            model.mask_callback.save_all_flat_pruned_data_txt()


x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:100]
y_test = y_test[:100]

# Train model
history = model.fit(
    x_train, y_train,
    batch_size=BATCH_SIZE_TRAIN,
    epochs=EPOCHS_PER_RUN,
    validation_data=(x_test, y_test),
    callbacks=[model.mask_callback, SaveTrackingDataCallback()]
)



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted. Files will be saved to /content/drive/My Drive/DenseNet_Pruning_Verification
Created directory: /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification
Test mode: Pruning exactly 10 parameters (approx. 0.0001351485% of 7,399,616 parameters)
Total trainable parameters: 7,399,616
Parameters to prune: 10 (0.00013514851%)
Saved initial weights data to /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification/weights_tracking/initial_weights_data.txt
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 2s/step - accuracy: 0.1045 - loss: 6.3766 - val_accuracy: 0.1100 - val_loss: 10571005437149184.0000
Saved all pruned weights tracking data to /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification/weights_tracking/pruned_weights_tracking.csv
Saved all pruned flat weight trackin