In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import numpy as np
import pickle

sys.path.append('D:\\Program\\MyCode\\Round_robin_SL\\Round-Robin')
sys.path.append('D:\\Program\\MyCode\\Round_robin_SL\\Defence')
from models import *
from clients_datasets import *
from tqdm.notebook import tqdm
from utils import *
from AttFunc import *
from Fisher_LeNet import *
from Def import *

In [2]:
batch_size = 600
epochs = 30
NC = 10
dataset = 'svhn'

clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
clients_testloader = load_clients_testsets(dataset, NC, batch_size)

server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
client_level = 1
server_level = 6

criterion = torch.nn.CrossEntropyLoss()

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat


In [3]:
att_type = 'sign'
acc0 = 95.48
acc1_varying = []
drop_varying = []
iters = 10
mode = 'Fisher'


# 初始化历史权重
history_weights = {client_id: {param_name: None for param_name, _ in clients[client_id].named_parameters()} for client_id in range(NC)}


for iter in tqdm(range(iters), desc="Training", unit="iter"):
    batch_size = 600
    epochs = 30
    NC = 10
    dataset = 'svhn'

    clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
    clients_testloader = load_clients_testsets(dataset, NC, batch_size)

    server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
    client_level = 1
    server_level = 6

    criterion = torch.nn.CrossEntropyLoss()
    # train
    mal_client_id = [8]
    server.train()
    for i in range(NC):
        clients[i].train()
    server.apply(init_weights)
    for client in clients:
        client.apply(init_weights)
    last_trained_params = clients[0].state_dict()

    for epoch in range(epochs):
        beta = 1.110609759342772
        for idx, client in enumerate(clients):
            # load sharing parameters
            client.load_state_dict(last_trained_params)
            # 防御：基于Fisher或Taylor信息矩阵检测和修复
            params_checked = defence(client, history_weights, idx, mode=mode, threshold_multiplier=3, percentile=0.33)
            # 更新参数
            for param_name, param in client.named_parameters():
                if param_name in params_checked:
                    param.data = params_checked[param_name]
            # 训练
            for j, data in enumerate(clients_trainloader[idx]):
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                clients_opts[idx].zero_grad()
                server_opt.zero_grad()
                loss = criterion(output, labels)
                loss.backward()
                clients_opts[idx].step()
                server_opt.step()
            # weight sharing
            last_trained_params = client.state_dict()
            # attack part
            if idx in mal_client_id :
                benign_params = list(client.parameters())[:2]

                Taylor_scores = {}
                for param_name, param in client.named_parameters():
                    if param.grad is not None:
                        Taylor_scores[param_name] = torch.abs(param * param.grad)

                weight_positions = []
                bias_positions = []
                weight_positions.append(find_positions(Taylor_scores['conv1.0.weight'].cpu().detach().numpy(), 0.33))
                bias_positions.append(find_positions(Taylor_scores['conv1.0.bias'].cpu().detach().numpy(), 0.33))

                mal_params = fisher_perturbation(client_level, beta, benign_params, weight_positions, bias_positions, type=att_type)
                last_trained_params['conv1.0.weight'] = mal_params[0]
                last_trained_params['conv1.0.bias'] = mal_params[1]

        # 更新历史权重
        for idx, client in enumerate(clients):
            params_to_update = {param_name: param.data.clone() for param_name, param in client.named_parameters()}
            update_history(history_weights, idx, params_to_update)

    for i in range(NC):
        clients[i].load_state_dict(last_trained_params)

    # test
    server.eval()
    for i in range(NC):
        clients[i].eval()
    with torch.no_grad():
        clients_acc1 = []
        clients_drop = []
        for idx, client in enumerate(clients):
            correct = 0
            total = 0
            for data in clients_testloader[idx]:
                images, labels = data
                images, labels = images.cuda(), labels.cuda()

                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                _, pre = torch.max(output.data, 1)
                total += images.shape[0]
                correct += (pre == labels).sum().item()
            acc1 = 100 * correct / total
            clients_acc1.append(acc1)
            drop = acc0 - acc1
            clients_drop.append(drop)
    acc1 = np.mean(clients_acc1)
    acc1_varying.append(acc1)
    drop = np.mean(clients_drop)
    drop_varying.append(drop)

Training:   0%|          | 0/10 [00:00<?, ?iter/s]

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat
Anomalies detected in conv1.0.weight: 96 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 3 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 6 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 2 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 8 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 12 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 2 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 5 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 1 elements, replacing with history values.
Anomal

In [4]:
print(acc1_varying)
print(drop_varying)
acc1_mean, acc1_range = data_process(acc1_varying)
print('acc1:')
print(acc1_mean)
print(acc1_range)
print('---------------------')
drop_mean, drop_range = data_process(drop_varying)
print('drop:')
print(drop_mean)
print(drop_range)

