<a href="https://colab.research.google.com/github/lenishu/IPA_using_Densenet/blob/main/Version_2_Pruning_Densenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import csv
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Input, BatchNormalization, ReLU, Conv2D, Dense, MaxPool2D, AvgPool2D, GlobalAvgPool2D, Concatenate
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10, cifar100
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import Model
import math
import gspread
from google.colab import auth
from google.auth import default
import time
from datetime import datetime

# Configuration
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 256
LEARNING_RATE = 0.1
PRUNE_PERCENTAGES = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]  # 0% to 100%
NUMBER_OF_RUNS = 10  # Number of runs per pruning percentage
EPOCHS_PER_RUN = 1   # Number of epochs per run
EMAIL = "lenishpandey@gmail.com"           # Enter your gmail here for datasheet

# Function to load and preprocess datasets
def load_dataset(dataset_name):
    if dataset_name == 'mnist':
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        # Reshape to add channel dimension and resize to 32x32
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        # Repeat the channel to match the 3-channel input expected by DenseNet
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'fashion_mnist':
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        num_classes = 10
    elif dataset_name == 'cifar100':
        (x_train, y_train), (x_test, y_test) = cifar100.load_data()
        num_classes = 100
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # Normalize data
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Convert class vectors to binary class matrices
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    return (x_train, y_train), (x_test, y_test), num_classes

# Define DenseNet-121 architecture
def bn_rl_conv(x, filters, kernel_size):
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
    return x

def dense_block(tensor, k, reps):
    for _ in range(reps):
        x = bn_rl_conv(tensor, filters=4 * k, kernel_size=1)
        x = bn_rl_conv(x, filters=k, kernel_size=3)
        tensor = Concatenate()([tensor, x])
    return tensor

def transition_layer(x, theta):
    f = int(tf.keras.backend.int_shape(x)[-1] * theta)
    x = bn_rl_conv(x, filters=f, kernel_size=1)
    x = AvgPool2D(pool_size=2, strides=2, padding='same')(x)
    return x

def create_densenet121(input_shape, num_classes, k=32, theta=0.5):
    # DenseNet-121 has repetitions [6, 12, 24, 16]
    repetitions = [6, 12, 24, 16]

    inputs = Input(shape=input_shape)
    x = Conv2D(2 * k, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPool2D(3, strides=2, padding='same')(x)

    for reps in repetitions:
        x = dense_block(x, k, reps)
        x = transition_layer(x, theta)

    x = GlobalAvgPool2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)

# Parameter Mask Pruning function
def parameter_mask_pruning(model, prune_percentage, seed=None):
    """
    Randomly select a percentage of weights across the entire network and set them to zero.
    These weights will remain fixed at zero throughout training.

    Parameters:
    model -- The Keras model to prune
    prune_percentage -- Percentage of weights to prune (0-100)
    seed -- Random seed for reproducibility

    Returns:
    The pruned model with a custom weight mask
    """
    # Set random seed if provided
    if seed is not None:
        np.random.seed(seed)

    # Get all trainable weights in the model
    all_weights = []
    all_shapes = []
    all_layer_indices = []

    for i, layer in enumerate(model.layers):
        if isinstance(layer, (Conv2D, Dense)) and len(layer.weights) > 0:
            # Only consider weight matrices, not biases
            weight = layer.get_weights()[0]
            all_weights.append(weight)
            all_shapes.append(weight.shape)
            all_layer_indices.append(i)

    # Count total parameters
    total_params = sum(w.size for w in all_weights)
    num_to_prune = int(total_params * prune_percentage / 100)

    print(f"Total trainable parameters: {total_params}")
    print(f"Parameters to prune: {num_to_prune} ({prune_percentage}%)")

    # Create masks for each layer (initially all ones)
    masks = [np.ones_like(w) for w in all_weights]

    # If pruning percentage is not 0, create masks and apply them
    if prune_percentage > 0:
        # Randomly select indices to prune across all parameters
        flat_indices = np.random.choice(total_params, size=num_to_prune, replace=False)

        # Map flat indices back to layer, row, col indices
        params_so_far = 0
        for i, weight in enumerate(all_weights):
            size = weight.size
            # Get indices that fall within this layer
            indices_in_layer = flat_indices[(flat_indices >= params_so_far) &
                                            (flat_indices < params_so_far + size)] - params_so_far

            # Flatten the mask, set the selected indices to zero, then reshape back
            flat_mask = masks[i].flatten()
            flat_mask[indices_in_layer] = 0
            masks[i] = flat_mask.reshape(all_shapes[i])

            params_so_far += size

        # Apply masks to each layer's weights
        for i, layer_idx in enumerate(all_layer_indices):
            layer = model.layers[layer_idx]
            weights = layer.get_weights()
            weights[0] = weights[0] * masks[i]  # Apply mask to weights
            layer.set_weights(weights)

    # Create a custom callback to enforce the mask during training
    class MaskWeightsCallback(tf.keras.callbacks.Callback):
        def __init__(self, masks, layer_indices):
            self.masks = masks
            self.layer_indices = layer_indices

        def on_batch_end(self, batch, logs=None):
            # Apply masks after each batch update
            for i, layer_idx in enumerate(self.layer_indices):
                layer = self.model.layers[layer_idx]
                weights = layer.get_weights()
                weights[0] = weights[0] * self.masks[i]
                layer.set_weights(weights)

    # Attach the mask callback to the model for later use
    model.mask_callback = MaskWeightsCallback(masks, all_layer_indices)

    # Reset random seed
    if seed is not None:
        np.random.seed(None)

    return model

