In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import numpy as np
import pickle

sys.path.append('D:\\Program\\MyCode\\Round_robin_SL\\Round-Robin')
sys.path.append('D:\\Program\\MyCode\\Round_robin_SL\\Defence')
from models import *
from clients_datasets import *
from tqdm.notebook import tqdm
from utils import *
from AttFunc import *
from Fisher_LeNet import *
from Def import *

In [2]:
batch_size = 600
epochs = 30
NC = 10
dataset = 'mnist'

clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
clients_testloader = load_clients_testsets(dataset, NC, batch_size)

server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
client_level = 1
server_level = 4

criterion = torch.nn.CrossEntropyLoss()

In [3]:
att_type = 'unit'
acc0 = 98.88
acc1_varying = []
drop_varying = []
iters = 10
mode = 'Fisher'

# 初始化历史权重
history_weights = {client_id: {param_name: None for param_name, _ in clients[client_id].named_parameters()} for client_id in range(NC)}

for iter in tqdm(range(iters), desc="Training", unit="iter"):
    batch_size = 600
    epochs = 30
    NC = 10
    dataset = 'mnist'

    clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
    clients_testloader = load_clients_testsets(dataset, NC, batch_size)

    server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
    client_level = 1
    server_level = 4

    criterion = torch.nn.CrossEntropyLoss()
    # train
    mal_client_id = [8]
    server.train()
    for i in range(NC):
        clients[i].train()
    server.apply(init_weights)
    for client in clients:
        client.apply(init_weights)
    last_trained_params = clients[0].state_dict()

    for epoch in range(epochs):
        beta = 0.868213151538335
        for idx, client in enumerate(clients):
            # load sharing parameters
            client.load_state_dict(last_trained_params)
            # 防御：基于Fisher或Taylor信息矩阵检测和修复
            params_checked = defence(client, history_weights, idx, mode=mode, threshold_multiplier=3, percentile=0.33)
            # 更新参数
            for param_name, param in client.named_parameters():
                if param_name in params_checked:
                    param.data = params_checked[param_name]
            # 训练
            for j, data in enumerate(clients_trainloader[idx]):
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                clients_opts[idx].zero_grad()
                server_opt.zero_grad()
                loss = criterion(output, labels)
                loss.backward()
                clients_opts[idx].step()
                server_opt.step()
            # weight sharing
            last_trained_params = client.state_dict()
            # attack part
            if idx in mal_client_id :
                benign_params = list(client.parameters())[:2]

                Taylor_scores = {}
                for param_name, param in client.named_parameters():
                    if param.grad is not None:
                        Taylor_scores[param_name] = torch.abs(param * param.grad)

                weight_positions = []
                bias_positions = []
                weight_positions.append(find_positions(Taylor_scores['conv1.0.weight'].cpu().detach().numpy(), 0.33))
                bias_positions.append(find_positions(Taylor_scores['conv1.0.bias'].cpu().detach().numpy(), 0.33))

                mal_params = fisher_perturbation(client_level, beta, benign_params, weight_positions, bias_positions, type=att_type)
                last_trained_params['conv1.0.weight'] = mal_params[0]
                last_trained_params['conv1.0.bias'] = mal_params[1]

        # 更新历史权重
        for idx, client in enumerate(clients):
            params_to_update = {param_name: param.data.clone() for param_name, param in client.named_parameters()}
            update_history(history_weights, idx, params_to_update)

    for i in range(NC):
        clients[i].load_state_dict(last_trained_params)

    # test
    server.eval()
    for i in range(NC):
        clients[i].eval()
    with torch.no_grad():
        clients_acc1 = []
        clients_drop = []
        for idx, client in enumerate(clients):
            correct = 0
            total = 0
            for data in clients_testloader[idx]:
                images, labels = data
                images, labels = images.cuda(), labels.cuda()

                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                _, pre = torch.max(output.data, 1)
                total += images.shape[0]
                correct += (pre == labels).sum().item()
            acc1 = 100 * correct / total
            clients_acc1.append(acc1)
            drop = acc0 - acc1
            clients_drop.append(drop)
    acc1 = np.mean(clients_acc1)
    acc1_varying.append(acc1)
    drop = np.mean(clients_drop)
    drop_varying.append(drop)

Training:   0%|          | 0/10 [00:00<?, ?iter/s]

Anomalies detected in conv1.0.weight: 7 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 18 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 18 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 2 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 7 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 2 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 2 elements, replacing with history values.
Anomalies detected in conv1.0.bias

In [4]:
print(acc1_varying)
print(drop_varying)
acc1_mean, acc1_range = data_process(acc1_varying)
print('acc1:')
print(acc1_mean)
print(acc1_range)
print('---------------------')
drop_mean, drop_range = data_process(drop_varying)
print('drop:')
print(drop_mean)
print(drop_range)

[94.94000000000001, 95.78999999999999, 94.03, 94.77, 94.77, 95.47999999999999, 94.86999999999999, 95.22999999999999, 94.34, 94.19]
[3.9399999999999964, 3.0899999999999963, 4.849999999999996, 4.109999999999998, 4.109999999999997, 3.399999999999994, 4.009999999999995, 3.649999999999996, 4.539999999999994, 4.689999999999996]
acc1:
94.8375
0.17000000000001592
---------------------
drop:
4.042499999999996
0.17000000000000126


