<a href="https://colab.research.google.com/github/lenishu/IPA_using_Densenet/blob/main/testing_5_Densenet_121.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
DenseNet-121 Pruning Verification

This script provides tools to verify that parameter mask pruning is working correctly
in a DenseNet-121 model. It tracks pruned weight indices across multiple training batches
to ensure they remain zero throughout training.

The script runs two verification tests:
1. A tiny pruning test (exactly 10 parameters) for easy inspection
2. A medium pruning test (30% of weights) to test at scale

For each test, it saves the pruned indices after initial pruning and after batches 1, 2, and 3,
then compares them to verify that pruning is working correctly.

All verification files are saved to Google Drive.
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import csv
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Input, BatchNormalization, ReLU, Conv2D, Dense, MaxPool2D, AvgPool2D, GlobalAvgPool2D, Concatenate
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10, cifar100
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import Model
import math
import time
from datetime import datetime

# Mount Google Drive (this will work in Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    GOOGLE_DRIVE_MOUNTED = True
    GOOGLE_DRIVE_BASE_PATH = "/content/drive/My Drive/DenseNet_Pruning_Verification"
    # Create base directory in Google Drive if it doesn't exist
    os.makedirs(GOOGLE_DRIVE_BASE_PATH, exist_ok=True)
    print(f"Google Drive mounted. Files will be saved to {GOOGLE_DRIVE_BASE_PATH}")
except ImportError:
    GOOGLE_DRIVE_MOUNTED = False
    GOOGLE_DRIVE_BASE_PATH = None
    print("Not running in Colab or Google Drive not available. Files will be saved locally.")

# Configuration
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 256
LEARNING_RATE = 0.1
EPOCHS_PER_RUN = 1

# Function to get output directory path (either on Google Drive or locally)
def get_output_dir(local_dir, create=True):
    """
    Get the output directory path (either on Google Drive or locally)

    Parameters:
    local_dir -- Local directory path
    create -- Whether to create the directory if it doesn't exist

    Returns:
    Full path to the directory
    """
    if GOOGLE_DRIVE_MOUNTED:
        # Extract the base name from the local path if it's a full path
        if os.path.isabs(local_dir):
            base_name = os.path.basename(local_dir)
        else:
            base_name = local_dir

        # Construct the full path on Google Drive
        full_path = os.path.join(GOOGLE_DRIVE_BASE_PATH, base_name)
    else:
        full_path = local_dir

    # Create the directory if requested
    if create:
        os.makedirs(full_path, exist_ok=True)
        print(f"Created directory: {full_path}")

    return full_path

# Function to load and preprocess datasets
def load_dataset(dataset_name):
    if dataset_name == 'mnist':
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        # Reshape to add channel dimension and resize to 32x32
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        # Repeat the channel to match the 3-channel input expected by DenseNet
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'fashion_mnist':
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        num_classes = 10
    elif dataset_name == 'cifar100':
        (x_train, y_train), (x_test, y_test) = cifar100.load_data()
        num_classes = 100
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # Normalize data
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Convert class vectors to binary class matrices
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    return (x_train, y_train), (x_test, y_test), num_classes

# Define DenseNet-121 architecture
def bn_rl_conv(x, filters, kernel_size):
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
    return x

def dense_block(tensor, k, reps):
    for _ in range(reps):
        x = bn_rl_conv(tensor, filters=4 * k, kernel_size=1)
        x = bn_rl_conv(x, filters=k, kernel_size=3)
        tensor = Concatenate()([tensor, x])
    return tensor

def transition_layer(x, theta):
    f = int(tf.keras.backend.int_shape(x)[-1] * theta)
    x = bn_rl_conv(x, filters=f, kernel_size=1)
    x = AvgPool2D(pool_size=2, strides=2, padding='same')(x)
    return x

def create_densenet121(input_shape, num_classes, k=32, theta=0.5):
    # DenseNet-121 has repetitions [6, 12, 24, 16]
    repetitions = [6, 12, 24, 16]

    inputs = Input(shape=input_shape)
    x = Conv2D(2 * k, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPool2D(3, strides=2, padding='same')(x)

    for reps in repetitions:
        x = dense_block(x, k, reps)
        x = transition_layer(x, theta)

    x = GlobalAvgPool2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)