# Google Sheets Logger class
class GoogleSheetsLogger(Callback):
    def __init__(self, test_data, batch_size, ce_threshold, spreadsheet=None, sheet_name=None, baseline_ce=None):
        """
        Initialize the Google Sheets Logger

        Parameters:
        test_data -- tuple of (x_test, y_test)
        batch_size -- batch size for evaluation
        ce_threshold -- Threshold for CE (ln(10) or ln(100) depending on dataset)
        spreadsheet -- existing Google spreadsheet object to use
        sheet_name -- name of the worksheet to create/use within the spreadsheet
        baseline_ce -- baseline cross-entropy from unpruned model (for IPA calculation)
        """
        super().__init__()
        self.test_x, self.test_y = test_data
        self.batch_size = batch_size
        self.total_batches = 0  # Using total_batches for continuous numbering across epochs
        self.epoch = 0
        self.baseline_ce = baseline_ce
        self.ce_threshold = ce_threshold  # Store the CE threshold (ln(10) or ln(100))
        self.ipa_start_batch = None  # Track when CE first goes below threshold
        self.spreadsheet = spreadsheet

        # Use provided sheet name or generate a default
        if sheet_name is None:
            sheet_name = f"Training_Log_{int(time.time())}"
        self.sheet_name = sheet_name

        # Column headers for the sheet
        self.headers = [
            'Epoch', 'Batch Number',
            'Train Accur', 'Train CE', 'Test Accur', 'Test CE',
            '|CE(b)-CEo|', 'IPA'
        ]

        # Authenticate and setup Google Sheets
        self._setup_google_sheets()

    def _setup_google_sheets(self):
        """Setup authentication and create the Google Sheet"""
        # Authenticate to Google if not already done
        if self.spreadsheet is None:
            auth.authenticate_user()

            # Get credentials (using google-auth instead of oauth2client)
            creds, _ = default()

            # Create a gspread client
            self.gc = gspread.authorize(creds)

            try:
                # Try to open existing spreadsheet
                self.sheet = self.gc.open("DenseNet121_Pruning_Experiments")
                print(f"Using existing spreadsheet: DenseNet121_Pruning_Experiments")
            except gspread.exceptions.SpreadsheetNotFound:
                # Create new spreadsheet if it doesn't exist
                self.sheet = self.gc.create("DenseNet121_Pruning_Experiments")
                print(f"Created new spreadsheet: DenseNet121_Pruning_Experiments")

                # Share the spreadsheet with the specified email if provided
                if EMAIL:
                    self.sheet.share(EMAIL, perm_type='user', role='writer')
                    print(f"Shared spreadsheet with: {EMAIL}")

            # Store for future use
            self.spreadsheet = self.sheet
        else:
            # Use the provided spreadsheet
            self.sheet = self.spreadsheet

        # Get or create the worksheet for this pruning percentage
        try:
            self.worksheet = self.sheet.worksheet(self.sheet_name)
            # Clear existing content
            self.worksheet.clear()
        except gspread.exceptions.WorksheetNotFound:
            self.worksheet = self.sheet.add_worksheet(title=self.sheet_name, rows=1000, cols=len(self.headers))

        # Add headers to the worksheet
        self.worksheet.update('A1', [self.headers])

        # Print the URL to access the spreadsheet
        print(f"Spreadsheet URL: {self.sheet.url}")

    def on_train_begin(self, logs=None):
        # If baseline CE is not provided, calculate it
        if self.baseline_ce is None:
            print("Calculating baseline CE...")
            _, _, self.baseline_ce = self.model.evaluate(
                self.test_x, self.test_y,
                batch_size=self.batch_size,
                verbose=1
            )
            print(f"Baseline CE (CEo): {self.baseline_ce}")

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
        # No batch number reset - using total_batches for continuous counting

    def on_train_batch_end(self, batch, logs=None):
        self.total_batches += 1

        # Calculate metrics on test data (silently)
        _, test_acc, test_ce = self.model.evaluate(
            self.test_x, self.test_y,
            batch_size=self.batch_size,
            verbose=0
        )

        # Calculate CE difference
        ce_diff = abs(test_ce - self.baseline_ce)

        # Check if CE is below threshold and update ipa_start_batch if first time
        if self.ipa_start_batch is None and test_ce <= self.ce_threshold:
            self.ipa_start_batch = self.total_batches
            print(f"CE dropped below threshold ({self.ce_threshold:.4f}) at batch {self.ipa_start_batch} (CE = {test_ce:.4f})")

        # Calculate IPA
        # If CE is above threshold, IPA is None
        # Otherwise, use the current batch number (not adjusted)
        if test_ce <= self.ce_threshold:
            ipa = ce_diff / (BATCH_SIZE_TRAIN * self.total_batches) if self.total_batches > 0 else 0
            ipa_value = ipa
        else:
            ipa_value = "N/A"

        # Prepare row to append
        row = [
            self.epoch,
            self.total_batches,
            logs.get('accuracy', 0),
            logs.get('loss', 0),  # Train CE
            test_acc,
            test_ce,
            ce_diff,
            ipa_value
        ]

        # Append row to Google Sheet
        try:
            self.worksheet.append_row(row)
        except Exception as e:
            print(f"Error writing to Google Sheet: {e}")

    def get_final_metrics(self):
        """Return the final IPA and test accuracy"""
        # Try to get the latest values from the Google Sheet
        try:
            values = self.worksheet.get_all_values()
            if len(values) > 1:  # If we have data (header + at least one row)
                last_row = values[-1]

                # Check if the IPA is a number or N/A
                ipa_value = last_row[7]  # Index 7 is IPA
                if ipa_value != "N/A":
                    return {
                        'Final Test Accuracy': float(last_row[4]),  # Index 4 is Test Accur
                        'IPA': float(ipa_value) * 1000,  # Scale by 1000
                        'CE Below Threshold': True
                    }
                else:
                    return {
                        'Final Test Accuracy': float(last_row[4]),  # Index 4 is Test Accur
                        'IPA': 0,  # IPA is 0 if N/A
                        'CE Below Threshold': False
                    }
        except Exception as e:
            print(f"Error getting final metrics from Google Sheet: {e}")

        # Fall back to calculating final metrics
        _, test_acc, test_ce = self.model.evaluate(
            self.test_x, self.test_y,
            batch_size=self.batch_size,
            verbose=0
        )

        ce_diff = abs(test_ce - self.baseline_ce)

        # Check if CE is below threshold
        if test_ce <= self.ce_threshold:
            ipa = ce_diff / (BATCH_SIZE_TRAIN * self.total_batches) if self.total_batches > 0 else 0
            return {
                'Final Test Accuracy': test_acc,
                'IPA': ipa * 1000,  # Scale by 1000
                'CE Below Threshold': True
            }
        else:
            return {
                'Final Test Accuracy': test_acc,
                'IPA': 0,  # IPA is 0 if CE is above threshold
                'CE Below Threshold': False
            }

