In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import numpy as np
import matplotlib.pyplot as plt

# 使用与论文相同的架构
class ConvNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.last_filter_output = 2 * 2
        self.num_conv_outputs = 128 * self.last_filter_output
        self.fc1 = nn.Linear(self.num_conv_outputs, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.pool = nn.MaxPool2d(2, 2)

        self.layers = nn.ModuleList([
            self.conv1, nn.ReLU(),
            self.conv2, nn.ReLU(),
            self.conv3, nn.ReLU(),
            self.fc1, nn.ReLU(),
            self.fc2, nn.ReLU(),
            self.fc3
        ])
        self.act_type = 'relu'

    def forward(self, x):
        return self.predict(x)[0]

    def predict(self, x):
        x1 = self.pool(self.layers[1](self.layers[0](x)))
        x2 = self.pool(self.layers[3](self.layers[2](x1)))
        x3 = self.pool(self.layers[5](self.layers[4](x2)))
        x3_flat = x3.view(-1, self.num_conv_outputs)
        x4 = self.layers[7](self.layers[6](x3_flat))
        x5 = self.layers[9](self.layers[8](x4))
        x6 = self.layers[10](x5)
        return x6, [x1, x2, x3_flat, x4, x5]

    def count_dead_units(self, loader, device, threshold=1e-6):
        self.eval()
        layer_names = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2']
        dead_counts = {}
        total_units = {}

        # 初始化每层激活收集器
        activations = {name: [] for name in layer_names}

        with torch.no_grad():
            for x, _ in loader:
                x = x.to(device)
                # 前向传播中获取中间层输出
                x1 = self.pool(self.layers[1](self.layers[0](x)))  # conv1 + relu
                x2 = self.pool(self.layers[3](self.layers[2](x1)))  # conv2 + relu
                x3 = self.pool(self.layers[5](self.layers[4](x2)))  # conv3 + relu
                x3f = x3.view(x3.size(0), -1)
                x4 = self.layers[7](self.layers[6](x3f))  # fc1 + relu
                x5 = self.layers[9](self.layers[8](x4))  # fc2 + relu

                activations['conv1'].append(x1.cpu())
                activations['conv2'].append(x2.cpu())
                activations['conv3'].append(x3.cpu())
                activations['fc1'].append(x4.cpu())
                activations['fc2'].append(x5.cpu())

        # 合并并统计
        for name in layer_names:
            layer_act = torch.cat(activations[name], dim=0)
            layer_act_flat = layer_act.view(layer_act.shape[0], layer_act.shape[1], -1)  # [B, C, H*W]
            layer_max = layer_act_flat.max(dim=2).values  # [B, C]
            dead_mask = (layer_max.abs() < threshold).all(dim=0)  # 判断是否该神经元在整个 batch 中都为0
            dead_count = dead_mask.sum().item()
            total = layer_max.size(1)
            dead_counts[name] = dead_count
            total_units[name] = total

        # 打印统计信息
        print("Dead ReLU Unit Summary:")
        for name in layer_names:
            pct = 100.0 * dead_counts[name] / total_units[name]
            print(f"  Layer {name:5s} | Dead: {dead_counts[name]:3d}/{total_units[name]:3d} ({pct:.2f}%)")

        # 总体百分比
        total_dead = sum(dead_counts.values())
        total_all = sum(total_units.values())
        overall_pct = 100.0 * total_dead / total_all
        return overall_pct



# 载入数据
def get_incremental_loader(classes, train=True, batch_size=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
    ])
    dataset = torchvision.datasets.CIFAR100(root='./data', train=train, download=True, transform=transform)
    indices = [i for i, (_, label) in enumerate(dataset) if label in classes]
    subset = Subset(dataset, indices)
    return torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=train)

# 训练
def train(model, loader, optimizer, criterion, device):
    model.train()
    for epoch in range(200): 
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()


def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total


#  训练代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ConvNet(num_classes=100).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
criterion = nn.CrossEntropyLoss()

classes_seen = []
accuracies = []
dead_counts = []

for i in range(0, 100, 5):
    new_classes = list(range(i, i + 5))
    classes_seen.extend(new_classes)

    print(f"\n==> Training on classes: {classes_seen}")
    train_loader = get_incremental_loader(classes_seen, train=True)
    test_loader = get_incremental_loader(classes_seen, train=False)

    train(model, train_loader, optimizer, criterion, device)
    acc = evaluate(model, test_loader, device)

    accuracies.append(acc)
    dead_pct = model.count_dead_units(test_loader, device)
    dead_counts.append(dead_pct)

    print(f"Test Accuracy: {acc*100:.2f}% | Dead units: {dead_pct:.3f}%")