In [5]:
att_type = 'unit'
acc0 = 98.88
acc1_varying = []
drop_varying = []
iters = 10
mode = 'Taylor'

# 初始化历史权重
history_weights = {client_id: {param_name: None for param_name, _ in clients[client_id].named_parameters()} for client_id in range(NC)}

for iter in tqdm(range(iters), desc="Training", unit="iter"):
    batch_size = 600
    epochs = 30
    NC = 10
    dataset = 'mnist'

    clients_trainloader = load_clients_trainsets(dataset, NC, batch_size)
    clients_testloader = load_clients_testsets(dataset, NC, batch_size)

    server, server_opt, clients, clients_opts = set_model_and_opt(dataset, NC)
    client_level = 1
    server_level = 4

    criterion = torch.nn.CrossEntropyLoss()
    # train
    mal_client_id = [8]
    server.train()
    for i in range(NC):
        clients[i].train()
    server.apply(init_weights)
    for client in clients:
        client.apply(init_weights)
    last_trained_params = clients[0].state_dict()

    for epoch in range(epochs):
        beta = 0.868213151538335
        for idx, client in enumerate(clients):
            # load sharing parameters
            client.load_state_dict(last_trained_params)
            # 防御：基于Fisher或Taylor信息矩阵检测和修复
            params_checked = defence(client, history_weights, idx, mode=mode, threshold_multiplier=3, percentile=0.33)
            # 更新参数
            for param_name, param in client.named_parameters():
                if param_name in params_checked:
                    param.data = params_checked[param_name]
            # 训练
            for j, data in enumerate(clients_trainloader[idx]):
                images, labels = data
                images = images.cuda()
                labels = labels.cuda()
                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                clients_opts[idx].zero_grad()
                server_opt.zero_grad()
                loss = criterion(output, labels)
                loss.backward()
                clients_opts[idx].step()
                server_opt.step()
            # weight sharing
            last_trained_params = client.state_dict()
            # attack part
            if idx in mal_client_id :
                benign_params = list(client.parameters())[:2]

                Taylor_scores = {}
                for param_name, param in client.named_parameters():
                    if param.grad is not None:
                        Taylor_scores[param_name] = torch.abs(param * param.grad)

                weight_positions = []
                bias_positions = []
                weight_positions.append(find_positions(Taylor_scores['conv1.0.weight'].cpu().detach().numpy(), 0.33))
                bias_positions.append(find_positions(Taylor_scores['conv1.0.bias'].cpu().detach().numpy(), 0.33))

                mal_params = fisher_perturbation(client_level, beta, benign_params, weight_positions, bias_positions, type=att_type)
                last_trained_params['conv1.0.weight'] = mal_params[0]
                last_trained_params['conv1.0.bias'] = mal_params[1]

        # 更新历史权重
        for idx, client in enumerate(clients):
            params_to_update = {param_name: param.data.clone() for param_name, param in client.named_parameters()}
            update_history(history_weights, idx, params_to_update)

    for i in range(NC):
        clients[i].load_state_dict(last_trained_params)

    # test
    server.eval()
    for i in range(NC):
        clients[i].eval()
    with torch.no_grad():
        clients_acc1 = []
        clients_drop = []
        for idx, client in enumerate(clients):
            correct = 0
            total = 0
            for data in clients_testloader[idx]:
                images, labels = data
                images, labels = images.cuda(), labels.cuda()

                smashed_data = client.forward(images, client_level=client_level)
                output = server.forward(smashed_data, server_level=server_level)
                _, pre = torch.max(output.data, 1)
                total += images.shape[0]
                correct += (pre == labels).sum().item()
            acc1 = 100 * correct / total
            clients_acc1.append(acc1)
            drop = acc0 - acc1
            clients_drop.append(drop)
    acc1 = np.mean(clients_acc1)
    acc1_varying.append(acc1)
    drop = np.mean(clients_drop)
    drop_varying.append(drop)

Training:   0%|          | 0/10 [00:00<?, ?iter/s]

Anomalies detected in conv1.0.weight: 7 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 5 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 19 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 2 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 15 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 9 elements, replacing with history values.
Anomalies detected in conv1.0.bias: 1 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 3 elements, replacing with history values.
Anomalies detected in conv1.0.weight: 21 elements, replacing with history values.
Anomalies detected in conv1.0.bia

In [6]:
print(acc1_varying)
print(drop_varying)
acc1_mean, acc1_range = data_process(acc1_varying)
print('acc1:')
print(acc1_mean)
print(acc1_range)
print('---------------------')
drop_mean, drop_range = data_process(drop_varying)
print('drop:')
print(drop_mean)
print(drop_range)

[95.36, 94.12, 95.34, 94.61, 94.96000000000001, 95.36, 94.94000000000001, 94.15, 94.46000000000001, 95.56]
[3.5199999999999974, 4.759999999999994, 3.5399999999999965, 4.269999999999994, 3.9199999999999946, 3.519999999999996, 3.939999999999995, 4.729999999999995, 4.419999999999995, 3.3199999999999976]
acc1:
94.95000000000002
0.01999999999999602
---------------------
drop:
3.929999999999995
0.020000000000000462