# Function to run a single experiment with a specific pruning percentage and seed
def run_single_experiment(dataset_name, pruning_method, prune_percentage, run_number, baseline_ce, spreadsheet, x_train, y_train, x_test, y_test, num_classes):
    """
    Run a single experiment with a specific pruning percentage and seed.

    Parameters:
    dataset_name -- Name of the dataset
    pruning_method -- Pruning method to use
    prune_percentage -- Percentage to prune
    run_number -- Run number (for setting seed)
    baseline_ce -- Baseline cross-entropy
    spreadsheet -- Google Sheets spreadsheet object
    x_train, y_train, x_test, y_test -- Training and test data
    num_classes -- Number of classes

    Returns:
    Dictionary with results
    """
    print(f"\nRunning experiment: {dataset_name}, {pruning_method}, P% = {prune_percentage}, Run #{run_number}")

    # Use run number as seed for reproducibility
    seed = 42 + run_number
    tf.random.set_seed(seed)

    # Define CE threshold based on dataset
    if dataset_name == 'cifar100':
        ce_threshold = math.log(100)  # ln(100) for CIFAR-100
        print(f"Using CE threshold ln(100) = {ce_threshold:.4f} for CIFAR-100")
    else:
        ce_threshold = math.log(10)  # ln(10) for other datasets
        print(f"Using CE threshold ln(10) = {ce_threshold:.4f} for {dataset_name}")

    # Create model
    model = create_densenet121(input_shape=x_train.shape[1:], num_classes=num_classes)

    # Apply parameter mask pruning with seed
    model = parameter_mask_pruning(model, prune_percentage, seed=seed)

    # Store the mask callback for use during training
    mask_callback = model.mask_callback

    # Compile model
    optimizer = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'categorical_crossentropy']
    )

    # Setup Google Sheets logger
    sheet_name = f"Run{run_number}_P{prune_percentage}"
    google_sheet_logger = GoogleSheetsLogger(
        test_data=(x_test, y_test),
        batch_size=BATCH_SIZE_TEST,
        ce_threshold=ce_threshold,
        spreadsheet=spreadsheet,
        sheet_name=sheet_name,
        baseline_ce=baseline_ce
    )

    # Setup callbacks
    callbacks = [google_sheet_logger, mask_callback]

    # Train for specified number of epochs
    history = model.fit(
        x_train, y_train,
        batch_size=BATCH_SIZE_TRAIN,
        validation_batch_size=BATCH_SIZE_TEST,
        epochs=EPOCHS_PER_RUN,
        validation_data=(x_test, y_test),
        callbacks=callbacks,
        verbose=2  # Less verbose output
    )

    # Get final metrics
    final_metrics = google_sheet_logger.get_final_metrics()

    # Clear memory
    tf.keras.backend.clear_session()

    result = {
        'Pruning Percentage': prune_percentage,
        'Run Number': run_number,
        'Final Test Accuracy': final_metrics['Final Test Accuracy'],
        'IPA': final_metrics['IPA'],
        'CE Below Threshold': final_metrics['CE Below Threshold']
    }

    threshold_status = "CE below threshold" if final_metrics['CE Below Threshold'] else "CE above threshold (IPA set to 0)"
    print(f"P{prune_percentage} Run #{run_number} - Test Accuracy: {result['Final Test Accuracy']:.4f}, IPA × 1000: {result['IPA']:.6f}, {threshold_status}")

    return result