# Create a custom callback to enforce the mask during training
# Modified MaskWeightsCallback with pre-mask zero indices tracking
class MaskWeightsCallback(tf.keras.callbacks.Callback):
    def __init__(self, masks, layer_indices, pruned_indices_per_layer, all_pruned_flat_indices, output_dir):
        super().__init__()
        self.masks = masks
        self.layer_indices = layer_indices
        self.pruned_indices_per_layer = pruned_indices_per_layer
        self.all_pruned_flat_indices = all_pruned_flat_indices
        self.output_dir = output_dir
        self.batch_count = 0
        self.max_batches_to_verify = 4  # Check up to batch 3

    def on_train_begin(self, logs=None):
        # This is called when the model starts training, ensuring the model is available
        # Save zero indices for batch 0 here
        self.save_all_zero_indices(0)

    def on_batch_end(self, batch, logs=None):
        # Before applying mask, save current zero weights
        self.save_pre_mask_zero_indices(self.batch_count)

        # Apply masks after each batch update
        for i, layer_idx in enumerate(self.layer_indices):
            layer = self.model.layers[layer_idx]
            weights = layer.get_weights()

            # Before applying mask, check if any pruned indices are non-zero
            if self.batch_count < self.max_batches_to_verify:
                # Check actual values at pruned indices
                flat_weights = weights[0].flatten()
                non_zero_pruned = []

                for idx in self.pruned_indices_per_layer.get(i, []):
                    if flat_weights[idx] != 0:
                        non_zero_pruned.append((idx, flat_weights[idx]))

                # Log any violations (pruned weights that became non-zero)
                if non_zero_pruned:
                    print(f"WARNING: Batch {self.batch_count}, Layer {i}: {len(non_zero_pruned)} pruned weights became non-zero")

            # Apply mask (set pruned weights back to zero)
            weights[0] = weights[0] * self.masks[i]
            layer.set_weights(weights)

        # Verify and save pruned indices after specific batches
        # if self.batch_count in [1, 2, 3]:  # Check batches 1, 2, and 3
        #     self.verify_pruning(self.batch_count)
        #     self.save_all_zero_indices(self.batch_count)  # Save all zero indices for this batch

        # self.batch_count += 1

    def save_pre_mask_zero_indices(self, batch_num):
        """Save all indices that are zero BEFORE applying the mask"""
        # Dictionary to store which indices are currently zero
        current_zero_indices_per_layer = {}
        all_current_zero_flat_indices = []

        # Track total parameters for flat index calculation
        params_so_far = 0

        # Check each layer for zero weights before mask is applied
        for i, layer_idx in enumerate(self.layer_indices):
            layer = self.model.layers[layer_idx]
            weights = layer.get_weights()[0]
            flat_weights = weights.flatten()

            # Find all indices that are zero
            zero_indices = np.where(flat_weights == 0)[0]
            current_zero_indices_per_layer[i] = sorted(zero_indices.tolist())

            # Convert to flat indices across the entire model
            flat_zero_indices = [idx + params_so_far for idx in zero_indices]
            all_current_zero_flat_indices.extend(flat_zero_indices)

            params_so_far += flat_weights.size

        all_current_zero_flat_indices = sorted(all_current_zero_flat_indices)

        # Write all zero indices to a separate file
        with open(os.path.join(self.output_dir, f"pre-mask-zero-indices_batch-{batch_num}.txt"), 'w') as f:
            f.write(f"Batch {batch_num} - All Zero Weight Indices BEFORE Mask Application\n")
            f.write(f"=========================================================\n\n")

            # Write summary statistics
            total_zero_indices = len(all_current_zero_flat_indices)
            f.write(f"Total zero weights before mask: {total_zero_indices}\n\n")

            # Write zero indices per layer
            f.write("Zero indices per layer before mask:\n")
            for layer_idx, indices in current_zero_indices_per_layer.items():
                f.write(f"Layer index {layer_idx}: {len(indices)} zero indices\n")

                # For small numbers of indices, write them all
                if len(indices) <= 20:
                    f.write(f"  {indices}\n")
                else:
                    # For larger sets, write first and last 10
                    f.write(f"  First 10: {indices[:10]}\n")
                    f.write(f"  Last 10: {indices[-10:]}\n")

            # Write all flat indices
            f.write("\nAll zero flat indices before mask:\n")
            f.write(f"{all_current_zero_flat_indices}\n")

            # Calculate how many match with originally pruned weights
            original_pruned_set = set(self.all_pruned_flat_indices)
            current_zero_set = set(all_current_zero_flat_indices)
            naturally_zero = current_zero_set - original_pruned_set
            still_pruned = current_zero_set.intersection(original_pruned_set)
            no_longer_zero = original_pruned_set - current_zero_set

            f.write(f"\nSummary statistics:\n")
            f.write(f"Originally pruned weights: {len(self.all_pruned_flat_indices)}\n")
            f.write(f"Currently zero weights before mask: {len(all_current_zero_flat_indices)}\n")
            f.write(f"Originally pruned weights that remained zero naturally: {len(still_pruned)}\n")

            f.write(f"Originally pruned weights that became non-zero (will be re-zeroed by mask): {len(no_longer_zero)}\n")
            f.write(f"Weights that became zero naturally during training: {len(naturally_zero)}\n")

            # Also save the full list to a separate CSV file for easier processing
            csv_path = os.path.join(self.output_dir, f"pre-mask-zero-indices_batch-{batch_num}.csv")
            with open(csv_path, 'w', newline='') as csv_file:
                writer = csv.writer(csv_file)
                writer.writerow(['flat_index', 'is_originally_pruned'])  # Header
                for idx in all_current_zero_flat_indices:
                    is_pruned = 1 if idx in original_pruned_set else 0
                    writer.writerow([idx, is_pruned])

    def save_all_zero_indices(self, batch_num):
        """Save all indices that are currently zero to a separate file"""
        # Dictionary to store which indices are currently zero
        current_zero_indices_per_layer = {}
        all_current_zero_flat_indices = []

        # Track total parameters for flat index calculation
        params_so_far = 0

        # Check each layer for zero weights
        for i, layer_idx in enumerate(self.layer_indices):
            layer = self.model.layers[layer_idx]
            weights = layer.get_weights()[0]
            flat_weights = weights.flatten()

            # Find all indices that are zero
            zero_indices = np.where(flat_weights == 0)[0]
            current_zero_indices_per_layer[i] = sorted(zero_indices.tolist())

            # Convert to flat indices across the entire model
            flat_zero_indices = [idx + params_so_far for idx in zero_indices]
            all_current_zero_flat_indices.extend(flat_zero_indices)

            params_so_far += flat_weights.size

        all_current_zero_flat_indices = sorted(all_current_zero_flat_indices)

        # Write all zero indices to a separate file
        with open(os.path.join(self.output_dir, f"all-zero-indices_batch-{batch_num}.txt"), 'w') as f:
            f.write(f"Batch {batch_num} - All Zero Weight Indices\n")
            f.write(f"===================================\n\n")

            # Write summary statistics
            total_zero_indices = len(all_current_zero_flat_indices)
            f.write(f"Total zero weights: {total_zero_indices}\n\n")

            # Write zero indices per layer
            f.write("Zero indices per layer:\n")
            for layer_idx, indices in current_zero_indices_per_layer.items():
                f.write(f"Layer index {layer_idx}: {len(indices)} zero indices\n")

                # For small numbers of indices, write them all
                if len(indices) <= 20:
                    f.write(f"  {indices}\n")
                else:
                    # For larger sets, write first and last 10
                    f.write(f"  First 10: {indices[:10]}\n")
                    f.write(f"  Last 10: {indices[-10:]}\n")

            # Write all flat indices (not just the first 100)
            f.write("\nAll zero flat indices:\n")
            f.write(f"{all_current_zero_flat_indices}\n")

            # Additional statistics
            orig_pruned_count = len(self.all_pruned_flat_indices)
            new_zero_count = total_zero_indices - orig_pruned_count

            f.write(f"\nSummary statistics:\n")
            f.write(f"Originally pruned weights: {orig_pruned_count}\n")
            f.write(f"Total zero weights: {total_zero_indices}\n")
            f.write(f"Additional zero weights (from training): {new_zero_count if new_zero_count > 0 else 0}\n")

            # Also save the full list to a separate CSV file for easier processing
            csv_path = os.path.join(self.output_dir, f"all-zero-indices_batch-{batch_num}.csv")
            with open(csv_path, 'w', newline='') as csv_file:
                writer = csv.writer(csv_file)
                writer.writerow(['flat_index'])  # Header
                for idx in all_current_zero_flat_indices:
                    writer.writerow([idx])
