In [None]:
import os
from pickle import load
import pandas as pd
from random import sample, choice, seed
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.nn.utils import prune

from sklearn.metrics import f1_score, precision_score, recall_score

from pathlib import Path

def find_repo_root(marker="setup.py"):
    path = Path.cwd()
    while not (path / marker).exists() and path != path.parent:
        path = path.parent
    return path

project_root = find_repo_root()

# Paths
INPUT_FOLDER = project_root/"data/cnn/input/"
MODEL_PATH = INPUT_FOLDER + "model.pth"
OUTPUT_FOLDER = project_root/"data/cnn/output/"
RESULTS_FOLDER = project_root/"data/cnn/results/"
FIGURES_FOLDER = project_root/"data/figures/cnn/"

In [None]:
# MultiPEC results

class_best_maxmean_map = {
    "0 - zero": 0,
    "1 - one":   63.327534,
    "2 - two":   16.665190,
    "3 - three":  1.334984,
    "4 - four":  22.349010,
    "5 - five":   8.334732,
    "7 - seven":  4.728566,
    "8 - eight": 54.599807,
    "9 - nine":  49.677487
}

In [None]:
# Set random seed
torch.backends.cudnn.enabled = False
torch.manual_seed(SEED)
seed(42)

transform = T.Compose([T.ToTensor(),
                                 T.Normalize((0.1307,), (0.3081,))])

trainset = MNIST('/files/', train=True, download=True, transform=transform)

testset = MNIST('/files/', train=False, download=True, transform=transform)

IDX_TO_LABEL = {v: k for k, v in testset.class_to_idx.items()}
n_classes = len(IDX_TO_LABEL)

# number of subprocesses to use for data loading
NUM_WORKERS = 2

def data_loaders(trainset, testset, trainsize, testsize):
    trainloader = DataLoader(
    trainset, batch_size=trainsize,
    num_workers=NUM_WORKERS
    )
    testloader = DataLoader(
        testset, batch_size=testsize,
        num_workers=NUM_WORKERS, shuffle=True
    )
    return trainloader, testloader

def test_accuracy(model, testloader):
    correct = 0
    f1_list, precision_list, recall_list = [],[],[]
    labels_sum, output_sum = [], []
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        model.eval()
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)

            # calculate outputs by running images through the modelwork
            outputs = model(images)

            # the class with the highest energy is what we choose as prediction
            predicted = outputs.data.max(1, keepdim=True)[1]

            correct += predicted.eq(labels.data.view_as(predicted)).sum()

            labels_sum += labels
            output_sum += predicted

    accuracy = correct / len(testloader.dataset)
    f1 = f1_score(torch.Tensor(labels_sum).to("cpu"), torch.Tensor(output_sum).to("cpu"), average=None, zero_division=0)
    precision = precision_score(torch.Tensor(labels_sum).to("cpu"), torch.Tensor(output_sum).to("cpu"), average=None, zero_division=0)
    recall = recall_score(torch.Tensor(labels_sum).to("cpu"), torch.Tensor(output_sum).to("cpu"), average=None, zero_division=0)

    return accuracy, f1, precision, recall


def test_accuracy_per_class(model, testloader):
    correct_pred = {classname: 0 for classname in trainset.classes}
    total_pred = {classname: 0 for classname in trainset.classes}

    with torch.no_grad():
        model.eval()
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            predicted = outputs.data.max(1, keepdim=True)[1]

            # collect the correct predictions for each class
            for label, prediction in zip(labels, predicted):
                if label == prediction:
                    correct_pred[trainset.classes[label]] += 1
                total_pred[trainset.classes[label]] += 1

    accuracy_per_class = {classname: 0 for classname in trainset.classes}
    for classname, correct_count in correct_pred.items():
        accuracy = (100 * float(correct_count)) / total_pred[classname]
        accuracy_per_class[classname] = accuracy

    return accuracy_per_class

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)


# Load unpruned model once and get baseline class accuracy
_, testloader = data_loaders(trainset, testset, BATCH_SIZE_TRAIN, BATCH_SIZE_TEST)
model = Net().to(device)
model.load_state_dict(torch.load(MODEL_PATH))
baseline_class_accuracies = test_accuracy_per_class(model, testloader)


def compute_maxmean_class_specificity(deltas: dict, class_name: str):
    others = [abs(v) for k, v in deltas.items() if k != class_name]
    if not others:
        return 0.0
    max_val = max(others)
    mean_val = sum(others) / len(others)
    return max_val - mean_val


# Track which classes have matched their target specificity
class_success_flags = {cls: False for cls in class_best_maxmean_map}
class_success_iters = {cls: None for cls in class_best_maxmean_map}

# Track all iterations
MAX_GLOBAL_ITERATIONS = 1000
detailed_results = []

# Define random networks
load_subnets = load(open(OUTPUT_FOLDER+f"nets.p", "rb"))

min_group_size = min([len(x[0]) for x in load_subnets])
max_group_size = max([len(x[0]) for x in load_subnets])

size_options = list(range(min_group_size,max_group_size,1))
print(size_options)

nodes = list(range(30))