# Function to run experiment for multiple pruning percentages with multiple runs
def run_multi_experiment(dataset_name, pruning_method):
    """
    Run experiments for multiple pruning percentages with multiple runs each.
    Modified to complete one full run across all pruning percentages before starting the next run.

    Parameters:
    dataset_name -- Name of the dataset
    pruning_method -- Pruning method to use

    Returns:
    DataFrame with summary results
    """
    # Create output directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"results_{dataset_name}_{pruning_method}_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)

    # Create subdirectory for plots
    plots_dir = os.path.join(output_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    # Create a CSV file to store results for all runs
    all_runs_file = os.path.join(output_dir, "all_runs_results.csv")
    with open(all_runs_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Run Number', 'Pruning Percentage', 'Final Test Accuracy', 'IPA', 'CE Below Threshold'])

    # Load dataset once
    print(f"\nLoading {dataset_name} dataset...")
    (x_train, y_train), (x_test, y_test), num_classes = load_dataset(dataset_name)

    # Calculate baseline CE from unpruned model once
    print("\nCreating baseline model to calculate baseline CE...")
    baseline_model = create_densenet121(input_shape=x_train.shape[1:], num_classes=num_classes)

    # Compile baseline model
    optimizer = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)
    baseline_model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'categorical_crossentropy']
    )

    # Calculate baseline CE
    _, _, baseline_ce = baseline_model.evaluate(
        x_test, y_test,
        batch_size=BATCH_SIZE_TEST,
        verbose=1
    )
    print(f"Baseline CE (CEo): {baseline_ce}")

    # Clear memory
    tf.keras.backend.clear_session()

    # Setup Google Sheets
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)

    try:
        # Try to open existing spreadsheet
        sheet = gc.open("DenseNet121_Pruning_Experiments")
    except gspread.exceptions.SpreadsheetNotFound:
        # Create new spreadsheet if it doesn't exist
        sheet = gc.create("DenseNet121_Pruning_Experiments")
        if EMAIL:
            sheet.share(EMAIL, perm_type='user', role='writer')

    # Create or update summary worksheet
    summary_name = f"{dataset_name}_{pruning_method}_summary_{timestamp}"
    try:
        summary_sheet = sheet.worksheet(summary_name)
        summary_sheet.clear()
    except gspread.exceptions.WorksheetNotFound:
        summary_sheet = sheet.add_worksheet(title=summary_name, rows=100, cols=6)

    # Add headers to summary sheet
    summary_sheet.update('A1', [['Prune Percentage', 'Avg Test Accuracy', 'Std Test Accuracy',
                               'Avg IPA × 1000', 'Std IPA × 1000', 'CE Below Threshold Count']])

    # Create run-wise summary worksheet
    run_summary_name = f"{dataset_name}_{pruning_method}_runs_{timestamp}"
    try:
        run_summary_sheet = sheet.worksheet(run_summary_name)
        run_summary_sheet.clear()
    except gspread.exceptions.WorksheetNotFound:
        run_summary_sheet = sheet.add_worksheet(title=run_summary_name, rows=100, cols=len(PRUNE_PERCENTAGES) + 1)

    # Add headers to run summary sheet - Run Number and all pruning percentages
    run_headers = ['Run Number'] + [f'P{p}%' for p in PRUNE_PERCENTAGES]
    run_summary_sheet.update('A1', [run_headers])

    # Store all results
    all_results = []

    # Dictionary to group results by pruning percentage
    results_by_percentage = {p: [] for p in PRUNE_PERCENTAGES}

    # Loop through runs first
    for run_number in range(1, NUMBER_OF_RUNS + 1):
        print(f"\n===== Starting Run #{run_number} =====")

        # Store IPA values for all pruning percentages in this run
        run_ipa_values = []

        # Loop through pruning percentages in each run
        for prune_percentage in PRUNE_PERCENTAGES:
            # Run a single experiment
            result = run_single_experiment(
                dataset_name=dataset_name,
                pruning_method=pruning_method,
                prune_percentage=prune_percentage,
                run_number=run_number,
                baseline_ce=baseline_ce,
                spreadsheet=sheet,
                x_train=x_train,
                y_train=y_train,
                x_test=x_test,
                y_test=y_test,
                num_classes=num_classes
            )

            # Add to overall results
            all_results.append(result)

            # Add to results by percentage
            results_by_percentage[prune_percentage].append(result)

            # Add IPA value for this pruning percentage
            run_ipa_values.append(result['IPA'])

            # Append to CSV file
            with open(all_runs_file, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([
                    result['Run Number'],
                    result['Pruning Percentage'],
                    result['Final Test Accuracy'],
                    result['IPA'],
                    result['CE Below Threshold']
                ])

        # Update run summary sheet with IPAs from this run
        run_summary_sheet.append_row([run_number] + run_ipa_values)

        print(f"\nCompleted Run #{run_number} across all pruning percentages")

    # Calculate and store summary statistics for each pruning percentage
    summary_results = []

    for prune_percentage, results in results_by_percentage.items():
        accuracies = [r['Final Test Accuracy'] for r in results]
        ipas = [r['IPA'] for r in results]
        threshold_count = sum(1 for r in results if r['CE Below Threshold'])

        avg_accuracy = np.mean(accuracies)
        std_accuracy = np.std(accuracies)
        avg_ipa = np.mean(ipas)
        std_ipa = np.std(ipas)

        # Add to summary results
        summary_results.append({
            'Pruning Percentage': prune_percentage,
            'Avg Test Accuracy': avg_accuracy,
            'Std Test Accuracy': std_accuracy,
            'Avg IPA': avg_ipa,
            'Std IPA': std_ipa,
            'CE Below Threshold Count': threshold_count
        })

        # Add to summary sheet
        summary_sheet.append_row([
            prune_percentage,
            avg_accuracy,
            std_accuracy,
            avg_ipa,
            std_ipa,
            threshold_count
        ])

        print(f"\nSummary for P{prune_percentage}:")
        print(f"  Avg Test Accuracy: {avg_accuracy:.4f} ± {std_accuracy:.4f}")
        print(f"  Avg IPA × 1000: {avg_ipa:.6f} ± {std_ipa:.6f}")
        print(f"  Runs with CE below threshold: {threshold_count}/{NUMBER_OF_RUNS}")

    # Create summary DataFrame
    summary_df = pd.DataFrame(summary_results)

    # Save summary to CSV
    summary_df.to_csv(os.path.join(output_dir, "summary_results.csv"), index=False)

    # Plot IPA vs Pruning Percentage with error bars
    plt.figure(figsize=(10, 6))
    plt.errorbar(
        summary_df['Pruning Percentage'],
        summary_df['Avg IPA'],
        yerr=summary_df['Std IPA'],
        fmt='o-',
        capsize=5,
        elinewidth=2,
        markeredgewidth=2
    )
    plt.xlabel('Pruning Percentage (P%)')
    plt.ylabel('Average IPA × 1000')
    plt.title(f'IPA vs P% for {dataset_name} using {pruning_method} pruning')
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, "avg_ipa_vs_prune.png"))

    # Plot Accuracy vs Pruning Percentage with error bars
    plt.figure(figsize=(10, 6))
    plt.errorbar(
        summary_df['Pruning Percentage'],
        summary_df['Avg Test Accuracy'],
        yerr=summary_df['Std Test Accuracy'],
        fmt='o-',
        capsize=5,
        eelinewidth=2,
        markeredgewidth=2
    )
    plt.xlabel('Pruning Percentage (P%)')
    plt.ylabel('Average Test Accuracy')
    plt.title(f'Accuracy vs P% for {dataset_name} using {pruning_method} pruning')
    plt.grid(True)
    plt.savefig(os.path.join(plots_dir, "avg_accuracy_vs_prune.png"))

    # Plot CE Below Threshold Count vs Pruning Percentage
    plt.figure(figsize=(10, 6))
    plt.bar(
        summary_df['Pruning Percentage'],
        summary_df['CE Below Threshold Count'],
        alpha=0.7
    )
    plt.xlabel('Pruning Percentage (P%)')
    plt.ylabel('Count of Runs with CE Below Threshold')
    plt.title(f'Convergence Rate vs P% for {dataset_name} using {pruning_method} pruning')
    plt.grid(True, axis='y')
    plt.savefig(os.path.join(plots_dir, "convergence_rate_vs_prune.png"))

    # Create heatmap of IPA values across runs and pruning percentages
    plt.figure(figsize=(12, 8))

    # Extract data for heatmap
    heatmap_data = np.zeros((NUMBER_OF_RUNS, len(PRUNE_PERCENTAGES)))
    for i, prune_percentage in enumerate(PRUNE_PERCENTAGES):
        for j, result in enumerate(results_by_percentage[prune_percentage]):
            heatmap_data[j, i] = result['IPA']

    # Create heatmap
    plt.imshow(heatmap_data, cmap='viridis', aspect='auto')
    plt.colorbar(label='IPA × 1000')

    # Set ticks and labels
    plt.xticks(np.arange(len(PRUNE_PERCENTAGES)), [f'{p}%' for p in PRUNE_PERCENTAGES])
    plt.yticks(np.arange(NUMBER_OF_RUNS), [f'Run {i+1}' for i in range(NUMBER_OF_RUNS)])

    plt.xlabel('Pruning Percentage')
    plt.ylabel('Run Number')
    plt.title(f'IPA Heatmap for {dataset_name} using {pruning_method} pruning')
    plt.savefig(os.path.join(plots_dir, "ipa_heatmap.png"))

    # Close all plots
    plt.close('all')

    return summary_df, output_dir