# Enhanced Parameter Mask Pruning function with verification
def parameter_mask_pruning(model, prune_percentage, seed=None, output_dir="pruning_verification"):
    """
    Randomly select a percentage of weights across the entire network and set them to zero.
    Added functionality to save pruned indices to files for verification.

    Parameters:
    model -- The Keras model to prune
    prune_percentage -- Percentage of weights to prune (0-100)
    seed -- Random seed for reproducibility
    output_dir -- Directory to save verification files

    Returns:
    The pruned model with a custom weight mask
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Set random seed if provided
    if seed is not None:
        np.random.seed(seed)

    # Get all trainable weights in the model
    all_weights = []
    all_shapes = []
    all_layer_indices = []
    layer_info = []  # Store layer name and shape for reference

    for i, layer in enumerate(model.layers):
        if isinstance(layer, (Conv2D, Dense)) and len(layer.weights) > 0:
            # Only consider weight matrices, not biases
            weight = layer.get_weights()[0]
            all_weights.append(weight)
            all_shapes.append(weight.shape)
            all_layer_indices.append(i)
            layer_info.append(f"Layer {i} ({layer.name}): {weight.shape}")

    # Count total parameters
    total_params = sum(w.size for w in all_weights)

    # For tiny test, if percentage is close to 0.00013514851% (which would prune exactly 10 parameters in a 7,399,616 parameter model)
    # Exact percentage to prune 10 parameters out of 7,399,616: 0.00013514851%
    if prune_percentage < 0.0002:  # Using a threshold slightly higher than the exact percentage
        num_to_prune = 10
        print(f"Test mode: Pruning exactly 10 parameters (approx. {prune_percentage:.10f}% of {total_params:,} parameters)")
    else:
        num_to_prune = int(total_params * prune_percentage / 100)
        print(f"Pruning {num_to_prune:,} parameters ({prune_percentage:.2f}% of {total_params:,} parameters)")

    print(f"Total trainable parameters: {total_params:,}")
    print(f"Parameters to prune: {num_to_prune:,} ({prune_percentage}%)")

    # Create masks for each layer (initially all ones)
    masks = [np.ones_like(w) for w in all_weights]

    # Dictionary to track which indices were set to zero in each layer
    pruned_indices_per_layer = {}

    # Store flat indices for verification
    all_pruned_flat_indices = []

    # If pruning percentage is not 0, create masks and apply them
    if prune_percentage > 0:
        # Randomly select indices to prune across all parameters
        flat_indices = np.random.choice(total_params, size=num_to_prune, replace=False)
        all_pruned_flat_indices = sorted(flat_indices.tolist())

        # Map flat indices back to layer, row, col indices
        params_so_far = 0
        for i, weight in enumerate(all_weights):
            size = weight.size
            # Get indices that fall within this layer
            indices_in_layer = flat_indices[(flat_indices >= params_so_far) &
                                          (flat_indices < params_so_far + size)] - params_so_far

            # Store these indices for verification
            pruned_indices_per_layer[i] = sorted(indices_in_layer.tolist())

            # Flatten the mask, set the selected indices to zero, then reshape back
            flat_mask = masks[i].flatten()
            flat_mask[indices_in_layer] = 0
            masks[i] = flat_mask.reshape(all_shapes[i])

            params_so_far += size

        # Apply masks to each layer's weights
        for i, layer_idx in enumerate(all_layer_indices):
            layer = model.layers[layer_idx]
            weights = layer.get_weights()
            weights[0] = weights[0] * masks[i]  # Apply mask to weights
            layer.set_weights(weights)

    # Write pruned indices to file
    with open(os.path.join(output_dir, "prune-indices_batch-0.txt"), 'w') as f:
        f.write(f"Total parameters: {total_params}\n")
        f.write(f"Total pruned: {num_to_prune}\n\n")
        f.write("Layer information:\n")
        for info in layer_info:
            f.write(f"{info}\n")
        f.write("\nPruned indices per layer:\n")
        for layer_idx, indices in pruned_indices_per_layer.items():
            f.write(f"Layer index {layer_idx}: {len(indices)} pruned indices\n")
            if len(indices) <= 20:  # Only print all indices if there are few
                f.write(f"  {indices}\n")
            else:
                f.write(f"  First 10: {indices[:10]}\n")
                f.write(f"  Last 10: {indices[-10:]}\n")

        f.write("\nAll pruned flat indices:\n")
        # Print all pruned indices, not just first/last 10
        f.write(f"{all_pruned_flat_indices}\n")

        # Also save the full list to a separate CSV file for easier processing
        csv_path = os.path.join(output_dir, "pruned_indices.csv")
        with open(csv_path, 'w', newline='') as csv_file:
            writer = csv.writer(csv_file)
            writer.writerow(['flat_index'])  # Header
            for idx in all_pruned_flat_indices:
                writer.writerow([idx])

    # Attach the mask callback to the model for later use
    mask_callback = MaskWeightsCallback(
        masks,
        all_layer_indices,
        pruned_indices_per_layer,
        all_pruned_flat_indices,
        output_dir
    )

    model.mask_callback = mask_callback

    # Reset random seed
    if seed is not None:
        np.random.seed(None)

    return model

# Simple logger class for training metrics
class SimpleLogger(tf.keras.callbacks.Callback):
    def __init__(self, test_data, batch_size, ce_threshold=None, baseline_ce=None, output_dir=None):
        """
        Initialize the Simple Logger

        Parameters:
        test_data -- tuple of (x_test, y_test)
        batch_size -- batch size for evaluation
        ce_threshold -- Threshold for CE (ln(10) or ln(100) depending on dataset)
        baseline_ce -- baseline cross-entropy from unpruned model
        output_dir -- directory to save logs
        """
        super().__init__()
        self.test_x, self.test_y = test_data
        self.batch_size = batch_size
        self.total_batches = 0  # Using total_batches for continuous numbering across epochs
        self.epoch = 0
        self.baseline_ce = baseline_ce
        self.ce_threshold = ce_threshold  # Store the CE threshold (ln(10) or ln(100))
        self.ipa_start_batch = None  # Track when CE first goes below threshold
        self.output_dir = output_dir

        # Create log file
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            self.log_file = os.path.join(output_dir, "training_log.csv")
            with open(self.log_file, 'w') as f:
                f.write("Epoch,Batch,Train_Accuracy,Train_CE,Test_Accuracy,Test_CE,CE_Diff,IPA\n")
        else:
            self.log_file = None

    def on_train_begin(self, logs=None):
        # If baseline CE is not provided, calculate it
        if self.baseline_ce is None:
            print("Calculating baseline CE...")
            _, _, self.baseline_ce = self.model.evaluate(
                self.test_x, self.test_y,
                batch_size=self.batch_size,
                verbose=1
            )
            print(f"Baseline CE (CEo): {self.baseline_ce}")

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
        # No batch number reset - using total_batches for continuous counting

    def on_train_batch_end(self, batch, logs=None):
        self.total_batches += 1

        # Calculate metrics on test data (silently)
        _, test_acc, test_ce = self.model.evaluate(
            self.test_x, self.test_y,
            batch_size=self.batch_size,
            verbose=0
        )

        # Calculate CE difference
        ce_diff = abs(test_ce - self.baseline_ce)

        # Check if CE is below threshold and update ipa_start_batch if first time
        if self.ce_threshold is not None and self.ipa_start_batch is None and test_ce <= self.ce_threshold:
            self.ipa_start_batch = self.total_batches
            print(f"CE dropped below threshold ({self.ce_threshold:.4f}) at batch {self.ipa_start_batch} (CE = {test_ce:.4f})")

        # Calculate IPA
        # If CE is above threshold, IPA is None
        # Otherwise, use the current batch number (not adjusted)
        if self.ce_threshold is not None and test_ce <= self.ce_threshold:
            ipa = ce_diff / (BATCH_SIZE_TRAIN * self.total_batches) if self.total_batches > 0 else 0
            ipa_value = ipa
        else:
            ipa_value = "N/A"

        # Log to file if available
        if self.log_file:
            with open(self.log_file, 'a') as f:
                f.write(f"{self.epoch},{self.total_batches},{logs.get('accuracy', 0)},{logs.get('loss', 0)},{test_acc},{test_ce},{ce_diff},{ipa_value}\n")

        # Print progress
        if self.total_batches % 10 == 0:
            print(f"Batch {self.total_batches}: Train Acc={logs.get('accuracy', 0):.4f}, Test Acc={test_acc:.4f}, Test CE={test_ce:.4f}")

    def get_final_metrics(self):
        """Return the final IPA and test accuracy"""
        # Calculate final metrics
        _, test_acc, test_ce = self.model.evaluate(
            self.test_x, self.test_y,
            batch_size=self.batch_size,
            verbose=0
        )

        ce_diff = abs(test_ce - self.baseline_ce)

        # Check if CE is below threshold
        if self.ce_threshold is not None and test_ce <= self.ce_threshold:
            ipa = ce_diff / (BATCH_SIZE_TRAIN * self.total_batches) if self.total_batches > 0 else 0
            return {
                'Final Test Accuracy': test_acc,
                'IPA': ipa * 1000,  # Scale by 1000
                'CE Below Threshold': True
            }
        else:
            return {
                'Final Test Accuracy': test_acc,
                'IPA': 0,  # IPA is 0 if CE is above threshold
                'CE Below Threshold': False
            }

# Function to run a single experiment with a specific pruning percentage and seed
def run_single_experiment(dataset_name, pruning_method, prune_percentage, run_number, baseline_ce,
                         x_train, y_train, x_test, y_test, num_classes, output_dir="pruning_verification"):
    """
    Run a single experiment with a specific pruning percentage and seed.

    Parameters:
    dataset_name -- Name of the dataset
    pruning_method -- Pruning method to use
    prune_percentage -- Percentage to prune
    run_number -- Run number (for setting seed)
    baseline_ce -- Baseline cross-entropy
    x_train, y_train, x_test, y_test -- Training and test data
    num_classes -- Number of classes
    output_dir -- Directory to save verification files

    Returns:
    Dictionary with results
    """
    # Get directory path (either on Google Drive or locally)
    verification_dir = get_output_dir(os.path.join(output_dir, f"P{prune_percentage}_run{run_number}"))

    print(f"\nRunning experiment: {dataset_name}, {pruning_method}, P% = {prune_percentage}, Run #{run_number}")
    print(f"Verification files will be saved to: {verification_dir}")

    # Use run number as seed for reproducibility
    seed = 42 + run_number
    tf.random.set_seed(seed)

    # Define CE threshold based on dataset
    if dataset_name == 'cifar100':
        ce_threshold = math.log(100)  # ln(100) for CIFAR-100
        print(f"Using CE threshold ln(100) = {ce_threshold:.4f} for CIFAR-100")
    else:
        ce_threshold = math.log(10)  # ln(10) for other datasets
        print(f"Using CE threshold ln(10) = {ce_threshold:.4f} for {dataset_name}")

    # Create model
    model = create_densenet121(input_shape=x_train.shape[1:], num_classes=num_classes)

    # Apply parameter mask pruning with seed
    model = parameter_mask_pruning(model, prune_percentage, seed=seed, output_dir=verification_dir)

    # Store the mask callback for use during training
    mask_callback = model.mask_callback

    # Compile model
    # Compile model
    optimizer = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'categorical_crossentropy']
    )

    # Setup logger
    simple_logger = SimpleLogger(
        test_data=(x_test, y_test),
        batch_size=BATCH_SIZE_TEST,
        ce_threshold=ce_threshold,
        baseline_ce=baseline_ce,
        output_dir=verification_dir
    )

    # Setup callbacks
    callbacks = [simple_logger, mask_callback]

    # Train for specified number of epochs
    history = model.fit(
        x_train, y_train,
        batch_size=BATCH_SIZE_TRAIN,
        validation_batch_size=BATCH_SIZE_TEST,
        epochs=EPOCHS_PER_RUN,
        validation_data=(x_test, y_test),
        callbacks=callbacks,
        verbose=0  # Less verbose output
    )

    # Get final metrics
    final_metrics = simple_logger.get_final_metrics()

    # Clear memory
    tf.keras.backend.clear_session()

    # Create a report file to summarize verification results
    with open(os.path.join(verification_dir, "verification_summary.txt"), 'w') as f:
        f.write(f"Pruning Verification Summary\n")
        f.write(f"==========================\n")
        f.write(f"Dataset: {dataset_name}\n")
        f.write(f"Pruning Method: {pruning_method}\n")
        f.write(f"Pruning Percentage: {prune_percentage}%\n")
        f.write(f"Run Number: {run_number}\n\n")

        # Compare files using diff-like approach
        f.write("Verification Results:\n")
        try:
            with open(os.path.join(verification_dir, "prune-indices_batch-0.txt"), 'r') as f0, \
                 open(os.path.join(verification_dir, "prune-indices_batch-1.txt"), 'r') as f1, \
                 open(os.path.join(verification_dir, "prune-indices_batch-2.txt"), 'r') as f2, \
                 open(os.path.join(verification_dir, "prune-indices_batch-3.txt"), 'r') as f3:

                batch0 = f0.read()
                batch1 = f1.read()
                batch2 = f2.read()
                batch3 = f3.read()

                if "All pruned indices are still zero" in batch1 and \
                   "All pruned indices are still zero" in batch2 and \
                   "All pruned indices are still zero" in batch3:
                    f.write("VERIFICATION SUCCESSFUL: All pruned indices remained zero across all batches.\n")
                    print("VERIFICATION SUCCESSFUL: All pruned indices remained zero across all batches.")
                else:
                    if "All pruned indices are still zero" not in batch1:
                        f.write("VERIFICATION FAILED: Some pruned indices did not remain zero in batch 1.\n")
                    if "All pruned indices are still zero" not in batch2:
                        f.write("VERIFICATION FAILED: Some pruned indices did not remain zero in batch 2.\n")
                    if "All pruned indices are still zero" not in batch3:
                        f.write("VERIFICATION FAILED: Some pruned indices did not remain zero in batch 3.\n")
                    print("VERIFICATION FAILED: Some pruned indices did not remain zero. Check log files.")
        except Exception as e:
            f.write(f"Error during verification: {str(e)}\n")
            print(f"Error during verification: {str(e)}")

        # Add analysis of all zero indices
        f.write("\nAnalysis of All Zero Weights:\n")
        try:
            zero_counts = {}
            for batch in [0, 1, 2, 3]:
                zero_file = os.path.join(verification_dir, f"all-zero-indices_batch-{batch}.txt")
                if os.path.exists(zero_file):
                    with open(zero_file, 'r') as zf:
                        for line in zf:
                            if "Total zero weights:" in line:
                                count = int(line.split(":")[1].strip())
                                zero_counts[batch] = count
                                f.write(f"  Batch {batch}: {count} zero weights\n")
                                break

            # Calculate changes between batches
            if len(zero_counts) > 1:
                f.write("\nChanges in zero weights between batches:\n")
                batches = sorted(zero_counts.keys())
                for i in range(len(batches)-1):
                    curr_batch = batches[i]
                    next_batch = batches[i+1]
                    diff = zero_counts[next_batch] - zero_counts[curr_batch]
                    if diff > 0:
                        f.write(f"  Batch {curr_batch} → {next_batch}: +{diff} weights became zero\n")
                    elif diff < 0:
                        f.write(f"  Batch {curr_batch} → {next_batch}: {diff} zero weights became non-zero (unexpected!)\n")
                    else:
                        f.write(f"  Batch {curr_batch} → {next_batch}: No change in zero weights\n")
        except Exception as e:
            f.write(f"Error analyzing zero weights: {str(e)}\n")

    result = {
        'Pruning Percentage': prune_percentage,
        'Run Number': run_number,
        'Final Test Accuracy': final_metrics['Final Test Accuracy'],
        'IPA': final_metrics['IPA'],
        'CE Below Threshold': final_metrics['CE Below Threshold'],
        'Verification': "pending"  # You can update this based on verification results
    }

    threshold_status = "CE below threshold" if final_metrics['CE Below Threshold'] else "CE above threshold (IPA set to 0)"
    print(f"P{prune_percentage} Run #{run_number} - Test Accuracy: {result['Final Test Accuracy']:.4f}, IPA × 1000: {result['IPA']:.6f}, {threshold_status}")

    return result

# Script to extract and compare pruned indices from files
def compare_pruned_indices_files(dir_path, batch_numbers=[0, 1, 2, 3]):
    """
    Extract pruned indices from batch files and compare them to verify pruning.

    Parameters:
    dir_path -- Directory containing the pruned indices files
    batch_numbers -- List of batch numbers to compare

    Returns:
    Dictionary with comparison results
    """
    import re

    print(f"\nComparing pruned indices in {dir_path}")

    results = {
        'all_matching': True,
        'comparisons': [],
        'zero_indices_analysis': {}
    }

    # Read the reference file (batch 0)
    ref_path = os.path.join(dir_path, f"prune-indices_batch-{batch_numbers[0]}.txt")
    if not os.path.exists(ref_path):
        print(f"Reference file {ref_path} not found!")
        return {'all_matching': False, 'error': f"Reference file not found: {ref_path}"}

    with open(ref_path, 'r') as f:
        ref_content = f.read()

    # Extract pruned indices from reference file
    pruned_lines = []
    in_pruned_section = False
    for line in ref_content.split('\n'):
        if 'Pruned indices per layer' in line:
            in_pruned_section = True
            continue
        if in_pruned_section and 'All pruned flat indices' in line:
            in_pruned_section = False
        if in_pruned_section and 'Layer index' in line and 'pruned indices' in line:
            pruned_lines.append(line)
        if in_pruned_section and any(x in line for x in ['First 10', 'Last 10', '[']):
            pruned_lines.append(line)

    # Analyze all zero indices files
    for batch in batch_numbers:
        zero_path = os.path.join(dir_path, f"all-zero-indices_batch-{batch}.txt")
        if os.path.exists(zero_path):
            print(f"Analyzing all zero indices for batch {batch}...")

            with open(zero_path, 'r') as f:
                zero_content = f.read()

            # Extract total zero indices
            total_zeros_match = re.search(r"Total zero weights: (\d+)", zero_content)
            if total_zeros_match:
                total_zeros = int(total_zeros_match.group(1))
                results['zero_indices_analysis'][batch] = {
                    'total_zero_weights': total_zeros
                }
                print(f"  Batch {batch}: {total_zeros} total zero weights")

    # For each batch file, extract and compare with reference
    for batch in batch_numbers[1:]:
        batch_path = os.path.join(dir_path, f"prune-indices_batch-{batch}.txt")
        if not os.path.exists(batch_path):
            print(f"Batch file {batch_path} not found!")
            results['all_matching'] = False
            results['comparisons'].append({
                'batch': batch,
                'matching': False,
                'error': f"File not found: {batch_path}"
            })
            continue

        with open(batch_path, 'r') as f:
            batch_content = f.read()

        # Check if the "All pruned indices are still zero" message is present
        if "All pruned indices are still zero" in batch_content:
            print(f"Batch {batch}: All pruned indices are still zero - VERIFICATION PASSED")
            results['comparisons'].append({
                'batch': batch,
                'matching': True,
                'message': "All pruned indices are still zero"
            })
        else:
            # If the success message isn't found, check for error messages
            error_lines = []
            for line in batch_content.split('\n'):
                if 'ERROR:' in line:
                    error_lines.append(line)

            if error_lines:
                print(f"Batch {batch}: Verification FAILED - some pruned indices are no longer zero")
                for line in error_lines:
                    print(f"  {line}")
                results['all_matching'] = False
                results['comparisons'].append({
                    'batch': batch,
                    'matching': False,
                    'errors': error_lines
                })
            else:
                print(f"Batch {batch}: Inconclusive - no clear success or error messages")
                results['all_matching'] = False
                results['comparisons'].append({
                    'batch': batch,
                    'matching': False,
                    'message': "Verification inconclusive"
                })

    # Compare zero indices across batches
    if len(results['zero_indices_analysis']) > 1:
        print("\nZero indices comparison across batches:")
        batches = sorted(results['zero_indices_analysis'].keys())
        for i in range(len(batches)-1):
            current = batches[i]
            next_batch = batches[i+1]
            current_zeros = results['zero_indices_analysis'][current]['total_zero_weights']
            next_zeros = results['zero_indices_analysis'][next_batch]['total_zero_weights']
            diff = next_zeros - current_zeros

            if diff > 0:
                print(f"  Batch {current} → {next_batch}: +{diff} new zero weights")
            elif diff < 0:
                print(f"  Batch {current} → {next_batch}: {diff} fewer zero weights (unexpected!)")
            else:
                print(f"  Batch {current} → {next_batch}: No change in zero weights")

    return results

# Additional utility to create Linux shell script for batch comparison
def create_comparison_script(output_dir):
    """
    Create a shell script to compare pruned indices files on Linux.

    Parameters:
    output_dir -- Directory where verification tests were run
    """
    script_path = os.path.join(output_dir, "compare_indices.sh")

    with open(script_path, 'w') as f:
        f.write("#!/bin/bash\n\n")
        f.write("# Script to compare pruned indices files for DenseNet pruning verification\n\n")

        # Tiny pruning comparison
        f.write("echo '===== Comparing Tiny Pruning (10 parameters) ====='\n")
        f.write("cd tiny_pruning/P0.00013514851_run1\n")
        f.write("echo 'Comparing batch 0 with batch 1:'\n")
        f.write("diff -y prune-indices_batch-0.txt prune-indices_batch-1.txt | grep -v 'NOTE:' | grep 'ERROR\\|missing'\n")
        f.write("echo 'Comparing batch 0 with batch 2:'\n")
        f.write("diff -y prune-indices_batch-0.txt prune-indices_batch-2.txt | grep -v 'NOTE:' | grep 'ERROR\\|missing'\n")
        f.write("echo 'Comparing batch 0 with batch 3:'\n")
        f.write("diff -y prune-indices_batch-0.txt prune-indices_batch-3.txt | grep -v 'NOTE:' | grep 'ERROR\\|missing'\n")
        f.write("cd ../..\n\n")

        # Medium pruning comparison
        f.write("echo '===== Comparing Medium Pruning (30%) ====='\n")
        f.write("cd medium_pruning/P30_run1\n")
        f.write("echo 'Checking batch 1 verification result:'\n")
        f.write("grep 'All pruned indices are still zero\\|ERROR' prune-indices_batch-1.txt\n")
        f.write("echo 'Checking batch 2 verification result:'\n")
        f.write("grep 'All pruned indices are still zero\\|ERROR' prune-indices_batch-2.txt\n")
        f.write("echo 'Checking batch 3 verification result:'\n")
        f.write("grep 'All pruned indices are still zero\\|ERROR' prune-indices_batch-3.txt\n")
        f.write("cd ../..\n\n")

        # Add comparison of all zero weights files
        f.write("echo '===== Analyzing All Zero Weights ====='\n")
        f.write("echo 'Tiny Pruning - Zero Weight Counts:'\n")
        f.write("cd tiny_pruning/P0.00013514851_run1\n")
        f.write("echo 'Batch 0:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-0.txt\n")
        f.write("echo 'Batch 1:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-1.txt\n")
        f.write("echo 'Batch 2:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-2.txt\n")
        f.write("echo 'Batch 3:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-3.txt\n")
        f.write("cd ../..\n\n")

        f.write("echo 'Medium Pruning - Zero Weight Counts:'\n")
        f.write("cd medium_pruning/P30_run1\n")
        f.write("echo 'Batch 0:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-0.txt\n")
        f.write("echo 'Batch 1:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-1.txt\n")
        f.write("echo 'Batch 2:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-2.txt\n")
        f.write("echo 'Batch 3:'\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-3.txt\n")
        f.write("cd ../..\n")

    # Make the script executable
    os.chmod(script_path, 0o755)

    print(f"Created comparison shell script: {script_path}")
    if GOOGLE_DRIVE_MOUNTED:
        print("This script is available in your Google Drive for download and use on a Linux machine.")
    else:
        print("Upload this script along with the verification directories to a Linux machine")
        print("and run it to quickly check if pruning is working correctly.")

# Function to run the pruning verification tests
def run_pruning_verification():
    """Run focused tests to verify that pruning is working correctly"""
    print("Running DenseNet-121 pruning verification tests")

    # Create output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir_name = f"pruning_verification_{timestamp}"
    output_dir = get_output_dir(output_dir_name)

    print(f"Creating base output directory: {output_dir}")

    # Create subdirectories for tiny and medium pruning
    tiny_pruning_dir = get_output_dir(os.path.join(output_dir_name, "tiny_pruning"))
    medium_pruning_dir = get_output_dir(os.path.join(output_dir_name, "medium_pruning"))

    # Load a dataset (using MNIST for quick testing)
    dataset_name = 'mnist'
    print(f"\nLoading {dataset_name} dataset for verification...")
    (x_train, y_train), (x_test, y_test), num_classes = load_dataset(dataset_name)

    # Reduce dataset size for faster testing
    x_train = x_train[:1000]
    y_train = y_train[:1000]
    x_test = x_test[:100]
    y_test = y_test[:100]

    # Create baseline model once
    print("\nCreating baseline model to calculate baseline CE...")
    baseline_model = create_densenet121(input_shape=x_train.shape[1:], num_classes=num_classes)
    optimizer = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)
    baseline_model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'categorical_crossentropy']
    )
    _, _, baseline_ce = baseline_model.evaluate(x_test, y_test, batch_size=BATCH_SIZE_TEST, verbose=1)
    print(f"Baseline CE (CEo): {baseline_ce}")
    tf.keras.backend.clear_session()

    # Test 1: Tiny pruning (exactly 10 parameters)
    print("\n===== TEST 1: Tiny pruning (exactly 10 parameters) =====")
    run_single_experiment(
        dataset_name=dataset_name,
        pruning_method='parameter_mask',
        prune_percentage=0.00013514851,  # Exact percentage to prune 10 parameters out of 7,399,616
        run_number=1,
        baseline_ce=baseline_ce,
        x_train=x_train,
        y_train=y_train,
        x_test=x_test,
        y_test=y_test,
        num_classes=num_classes,
        output_dir=tiny_pruning_dir
    )

    # Test 2: Medium pruning (30%)
    print("\n===== TEST 2: Medium pruning (30%) =====")
    run_single_experiment(
        dataset_name=dataset_name,
        pruning_method='parameter_mask',
        prune_percentage=30,
        run_number=1,
        baseline_ce=baseline_ce,
        x_train=x_train,
        y_train=y_train,
        x_test=x_test,
        y_test=y_test,
        num_classes=num_classes,
        output_dir=medium_pruning_dir
    )

    # Create a summary of verification results
    with open(os.path.join(output_dir, "verification_summary.txt"), 'w') as f:
        f.write("DenseNet-121 Pruning Verification Summary\n")
        f.write("=======================================\n\n")
        f.write("Two pruning tests were conducted:\n")
        f.write("1. Tiny pruning: exactly 10 parameters pruned (0.00013514851% of 7,399,616 total parameters)\n")
        f.write("2. Medium pruning: 30% of parameters pruned\n\n")
        f.write("For each test, pruned indices were tracked across batches 0 (initial), 1, 2, and 3.\n")
        f.write("See the subdirectories for detailed verification results.\n\n")
        f.write("Instructions for comparing files with Linux 'diff' command:\n")
        f.write("cd tiny_pruning/P0.00013514851_run1\n")
        f.write("diff -y prune-indices_batch-0.txt prune-indices_batch-1.txt | grep -v 'NOTE:'\n")
        f.write("diff -y prune-indices_batch-0.txt prune-indices_batch-2.txt | grep -v 'NOTE:'\n")
        f.write("diff -y prune-indices_batch-0.txt prune-indices_batch-3.txt | grep -v 'NOTE:'\n\n")
        f.write("cd ../../medium_pruning/P30_run1\n")
        f.write("diff prune-indices_batch-0.txt prune-indices_batch-1.txt | grep -v 'NOTE:'\n")
        f.write("diff prune-indices_batch-0.txt prune-indices_batch-2.txt | grep -v 'NOTE:'\n")
        f.write("diff prune-indices_batch-0.txt prune-indices_batch-3.txt | grep -v 'NOTE:'\n\n")

        f.write("To analyze ALL zero weights (not just pruned ones):\n")
        f.write("cd tiny_pruning/P0.00013514851_run1\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-*.txt\n")
        f.write("cd ../../medium_pruning/P30_run1\n")
        f.write("grep 'Total zero weights:' all-zero-indices_batch-*.txt\n")

    print(f"\nVerification tests complete! Results saved to {output_dir}")
    if GOOGLE_DRIVE_MOUNTED:
        print(f"Results are available in your Google Drive at: {GOOGLE_DRIVE_BASE_PATH}")
    print("Check the verification summary and individual test directories for detailed results.")

    # Create comparison script
    create_comparison_script(output_dir)

    return output_dir

# Function to find verification files in the filesystem
def find_verification_files():
    """Utility to find verification files if they're not where expected"""
    print("\nSearching for verification files...")

    # List of patterns to look for
    patterns = ['prune-indices_batch', 'all-zero-indices_batch', 'pruned_indices.csv']
    found_files = []

    if GOOGLE_DRIVE_MOUNTED:
        # Search Google Drive
        search_root = '/content/drive/My Drive'
        print(f"Searching in Google Drive: {search_root}")

        try:
            for root, dirs, files in os.walk(search_root):
                for file in files:
                    if any(pattern in file for pattern in patterns):
                        found_files.append(os.path.join(root, file))
                        # Limit search to avoid taking too long
                        if len(found_files) > 100:
                            break
                if len(found_files) > 100:
                    break
        except Exception as e:
            print(f"Error searching Google Drive: {str(e)}")

    # Also search locally
    local_search_root = '.'
    print(f"Searching locally: {local_search_root}")

    for root, dirs, files in os.walk(local_search_root):
        for file in files:
            if any(pattern in file for pattern in patterns):
                found_files.append(os.path.join(root, file))
                # Limit search to avoid taking too long
                if len(found_files) > 100:
                    break
        if len(found_files) > 100:
            break

    # Report findings
    if found_files:
        print(f"Found {len(found_files)} verification-related files.")
        print("First 10 files found:")
        for i, file in enumerate(found_files[:10]):
            print(f"  {i+1}. {file}")
        if len(found_files) > 10:
            print(f"  ... and {len(found_files) - 10} more files.")
    else:
        print("No verification files found.")

    return found_files