# 绘图
steps = list(range(5, 105, 5))
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(steps, accuracies, marker='o')
plt.title('Accuracy vs #Classes')
plt.xlabel('Number of Classes')
plt.ylabel('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(steps, dead_counts, marker='o', color='r')
plt.title('Dead ReLU Units vs #Classes')
plt.xlabel('Number of Classes')
plt.ylabel('Dead Units (Last Hidden Layer)')

plt.tight_layout()
plt.show()



==> Training on classes: [0, 1, 2, 3, 4]
Files already downloaded and verified
Files already downloaded and verified
Dead ReLU Unit Summary:
  Layer conv1 | Dead:   0/ 32 (0.00%)
  Layer conv2 | Dead:   7/ 64 (10.94%)
  Layer conv3 | Dead:   0/128 (0.00%)
  Layer fc1   | Dead:  26/128 (20.31%)
  Layer fc2   | Dead:  37/128 (28.91%)
Test Accuracy: 55.60% | Dead units: 14.583%

==> Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Files already downloaded and verified
Files already downloaded and verified
Dead ReLU Unit Summary:
  Layer conv1 | Dead:   0/ 32 (0.00%)
  Layer conv2 | Dead:   1/ 64 (1.56%)
  Layer conv3 | Dead:   0/128 (0.00%)
  Layer fc1   | Dead:  16/128 (12.50%)
  Layer fc2   | Dead:  33/128 (25.78%)
Test Accuracy: 48.20% | Dead units: 10.417%

==> Training on classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
Files already downloaded and verified
Files already downloaded and verified
Dead ReLU Unit Summary:
  Layer conv1 | Dead:   0/ 32 (0.00%)
  Layer conv2

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import numpy as np
import matplotlib.pyplot as plt
import random

# ----------------------------
# 模型定义
# ----------------------------
class ConvNet(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.last_filter_output = 2 * 2
        self.num_conv_outputs = 128 * self.last_filter_output
        self.fc1 = nn.Linear(self.num_conv_outputs, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, num_classes)
        self.pool = nn.MaxPool2d(2, 2)

        self.layers = nn.ModuleList([
            self.conv1, nn.ReLU(),
            self.conv2, nn.ReLU(),
            self.conv3, nn.ReLU(),
            self.fc1, nn.ReLU(),
            self.fc2, nn.ReLU(),
            self.fc3
        ])

    def forward(self, x):
        return self.predict(x)[0]

    def predict(self, x):
        x1 = self.pool(self.layers[1](self.layers[0](x)))
        x2 = self.pool(self.layers[3](self.layers[2](x1)))
        x3 = self.pool(self.layers[5](self.layers[4](x2)))
        x3_flat = x3.view(x3.size(0), -1)
        x4 = self.layers[7](self.layers[6](x3_flat))
        x5 = self.layers[9](self.layers[8](x4))
        x6 = self.layers[10](x5)
        return x6, [x1, x2, x3, x4, x5]

    def count_dead_units_all_layers(self, loader, device, threshold=1e-6):
        self.eval()
        layer_names = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2']
        dead_counts = {}
        total_units = {}
        activations = {name: [] for name in layer_names}

        with torch.no_grad():
            for x, _ in loader:
                x = x.to(device)
                _, acts = self.predict(x)
                activations['conv1'].append(acts[0].cpu())
                activations['conv2'].append(acts[1].cpu())
                activations['conv3'].append(acts[2].cpu())
                activations['fc1'].append(acts[3].cpu())
                activations['fc2'].append(acts[4].cpu())

        for name in layer_names:
            layer_act = torch.cat(activations[name], dim=0)
            layer_act_flat = layer_act.view(layer_act.shape[0], layer_act.shape[1], -1)
            layer_max = layer_act_flat.max(dim=2).values
            dead_mask = (layer_max.abs() < threshold).all(dim=0)
            dead_count = dead_mask.sum().item()
            total = layer_max.size(1)
            dead_counts[name] = dead_count
            total_units[name] = total

        print("Dead ReLU Unit Summary:")
        for name in layer_names:
            pct = 100.0 * dead_counts[name] / total_units[name]
            print(f"  Layer {name:5s} | Dead: {dead_counts[name]:3d}/{total_units[name]:3d} ({pct:.2f}%)")

        total_dead = sum(dead_counts.values())
        total_all = sum(total_units.values())
        overall_pct = 100.0 * total_dead / total_all
        return overall_pct


# ----------------------------
# 数据加载函数
# ----------------------------
def get_loader_for_classes(classes, train=True, batch_size=90):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
    ])
    dataset = torchvision.datasets.CIFAR100(root='./data', train=train, download=True, transform=transform)
    indices = [i for i, (_, label) in enumerate(dataset) if label in classes]
    subset = Subset(dataset, indices)
    return torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=train)