# Main execution - Select dataset and run experiment
def main():
    # Options for datasets
    dataset_options = {
        '1': 'mnist',
        '2': 'fashion_mnist',
        '3': 'cifar10',
        '4': 'cifar100'
    }

    # Print options
    print("Available datasets:")
    for key, value in dataset_options.items():
        print(f"{key}: {value}")

    # Get user input or use default
    try:
        dataset_choice = input("Select dataset (1-4) or press Enter for MNIST: ")
        if dataset_choice == "":
            dataset_choice = "1"  # Default to MNIST

        # Validate input
        if dataset_choice not in dataset_options:
            print(f"Invalid choice '{dataset_choice}'. Using default (MNIST).")
            dataset_choice = "1"

        dataset = dataset_options[dataset_choice]
    except Exception as e:
        print(f"Error: {e}. Using default dataset (MNIST).")
        dataset = 'mnist'

    # Use parameter_mask pruning
    pruning_method = 'parameter_mask'

    print(f"Running DenseNet-121 pruning experiments on {dataset} dataset")
    print(f"Training batch size: {BATCH_SIZE_TRAIN}, Testing batch size: {BATCH_SIZE_TEST}")
    print(f"Pruning percentages: {PRUNE_PERCENTAGES}")
    print(f"Number of runs per percentage: {NUMBER_OF_RUNS}")
    print(f"Epochs per run: {EPOCHS_PER_RUN}")
    if EMAIL:
        print(f"Results will be shared with: {EMAIL}")

    # Configure TensorFlow to be less verbose
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    # Run parameter mask pruning experiment
    print("\nStarting Parameter Mask Pruning experiment...")
    summary_df, output_dir = run_multi_experiment(dataset, pruning_method)

    print("\nExperiment complete!")
    print(f"Results saved to {output_dir}")

