<a href="https://colab.research.google.com/github/lenishu/IPA_using_Densenet/blob/main/Version_1_Pruning_Densenet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import csv
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Input, BatchNormalization, ReLU, Conv2D, Dense, MaxPool2D, AvgPool2D, GlobalAvgPool2D, Concatenate
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10, cifar100
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import Model
import math
import gspread
from google.colab import auth
from google.auth import default
import time

# Configuration
BATCH_SIZE_TRAIN = 5000
BATCH_SIZE_TEST = 5000
LEARNING_RATE = 0.1
PRUNE_PERCENTAGES = [90]
EMAIL = ""  # Enter your gmail here for datasheet

# Function to load and preprocess datasets
def load_dataset(dataset_name):
    if dataset_name == 'mnist':
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        # Reshape to add channel dimension and resize to 32x32
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        # Repeat the channel to match the 3-channel input expected by DenseNet
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'fashion_mnist':
        (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
        x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'constant')
        x_train = np.expand_dims(x_train, axis=-1)
        x_test = np.expand_dims(x_test, axis=-1)
        x_train = np.repeat(x_train, 3, axis=-1)
        x_test = np.repeat(x_test, 3, axis=-1)
        num_classes = 10
    elif dataset_name == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        num_classes = 10
    elif dataset_name == 'cifar100':
        (x_train, y_train), (x_test, y_test) = cifar100.load_data()
        num_classes = 100
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # Normalize data
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Convert class vectors to binary class matrices
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    return (x_train, y_train), (x_test, y_test), num_classes

# Define DenseNet-121 architecture
def bn_rl_conv(x, filters, kernel_size):
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
    return x

def dense_block(tensor, k, reps):
    for _ in range(reps):
        x = bn_rl_conv(tensor, filters=4 * k, kernel_size=1)
        x = bn_rl_conv(x, filters=k, kernel_size=3)
        tensor = Concatenate()([tensor, x])
    return tensor

def transition_layer(x, theta):
    f = int(tf.keras.backend.int_shape(x)[-1] * theta)
    x = bn_rl_conv(x, filters=f, kernel_size=1)
    x = AvgPool2D(pool_size=2, strides=2, padding='same')(x)
    return x