def get_incremental_loader(classes, train=True, batch_size=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
    ])
    dataset = torchvision.datasets.CIFAR100(root='./data', train=train, download=True, transform=transform)
    indices = [i for i, (_, label) in enumerate(dataset) if label in classes]
    subset = Subset(dataset, indices)
    return torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=train)
# ----------------------------
# 训练与评估
# ----------------------------
def train(model, loader, optimizer, criterion, device):
    model.train()
    for epoch in range(200):  # 可改为10
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()


def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total


# ----------------------------
# 主流程
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ConvNet(num_classes=100).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0)
criterion = nn.CrossEntropyLoss()

# 生成不重复的任务
all_classes = list(range(100))
random.seed(42)
random.shuffle(all_classes)
task_classes = [all_classes[i:i + 2] for i in range(0, 100, 2)]

accuracies = []
dead_percents = []

for task_id, class_group in enumerate(task_classes):
    print(f"\n==> Task {task_id+1}: classes {class_group}")
    train_loader = get_loader_for_classes(class_group, train=True)
    test_loader = get_loader_for_classes(class_group, train=False)
    # train_loader = get_incremental_loader(classes_seen, train=True)
    # test_loader = get_incremental_loader(classes_seen, train=False)


    train(model, train_loader, optimizer, criterion, device)
    acc = evaluate(model, test_loader, device)
    dead_pct = model.count_dead_units_all_layers(test_loader, device)

    accuracies.append(acc)
    dead_percents.append(dead_pct)
    print(f"Accuracy: {acc*100:.2f}% | Dead Unit %: {dead_pct:.2f}%")

# ----------------------------
# 可视化结果
# ----------------------------
steps = list(range(1, len(task_classes)+1))

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(steps, [a * 100 for a in accuracies], marker='o')
plt.title('Accuracy vs Task ID')
plt.xlabel('Task #')
plt.ylabel('Test Accuracy (%)')

plt.subplot(1, 2, 2)
plt.plot(steps, dead_percents, marker='o', color='red')
plt.title('Dead ReLU Units vs Task ID')
plt.xlabel('Task #')
plt.ylabel('Dead Units (%)')

plt.tight_layout()
plt.show()



==> Task 1: classes [42, 41]
Files already downloaded and verified
Files already downloaded and verified
Dead ReLU Unit Summary:
  Layer conv1 | Dead:   0/ 32 (0.00%)
  Layer conv2 | Dead:   0/ 64 (0.00%)
  Layer conv3 | Dead:   1/128 (0.78%)
  Layer fc1   | Dead:  14/128 (10.94%)
  Layer fc2   | Dead:  35/128 (27.34%)
Accuracy: 94.50% | Dead Unit %: 10.42%

==> Task 2: classes [91, 9]
Files already downloaded and verified
Files already downloaded and verified
Dead ReLU Unit Summary:
  Layer conv1 | Dead:   0/ 32 (0.00%)
  Layer conv2 | Dead:   0/ 64 (0.00%)
  Layer conv3 | Dead:   1/128 (0.78%)
  Layer fc1   | Dead:  11/128 (8.59%)
  Layer fc2   | Dead:  35/128 (27.34%)
Accuracy: 87.50% | Dead Unit %: 9.79%

==> Task 3: classes [65, 50]
Files already downloaded and verified
Files already downloaded and verified
Dead ReLU Unit Summary:
  Layer conv1 | Dead:   0/ 32 (0.00%)
  Layer conv2 | Dead:   0/ 64 (0.00%)
  Layer conv3 | Dead:   0/128 (0.00%)
  Layer fc1   | Dead:   9/128 (7.03%)