[60.23358224853398, 78.16533622731187, 70.91665471661258, 80.56240937875651, 67.39785506856381, 48.46736573007748, 62.10826985051515, 70.21364041136512, 76.72868154610686, 47.36876627641626]
[35.246417751466026, 17.314663772688142, 24.563345283387424, 14.917590621243482, 28.082144931436197, 47.012634269922515, 33.371730149484854, 25.266359588634877, 18.751318453893155, 48.111233723583744]
acc1:
67.65910501176417
8.808384866097434
---------------------
drop:
27.820894988235843
8.80838486609743


In [5]:
att_type = 'sign'
acc0 = 95.48
acc1_varying = []
drop_varying = []
iters = 10
mode = 'Taylor'


# 初始化历史权重
history_weights = {client_id: {param_name: None for param_name, _ in clients[client_id].named_parameters()} for client_id in range(NC)}


for iter in tqdm(range(iters), desc="Training", unit="iter"):
    batch_size = 600
    epochs = 30
    NC = 10
    dataset = 'svhn'

    clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
    clients_testloader = load_clients_testsets(dataset, NC, batch_size)

    server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
    client_level = 1
    server_level = 6

    criterion = torch.nn.CrossEntropyLoss()
    # train
    mal_client_id = [8]
    server.train()
    for i in range(NC):
        clients[i].train()
    server.apply(init_weights)
    for client in clients:
        client.apply(init_weights)
    last_trained_params = clients[0].state_dict()

    for epoch in range(epochs):
        beta = 1.110609759342772
        for idx, client in enumerate(clients):
            # load sharing parameters
            client.load_state_dict(last_trained_params)
            # 防御：基于Fisher或Taylor信息矩阵检测和修复
            params_checked = defence(client, history_weights, idx, mode=mode, threshold_multiplier=3, percentile=0.33)
            # 更新参数
            for param_name, param in client.named_parameters():
                if param_name in params_checked:
                    param.data = params_checked[param_name]
            # 训练
            for j, data in enumerate(clients_trainloader[idx]):
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                clients_opts[idx].zero_grad()
                server_opt.zero_grad()
                loss = criterion(output, labels)
                loss.backward()
                clients_opts[idx].step()
                server_opt.step()
            # weight sharing
            last_trained_params = client.state_dict()
            # attack part
            if idx in mal_client_id :
                benign_params = list(client.parameters())[:2]

                Taylor_scores = {}
                for param_name, param in client.named_parameters():
                    if param.grad is not None:
                        Taylor_scores[param_name] = torch.abs(param * param.grad)

                weight_positions = []
                bias_positions = []
                weight_positions.append(find_positions(Taylor_scores['conv1.0.weight'].cpu().detach().numpy(), 0.33))
                bias_positions.append(find_positions(Taylor_scores['conv1.0.bias'].cpu().detach().numpy(), 0.33))

                mal_params = fisher_perturbation(client_level, beta, benign_params, weight_positions, bias_positions, type=att_type)
                last_trained_params['conv1.0.weight'] = mal_params[0]
                last_trained_params['conv1.0.bias'] = mal_params[1]

        # 更新历史权重
        for idx, client in enumerate(clients):
            params_to_update = {param_name: param.data.clone() for param_name, param in client.named_parameters()}
            update_history(history_weights, idx, params_to_update)

    for i in range(NC):
        clients[i].load_state_dict(last_trained_params)

    # test
    server.eval()
    for i in range(NC):
        clients[i].eval()
    with torch.no_grad():
        clients_acc1 = []
        clients_drop = []
        for idx, client in enumerate(clients):
            correct = 0
            total = 0
            for data in clients_testloader[idx]:
                images, labels = data
                images, labels = images.cuda(), labels.cuda()

                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                _, pre = torch.max(output.data, 1)
                total += images.shape[0]
                correct += (pre == labels).sum().item()
            acc1 = 100 * correct / total
            clients_acc1.append(acc1)
            drop = acc0 - acc1
            clients_drop.append(drop)
    acc1 = np.mean(clients_acc1)
    acc1_varying.append(acc1)
    drop = np.mean(clients_drop)
    drop_varying.append(drop)

Training:   0%|          | 0/10 [00:00<?, ?iter/s]

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat
Anomalies detected in conv1.0.weight: 82 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 2 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 5 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 2 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 3 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 3 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 6 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 10 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 6 elements, replacing with history values.
Anom

In [6]:
print(acc1_varying)
print(drop_varying)
acc1_mean, acc1_range = data_process(acc1_varying)
print('acc1:')
print(acc1_mean)
print(acc1_range)
print('---------------------')
drop_mean, drop_range = data_process(drop_varying)
print('drop:')
print(drop_mean)
print(drop_range)

[71.4774382978874, 65.2198116553451, 64.20560613920014, 73.58249933758341, 77.93868943609318, 68.26983871262806, 54.79417875982633, 62.70369383548346, 70.46710843508582, 55.22824160707868]
[24.002561702112597, 30.26018834465491, 31.274393860799876, 21.897500662416583, 17.541310563906826, 27.210161287371953, 40.68582124017367, 32.776306164516555, 25.012891564914174, 40.251758392921325]
acc1:
65.89841883572443
4.064232573427915
---------------------
drop:
29.581581164275576
4.0642325734279225