def create_densenet121(input_shape, num_classes, k=32, theta=0.5):
    # DenseNet-121 has repetitions [6, 12, 24, 16]
    repetitions = [6, 12, 24, 16]

    inputs = Input(shape=input_shape)
    x = Conv2D(2 * k, 7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPool2D(3, strides=2, padding='same')(x)

    for reps in repetitions:
        x = dense_block(x, k, reps)
        x = transition_layer(x, theta)

    x = GlobalAvgPool2D()(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)

# Parameter Mask Pruning function
def parameter_mask_pruning(model, prune_percentage):
    """
    Randomly select a percentage of weights across the entire network and set them to zero.
    These weights will remain fixed at zero throughout training.

    Parameters:
    model -- The Keras model to prune
    prune_percentage -- Percentage of weights to prune (0-100)

    Returns:
    The pruned model with a custom weight mask
    """
    # Get all trainable weights in the model
    all_weights = []
    all_shapes = []
    all_layer_indices = []

    for i, layer in enumerate(model.layers):
        if isinstance(layer, (Conv2D, Dense)) and len(layer.weights) > 0:
            # Only consider weight matrices, not biases
            weight = layer.get_weights()[0]
            all_weights.append(weight)
            all_shapes.append(weight.shape)
            all_layer_indices.append(i)

    # Count total parameters
    total_params = sum(w.size for w in all_weights)
    num_to_prune = int(total_params * prune_percentage / 100)

    print(f"Total trainable parameters: {total_params}")
    print(f"Parameters to prune: {num_to_prune} ({prune_percentage}%)")

    # Create masks for each layer (initially all ones)
    masks = [np.ones_like(w) for w in all_weights]

    # If pruning percentage is not 0, create masks and apply them
    if prune_percentage > 0:
        # Randomly select indices to prune across all parameters
        flat_indices = np.random.choice(total_params, size=num_to_prune, replace=False)

        # Map flat indices back to layer, row, col indices
        params_so_far = 0
        for i, weight in enumerate(all_weights):
            size = weight.size
            # Get indices that fall within this layer
            indices_in_layer = flat_indices[(flat_indices >= params_so_far) &
                                             (flat_indices < params_so_far + size)] - params_so_far

            # Flatten the mask, set the selected indices to zero, then reshape back
            flat_mask = masks[i].flatten()
            flat_mask[indices_in_layer] = 0
            masks[i] = flat_mask.reshape(all_shapes[i])

            params_so_far += size

        # Apply masks to each layer's weights
        for i, layer_idx in enumerate(all_layer_indices):
            layer = model.layers[layer_idx]
            weights = layer.get_weights()
            weights[0] = weights[0] * masks[i]  # Apply mask to weights
            layer.set_weights(weights)

    # Create a custom calo enforce the mask during training
    class MaskWeightsCallback(tf.keras.callbacks.Callback):
        def __init__(self, masks, layer_indices):
            self.masks = masks
            self.layer_indices = layer_indices

        def on_batch_end(self, batch, logs=None): ## here i need to be careful
            # Apply masks after each batch update
            for i, layer_idx in enumerate(self.layer_indices):
                layer = self.model.layers[layer_idx]
                weights = layer.get_weights()
                weights[0] = weights[0] * self.masks[i]
                layer.set_weights(weights)

    # Attach the mask callback to the model for later use
    model.mask_callback = MaskWeightsCallback(masks, all_layer_indices)

    return model

# Google Sheets Logger class
class GoogleSheetsLogger(Callback):
    def __init__(self, test_data, batch_size, spreadsheet=None, sheet_name=None, baseline_ce=None):
        """
        Initialize the Google Sheets Logger

        Parameters:
        test_data -- tuple of (x_test, y_test)
        batch_size -- batch size for evaluation
        spreadsheet -- existing Google spreadsheet object to use
        sheet_name -- name of the worksheet to create/use within the spreadsheet
        baseline_ce -- baseline cross-entropy from unpruned model (for IPA calculation)
        """
        super().__init__()
        self.test_x, self.test_y = test_data
        self.batch_size = batch_size
        self.epoch = 0
        self.batch_num = 0
        self.total_batches = 0
        self.baseline_ce = baseline_ce
        self.ln10 = math.log(10)  # Approx 2.302
        self.ipa_start_batch = None
        self.spreadsheet = spreadsheet

        # Use pruning percentage as sheet name if not provided
        if sheet_name is None:
            sheet_name = f"Training_Log_{int(time.time())}"
        self.sheet_name = sheet_name

        # Column headers for the sheet
        self.headers = [
            'Epoch', 'Batch Number',
            'Train Accur', 'Train CE', 'Test Accur', 'Test CE',
            '|CE(b)-CEo|', 'IPA'
        ]

        # Authenticate and setup Google Sheets
        self._setup_google_sheets()

    def _setup_google_sheets(self):
        """Setup authentication and create the Google Sheet"""
        # Authenticate to Google if not already done
        if self.spreadsheet is None:
            auth.authenticate_user()

            # Get credentials (using google-auth instead of oauth2client)
            creds, _ = default()

            # Create a gspread client
            self.gc = gspread.authorize(creds)

            try:
                # Try to open existing spreadsheet
                self.sheet = self.gc.open("DenseNet121_Pruning_Experiments")
                print(f"Using existing spreadsheet: DenseNet121_Pruning_Experiments")
            except gspread.exceptions.SpreadsheetNotFound:
                # Create new spreadsheet if it doesn't exist
                self.sheet = self.gc.create("DenseNet121_Pruning_Experiments")
                print(f"Created new spreadsheet: DenseNet121_Pruning_Experiments")

                # Share the spreadsheet with the specified email
                self.sheet.share(EMAIL, perm_type='user', role='writer')
                print(f"Shared spreadsheet with: {EMAIL}")

            # Store for future use
            self.spreadsheet = self.sheet
        else:
            # Use the provided spreadsheet
            self.sheet = self.spreadsheet

        # Get or create the worksheet for this pruning percentage
        try:
            self.worksheet = self.sheet.worksheet(self.sheet_name)
            # Clear existing content
            self.worksheet.clear()
        except gspread.exceptions.WorksheetNotFound:
            self.worksheet = self.sheet.add_worksheet(title=self.sheet_name, rows=1000, cols=len(self.headers))

        # Add headers to the worksheet
        self.worksheet.update('A1', [self.headers])

        # Print the URL to access the spreadsheet
        print(f"Spreadsheet URL: {self.sheet.url}")

    def on_train_begin(self, logs=None):
        # If baseline CE is not provided, calculate it
        if self.baseline_ce is None:
            print("Calculating baseline CE...")
            _, _, self.baseline_ce = self.model.evaluate(
                self.test_x, self.test_y,
                batch_size=self.batch_size,
                verbose=1
            )
            print(f"Baseline CE (CEo): {self.baseline_ce}")

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
        self.batch_num = 0 # remove this to remove the reset of the batch

    def on_train_batch_end(self, batch, logs=None):
        self.batch_num += 1
        self.total_batches += 1

        # Calculate metrics on test data (silently)
        _, test_acc, test_ce = self.model.evaluate(
            self.test_x, self.test_y,
            batch_size=self.batch_size,
            verbose=0
        )

        # Calculate CE difference
        ce_diff = abs(test_ce - self.baseline_ce)

        # Determine if we should start IPA calculation
        if self.ipa_start_batch is None and test_ce <= self.ln10:
            self.ipa_start_batch = self.total_batches
            print(f"Started IPA calculation at batch {self.ipa_start_batch} (CE = {test_ce:.4f})")

        # Calculate IPA only if we've reached the threshold
        if self.ipa_start_batch is not None:
            # Use batches since we hit the ln10 threshold
            adjusted_batches = self.total_batches - self.ipa_start_batch + 1
            tl = BATCH_SIZE_TRAIN * adjusted_batches
            ipa = ce_diff / tl if tl > 0 else 0
            ipa_value = ipa  # Store the actual value
        else:
            # Before we hit threshold, IPA is marked as N/A
            ipa_value = "N/A"
            ipa = None

        # Prepare row to append
        row = [
            self.epoch,
            self.batch_num,
            logs.get('accuracy', 0),
            logs.get('loss', 0),  # Train CE
            test_acc,
            test_ce,
            ce_diff,
            ipa_value
        ]

        # Append row to Google Sheet (without printing each batch)
        try:
            self.worksheet.append_row(row)
        except Exception as e:
            print(f"Error writing to Google Sheet: {e}")

    def get_final_metrics(self):
        """Return the final IPA and test accuracy"""
        # If we've identified an IPA starting batch
        if self.ipa_start_batch is not None:
            # Try to get the latest values from the Google Sheet
            try:
                values = self.worksheet.get_all_values()
                if len(values) > 1:  # If we have data (header + at least one row)
                    last_row = values[-1]
                    # Check if the IPA is a number
                    ipa_value = last_row[7]  # Index 7 is IPA
                    if ipa_value != "N/A":
                        return {
                            'Final Test Accuracy': float(last_row[4]),  # Index 4 is Test Accur
                            'IPA': float(ipa_value) * 1000  # Scale by 1000
                        }
            except Exception as e:
                print(f"Error getting final metrics from Google Sheet: {e}")

        # Fall back to calculating final metrics
        _, test_acc, test_ce = self.model.evaluate(
            self.test_x, self.test_y,
            batch_size=self.batch_size,
            verbose=0
        )

        ce_diff = abs(test_ce - self.baseline_ce)

        if self.ipa_start_batch is not None:
            adjusted_batches = self.total_batches - self.ipa_start_batch + 1
            tl = BATCH_SIZE_TRAIN * adjusted_batches
            ipa = ce_diff / tl if tl > 0 else 0
            return {
                'Final Test Accuracy': test_acc,
                'IPA': ipa * 1000  # Scale by 1000 as requested
            }
        else:
            return {
                'Final Test Accuracy': test_acc,
                'IPA': 0  # If IPA calculation never started
            }

# Function to run experiment for a dataset and pruning method
def run_experiment(dataset_name, pruning_method):
    results = []

    # Load dataset
    (x_train, y_train), (x_test, y_test), num_classes = load_dataset(dataset_name)

    # Create output directory
    output_dir = f"results_{dataset_name}_{pruning_method}"
    os.makedirs(output_dir, exist_ok=True)

    # First, get baseline CE from unpruned model
    print("\nCreating baseline model to calculate baseline CE...")
    baseline_model = create_densenet121(input_shape=x_train.shape[1:], num_classes=num_classes)

    # Compile baseline model
    optimizer = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)
    baseline_model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy', 'categorical_crossentropy']
    )

    # Calculate baseline CE
    _, _, baseline_ce = baseline_model.evaluate(
        x_test, y_test,
        batch_size=BATCH_SIZE_TEST,
        verbose=1
    )
    print(f"Baseline CE (CEo): {baseline_ce}")

    # Clear memory
    tf.keras.backend.clear_session()

    # Create a summary worksheet for overall results
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)

    try:
        # Try to open existing spreadsheet
        sheet = gc.open("DenseNet121_Pruning_Experiments")
    except gspread.exceptions.SpreadsheetNotFound:
        # Create new spreadsheet if it doesn't exist
        sheet = gc.create("DenseNet121_Pruning_Experiments")
        sheet.share(EMAIL, perm_type='user', role='writer')

    # Create or update summary worksheet
    summary_name = f"{dataset_name}_{pruning_method}_summary"
    try:
        summary_sheet = sheet.worksheet(summary_name)
        summary_sheet.clear()
    except gspread.exceptions.WorksheetNotFound:
        summary_sheet = sheet.add_worksheet(title=summary_name, rows=100, cols=3)

    # Add headers to summary sheet
    summary_sheet.update('A1', [['Prune Percentage', 'Final Test Accuracy', 'IPA × 1000']])

    # Save spreadsheet object for reuse
    spreadsheet = sheet

    for prune_percentage in PRUNE_PERCENTAGES:
        print(f"\nRunning experiment: {dataset_name}, {pruning_method}, P% = {prune_percentage}")

        # Create model
        model = create_densenet121(input_shape=x_train.shape[1:], num_classes=num_classes)

        # Count total parameters before pruning
        total_params = model.count_params()
        print(f"Total parameters before pruning: {total_params:,}")

        # Apply parameter mask pruning
        model = parameter_mask_pruning(model, prune_percentage)
        print(f"Applied parameter mask pruning: P% = {prune_percentage}%")

        # Store the mask callback for use during training
        mask_callback = model.mask_callback

        # Compile model
        optimizer = tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE)
        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'categorical_crossentropy']
        )

        # Setup Google Sheets logger
        sheet_name = f"P{prune_percentage}"
        google_sheet_logger = GoogleSheetsLogger(
            test_data=(x_test, y_test),
            batch_size=BATCH_SIZE_TEST,
            spreadsheet=spreadsheet,
            sheet_name=sheet_name,
            baseline_ce=baseline_ce
        )

        # Setup callbacks
        callbacks = [google_sheet_logger, mask_callback]

        # Train for 1 epoch
        print(f"Training with batch size: {BATCH_SIZE_TRAIN}")
        print(f"Validation batch size: {BATCH_SIZE_TEST}")
        print(f"Logging to Google Sheet: {sheet_name}")

        # Configure TensorFlow to be less verbose during training
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

        history = model.fit(
            x_train, y_train,
            batch_size=BATCH_SIZE_TRAIN,
            validation_batch_size=BATCH_SIZE_TEST,
            epochs=5,
            validation_data=(x_test, y_test),
            callbacks=callbacks,
            verbose=2  # Less verbose output
        )

        # Get final metrics (includes IPA * 1000)
        final_metrics = google_sheet_logger.get_final_metrics()

        # Add results to summary sheet
        summary_sheet.append_row([
            prune_percentage,
            final_metrics['Final Test Accuracy'],
            final_metrics['IPA']
        ])

        print(f"Pruning {prune_percentage}% - Test Accuracy: {final_metrics['Final Test Accuracy']:.4f}, IPA × 1000: {final_metrics['IPA']:.6f}")

        results.append({
            'Prune Percentage': prune_percentage,
            'Final Test Accuracy': final_metrics['Final Test Accuracy'],
            'IPA': final_metrics['IPA']
        })

        # Clear memory between runs
        tf.keras.backend.clear_session()

    # Create the DataFrame from the results list
    results_df = pd.DataFrame(results)

    # Save results to CSV
    results_df.to_csv(os.path.join(output_dir, "experiment_results.csv"), index=False)

    # Plot IPA vs Pruning Percentage
    plt.figure(figsize=(10, 6))
    plt.plot(results_df['Prune Percentage'], results_df['IPA'], 'o-')
    plt.xlabel('Pruning Percentage (P%)')
    plt.ylabel('IPA × 1000')
    plt.title(f'IPA vs P% for {dataset_name} using {pruning_method} pruning')
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, "ipa_vs_prune.png"))

    # Also plot accuracy vs Pruning Percentage
    plt.figure(figsize=(10, 6))
    plt.plot(results_df['Prune Percentage'], results_df['Final Test Accuracy'], 'o-')
    plt.xlabel('Pruning Percentage (P%)')
    plt.ylabel('Final Test Accuracy')
    plt.title(f'Accuracy vs P% for {dataset_name} using {pruning_method} pruning')
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, "accuracy_vs_prune.png"))

    # Close all plots
    plt.close('all')

    return results_df

