In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import numpy as np
import pickle

sys.path.append('D:\\Program\\MyCode\\Round_robin_SL\\Round-Robin')
sys.path.append('D:\\Program\\MyCode\\Round_robin_SL\\Defence')
from models import *
from clients_datasets import *
from tqdm.notebook import tqdm
from utils import *
from AttFunc import *
from Fisher_LeNet import *
from Def import *

In [2]:
batch_size = 600
epochs = 30
NC = 10
dataset = 'cifar10'

clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
clients_testloader = load_clients_testsets(dataset, NC, batch_size)

server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
client_level = 1
server_level = 6

criterion = torch.nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


In [3]:
att_type = 'sign'
acc0 = 85.52
acc1_varying = []
drop_varying = []
iters = 10
mode = 'Fisher'


# 初始化历史权重
history_weights = {client_id: {param_name: None for param_name, _ in clients[client_id].named_parameters()} for client_id in range(NC)}


for iter in tqdm(range(iters), desc="Training", unit="iter"):
    batch_size = 600
    epochs = 30
    NC = 10
    dataset = 'cifar10'

    clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
    clients_testloader = load_clients_testsets(dataset, NC, batch_size)

    server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
    client_level = 1
    server_level = 6

    criterion = torch.nn.CrossEntropyLoss()
    # train
    mal_client_id = [8]
    server.train()
    for i in range(NC):
        clients[i].train()
    server.apply(init_weights)
    for client in clients:
        client.apply(init_weights)
    last_trained_params = clients[0].state_dict()

    for epoch in range(epochs):
        beta = 0.872058758893521
        for idx, client in enumerate(clients):
            # load sharing parameters
            client.load_state_dict(last_trained_params)
            # 防御：基于Fisher或Taylor信息矩阵检测和修复
            params_checked = defence(client, history_weights, idx, mode=mode, threshold_multiplier=3, percentile=0.33)
            # 更新参数
            for param_name, param in client.named_parameters():
                if param_name in params_checked:
                    param.data = params_checked[param_name]
            # 训练
            for j, data in enumerate(clients_trainloader[idx]):
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                clients_opts[idx].zero_grad()
                server_opt.zero_grad()
                loss = criterion(output, labels)
                loss.backward()
                clients_opts[idx].step()
                server_opt.step()
            # weight sharing
            last_trained_params = client.state_dict()
            # attack part
            if idx in mal_client_id :
                benign_params = list(client.parameters())[:2]

                fisher_matrix = {}
                for param_name, param in client.named_parameters():
                    if param_name == 'conv1.0.weight':
                        grad = param.grad.cpu().detach().numpy()
                        if param_name not in fisher_matrix:
                            fisher_matrix[param_name] = grad ** 2
                        else:
                            fisher_matrix[param_name] += grad ** 2
                    if param_name == 'conv1.0.bias':
                        grad = param.grad.cpu().detach().numpy()
                        if param_name not in fisher_matrix:
                            fisher_matrix[param_name] = grad ** 2
                        else:
                            fisher_matrix[param_name] += grad ** 2
                weight_positions = []
                bias_positions = []
                weight_positions.append(find_positions(fisher_matrix['conv1.0.weight'], 0.33))
                bias_positions.append(find_positions(fisher_matrix['conv1.0.bias'], 0.33))

                mal_params = fisher_perturbation(client_level, beta, benign_params, weight_positions, bias_positions, type=att_type)
                last_trained_params['conv1.0.weight'] = mal_params[0]
                last_trained_params['conv1.0.bias'] = mal_params[1]

        # 更新历史权重
        for idx, client in enumerate(clients):
            params_to_update = {param_name: param.data.clone() for param_name, param in client.named_parameters()}
            update_history(history_weights, idx, params_to_update)

    for i in range(NC):
        clients[i].load_state_dict(last_trained_params)

    # test
    server.eval()
    for i in range(NC):
        clients[i].eval()
    with torch.no_grad():
        clients_acc1 = []
        clients_drop = []
        for idx, client in enumerate(clients):
            correct = 0
            total = 0
            for data in clients_testloader[idx]:
                images, labels = data
                images, labels = images.cuda(), labels.cuda()

                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                _, pre = torch.max(output.data, 1)
                total += images.shape[0]
                correct += (pre == labels).sum().item()
            acc1 = 100 * correct / total
            clients_acc1.append(acc1)
            drop = acc0 - acc1
            clients_drop.append(drop)
    acc1 = np.mean(clients_acc1)
    acc1_varying.append(acc1)
    drop = np.mean(clients_drop)
    drop_varying.append(drop)

Training:   0%|          | 0/10 [00:00<?, ?iter/s]

Files already downloaded and verified
Files already downloaded and verified
Anomalies detected in conv1.0.weight: 92 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 6 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 4 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 3 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 2 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 2 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 5 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 9 elem

In [4]:
print(acc1_varying)
print(drop_varying)
acc1_mean, acc1_range = data_process(acc1_varying)
print('acc1:')
print(acc1_mean)
print(acc1_range)
print('---------------------')
drop_mean, drop_range = data_process(drop_varying)
print('drop:')
print(drop_mean)
print(drop_range)