rd_runs = MAX_GLOBAL_ITERATIONS
rd_subnets = [sample(nodes, choice(size_options)) for i in range(rd_runs)]

print(f"Net: MNIST model (2 layers)")
print("Number of subnets:", len(rd_subnets))

for iteration_idx, subnet in enumerate(rd_subnets):
    if iteration_idx > MAX_GLOBAL_ITERATIONS:
        print("Reached global iteration limit.")
        break
    if all(class_success_flags.values()):
        print("All class specificities reached!")
        break

    print(f"\nIteration {iteration_idx}")
    torch.cuda.empty_cache()
    _, testloader = data_loaders(trainset, testset, BATCH_SIZE_TRAIN, BATCH_SIZE_TEST)

    # Load and optionally prune model
    model = Net().to(device)
    model.load_state_dict(torch.load(MODEL_PATH))

    if subnet:
        conv1_part = [node for node in subnet if node < 10]
        conv2_part = [node - 10 for node in subnet if node >= 10]

        mask1 = torch.ones(model.conv1.weight.shape).to(device)
        for kernel in conv1_part:
            mask1[kernel] = 0
        prune.custom_from_mask(model.conv1, 'weight', mask1)

        mask2 = torch.ones(model.conv2.weight.shape).to(device)
        for kernel in conv2_part:
            mask2[kernel] = 0
        prune.custom_from_mask(model.conv2, 'weight', mask2)

    # Evaluate
    overall_accuracy, f1, precision, recall = test_accuracy(model, testloader)
    pruned_class_accuracies = test_accuracy_per_class(model, testloader)

    # Compute Δs for all classes
    deltas = {
        cname: pruned_class_accuracies.get(cname, 0) - baseline_class_accuracies.get(cname, 0)
        for cname in IDX_TO_LABEL.values()
    }

    # Compute ratios for all classes
    specificity_per_class = {
        cname: compute_maxmean_class_specificity(deltas, cname)
        for cname in class_best_maxmean_map
    }

    # Identify most specific class this subnet targets
    specific_class = max(specificity_per_class, key=specificity_per_class.get)
    spec_value = specificity_per_class[specific_class]
    target_spec = class_best_maxmean_map[specific_class]
    is_success = spec_value >= target_spec and not class_success_flags[specific_class]

    # Mark class as matched if applicable
    if is_success:
        class_success_flags[specific_class] = True
        class_success_iters[specific_class] = iteration_idx
        print(f"Matched class specificity for '{specific_class}' at iteration {iteration_idx} ({spec_value:.4f})")

    # Record the iteration result
    detailed_results.append({
        "Iteration": iteration_idx,
        "Subnet": subnet,
        "Specific_Class": specific_class,
        "Specificity": spec_value,
        "Target_Specificity": target_spec,
        "Successful": is_success
    })


 # Save full iteration log
detailed_df = pd.DataFrame(detailed_results)
detailed_df.to_excel(os.path.join(RESULTS_FOLDER, "global_random_pruning_log.xlsx"), index=False)

# Summary
summary_df = pd.DataFrame([
    {"Class": cls, "Reached": success, "Iteration": class_success_iters[cls]}
    for cls, success in class_success_flags.items()
])
summary_df.to_excel(os.path.join(RESULTS_FOLDER, "class_specificity_summary.xlsx"), index=False)

print("\nSummary:")
print(summary_df)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
[4, 5, 6, 7]
Net: MNIST model (2 layers)
Number of subnets: 501

=== Iteration 1 ===

=== Iteration 2 ===
Matched class specificity for '2 - two' at iteration 2 (4.5239)

=== Iteration 3 ===

=== Iteration 4 ===
Matched class specificity for '5 - five' at iteration 4 (3.9343)

=== Iteration 5 ===

=== Iteration 6 ===

=== Iteration 7 ===

=== Iteration 8 ===

=== Iteration 9 ===

=== Iteration 10 ===
Matched class specificity for '4 - four' at iteration 10 (4.9573)

=== Iteration 11 ===

=== Iteration 12 ===

=== Iteration 13 ===

=== Iteration 14 ===

=== Iteration 15 ===

=== Iteration 16 ===

=== Iteration 17 ===

=== Iteration 18 ===

=== Iteration 19 ===

=== Iteration 20 ===

=== Iteration 21 ===

=== Iteration 22 ===

=== Iteration 23 ===

=== Iteration 24 ===

=== Iteration 25 ===

=== Iteration 26 ===

=== Iteration 27 ===

=== Iteration 28 ===

==

In [None]:
# MultiPEC Result
multipec_data = [
    {"Class": "0 - zero",  "MultiPEC Net": (), "MaxMean": 0.0},
    {"Class": "1 - one",   "MultiPEC Net": (20, 8, 9, 6), "MaxMean": 63.327534},
    {"Class": "2 - two",   "MultiPEC Net": (7, 6, 9, 4, 24, 5, 0), "MaxMean": 16.665190},
    {"Class": "3 - three", "MultiPEC Net": (0, 12, 21), "MaxMean": 1.334984},
    {"Class": "4 - four",  "MultiPEC Net": (6, 5, 9, 2), "MaxMean": 22.349010},
    {"Class": "5 - five",  "MultiPEC Net": (0, 9, 1, 18), "MaxMean": 8.334732},
    {"Class": "7 - seven", "MultiPEC Net": (12, 24, 2, 25, 18, 3, 7, 11, 6, 9), "MaxMean": 4.728566},
    {"Class": "8 - eight", "MultiPEC Net": (4, 2, 9, 6, 3, 5), "MaxMean": 54.599807},
    {"Class": "9 - nine",  "MultiPEC Net": (0, 5, 9, 29, 4, 7, 8, 2, 3), "MaxMean": 49.677487}
]

