In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import numpy as np
import matplotlib.pyplot as plt
import random


# ----------------------------
# ConvNet definition
# ----------------------------
class ConvNet(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.last_filter_output = 2 * 2
        self.num_conv_outputs = 128 * self.last_filter_output
        self.fc1 = nn.Linear(self.num_conv_outputs, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.pool = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.utility = {"fc1": torch.zeros(128).to(self.device), "fc2": torch.zeros(128).to(self.device)}
        self.age = {"fc1": torch.zeros(128).to(self.device), "fc2": torch.zeros(128).to(self.device)}

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, self.num_conv_outputs)
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        out = self.fc3(h2)
        return out, h1, h2

    def cbp_update_util(self, h1, h2):
        with torch.no_grad():
            u1 = torch.abs(h1).mean(dim=0)
            u2 = torch.abs(h2).mean(dim=0)
            self.utility["fc1"] = 0.99 * self.utility["fc1"] + 0.01 * u1
            self.utility["fc2"] = 0.99 * self.utility["fc2"] + 0.01 * u2
            self.age["fc1"] += 1
            self.age["fc2"] += 1

    def cbp_reset(self, mature_age=500, reset_fraction=0.01):
        for name, layer in zip(["fc1", "fc2"], [self.fc1, self.fc2]):
            util = self.utility[name]
            age = self.age[name]
            eligible = (age > mature_age)
            if eligible.sum() == 0:
                continue
            num_reset = max(1, int(reset_fraction * eligible.sum().item()))
            to_reset = torch.topk(util[eligible], num_reset, largest=False).indices
            idx = eligible.nonzero().flatten()[to_reset]
            with torch.no_grad():
                for i in idx:
                    layer.weight[i].uniform_(-0.05, 0.05)
                    layer.bias[i].zero_()
                    self.utility[name][i] = 0
                    self.age[name][i] = 0


# ----------------------------
# Utility functions
# ----------------------------
def get_loader_for_classes(classes, train=True, batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
    ])
    dataset = torchvision.datasets.CIFAR100(root='./data', train=train, download=True, transform=transform)
    indices = [i for i, (_, label) in enumerate(dataset) if label in classes]
    subset = Subset(dataset, indices)
    return torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=train)


def shrink_and_perturb(model, shrink=0.985, noise=1e-3):
    with torch.no_grad():
        for p in model.parameters():
            p.mul_(shrink).add_(torch.randn_like(p) * noise)


def train(model, loader, optimizer, criterion, device, method='vanilla'):
    model.train()
    for epoch in range(200):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out, h1, h2 = model(x)
            loss = criterion(out, y)
            if method == 'l2':
                loss += 1e-4 * sum(p.pow(2).sum() for p in model.parameters())
            loss.backward()
            optimizer.step()
            if method == 'cbp':
                model.cbp_update_util(h1.detach(), h2.detach())
        if method == 'shrink_perturb':
            shrink_and_perturb(model)
        if method == 'cbp':
            model.cbp_reset()


def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out, _, _ = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total


def count_dead_units(model, loader, device):
    model.eval()
    activations = {"fc1": [], "fc2": []}
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            _, h1, h2 = model(x)
            activations["fc1"].append(h1.cpu())
            activations["fc2"].append(h2.cpu())
    dead_pct = {}
    for name in activations:
        act = torch.cat(activations[name], dim=0)
        dead = (act.abs() < 1e-6).all(dim=0).sum().item()
        total = act.size(1)
        dead_pct[name] = 100.0 * dead / total
    return sum(dead_pct.values()) / len(dead_pct)


# ----------------------------
# Main loop comparing methods
# ----------------------------
all_classes = list(range(100))
random.seed(42)
random.shuffle(all_classes)
tasks = [all_classes[i:i+5] for i in range(0, 100, 5)]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

results = {}

for method in ["vanilla", "l2", "shrink_perturb", "bp"]:
    print(f"\n=== Training with method: {method} ===")
    model = ConvNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    accs, deads = [], []
    for task_id, cls in enumerate(tasks):
        print(f"Task {task_id+1}, Classes: {cls}")
        train_loader = get_loader_for_classes(cls, train=True)
        test_loader = get_loader_for_classes(cls, train=False)
        train(model, train_loader, optimizer, criterion, device, method=method)
        acc = evaluate(model, test_loader, device)
        dead = count_dead_units(model, test_loader, device)
        accs.append(acc)
        deads.append(dead)
        print(f"Accuracy: {acc*100:.2f}%, Dead ReLU: {dead:.2f}%")

    results[method] = (accs, deads)

# ----------------------------
# Plotting results
# ----------------------------
tasks_x = list(range(1, 21))
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
for method in results:
    plt.plot(tasks_x, [a * 100 for a in results[method][0]], label=method)
plt.title("Accuracy vs Task")
plt.xlabel("Task ID")
plt.ylabel("Accuracy (%)")
plt.legend()

plt.subplot(1, 2, 2)
for method in results:
    plt.plot(tasks_x, results[method][1], label=method)
plt.title("Dead Units vs Task")
plt.xlabel("Task ID")
plt.ylabel("Dead ReLU Units (%)")
plt.legend()
plt.tight_layout()
plt.show()



=== Training with method: vanilla ===
Task 1, Classes: [42, 41, 91, 9, 65]
Files already downloaded and verified
Files already downloaded and verified
Accuracy: 72.60%, Dead ReLU: 4.30%
Task 2, Classes: [50, 1, 70, 15, 78]
Files already downloaded and verified
Files already downloaded and verified
Accuracy: 63.80%, Dead ReLU: 23.05%
Task 3, Classes: [73, 10, 55, 56, 72]
Files already downloaded and verified
Files already downloaded and verified
Accuracy: 64.60%, Dead ReLU: 18.36%
Task 4, Classes: [45, 48, 92, 76, 37]
Files already downloaded and verified
Files already downloaded and verified
Accuracy: 77.60%, Dead ReLU: 20.31%
Task 5, Classes: [30, 21, 32, 96, 80]
Files already downloaded and verified
Files already downloaded and verified
Accuracy: 69.80%, Dead ReLU: 21.09%
Task 6, Classes: [49, 83, 26, 87, 33]
Files already downloaded and verified
Files already downloaded and verified
Accuracy: 81.80%, Dead ReLU: 30.86%
Task 7, Classes: [8, 47, 59, 63, 74]
Files already downloaded an