if __name__ == "__main__":
    main()

Available datasets:
1: mnist
2: fashion_mnist
3: cifar10
4: cifar100
Select dataset (1-4) or press Enter for MNIST: 1
Running DenseNet-121 pruning experiments on mnist dataset
Training batch size: 64, Testing batch size: 256
Pruning percentages: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
Number of runs per percentage: 10
Epochs per run: 1
Results will be shared with: lenishpandey@gmail.com

Starting Parameter Mask Pruning experiment...

Loading mnist dataset...

Creating baseline model to calculate baseline CE...
[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 226ms/step - accuracy: 0.0767 - categorical_crossentropy: 2.3030 - loss: 2.3030
Baseline CE (CEo): 2.3030436038970947


  summary_sheet.update('A1', [['Prune Percentage', 'Avg Test Accuracy', 'Std Test Accuracy',
  run_summary_sheet.update('A1', [run_headers])



===== Starting Run #1 =====

Running experiment: mnist, parameter_mask, P% = 0, Run #1
Using CE threshold ln(10) = 2.3026 for mnist
Total trainable parameters: 7399616
Parameters to prune: 0 (0%)


  self.worksheet.update('A1', [self.headers])


Spreadsheet URL: https://docs.google.com/spreadsheets/d/1LCi-L80tAcu2qfyrfDGcJzzbEkmec912eQBfn_uvaKY
CE dropped below threshold (2.3026) at batch 56 (CE = 2.2935)
Error writing to Google Sheet: APIError: [429]: Quota exceeded for quota metric 'Write requests' and limit 'Write requests per minute' of service 'sheets.googleapis.com' for consumer 'project_number:522309567947'.
Error writing to Google Sheet: APIError: [429]: Quota exceeded for quota metric 'Write requests' and limit 'Write requests per minute' of service 'sheets.googleapis.com' for consumer 'project_number:522309567947'.
938/938 - 1952s - 2s/step - accuracy: 0.9853 - categorical_crossentropy: 0.0489 - loss: 0.0489 - val_accuracy: 0.9853 - val_categorical_crossentropy: 0.0489 - val_loss: 0.0489
P0 Run #1 - Test Accuracy: 0.9853, IPA × 1000: 0.037549, CE below threshold

Running experiment: mnist, parameter_mask, P% = 10, Run #1
Using CE threshold ln(10) = 2.3026 for mnist
Total trainable parameters: 7399616
Parameters to pr