In [1]:
import os
import subprocess
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim

# 使用 CPU
DEVICE = torch.device('cpu')

# HiddenDataset 类定义
class HiddenDataset(Dataset):
    def __init__(self, split='train'):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)), # 调整图像大小以适应 ResNet
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        if split == 'train':
            self.dataset = datasets.MNIST(root='/Users/macbook/Downloads/data', train=True, download=True, transform=self.transform)
        else:
            self.dataset = datasets.MNIST(root='/Users/macbook/Downloads/data', train=False, download=True, transform=self.transform)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return {'image': image, 'age_group': label}

# 获取数据集
def get_dataset(batch_size):
    retain_ds = HiddenDataset(split='train')
    forget_ds = HiddenDataset(split='validation')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

# unlearning 函数
def unlearning(net, retain_loader, forget_loader, val_loader):
    epochs = 1
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()

    for ep in range(epochs):
        net.train()
        for batch_idx, sample in enumerate(retain_loader):
            inputs = sample["image"]
            targets = sample["age_group"]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:  # 每处理 10 个批次打印一次信息
                print(f"Epoch {ep}, Batch {batch_idx}, Loss: {loss.item()}")
        
        scheduler.step()
        
    net.eval()

# 创建模型保存目录
model_save_dir = '/Users/macbook/Downloads/models'
os.makedirs(model_save_dir, exist_ok=True)

# 训练和保存模型
retain_loader, forget_loader, validation_loader = get_dataset(64)

# 初始化 ResNet18 模型并修改第一个卷积层
net = resnet18(weights=None, num_classes=10)
net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # 修改为单通道输入
net.to(DEVICE)

# 减少迭代次数
num_iterations = 20

for i in range(num_iterations):
    print(f"Starting iteration {i+1}/{num_iterations}")
    unlearning(net, retain_loader, forget_loader, validation_loader)
    state = net.state_dict()
    torch.save(state, os.path.join(model_save_dir, f'unlearned_checkpoint_{i}.pth'))
    print(f"Completed iteration {i+1}/{num_iterations}")

# 检查保存的模型数量
unlearned_ckpts = os.listdir(model_save_dir)
if len(unlearned_ckpts) != num_iterations:
    raise RuntimeError(f'Expected exactly {num_iterations} checkpoints, but found {len(unlearned_ckpts)}.')

# 打包模型检查点
subprocess.run(f'zip {os.path.join(model_save_dir, "submission.zip")} {os.path.join(model_save_dir, "*.pth")}', shell=True)


Starting iteration 1/20
Epoch 0, Batch 0, Loss: 2.3712334632873535
Epoch 0, Batch 10, Loss: 2.3384757041931152
Epoch 0, Batch 20, Loss: 2.242420196533203
Epoch 0, Batch 30, Loss: 2.134303331375122
Epoch 0, Batch 40, Loss: 2.0908191204071045
Epoch 0, Batch 50, Loss: 2.010789155960083
Epoch 0, Batch 60, Loss: 2.0256361961364746
Epoch 0, Batch 70, Loss: 1.9019819498062134
Epoch 0, Batch 80, Loss: 1.7030620574951172
Epoch 0, Batch 90, Loss: 1.7157833576202393
Epoch 0, Batch 100, Loss: 1.7372769117355347
Epoch 0, Batch 110, Loss: 1.7002419233322144
Epoch 0, Batch 120, Loss: 1.6217774152755737
Epoch 0, Batch 130, Loss: 1.5312106609344482
Epoch 0, Batch 140, Loss: 1.5232921838760376
Epoch 0, Batch 150, Loss: 1.4916285276412964
Epoch 0, Batch 160, Loss: 1.2463301420211792
Epoch 0, Batch 170, Loss: 1.1994186639785767
Epoch 0, Batch 180, Loss: 1.0945963859558105
Epoch 0, Batch 190, Loss: 1.1047077178955078
Epoch 0, Batch 200, Loss: 0.9857563972473145
Epoch 0, Batch 210, Loss: 1.015592098236084
E

Epoch 0, Batch 850, Loss: 0.08436258137226105
Epoch 0, Batch 860, Loss: 0.23008093237876892
Epoch 0, Batch 870, Loss: 0.25425252318382263
Epoch 0, Batch 880, Loss: 0.13638131320476532
Epoch 0, Batch 890, Loss: 0.08450151234865189
Epoch 0, Batch 900, Loss: 0.09500515460968018
Epoch 0, Batch 910, Loss: 0.04323992133140564
Epoch 0, Batch 920, Loss: 0.11432589590549469
Epoch 0, Batch 930, Loss: 0.033766262233257294
Completed iteration 2/20
Starting iteration 3/20
Epoch 0, Batch 0, Loss: 0.10821042209863663
Epoch 0, Batch 10, Loss: 0.10602517426013947
Epoch 0, Batch 20, Loss: 0.03781226649880409
Epoch 0, Batch 30, Loss: 0.17589157819747925
Epoch 0, Batch 40, Loss: 0.04018794745206833
Epoch 0, Batch 50, Loss: 0.06625708937644958
Epoch 0, Batch 60, Loss: 0.039134521037340164
Epoch 0, Batch 70, Loss: 0.07853473722934723
Epoch 0, Batch 80, Loss: 0.09167983382940292
Epoch 0, Batch 90, Loss: 0.0690116211771965
Epoch 0, Batch 100, Loss: 0.0745081752538681
Epoch 0, Batch 110, Loss: 0.02372818253934

KeyboardInterrupt: 