## IID

### retrain

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
from utils import label_to_onehot, cross_entropy_for_onehot
from models.vision import weights_init, LeNet
num_classes=10

net = LeNet(num_classes).to("cuda")
compress_rate = 1.0
torch.manual_seed(1234)
net.apply(weights_init)
criterion = cross_entropy_for_onehot


def federated_train(global_model, client_loaders, criterion, num_rounds=10, num_local_epochs=1, lr=0.001):
    """联邦训练函数(FedAvg)"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global_model.to(device)
    
    for round in range(num_rounds):
        print(f"Communication Round {round+1}/{num_rounds}")
        client_models = []
        
        # 训练所有客户端
        for client_id, loader in enumerate(client_loaders):
            # 克隆全局模型
            local_model = LeNet(num_classes=10)
            local_model.load_state_dict(global_model.state_dict())
            local_model.to(device)
            optimizer = optim.Adam(local_model.parameters(), lr=lr)
            
            # 本地训练
            local_model.train()
            for _ in range(num_local_epochs):
                for images, labels in loader:
                    images, labels = images.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = local_model(images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
            
            # 保存客户端模型参数
            client_models.append(local_model.state_dict())
        
        # 参数平均（FedAvg）
        global_dict = global_model.state_dict()
        for key in global_dict.keys():
            global_dict[key] = torch.stack(
                [client_models[i][key].float() for i in range(len(client_models))], 0
            ).mean(0)
        global_model.load_state_dict(global_dict)
    
    return global_model

#### test固定样本和数据

In [5]:
import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# 设置全局随机种子保证可重复性
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# 数据加载配置
transform = transforms.Compose([transforms.ToTensor()])
dst_train = datasets.CIFAR10(
    root="~/.torch", 
    train=True, 
    download=True, 
    transform=transform
)

# 定义客户端数量和遗忘参数
CLIENT_NUM = 4
FORGOTTEN_CLIENT_IDX = 3  # 要遗忘的客户端索引
FORGET_SIZE = 1000        # 固定遗忘样本数


# 固定划分客户端数据（使用确定性的随机划分）
client_datasets = torch.utils.data.random_split(
    dst_train,
    [len(dst_train)//CLIENT_NUM]*CLIENT_NUM,
    generator=torch.Generator().manual_seed(SEED)  # 固定划分随机种子
)

# 获取目标客户端原始数据索引
target_dataset = client_datasets[FORGOTTEN_CLIENT_IDX]
original_indices = target_dataset.indices.copy()  # 原始索引列表

# 确定性地选择前N个样本作为遗忘集（方法1：绝对位置）
fixed_forgotten_indices = sorted(original_indices)[:FORGET_SIZE]  # 按原始顺序取前1000

# 更新客户端数据集划分
remaining_indices = list(set(original_indices) - set(fixed_forgotten_indices))
client_datasets[FORGOTTEN_CLIENT_IDX] = Subset(dst_train, remaining_indices)

# 创建遗忘数据集加载器
forgotten_dataset = Subset(dst_train, fixed_forgotten_indices)
forgotten_loader = DataLoader(
    forgotten_dataset, 
    batch_size=128, 
    shuffle=False
)

# 创建客户端加载器（包含更新后的数据集）
client_loaders = [
    DataLoader(
        ds, 
        batch_size=128, 
        shuffle=True,  # 训练时保持shuffle但随机种子固定
        generator=torch.Generator().manual_seed(SEED))
    for ds in client_datasets ]

# 验证固定遗忘样本
def verify_fixed_samples():
    # 第一次运行获取样本特征
    first_run_samples = []
    for batch in forgotten_loader:
        first_run_samples.append(batch[0].sum().item())
    first_sum = sum(first_run_samples)
    
    # 第二次运行应该完全相同
    second_run_samples = []
    for batch in forgotten_loader:
        second_run_samples.append(batch[0].sum().item())
    
    assert np.allclose(first_sum, sum(second_run_samples)), "样本不固定!"
    print("验证通过：遗忘数据集样本保持固定")

verify_fixed_samples()

验证通过：遗忘数据集样本保持固定


In [3]:
client_batch_size = 128

print("联邦训练完整模型...")
# 初始化全局模型
full_net = LeNet(num_classes=10)
criterion = nn.CrossEntropyLoss()
global_round = 20

full_net = federated_train(
    full_net,
    client_loaders,  # 包含调整后的客户端3数据
    criterion,
    num_rounds=global_round,
    num_local_epochs=10,
    lr=0.001
)

# 未学习模型训练（从原始客户端加载器重建）
# 需要重新加载原始客户端数据（排除遗忘样本）
modified_client_loaders = [
    DataLoader(
        ds if idx != FORGOTTEN_CLIENT_IDX else Subset(ds.dataset, remaining_indices),
        batch_size=client_batch_size,
        shuffle=True
    )
    for idx, ds in enumerate(client_datasets)
]

unlearned_net = LeNet(num_classes=10)
print("federated unlearning training...")
unlearned_net = federated_train(
    unlearned_net,
    modified_client_loaders,  # 使用排除遗忘样本的加载器
    criterion,
    num_rounds=global_round,
    num_local_epochs=10,
    lr=0.001
)

# 保存模型
torch.save(full_net.state_dict(), "federated_full_sample_1000_round_20_partial.pth")
torch.save(unlearned_net.state_dict(), "federated_unlearned_sample_1000_round_20_partial.pth")

联邦训练完整模型...
Communication Round 1/20
Communication Round 2/20
Communication Round 3/20
Communication Round 4/20
Communication Round 5/20
Communication Round 6/20
Communication Round 7/20
Communication Round 8/20
Communication Round 9/20
Communication Round 10/20
Communication Round 11/20
Communication Round 12/20
Communication Round 13/20
Communication Round 14/20
Communication Round 15/20
Communication Round 16/20
Communication Round 17/20
Communication Round 18/20
Communication Round 19/20
Communication Round 20/20
federated unlearning training...
Communication Round 1/20
Communication Round 2/20
Communication Round 3/20
Communication Round 4/20
Communication Round 5/20
Communication Round 6/20
Communication Round 7/20
Communication Round 8/20
Communication Round 9/20
Communication Round 10/20
Communication Round 11/20
Communication Round 12/20
Communication Round 13/20
Communication Round 14/20
Communication Round 15/20
Communication Round 16/20
Communication Round 17/20
Communicati

### 梯度上升

In [None]:
# 定义 criterion
criterion = nn.CrossEntropyLoss()

# 在 federated_unlearning_gradient_ascent 中直接使用
def federated_unlearning_gradient_ascent(global_model, forgotten_loader, client_loaders, criterion, num_rounds=5, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global_model.to(device)
    
    optimizer = optim.Adam(global_model.parameters(), lr=lr)
    
    print("Performing gradient ascent on forgotten data...")
    global_model.train()
    for _ in tqdm(range(num_rounds)):
        for images, labels in forgotten_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = global_model(images)
            loss = criterion(outputs, labels)  # 直接使用类别索引
            loss.backward()
            for param in global_model.parameters():
                if param.grad is not None:
                    param.data += lr * param.grad.data  # 梯度上升
            optimizer.step()
        print("loss:",loss)
    
    print("Fine-tuning on remaining data...")
    global_model = federated_train(
        global_model,
        client_loaders,
        criterion,
        num_rounds=5,
        num_local_epochs=1,
        lr=0.001
    )
    
    return global_model

client_batch_size = 128

# print("联邦训练完整模型...")
# # 初始化全局模型
# full_net = LeNet(num_classes=10)
# criterion = nn.CrossEntropyLoss()
# global_round = 20

# full_net = federated_train(
#     full_net,
#     client_loaders,  # 包含调整后的客户端3数据
#     criterion,
#     num_rounds=global_round,
#     num_local_epochs=10,
#     lr=0.001
# )

# # 未学习模型训练（从原始客户端加载器重建）
# # 需要重新加载原始客户端数据（排除遗忘样本）
# modified_client_loaders = [
#     DataLoader(
#         ds if idx != FORGOTTEN_CLIENT_IDX else Subset(ds.dataset, remaining_indices),
#         batch_size=client_batch_size,
#         shuffle=True
#     )
#     for idx, ds in enumerate(client_datasets)
# ]

# torch.save(full_net.state_dict(), "federated_full_sample_1000_round_20_partial.pth")
full_net = LeNet(num_classes=10)
full_model_path = "/home/ecs-user/fgi/federated_weight/federated_full_sample_1000_round_20_partial.pth"
print(f"Found existing full model at '{full_model_path}', loading weights...")
full_net.load_state_dict(torch.load(full_model_path))

unlearned_net_ga = LeNet(num_classes=10)
unlearned_net_ga.load_state_dict(full_net.state_dict())  # 从完整模型开始
unlearned_net_ga = federated_unlearning_gradient_ascent(
    unlearned_net_ga,
    forgotten_loader,
    client_loaders,
    criterion,
    num_rounds=150,
    lr=0.001
)
torch.save(unlearned_net_ga.state_dict(), "efficient_federated_unlearned_gradient_sample_1000_round_20.pth")

## Non-iid(取前五个类别进行训练)


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
from utils import label_to_onehot, cross_entropy_for_onehot
from models.vision import weights_init, LeNet
num_classes=5

net = LeNet(num_classes).to("cuda")
compress_rate = 1.0
torch.manual_seed(1234)
net.apply(weights_init)
criterion = cross_entropy_for_onehot


def federated_train(global_model, client_loaders, criterion, num_rounds=10, num_local_epochs=1, lr=0.001):
    """联邦训练函数(FedAvg)"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global_model.to(device)
    
    for round in range(num_rounds):
        print(f"Communication Round {round+1}/{num_rounds}")
        client_models = []
        
        # 训练所有客户端
        for client_id, loader in enumerate(client_loaders):
            # 克隆全局模型
            local_model = LeNet(num_classes=5)
            local_model.load_state_dict(global_model.state_dict())
            local_model.to(device)
            optimizer = optim.Adam(local_model.parameters(), lr=lr)
            
            # 本地训练
            local_model.train()
            for _ in range(num_local_epochs):
                for images, labels in loader:
                    images, labels = images.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = local_model(images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
            
            # 保存客户端模型参数
            client_models.append(local_model.state_dict())
        
        # 参数平均（FedAvg）
        global_dict = global_model.state_dict()
        for key in global_dict.keys():
            global_dict[key] = torch.stack(
                [client_models[i][key].float() for i in range(len(client_models))], 0
            ).mean(0)
        global_model.load_state_dict(global_dict)
    
    return global_model

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from collections import defaultdict
from utils import label_to_onehot, cross_entropy_for_onehot
from models.vision import weights_init, LeNet
num_classes=5

net = LeNet(num_classes).to("cuda")
compress_rate = 1.0
torch.manual_seed(1234)
net.apply(weights_init)
criterion = cross_entropy_for_onehot

# 设置全局随机种子保证可重复性
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# 数据加载配置
transform = transforms.Compose([transforms.ToTensor()])
dst_train = datasets.CIFAR10(
    root="~/.torch", 
    train=True, 
    download=True, 
    transform=transform
)

# 选择前5个类别
selected_classes = list(range(5))

# 筛选出前5个类别的数据
class_indices = defaultdict(list)
for idx, (_, label) in enumerate(dst_train):
    if label in selected_classes:
        class_indices[label].append(idx)

# 合并前5个类别的所有数据索引
selected_indices = []
for label in selected_classes:
    selected_indices.extend(class_indices[label])

# 使用筛选后的数据索引创建新的数据集
filtered_dataset = Subset(dst_train, selected_indices)

# 定义客户端数量和遗忘参数
CLIENT_NUM = 4
FORGOTTEN_CLIENT_IDX = 3  # 要遗忘的客户端索引
FORGET_SIZE = 100        # 固定遗忘样本数

# 固定划分客户端数据（使用确定性的随机划分）
client_datasets = torch.utils.data.random_split(
    filtered_dataset,
    [len(filtered_dataset)//CLIENT_NUM]*CLIENT_NUM,
    generator=torch.Generator().manual_seed(SEED)  # 固定划分随机种子
)

# 获取目标客户端原始数据索引
target_dataset = client_datasets[FORGOTTEN_CLIENT_IDX]
original_indices = target_dataset.indices.copy()  # 原始索引列表

# 确定性地选择前N个样本作为遗忘集（方法1：绝对位置）
fixed_forgotten_indices = sorted(original_indices)[:FORGET_SIZE]  # 按原始顺序取前256

# 更新客户端数据集划分
remaining_indices = list(set(original_indices) - set(fixed_forgotten_indices))
client_datasets[FORGOTTEN_CLIENT_IDX] = Subset(filtered_dataset, remaining_indices)

# 创建遗忘数据集加载器
forgotten_dataset = Subset(filtered_dataset, fixed_forgotten_indices)
forgotten_loader = DataLoader(
    forgotten_dataset, 
    batch_size=32, 
    shuffle=False
)

# 创建客户端加载器（包含更新后的数据集）
client_loaders = [
    DataLoader(
        ds, 
        batch_size=32, 
        shuffle=True,  # 训练时保持shuffle但随机种子固定
        generator=torch.Generator().manual_seed(SEED))
    for ds in client_datasets
]

# 验证固定遗忘样本
def verify_fixed_samples():
    # 第一次运行获取样本特征
    first_run_samples = []
    for batch in forgotten_loader:
        first_run_samples.append(batch[0].sum().item())
    first_sum = sum(first_run_samples)
    
    # 第二次运行应该完全相同
    second_run_samples = []
    for batch in forgotten_loader:
        second_run_samples.append(batch[0].sum().item())
    
    assert np.allclose(first_sum, sum(second_run_samples)), "样本不固定!"
    print("验证通过：遗忘数据集样本保持固定")

verify_fixed_samples()

# 选择参与训练的类别（前5个类别）
def filter_classes_for_training(data, selected_classes):
    return [item for item in data if item[1] in selected_classes]

# Filter data for training (learning and unlearning)
filtered_train_data = filter_classes_for_training(dst_train, selected_classes)

# 更新客户端加载器（仅包含前5个类别的数据）
modified_client_loaders = [
    DataLoader(
        Subset(filtered_train_data, ds.indices),
        batch_size=128,
        shuffle=True
    )
    for idx, ds in enumerate(client_datasets)
]

# 联邦训练
print("联邦训练完整模型...")
full_net = LeNet(num_classes=5)  # 更新类别数为5
criterion = nn.CrossEntropyLoss()
global_round = 2

full_net = federated_train(
    full_net,
    modified_client_loaders,  # 包含前5个类别数据的客户端加载器
    criterion,
    num_rounds=global_round,
    num_local_epochs=1,
    lr=0.001
)

# 未学习模型训练（从原始客户端加载器重建）
# 需要重新加载原始客户端数据（排除遗忘样本）
modified_client_loaders_unlearning = [
    DataLoader(
        ds if idx != FORGOTTEN_CLIENT_IDX else Subset(ds.dataset, remaining_indices),
        batch_size=128,
        shuffle=True
    )
    for idx, ds in enumerate(client_datasets)
]

unlearned_net = LeNet(num_classes=5)
print("federated unlearning training...")
unlearned_net = federated_train(
    unlearned_net,
    modified_client_loaders_unlearning,  # 使用排除遗忘样本的加载器
    criterion,
    num_rounds=global_round,
    num_local_epochs=1,
    lr=0.001
)

# 保存模型
torch.save(full_net.state_dict(), "/home/ecs-user/fgi/federated_weight/noniid_federated_full_sample_100_round_2_partial.pth")
torch.save(unlearned_net.state_dict(), "/home/ecs-user/fgi/federated_weight/noniid_federated_unlearned_sample_100_round_2_partial.pth")


验证通过：遗忘数据集样本保持固定
联邦训练完整模型...
Communication Round 1/2
Communication Round 2/2
federated unlearning training...
Communication Round 1/2
Communication Round 2/2


In [15]:
selected_classes = list(range(5, 10))


[5, 6, 7, 8, 9]

## Client level

In [None]:
# 数据准备
transform = transforms.Compose([transforms.ToTensor()])
dst_train = datasets.CIFAR10("~/.torch", train=True, download=True, transform=transform)
dst_test = datasets.CIFAR10("~/.torch", train=False, download=True, transform=transform)

# 划分训练数据到4个客户端（IID划分）
client_num = 4
client_datasets = torch.utils.data.random_split(
    dst_train, 
    [len(dst_train)//client_num]*client_num
)

# 创建客户端DataLoader
client_batch_size = 32
client_loaders = [
    DataLoader(ds, batch_size=client_batch_size, shuffle=True) 
    for ds in client_datasets
]

# 初始化全局模型
full_net = LeNet(num_classes=10)
criterion = nn.CrossEntropyLoss()

# 联邦训练（完整数据）
print("联邦训练全局模型...")
full_net = federated_train(
    full_net,
    client_loaders,
    criterion,
    num_rounds=50,      # 通信轮次
    num_local_epochs=1, # 每个客户端本地训练epoch数
    lr=0.001
)
torch.save(full_net.state_dict(), "federated_full_model.pth")

# 联邦遗忘（排除第3个客户端）
forgotten_client_idx = 3
remaining_loaders = [loader for idx, loader in enumerate(client_loaders) if idx != forgotten_client_idx]

# 初始化新全局模型
unlearned_net = LeNet(num_classes=10)

print("联邦未学习训练...")
unlearned_net = federated_train(
    unlearned_net,
    remaining_loaders,
    criterion,
    num_rounds=50,
    num_local_epochs=1,
    lr=0.001
)
torch.save(unlearned_net.state_dict(), "federated_unlearned_model.pth")

## Resnet

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
from utils import label_to_onehot, cross_entropy_for_onehot
from models.vision import weights_init, LeNet
from models.resnet import resnet20
num_classes=10

net = resnet20(num_classes).to("cuda")
compress_rate = 1.0
torch.manual_seed(1234)
net.apply(weights_init)
criterion = cross_entropy_for_onehot


def federated_train(global_model, client_loaders, criterion, num_rounds=10, num_local_epochs=1, lr=0.001):
    """联邦训练函数(FedAvg)"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    global_model.to(device)
    
    for round in range(num_rounds):
        print(f"Communication Round {round+1}/{num_rounds}")
        client_models = []
        
        # 训练所有客户端
        for client_id, loader in enumerate(client_loaders):
            # 克隆全局模型
            local_model = resnet20(num_classes=10)
            local_model.load_state_dict(global_model.state_dict())
            local_model.to(device)
            optimizer = optim.Adam(local_model.parameters(), lr=lr)
            
            # 本地训练
            local_model.train()
            for _ in range(num_local_epochs):
                for images, labels in loader:
                    images, labels = images.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = local_model(images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
            
            # 保存客户端模型参数
            client_models.append(local_model.state_dict())
        
        # 参数平均（FedAvg）
        global_dict = global_model.state_dict()
        for key in global_dict.keys():
            global_dict[key] = torch.stack(
                [client_models[i][key].float() for i in range(len(client_models))], 0
            ).mean(0)
        global_model.load_state_dict(global_dict)
    
    return global_model

In [6]:
import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# 设置全局随机种子保证可重复性
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# 数据加载配置
transform = transforms.Compose([transforms.ToTensor()])
dst_train = datasets.CIFAR10(
    root="~/.torch", 
    train=True, 
    download=True, 
    transform=transform
)

# 定义客户端数量和遗忘参数
CLIENT_NUM = 4
FORGOTTEN_CLIENT_IDX = 3  # 要遗忘的客户端索引
FORGET_SIZE = 256        # 固定遗忘样本数


# 固定划分客户端数据（使用确定性的随机划分）
client_datasets = torch.utils.data.random_split(
    dst_train,
    [len(dst_train)//CLIENT_NUM]*CLIENT_NUM,
    generator=torch.Generator().manual_seed(SEED)  # 固定划分随机种子
)

# 获取目标客户端原始数据索引
target_dataset = client_datasets[FORGOTTEN_CLIENT_IDX]
original_indices = target_dataset.indices.copy()  # 原始索引列表

# 确定性地选择前N个样本作为遗忘集（方法1：绝对位置）
fixed_forgotten_indices = sorted(original_indices)[:FORGET_SIZE]  # 按原始顺序取前1000

# 更新客户端数据集划分
remaining_indices = list(set(original_indices) - set(fixed_forgotten_indices))
client_datasets[FORGOTTEN_CLIENT_IDX] = Subset(dst_train, remaining_indices)

# 创建遗忘数据集加载器
forgotten_dataset = Subset(dst_train, fixed_forgotten_indices)
forgotten_loader = DataLoader(
    forgotten_dataset, 
    batch_size=32, 
    shuffle=False
)

# 创建客户端加载器（包含更新后的数据集）
client_loaders = [
    DataLoader(
        ds, 
        batch_size=32, 
        shuffle=True,  # 训练时保持shuffle但随机种子固定
        generator=torch.Generator().manual_seed(SEED))
    for ds in client_datasets ]

# 验证固定遗忘样本
def verify_fixed_samples():
    # 第一次运行获取样本特征
    first_run_samples = []
    for batch in forgotten_loader:
        first_run_samples.append(batch[0].sum().item())
    first_sum = sum(first_run_samples)
    
    # 第二次运行应该完全相同
    second_run_samples = []
    for batch in forgotten_loader:
        second_run_samples.append(batch[0].sum().item())
    
    assert np.allclose(first_sum, sum(second_run_samples)), "样本不固定!"
    print("验证通过：遗忘数据集样本保持固定")

verify_fixed_samples()

Files already downloaded and verified
验证通过：遗忘数据集样本保持固定


In [7]:
client_batch_size = 128

print("联邦训练完整模型...")
# 初始化全局模型
full_net = resnet20(num_classes).to("cuda")
compress_rate = 0.5
torch.manual_seed(1234)
full_net.apply(weights_init)


criterion = nn.CrossEntropyLoss()
global_round = 50

full_net = federated_train(
    full_net,
    client_loaders,  # 包含调整后的客户端3数据
    criterion,
    num_rounds=global_round,
    num_local_epochs=1,
    lr=0.001
)

# 未学习模型训练（从原始客户端加载器重建）
# 需要重新加载原始客户端数据（排除遗忘样本）
modified_client_loaders = [
    DataLoader(
        ds if idx != FORGOTTEN_CLIENT_IDX else Subset(ds.dataset, remaining_indices),
        batch_size=client_batch_size,
        shuffle=True
    )
    for idx, ds in enumerate(client_datasets)
]

unlearned_net = resnet20(num_classes=10)
print("federated unlearning training...")
unlearned_net = federated_train(
    unlearned_net,
    modified_client_loaders,  # 使用排除遗忘样本的加载器
    criterion,
    num_rounds=global_round,
    num_local_epochs=1,
    lr=0.001
)

# 保存模型
torch.save(full_net.state_dict(), "federated_full_resnet_sample_256_round_50_partial.pth")
torch.save(unlearned_net.state_dict(), "federated_unlearned_resnet_sample_256_round_50_partial.pth")

联邦训练完整模型...
Communication Round 1/50


Communication Round 2/50
Communication Round 3/50
Communication Round 4/50
Communication Round 5/50
Communication Round 6/50
Communication Round 7/50
Communication Round 8/50
Communication Round 9/50
Communication Round 10/50
Communication Round 11/50
Communication Round 12/50
Communication Round 13/50
Communication Round 14/50
Communication Round 15/50
Communication Round 16/50
Communication Round 17/50
Communication Round 18/50
Communication Round 19/50
Communication Round 20/50
Communication Round 21/50
Communication Round 22/50
Communication Round 23/50
Communication Round 24/50
Communication Round 25/50
Communication Round 26/50
Communication Round 27/50
Communication Round 28/50
Communication Round 29/50
Communication Round 30/50
Communication Round 31/50
Communication Round 32/50
Communication Round 33/50
Communication Round 34/50
Communication Round 35/50
Communication Round 36/50
Communication Round 37/50
Communication Round 38/50
Communication Round 39/50
Communication Round 

KeyboardInterrupt: 