# A more comprehensive verification function that both runs the tests and analyzes results
def verify_densenet_pruning_comprehensive():
    """
    Run the pruning verification tests and analyze the results
    """
    # Run the verification tests
    output_dir = run_pruning_verification()

    # Fix for "pruning is working correctly" warning - we see the warning but the mask
    # callback is re-applying the zeros, so the pruning is actually working
    print("\n===== NOTE ABOUT WARNINGS =====")
    print("You may see warnings about 'pruned weights became non-zero'. This is normal and expected.")
    print("These warnings indicate that gradient updates tried to change pruned weights,")
    print("but our MaskWeightsCallback detected this and re-applied the mask to set them back to zero.")
    print("The verification files will confirm that all pruned weights remained zero after the callback ran.\n")

    # Run the comparison analysis
    tiny_pruning_dir = os.path.join(output_dir, "tiny_pruning", "P0.00013514851_run1")
    medium_pruning_dir = os.path.join(output_dir, "medium_pruning", "P30_run1")

    tiny_results = compare_pruned_indices_files(tiny_pruning_dir)
    # medium_results = compare_pruned_indices_files(medium_pruning_dir)

    # Create a final verification report
    report_path = os.path.join(output_dir, "pruning_verification_report.txt")
    with open(report_path, 'w') as f:
        f.write("DenseNet Pruning Verification Report\n")
        f.write("==================================\n\n")

        f.write("TINY PRUNING TEST (10 parameters, 0.00013514851% of 7,399,616 total parameters)\n")
        f.write("------------------------------------------------------\n")
        if tiny_results['all_matching']:
            f.write("PASS: All pruned indices remain zero across all batch updates\n")
        else:
            f.write("FAIL: Some pruned indices changed during training\n")
            for comp in tiny_results.get('comparisons', []):
                if not comp.get('matching', False):
                    f.write(f"  Batch {comp.get('batch')}: Verification failed\n")
                    for error in comp.get('errors', []):
                        f.write(f"    {error}\n")

        # Add analysis of zero indices across batches for tiny pruning
        if 'zero_indices_analysis' in tiny_results and len(tiny_results['zero_indices_analysis']) > 0:
            f.write("\nAnalysis of ALL zero weights across batches:\n")
            for batch, data in sorted(tiny_results['zero_indices_analysis'].items()):
                f.write(f"  Batch {batch}: {data['total_zero_weights']} total zero weights\n")

            # Add trends
            batches = sorted(tiny_results['zero_indices_analysis'].keys())
            if len(batches) > 1:
                f.write("\nTrends in zero weights:\n")
                for i in range(len(batches)-1):
                    curr_batch = batches[i]
                    next_batch = batches[i+1]
                    curr_zeros = tiny_results['zero_indices_analysis'][curr_batch]['total_zero_weights']
                    next_zeros = tiny_results['zero_indices_analysis'][next_batch]['total_zero_weights']
                    diff = next_zeros - curr_zeros

                    if diff > 0:
                        f.write(f"  Batch {curr_batch} → {next_batch}: +{diff} weights became zero\n")
                    elif diff < 0:
                        f.write(f"  Batch {curr_batch} → {next_batch}: {diff} zero weights became non-zero\n")
                    else:
                        f.write(f"  Batch {curr_batch} → {next_batch}: No change in zero weights\n")
        f.write("\n")

        # f.write("MEDIUM PRUNING TEST (30%)\n")
        # f.write("-------------------------\n")
        # if medium_results['all_matching']:
        #     f.write("PASS: All pruned indices remain zero across all batch updates\n")
        # else:
        #     f.write("FAIL: Some pruned indices changed during training\n")
        #     for comp in medium_results.get('comparisons', []):
        #         if not comp.get('matching', False):
        #             f.write(f"  Batch {comp.get('batch')}: Verification failed\n")
        #             for error in comp.get('errors', []):
        #                 f.write(f"    {error}\n")

        # # Add analysis of zero indices across batches for medium pruning
        # if 'zero_indices_analysis' in medium_results and len(medium_results['zero_indices_analysis']) > 0:
        #     f.write("\nAnalysis of ALL zero weights across batches:\n")
        #     for batch, data in sorted(medium_results['zero_indices_analysis'].items()):
        #         f.write(f"  Batch {batch}: {data['total_zero_weights']} total zero weights\n")

        #     # Add trends
        #     batches = sorted(medium_results['zero_indices_analysis'].keys())
        #     if len(batches) > 1:
        #         f.write("\nTrends in zero weights:\n")
        #         for i in range(len(batches)-1):
        #             curr_batch = batches[i]
        #             next_batch = batches[i+1]
        #             curr_zeros = medium_results['zero_indices_analysis'][curr_batch]['total_zero_weights']
        #             next_zeros = medium_results['zero_indices_analysis'][next_batch]['total_zero_weights']
        #             diff = next_zeros - curr_zeros

        #             if diff > 0:
        #                 f.write(f"  Batch {curr_batch} → {next_batch}: +{diff} weights became zero\n")
        #             elif diff < 0:
        #                 f.write(f"  Batch {curr_batch} → {next_batch}: {diff} zero weights became non-zero\n")
        #             else:
        #                 f.write(f"  Batch {curr_batch} → {next_batch}: No change in zero weights\n")
        #                 f.write("\n")

        # f.write("CONCLUSION\n")
        # f.write("----------\n")
        # if tiny_results['all_matching'] and medium_results['all_matching']:
        #     f.write("The DenseNet pruning implementation is working correctly.\n")
        #     f.write("All pruned weights remain zero during training as expected.\n")
        #     f.write("\nNote: You may have seen warnings about 'pruned weights became non-zero'.\n")
        #     f.write("This is normal and expected. These warnings indicate that gradient updates\n")
        #     f.write("tried to change pruned weights, but our MaskWeightsCallback detected this\n")
        #     f.write("and re-applied the mask to set them back to zero. The verification files\n")
        #     f.write("confirm that all pruned weights remained zero after the callback ran.\n")

        #     # Add insights from zero weights analysis
        #     f.write("\nADDITIONAL INSIGHTS FROM ALL-ZERO-INDICES ANALYSIS:\n")
        #     if 'zero_indices_analysis' in tiny_results and 'zero_indices_analysis' in medium_results:
        #         # Check if number of zero weights increases during training
        #         tiny_batches = sorted(tiny_results['zero_indices_analysis'].keys())
        #         if len(tiny_batches) > 1:
        #             first = tiny_batches[0]
        #             last = tiny_batches[-1]
        #             if tiny_results['zero_indices_analysis'][last]['total_zero_weights'] > tiny_results['zero_indices_analysis'][first]['total_zero_weights']:
        #                 f.write("- In the tiny pruning test, the number of zero weights increased during training.\n")
        #                 f.write("  This suggests that the training process is naturally driving some weights to zero.\n")

        #         medium_batches = sorted(medium_results['zero_indices_analysis'].keys())
        #         if len(medium_batches) > 1:
        #             first = medium_batches[0]
        #             last = medium_batches[-1]
        #             if medium_results['zero_indices_analysis'][last]['total_zero_weights'] > medium_results['zero_indices_analysis'][first]['total_zero_weights']:
        #                 f.write("- In the medium pruning test, the number of zero weights increased during training.\n")
        #                 f.write("  This is normal behavior and indicates that the network is becoming more sparse\n")
        #                 f.write("  as training progresses, which can improve efficiency without sacrificing performance.\n")
        # else:
        #     f.write("ISSUES DETECTED: The pruning implementation may have problems.\n")
        #     f.write("Some pruned weights did not remain zero during training.\n")
        #     f.write("Please review the verification files for details.\n")

        f.write("\nFor a detailed analysis:\n")
        if GOOGLE_DRIVE_MOUNTED:
            f.write(f"1. Download the files from Google Drive at {GOOGLE_DRIVE_BASE_PATH}\n")
        else:
            f.write(f"1. Upload the {output_dir} directory to a Linux machine\n")
        f.write("2. Run the compare_indices.sh script\n")
        f.write("3. Examine the all-zero-indices_batch-*.txt files to see all weights that are zero at each batch\n")
        f.write("4. CSV files of all zero indices are also available for easier processing in other tools\n")

    print(f"\nVerification report created: {report_path}")
    # if GOOGLE_DRIVE_MOUNTED:
    #     print(f"Report is available in your Google Drive at: {GOOGLE_DRIVE_BASE_PATH}")

    # if tiny_results['all_matching'] and medium_results['all_matching']:
    #     print("OVERALL RESULT: PASS - Pruning is working correctly")
    #     print("\nNOTE: Despite the warnings about non-zero weights during training,")
    #     print("the mask callback properly resets them to zero. Your implementation is working correctly.")
    # else:
    #     print("OVERALL RESULT: FAIL - Issues detected with pruning implementation")

    # Final check: If files weren't found where expected, run a file search
    if not os.path.exists(os.path.join(tiny_pruning_dir, "all-zero-indices_batch-0.txt")) or \
       not os.path.exists(os.path.join(medium_pruning_dir, "all-zero-indices_batch-0.txt")):
        print("\nWARNING: Some expected files were not found where expected.")
        print("Running a file search to locate verification files...")
        found_files = find_verification_files()


