In [None]:
import torch

from fl import fedsgd
from models import encoder
from attacks.idgla import IDGLA
from attacks.fgla import FGLA
from utils.datasets import get_dataset_info
from utils.utils import visualize, init_seed

init_seed(1)
test_dataset = "cifar100"  # 测试数据集
image_size = 224

# FL模型
encoder_ = encoder.resnet50()
dataset_info = get_dataset_info(test_dataset)
fl_model = encoder.CNN_FC(
    encoder=encoder_,
    fc_layers=[1000],
    image_size=image_size,
    channels=3,
    record_test_info=True,
)
fc_1_w_index = fl_model.get_fc_1_w_index()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fl_model = fl_model.to(device)

# 攻击测试类
idgla = IDGLA(
    decoder_class="resnet50",  # 解码器类名
    decoder_name="FGLA_resnet50_imagenet_std",  # 解码器参数名称
    ica_num_epochs=4000,  # ICA迭代次数
    is_log=False,
    device=device,
)
fgla = FGLA(
    decoder_class="resnet50",  # 解码器类名
    decoder_name="FGLA_resnet50_imagenet_none",  # 解码器参数名称
    device=device,
)

# Visualization

In [None]:

init_seed(1)
max_test_ep = 1 # 实验重复次数
batch_size = 8  # 批大小

sgd_iterator = fedsgd.data_iterator(model=fl_model, 
                                    dataset_name = test_dataset,
                                    batch_size = batch_size,
                                    image_size = image_size,
                                    shuffle = True,
                                    device=device)

attacks = [fgla, idgla]
original_x = None

print(f"attack_name, mse, psnr, ssim, feature_cos_sim, feature_mse, time")
attack_res_dict = {}
for attack in attacks:
    attack_res_dict[type(attack).__name__] = None
ep = 0
for data in sgd_iterator:
    x = data["images"]
    y = data["labels"]
    grads = data["grads"]
    fc_1_input = data["fc_1_input"]
    fc_1_w_grad = grads[fc_1_w_index]
    fc_1_b_grad = grads[fc_1_w_index + 1]

    for attack in attacks:
        attack_name = type(attack).__name__
        attack_res = attack.assess(
            grads,  # 梯度，或参数更新
            x,  # 真样本，用于比较
            y,  # 真标签
            fl_model,  # FL全局模型
            fc_1_w_grad,  # 第1个FC层权重梯度
            fc_1_b_grad,  # 第1个FC层偏置梯度
            fc_1_input,  # 第1个FC层输入
            batch_size,  # 批大小
        )
        dummy_x = attack.dummy_x
        if attack_res_dict[attack_name] is None:
            attack_res["attack"] = attack_name
            attack_res["dummy_x"] = dummy_x
            attack_res_dict[attack_name] = attack_res
            original_x = x
        else:
            for k, v in attack_res.items():
                if isinstance(v, (int, float)):
                    attack_res_dict[attack_name][k] += v
    ep += 1
    if ep >= max_test_ep:
        break
for attack_name, attack_res in attack_res_dict.items():
    # 取平均
    for k, v in attack_res.items():
        if isinstance(v, (int, float)):
            attack_res[k] = v / ep
    mse = attack_res["mse"]
    psnr = attack_res["psnr"]
    ssim = attack_res["ssim"]
    t = attack_res["time"]
    feature_cos_sim = attack_res.get("feature_cos_sim", 0.0)
    feature_mse = attack_res.get("feature_mse", 0.0)
    print(f"{attack_name}, {mse:.4f}, {psnr:.4f}, {ssim:.4f}, {feature_cos_sim:.4f}, {feature_mse:.4f}, {t:.4f}")

for attack_name, attack_res in attack_res_dict.items():
    visualize(attack_res["dummy_x"], attack_name)
visualize(original_x, "Original")

# Table

In [None]:

init_seed(1)
max_test_ep = 100 # 实验重复次数
attacks = [fgla, idgla]

print(f"batch_size, attack_name, psnr, ssim, feature_cos_sim, time")
for batch_size in [64, 128, 256, 512, 1000]:
    sgd_iterator = fedsgd.data_iterator(model=fl_model, 
                                        dataset_name = test_dataset,
                                        batch_size = batch_size,
                                        image_size = image_size,
                                        shuffle=False,
                                        device="cpu") # 防止大批量爆显存

    original_x = None

    attack_res_dict = {}
    for attack in attacks:
        attack_res_dict[type(attack).__name__] = None
    ep = 0
    for data in sgd_iterator:
        x = data["images"].to(device)
        y = data["labels"].to(device)
        grads = data["grads"]
        fc_1_input = data["fc_1_input"].to(device)
        fc_1_w_grad = grads[fc_1_w_index].to(device)
        fc_1_b_grad = grads[fc_1_w_index + 1].to(device)

        for attack in attacks:
            attack_name = type(attack).__name__
            attack_res = attack.assess(
                grads,  # 梯度，或参数更新
                x,  # 真样本，用于比较
                y,  # 真标签
                fl_model,  # FL全局模型
                fc_1_w_grad,  # 第1个FC层权重梯度
                fc_1_b_grad,  # 第1个FC层偏置梯度
                fc_1_input,  # 第1个FC层输入
                batch_size,  # 批大小
            )
            dummy_x = attack.dummy_x
            if attack_res_dict[attack_name] is None:
                attack_res["attack"] = attack_name
                attack_res["dummy_x"] = dummy_x
                attack_res_dict[attack_name] = attack_res
                original_x = x
            else:
                for k, v in attack_res.items():
                    if isinstance(v, (int, float)):
                        attack_res_dict[attack_name][k] += v
        ep += 1
        if ep >= max_test_ep:
            break
    for attack_name, attack_res in attack_res_dict.items():
        # 取平均
        for k, v in attack_res.items():
            if isinstance(v, (int, float)):
                attack_res[k] = v / ep
        mse = attack_res["mse"]
        psnr = attack_res["psnr"]
        ssim = attack_res["ssim"]
        t = attack_res["time"]
        feature_cos_sim = attack_res.get("feature_cos_sim", 0.0)
        feature_mse = attack_res.get("feature_mse", 0.0)
        print(f"{batch_size},{attack_name},{psnr:.4f},{ssim:.4f},{feature_cos_sim:.4f},{t:.4f}")

    # for attack_name, attack_res in attack_res_dict.items():
    #     visualize(attack_res["dummy_x"], attack_name)
    # visualize(original_x, "Original")