In [None]:
import torch

from fl import fedavg
from models import encoder
from attacks.idgla import IDGLA
from attacks.fgla import FGLA
from utils.datasets import get_dataset_info
from utils.utils import visualize, init_seed
from utils.datasets import get_dataloader

init_seed(1)
test_dataset = "cifar100"  # 测试数据集
batch_size = 8  # 批大小
image_size = 224
max_test_ep = 1 # 实验重复次数

# FL模型
encoder_ = encoder.resnet50()
dataset_info = get_dataset_info(test_dataset)
fl_model = encoder.CNN_FC(
    encoder=encoder_,
    fc_layers=[1000],
    image_size=image_size,
    channels=3,
    record_test_info=True,
)

dataloader = get_dataloader(
    test_dataset, batch_size, shuffle=False, train=False
)
fc_1_w_index = fl_model.get_fc_1_w_index()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fl_model = fl_model.to(device)

# 攻击测试类
idgla = IDGLA(
    decoder_class="resnet50",  # 解码器类名
    decoder_name="FGLA_resnet50_imagenet_std",  # 解码器参数名称
    ica_num_epochs=4000,  # ICA迭代次数
    is_log=False,
    device=device,
)
fgla = FGLA(
    decoder_class="resnet50",  # 解码器类名
    decoder_name="FGLA_resnet50_imagenet_none",  # 解码器参数名称
    device=device,
)

In [None]:
attacks = [fgla, idgla]
original_x = None

local_epochs = 2
print(f"attack_name, mse, psnr, ssim, feature_cos_sim, feature_mse, time")
iterator = fedavg.data_iterator(model=fl_model,
                                dataloader=dataloader,
                                local_epochs = local_epochs,
                                local_lr = 0.1, # 本地学习率
                                step_item = 13,
                                device = device)
attack_res_dict = {}
for attack in attacks:
    attack_res_dict[type(attack).__name__] = None
ep = 0
for data in iterator:
    batch_images = data["batch_images"]
    batch_labels = data["batch_labels"]
    params_update = data["params_update"]
    batch_fc_1_inputs = data["batch_fc_1_inputs"]
    fc_1_w_update = params_update[fc_1_w_index]
    fc_1_b_update = params_update[fc_1_w_index+1]
    fc_1_input = batch_fc_1_inputs[0]

    for attack in attacks:
        attack_name = type(attack).__name__
        if attack_name=="IDGLA" or attack_name=="CPA":
            x = torch.cat(batch_images, dim=0).to(device)
            y = torch.cat(batch_labels, dim=0).to(device)
            bs = local_epochs*batch_size
        else:
            x = batch_images[0].to(device)
            y = batch_labels[0].to(device)
            bs = batch_size

        attack_res = attack.assess(
            params_update,  # 参数更新
            x,  # 真样本，用于比较
            y,  # 真标签
            fl_model,  # FL全局模型
            fc_1_w_update,  # 第1个FC层权重更新
            fc_1_b_update,  # 第1个FC层偏置更新
            fc_1_input,  # 第1个FC层输入
            bs,  # 输入图像数量
        )
        dummy_x = attack.dummy_x
        if attack_res_dict[attack_name] is None:
            attack_res["attack"] = attack_name
            attack_res["dummy_x"] = dummy_x
            attack_res_dict[attack_name] = attack_res
            original_x = x
        else:
            for k, v in attack_res.items():
                if isinstance(v, (int, float)):
                    attack_res_dict[attack_name][k] += v
    ep += 1
    if ep >= max_test_ep:
        break
for attack_name, attack_res in attack_res_dict.items():
    # 取平均
    for k, v in attack_res.items():
        if isinstance(v, (int, float)):
            attack_res[k] = v / ep
    mse = attack_res["mse"]
    psnr = attack_res["psnr"]
    ssim = attack_res["ssim"]
    t = attack_res["time"]
    feature_cos_sim = attack_res.get("feature_cos_sim", 0.0)
    feature_mse = attack_res.get("feature_mse", 0.0)
    print(f"{attack_name},{mse:.4f},{psnr:.4f},{ssim:.4f},{feature_cos_sim:.4f},{feature_mse:.4f},{t:.4f}")

for attack_name, attack_res in attack_res_dict.items():
    visualize(attack_res["dummy_x"], attack_name)
visualize(original_x, "Original")