# Main function
if __name__ == "__main__":
    # Print information about file storage
    if GOOGLE_DRIVE_MOUNTED:
        print(f"Google Drive is mounted. All verification files will be saved to:")
        print(f"  {GOOGLE_DRIVE_BASE_PATH}")

        # Test Google Drive write access
        test_file = os.path.join(GOOGLE_DRIVE_BASE_PATH, "write_test.txt")
        try:
            with open(test_file, 'w') as f:
                f.write("This is a test to verify write permissions to Google Drive")
            print("Successfully wrote test file to Google Drive. File access is working correctly.")
            os.remove(test_file)  # Clean up test file
        except Exception as e:
            print(f"WARNING: Failed to write test file to Google Drive: {str(e)}")
            print("Files may not be properly saved. Consider running locally or fixing Drive permissions.")
    else:
        print("Google Drive is not mounted. Files will be saved locally.")
        print("To save files to Google Drive, run this script in Google Colab.")

    verify_densenet_pruning_comprehensive()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted. Files will be saved to /content/drive/My Drive/DenseNet_Pruning_Verification
Google Drive is mounted. All verification files will be saved to:
  /content/drive/My Drive/DenseNet_Pruning_Verification
Successfully wrote test file to Google Drive. File access is working correctly.
Running DenseNet-121 pruning verification tests
Created directory: /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification_20250519_210609
Creating base output directory: /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification_20250519_210609
Created directory: /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification_20250519_210609/tiny_pruning
Created directory: /content/drive/My Drive/DenseNet_Pruning_Verification/pruning_verification_20250519_210609/medium_pruning

Loading mnist dataset for verification..