[15.74, 17.92, 20.08, 11.149999999999999, 16.959999999999997, 17.01, 10.889999999999999, 15.459999999999999, 20.490000000000002, 15.65]
[69.78, 67.6, 65.43999999999998, 74.36999999999999, 68.55999999999999, 68.51, 74.63, 70.05999999999997, 65.03, 69.86999999999999]
acc1:
16.163999999999998
1.5500000000000025
---------------------
drop:
69.356
1.5499999999999687


In [5]:
att_type = 'sign'
acc0 = 85.52
acc1_varying = []
drop_varying = []
iters = 10
mode = 'Taylor'


# 初始化历史权重
history_weights = {client_id: {param_name: None for param_name, _ in clients[client_id].named_parameters()} for client_id in range(NC)}


for iter in tqdm(range(iters), desc="Training", unit="iter"):
    batch_size = 600
    epochs = 30
    NC = 10
    dataset = 'cifar10'

    clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
    clients_testloader = load_clients_testsets(dataset, NC, batch_size)

    server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
    client_level = 1
    server_level = 6

    criterion = torch.nn.CrossEntropyLoss()
    # train
    mal_client_id = [8]
    server.train()
    for i in range(NC):
        clients[i].train()
    server.apply(init_weights)
    for client in clients:
        client.apply(init_weights)
    last_trained_params = clients[0].state_dict()

    for epoch in range(epochs):
        beta = 0.872058758893521
        for idx, client in enumerate(clients):
            # load sharing parameters
            client.load_state_dict(last_trained_params)
            # 防御：基于Fisher或Taylor信息矩阵检测和修复
            params_checked = defence(client, history_weights, idx, mode=mode, threshold_multiplier=3, percentile=0.33)
            # 更新参数
            for param_name, param in client.named_parameters():
                if param_name in params_checked:
                    param.data = params_checked[param_name]
            # 训练
            for j, data in enumerate(clients_trainloader[idx]):
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                clients_opts[idx].zero_grad()
                server_opt.zero_grad()
                loss = criterion(output, labels)
                loss.backward()
                clients_opts[idx].step()
                server_opt.step()
            # weight sharing
            last_trained_params = client.state_dict()
            # attack part
            if idx in mal_client_id :
                benign_params = list(client.parameters())[:2]

                fisher_matrix = {}
                for param_name, param in client.named_parameters():
                    if param_name == 'conv1.0.weight':
                        grad = param.grad.cpu().detach().numpy()
                        if param_name not in fisher_matrix:
                            fisher_matrix[param_name] = grad ** 2
                        else:
                            fisher_matrix[param_name] += grad ** 2
                    if param_name == 'conv1.0.bias':
                        grad = param.grad.cpu().detach().numpy()
                        if param_name not in fisher_matrix:
                            fisher_matrix[param_name] = grad ** 2
                        else:
                            fisher_matrix[param_name] += grad ** 2
                weight_positions = []
                bias_positions = []
                weight_positions.append(find_positions(fisher_matrix['conv1.0.weight'], 0.33))
                bias_positions.append(find_positions(fisher_matrix['conv1.0.bias'], 0.33))

                mal_params = fisher_perturbation(client_level, beta, benign_params, weight_positions, bias_positions, type=att_type)
                last_trained_params['conv1.0.weight'] = mal_params[0]
                last_trained_params['conv1.0.bias'] = mal_params[1]

        # 更新历史权重
        for idx, client in enumerate(clients):
            params_to_update = {param_name: param.data.clone() for param_name, param in client.named_parameters()}
            update_history(history_weights, idx, params_to_update)

    for i in range(NC):
        clients[i].load_state_dict(last_trained_params)

    # test
    server.eval()
    for i in range(NC):
        clients[i].eval()
    with torch.no_grad():
        clients_acc1 = []
        clients_drop = []
        for idx, client in enumerate(clients):
            correct = 0
            total = 0
            for data in clients_testloader[idx]:
                images, labels = data
                images, labels = images.cuda(), labels.cuda()

                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                _, pre = torch.max(output.data, 1)
                total += images.shape[0]
                correct += (pre == labels).sum().item()
            acc1 = 100 * correct / total
            clients_acc1.append(acc1)
            drop = acc0 - acc1
            clients_drop.append(drop)
    acc1 = np.mean(clients_acc1)
    acc1_varying.append(acc1)
    drop = np.mean(clients_drop)
    drop_varying.append(drop)

Training:   0%|          | 0/10 [00:00<?, ?iter/s]

Files already downloaded and verified
Files already downloaded and verified
Anomalies detected in conv1.0.weight: 100 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 4 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 3 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 4 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 5 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 2 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.1.weight: 5 elements, replacing with history values.
Anomalies detected in conv1.1.bias: 6 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 2 e

In [6]:
print(acc1_varying)
print(drop_varying)
acc1_mean, acc1_range = data_process(acc1_varying)
print('acc1:')
print(acc1_mean)
print(acc1_range)
print('---------------------')
drop_mean, drop_range = data_process(drop_varying)
print('drop:')
print(drop_mean)
print(drop_range)

[16.27, 17.88, 14.729999999999999, 19.26, 12.419999999999998, 15.24, 20.61, 11.68, 13.870000000000001, 21.89]
[69.25, 67.64, 70.79, 66.25999999999999, 73.1, 70.28, 64.91, 73.84, 71.64999999999999, 63.629999999999995]
acc1:
16.463333333333335
2.639999999999999
---------------------
drop:
69.05666666666666
2.6400000000000006