# Main execution
if __name__ == "__main__":
    # Focus only on MNIST dataset
    dataset = 'mnist'

    # Use parameter_mask pruning
    pruning_method = 'parameter_mask'

    print(f"Running DenseNet-121 pruning experiments on {dataset} dataset")
    print(f"Training batch size: {BATCH_SIZE_TRAIN}, Testing batch size: {BATCH_SIZE_TEST}")
    print(f"Pruning percentages: {PRUNE_PERCENTAGES}")
    print(f"Results will be shared with: {EMAIL}")

    # Run parameter mask pruning experiment
    print("\nStarting Parameter Mask Pruning experiment...")
    results = run_experiment(dataset, pruning_method)

    print("\nExperiment complete!")
    print(f"Results saved to results_{dataset}_{pruning_method}/experiment_results.csv")
    print(f"Plots saved to results_{dataset}_{pruning_method}/")

Running DenseNet-121 pruning experiments on mnist dataset
Training batch size: 5000, Testing batch size: 5000
Pruning percentages: [90]
Results will be shared with: lenishpandey02@gmail.com

Starting Parameter Mask Pruning experiment...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step

Creating baseline model to calculate baseline CE...
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 259ms/step - accuracy: 0.0908 - categorical_crossentropy: 2.3023 - loss: 2.3023
Baseline CE (CEo): 2.302340030670166


  summary_sheet.update('A1', [['Prune Percentage', 'Final Test Accuracy', 'IPA × 1000']])