multipec_df = pd.DataFrame(multipec_data)
multipec_df["Class"] = multipec_df["Class"].astype(str)
multipec_df["MultiPEC Net"] = multipec_df["MultiPEC Net"].apply(lambda x: str(x))
multipec_df["N(Nodes)"] = multipec_df["MultiPEC Net"].apply(lambda x: len(x.split(", ")) if x != "()" else 0)
multipec_df.rename(columns={"MaxMean": "MultiPEC CS"}, inplace=True)
multipec_df["MultiPEC CS"] = multipec_df["MultiPEC CS"].apply(lambda x: f"{x:.2f}")

# Filter only successful results
detailed_df = pd.read_excel(os.path.join(RESULTS_FOLDER, "global_random_pruning_log.xlsx"))
detailed_df['Successful'] = detailed_df['Successful'].astype(str).str.upper() == 'TRUE'
random_best = detailed_df[detailed_df['Successful']].copy()

random_best["Class"] = random_best["Specific_Class"].astype(str)
random_best.rename(columns={"Subnet": "Random Net"}, inplace=True)
random_best["Random Net"] = random_best["Random Net"].apply(lambda x: f"({x[1:-1]})")
random_best.rename(columns={"Specificity": "Random CS"}, inplace=True)
random_best["Random CS"] = random_best["Random CS"].apply(lambda x: f"{x:.2f}")

comparison_df = pd.merge(
    multipec_df,
    random_best,
    on="Class",
    how="inner"
)

import matplotlib.pyplot as plt
import textwrap

table_df = comparison_df.drop(columns=["Specific_Class", "Target_Specificity", "Successful"])

# Wrap long strings in subnet columns
def wrap_text(s, width=25):
    return "\n".join(textwrap.wrap(str(s), width=width))

table_df["MultiPEC Net"] = table_df["MultiPEC Net"].apply(lambda x: wrap_text(x))
table_df["Random Net"] = table_df["Random Net"].apply(lambda x: wrap_text(x))

# Plot table
fig, ax = plt.subplots(figsize=(22, len(table_df) * 0.8))
ax.axis('off')

table = ax.table(
    cellText=table_df.values,
    colLabels=table_df.columns,
    cellLoc='center',
    loc='center'
)

table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)

for key, cell in table.get_celld().items():
    if key[0] == 0:
        cell.set_fontsize(11)
        cell.set_text_props(weight='bold')
        cell.set_facecolor('#cccccc')

plt.title("MultiPEC vs Random Pruning Class Specificity", fontsize=14, weight='bold')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.lines as mlines

# Convert CS back to float
multipec_df["MultiPEC CS"] = multipec_df["MultiPEC CS"].astype(float)

# Define desired class order explicitly (0 through 9)
ordered_classes = [
    "0 - zero", "1 - one", "2 - two", "3 - three", "4 - four",
    "5 - five", "6 - six", "7 - seven", "8 - eight", "9 - nine"
]

# Filter classes present in detailed_df (some classes might be missing)
class_order = [cls for cls in ordered_classes if cls in detailed_df['Specific_Class'].unique()]

plt.figure(figsize=(8, 7))

# Boxplot data grouped by ordered class list
box = plt.boxplot(
    [detailed_df[detailed_df['Specific_Class'] == cls]['Specificity'] for cls in class_order],
    tick_labels=class_order,
    patch_artist=True
)
# Set median line color
for median in box['medians']:
    median.set_color('black')

# Color boxplots light gray
for patch in box['boxes']:
    patch.set_facecolor('#d3d3d3')  # light gray

# Plot red 'X' for MultiPEC CS on top
for i, cls in enumerate(class_order):
    match = multipec_df[multipec_df['Class'] == cls]
    if not match.empty:
        cs_value = match["MultiPEC CS"].values[0]
        plt.scatter(i + 1, cs_value, color='red', marker='x', s=100, linewidths=3, zorder=10)

# Custom legend entry for red 'X'
red_x = mlines.Line2D([], [], color='red', marker='x', linestyle='None',
                      markersize=12, label='MultiPEC')
plt.legend(handles=[red_x], fontsize=14)

# Labels and title with bigger fonts
plt.xlabel('Class', fontsize=16)
plt.ylabel('Class Specificity', fontsize=16)
plt.title('Random Pruning: Class Specificity', fontsize=18)

plt.xticks(rotation=45, fontsize=14)
plt.yticks(fontsize=14)
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_FOLDER, "Random_pruning_specificity.png"), bbox_inches='tight', dpi=300)
plt.show()
