In [4]:
import os
import subprocess
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim

# 使用 CPU
DEVICE = torch.device('cpu')

# HiddenDataset 类定义
class HiddenDataset(Dataset):
    def __init__(self, split='train'):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)), # 调整图像大小以适应 ResNet
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        if split == 'train':
            self.dataset = datasets.MNIST(root='/Users/billdeng/PycharmProjects/machine_unlearning/data', train=True, download=True, transform=self.transform)
        else:
            self.dataset = datasets.MNIST(root='/Users/billdeng/PycharmProjects/machine_unlearning/data', train=False, download=True, transform=self.transform)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return {'image': image, 'age_group': label}

# 获取数据集
def get_dataset(batch_size):
    retain_ds = HiddenDataset(split='train')
    forget_ds = HiddenDataset(split='validation')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

# 新的 unlearning 函数
def neurogenesis_unlearning(net, retain_loader, forget_loader, val_loader, turnover_rate=0.032, turnover_frequency=640):
    epochs = 1  # 只进行一次迭代
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()

    minibatch_count = 0
    for ep in range(epochs):
        net.train()
        for batch_idx, sample in enumerate(retain_loader):
            inputs = sample["image"]
            targets = sample["age_group"]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # Neurogenesis: Turnover neurons
            if minibatch_count % turnover_frequency == 0:
                turnover_neurons(net, turnover_rate)
            
            minibatch_count += 1

            # 每 10 个批次打印一次状态信息
            if batch_idx % 10 == 0:
                print(f"Epoch {ep}, Batch {batch_idx}, Loss: {loss.item()}")

        scheduler.step()
        
    net.eval()

def turnover_neurons(net, turnover_rate):
    # Identify the layer for neurogenesis (e.g., the last fully connected layer)
    layer = net.fc
    total_neurons = layer.weight.data.shape[0]
    turnover_count = int(total_neurons * turnover_rate)

    # Select random neurons for turnover
    turnover_indices = torch.randperm(total_neurons)[:turnover_count]

    # Reinitialize weights of selected neurons
    for idx in turnover_indices:
        layer.weight.data[idx] = torch.nn.init.uniform_(layer.weight.data[idx], a=0, b=1)
        if layer.bias is not None:
            layer.bias.data[idx] = torch.nn.init.uniform_(layer.bias.data[idx], a=0, b=1)

# 创建模型保存目录
model_save_dir = '/Users/billdeng/PycharmProjects/machine_unlearning'
os.makedirs(model_save_dir, exist_ok=True)

# 训练和保存模型
retain_loader, forget_loader, validation_loader = get_dataset(64)

# 初始化 ResNet18 模型并修改第一个卷积层
net = resnet18(weights=None, num_classes=10)
net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 修改为单通道输入
net.to(DEVICE)

print("Starting training...")
neurogenesis_unlearning(net, retain_loader, forget_loader, validation_loader)
state = net.state_dict()
torch.save(state, os.path.join(model_save_dir, 'unlearned_checkpoint.pth'))
print("Training completed.")

# 打包模型检查点
subprocess.run(f'zip {os.path.join(model_save_dir, "submission.zip")} {os.path.join(model_save_dir, "*.pth")}', shell=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 34023066.75it/s]


Extracting /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/train-images-idx3-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 36552713.89it/s]

Extracting /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 16199919.41it/s]


Extracting /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4228752.22it/s]


Extracting /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/billdeng/PycharmProjects/machine_unlearning/MNIST/raw

Starting training...
Epoch 0, Batch 0, Loss: 2.3993215560913086


KeyboardInterrupt: 