Running experiment: mnist, parameter_mask, P% = 90
Total parameters before pruning: 7,577,674
Total trainable parameters: 7399616
Parameters to prune: 6659654 (90%)
Applied parameter mask pruning: P% = 90%


  self.worksheet.update('A1', [self.headers])


Spreadsheet URL: https://docs.google.com/spreadsheets/d/1LCi-L80tAcu2qfyrfDGcJzzbEkmec912eQBfn_uvaKY
Training with batch size: 5000
Validation batch size: 5000
Logging to Google Sheet: P90
Epoch 1/5
Started IPA calculation at batch 1 (CE = 2.3026)
12/12 - 224s - 19s/step - accuracy: 0.1135 - categorical_crossentropy: 2.3022 - loss: 2.3022 - val_accuracy: 0.1135 - val_categorical_crossentropy: 2.3022 - val_loss: 2.3022
Epoch 2/5
12/12 - 34s - 3s/step - accuracy: 0.1135 - categorical_crossentropy: 2.3024 - loss: 2.3024 - val_accuracy: 0.1135 - val_categorical_crossentropy: 2.3024 - val_loss: 2.3024
Epoch 3/5
12/12 - 37s - 3s/step - accuracy: 0.1009 - categorical_crossentropy: 2.3034 - loss: 2.3034 - val_accuracy: 0.0982 - val_categorical_crossentropy: 2.3033 - val_loss: 2.3033
Epoch 4/5
12/12 - 40s - 3s/step - accuracy: 0.1009 - categorical_crossentropy: 2.3047 - loss: 2.3047 - val_accuracy: 0.1009 - val_categorical_crossentropy: 2.3046 - val_loss: 2.3046
Epoch 5/5
12/12 - 40s - 3s/step 