# 一、 背景设置


## 初始设置
模型架构:SplitFedV1   
DNN架构:ResNet18<br>
数据集:HAM10000,  
多分类问题  

In [1]:
#============================================================================
# SplitfedV1 (SFLV1) learning: ResNet18 on HAM10000
# HAM10000 dataset: Tschandl, P.: The HAM10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions (2018), doi:10.7910/DVN/DBW86T

# We have three versions of our implementations
# Version1: without using socket and no DP+PixelDP
# Version2: with using socket but no DP+PixelDP
# Version3: without using socket but with DP+PixelDP

# This program is Version1: Single program simulation 
# ============================================================================
import torch
from torch import nn
from torchvision import transforms   # 导入PyTorch中的图像变换模块。
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F  # 导入PyTorch中的常用函数模块，例如ReLU、softmax等。
import math
import os.path
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
from glob import glob
from pandas import DataFrame

import random
import numpy as np
import os


import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy

## 创建随机数种子，保证程序结果可以重复

In [2]:
# 将随机数种子设置为1234，这可以确保在每次运行代码时生成的随机数序列都是相同的。
SEED = 1234
random.seed(SEED)  # 确保Python标准库中的随机数生成器生成的随机数序列相同。
np.random.seed(SEED)  # 确保NumPy中的随机数生成器生成的随机数序列相同。
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)  #  设置PyTorch中的CUDA随机数生成器的种子，确保在使用CUDA时生成的随机数序列相同。
if torch.cuda.is_available():
    # 如果CUDA可用，则打印第一个可用的CUDA设备的名称。
    torch.backends.cudnn.deterministic = True
    print(torch.cuda.get_device_name(0)) 

NVIDIA GeForce GTX 1660 Ti


## 定义程序和变量

In [3]:
#===================================================================
program = "SFLV1 ResNet18 on HAM10000"    # 定义了程序名称。
print(f"---------{program}----------")              # this is to identify the program in the slurm outputs files

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# To print in color -------test/train of the client side
# 定义一个函数 prRed，用于在控制台中以红色打印文本。
def prRed(skk): print("\033[91m {}\033[00m" .format(skk)) 
# 用于在控制台中以绿色打印文本。
def prGreen(skk): print("\033[92m {}\033[00m" .format(skk))     

#===================================================================
# No. of users
num_users = 5  # 参与训练的客户端
epochs = 200
frac = 1        # participation of clients; if 1 then 100% clients participate in SFLV1
lr = 0.0001

---------SFLV1 ResNet18 on HAM10000----------


# 模型定义

## Client-side Model Definition

In [4]:
#=====================================================================================================
#                           Client-side Model definition
#=====================================================================================================
# Model at client side
class ResNet18_client_side(nn.Module):
    def __init__(self):
        super(ResNet18_client_side,                                                          self).__init__()
        self.layer1 = nn.Sequential (
                # (n+1)/2
                nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False),
                nn.BatchNorm2d(64),
                nn.ReLU (inplace = True),
                # (n+3)/2
                nn.MaxPool2d(kernel_size = 3, stride = 2, padding =1),
            )
        self.layer2 = nn.Sequential  (
            # 不变
                nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1, bias = False),
                nn.BatchNorm2d(64),
                nn.ReLU (inplace = True),
                nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
                nn.BatchNorm2d(64),              
            )
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 其中，m.weight 表示子模块 m 的权重张量，.data 表示获取该张量的底层数据，并且.normal_() 表示在该数据上进行 Inplace 操作，即直接在原数据上修改而不返回新的数据。
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                # 如果该子模块是 nn.BatchNorm2d 类型，就将权重设置为1，偏差设置为0。
                # 这是 Batch Normalization 的初始化方式。
                m.weight.data.fill_(1)
                
                m.bias.data.zero_()
        
        
    def forward(self, x):
        resudial1 = F.relu(self.layer1(x))
        out1 = self.layer2(resudial1)
        out1 = out1 + resudial1 # adding the resudial inputs -- downsampling not required in this layer
        resudial2 = F.relu(out1)
        return resudial2

创建实例，打印

In [5]:
net_glob_client = ResNet18_client_side()
if torch.cuda.device_count() > 1:
    print("We use",torch.cuda.device_count(), "GPUs")
    net_glob_client = nn.DataParallel(net_glob_client)    

net_glob_client.to(device)
print(net_glob_client)  

ResNet18_client_side(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


## Server-side Model definition

### BaseBlock

In [6]:
class Baseblock(nn.Module):
    # residule block
    expansion = 1 # 类变量，用于定义基本块的输出通道数与输入通道数的比例，默认为1。
    def __init__(self, input_planes, planes, stride = 1, dim_change = None):
        super(Baseblock, self).__init__()
        self.conv1 = nn.Conv2d(input_planes, planes, stride =  stride, kernel_size = 3, padding = 1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, stride = 1, kernel_size = 3, padding = 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.dim_change = dim_change
        
    def forward(self, x):
        res = x
        output = F.relu(self.bn1(self.conv1(x)))
        output = self.bn2(self.conv2(output))
        
        if self.dim_change is not None:
            res =self.dim_change(res)
            
        output += res
        output = F.relu(output)
        
        return output

### Server-side

In [7]:
class ResNet18_server_side(nn.Module):
    def __init__(self, block, num_layers, classes):
        super(ResNet18_server_side, self).__init__()
        self.input_planes = 64
        self.layer3 = nn.Sequential (
                nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
                nn.BatchNorm2d(64),
                nn.ReLU (inplace = True),
                nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
                nn.BatchNorm2d(64),       
                )   
        
        self.layer4 = self._layer(block, 128, num_layers[0], stride = 2)
        self.layer5 = self._layer(block, 256, num_layers[1], stride = 2)
        self.layer6 = self._layer(block, 512, num_layers[2], stride = 2)
        self.averagePool = nn.AvgPool2d(kernel_size = 2, stride = 1)
        self.fc = nn.Linear(512 * block.expansion, classes)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        
        
    def _layer(self, block, planes, num_layers, stride = 2):
        dim_change = None
        if stride != 1 or planes != self.input_planes * block.expansion:
            dim_change = nn.Sequential(nn.Conv2d(self.input_planes, planes*block.expansion, kernel_size = 1, stride = stride),
                                       nn.BatchNorm2d(planes*block.expansion))
        netLayers = []
        # # 将一个基本块block加入到列表中，并且输入参数为当前输入通道数self.inplanes、输出通道数planes、步长stride和下采样操作downsample。
        netLayers.append(block(self.input_planes, planes, stride = stride, dim_change = dim_change))
        self.input_planes = planes * block.expansion
        for i in range(1, num_layers):
            netLayers.append(block(self.input_planes, planes))
            self.input_planes = planes * block.expansion
        # 最后，将netLayers序列封装为nn.Sequential类型的对象，并将其返回
        return nn.Sequential(*netLayers)
        
    
    def forward(self, x):
        out2 = self.layer3(x)
        out2 = out2 + x          # adding the resudial inputs -- downsampling not required in this layer
        x3 = F.relu(out2)
        
        x4 = self. layer4(x3)
        x5 = self.layer5(x4)
        x6 = self.layer6(x5)
        
        x7 = F.avg_pool2d(x6, 2)
        x8 = x7.view(x7.size(0), -1) 
        y_hat =self.fc(x8)
        
        return y_hat

### 创建实例

In [9]:
net_glob_server = ResNet18_server_side(Baseblock, [2,2,2], 7) #7 is my numbr of classes
if torch.cuda.device_count() > 1:
    print("We use",torch.cuda.device_count(), "GPUs")
    net_glob_server = nn.DataParallel(net_glob_server)   # to use the multiple GPUs 

net_glob_server.to(device)
print(net_glob_server) 

ResNet18_server_side(
  (layer3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer4): Sequential(
    (0): Baseblock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dim_change): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )

## Server端参数和聚合函数定义

In [11]:
#===================================================================================
# For Server Side Loss and Accuracy 
loss_train_collect = []  # 创建一个空列表，用于记录每个客户端训练的损失。
acc_train_collect = []
loss_test_collect = []  # # 创建一个空列表，用于记录全局模型在测试集上的损失。
acc_test_collect = []
batch_acc_train = []
batch_loss_train = []
batch_acc_test = []
batch_loss_test = []


criterion = nn.CrossEntropyLoss()
count1 = 0
count2 = 0

def calculate_accuracy(fx, y):
    preds = fx.max(1, keepdim=True)[1]
    correct = preds.eq(y.view_as(preds)).sum()
    acc = 100.00 *correct.float()/preds.shape[0]
    return acc

# to print train - test together in each round-- these are made global
acc_avg_all_user_train = 0
loss_avg_all_user_train = 0
loss_train_collect_user = []
acc_train_collect_user = []
loss_test_collect_user = []
acc_test_collect_user = []

# （即权重和偏置）保存到w_glob_server中。
w_glob_server = net_glob_server.state_dict()
w_locals_server = []

#client idx collector
idx_collect = []    # 初始化一个空列表，用于收集选择的客户端的索引。
l_epoch_check = False   # 初始化一个布尔变量，用于指示是否进行了本地训练轮次的检查。
fed_check = False   # 初始化一个布尔变量，用于指示是否完成了联邦学习。
# Initialization of net_model_server and net_server (server-side model)
net_model_server = [net_glob_server for i in range(num_users)]  # 该列表包含了每个客户端的初始模型。
net_server = copy.deepcopy(net_model_server[0]).to(device)  # 初始化为net_model_server的第一个元素的深拷贝，并将其移到GPU上。
#optimizer_server = torch.optim.Adam(net_server.parameters(), lr = lr)

### 定义聚合规则

In [12]:
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

# 训练&评估函数定义

## Server 训练和评估函数定义
### train_server

In [13]:
def train_server(fx_client, y, l_epoch_count, l_epoch, idx, len_batch):
    """

    Args:
        fx_client: 一个函数，用于在客户端更新模型参数，它接受以下参数：net_model_client（客户端模型），optimizer_client（客户端优化器），train_loader（客户端训练数据），l_epoch（客户端训练轮数）。
        y:目标变量的标签值。
        l_epoch_count:训练的总轮数
        l_epoch:当前训练的轮数
        idx:用于选择在全局模型中使用哪些本地模型进行更新的客户端的索引。
        len_batch:训练数据的批次大小。

    Returns:

    """
    # 这些是全局变量，因为它们在函数内被更新，并且在函数之外被调用。
    """
    net_model_server: 全局模型。
    criterion: 损失函数，用于计算模型的误差。
    optimizer_server: 优化器，用于更新全局模型的参数。
    device: 设备（CPU或GPU）用于计算。
    batch_acc_train: 当前批次的准确度。
    batch_loss_train: 当前批次的损失。
    l_epoch_check: 在训练期间用于检查损失和准确度的训练周期数。
    fed_check: 用于检查训练周期是否已完成的标志。
    loss_train_collect: 用于收集所有客户端训练损失的列表。
    acc_train_collect: 用于收集所有客户端训练准确度的列表。
    count1: 计数器，用于跟踪当前已经训练的客户端数量。
    acc_avg_all_user_train: 所有客户端训练准确度的平均值。
    loss_avg_all_user_train: 所有客户端训练损失的平均值。
    idx_collect: 用于跟踪已经训练的客户端的索引列表。
    w_locals_server: 所有客户端本地模型参数的列表。
    w_glob_server: 全局模型参数的列表。
    net_server: 全局模型。
    """
    global net_model_server, criterion, optimizer_server, device, batch_acc_train, batch_loss_train, l_epoch_check, fed_check
    global loss_train_collect, acc_train_collect, count1, acc_avg_all_user_train, loss_avg_all_user_train, idx_collect, w_locals_server, w_glob_server, net_server
    global loss_train_collect_user, acc_train_collect_user, lr

    # net_server是全局模型，返回制定索引的本地模型
    net_server = copy.deepcopy(net_model_server[idx]).to(device)    # copy.deepcopy() 函数用于创建一个当前本地模型的副本，以便我们可以在全局模型的更新过程中使用它，而不会对原始本地模型进行更改。
    # 方法将模型设置为训练模式，这意味着在计算时会使用训练期间的正则化技术，如dropout或batch normalization。
    net_server.train()
    # 是一个PyTorch中的Adam优化器的实现，它接受模型参数和学习率作为参数，用于更新模型参数以最小化损失函数。在这里，我们使用全局模型的参数和一个预定义的学习率 lr 创建了一个Adam优化器对象
    optimizer_server = torch.optim.Adam(net_server.parameters(), lr = lr)

    
    # 1.train and update
    # 用于清空之前的梯度信息，这样我们可以在每个训练迭代中计算新的梯度并更新模型参数。
    optimizer_server.zero_grad()
    
    fx_client = fx_client.to(device)
    y = y.to(device)
    
    #---------forward prop-------------
    fx_server = net_server(fx_client)   # 作为输入传递到全局模型 net_server 中，然后返回模型的预测输出 fx_server
    
    # calculate loss
    loss = criterion(fx_server, y)
    # calculate accuracy
    acc = calculate_accuracy(fx_server, y)
    
    #--------backward prop--------------
    loss.backward()
    # 由于我们需要在全局模型更新之前将 fx_client 更新到最新的版本，因此我们使用 clone().detach() 函数来创建一个新的 dfx_client 张量，它具有相同的值但不会被计算图所记录。
    dfx_client = fx_client.grad.clone().detach()
    optimizer_server.step()
    
    batch_loss_train.append(loss.item())
    batch_acc_train.append(acc.item())
    
    # Update the server-side model for the current batch
    net_model_server[idx] = copy.deepcopy(net_server)
    
    # count1: to track the completion of the local batch associated with one client
    count1 += 1
    if count1 == len_batch:
        acc_avg_train = sum(batch_acc_train)/len(batch_acc_train)           # 计算当前batch的准确率
        loss_avg_train = sum(batch_loss_train)/len(batch_loss_train)    # 计算当前batch的损失
        
        batch_acc_train = []    # 将当前batch准确率清零
        batch_loss_train = []
        count1 = 0
        
        prRed('Client{} Train => Local Epoch: {} \tAcc: {:.3f} \tLoss: {:.4f}'.format(idx, l_epoch_count, acc_avg_train, loss_avg_train))
        
        # copy the last trained model in the batch
        # 的状态字典复制到一个新的字典中，以便我们可以将其发送到参与者，从而启动下一轮的联邦学习。注意，w_server 中包含的参数是最新一轮训练的参数，因此每个参与者将从这些参数开始训练它们的本地模型。
        w_server = net_server.state_dict()      
        
        # If one local epoch is completed, after this a new client will come
        if l_epoch_count == l_epoch-1:
            # l_epoch_count 是本地epoch的计数器，l_epoch 是本地epoch的总数。当计数器 l_epoch_count 等于总数 l_epoch 减 1 时，说明本地epoch已经完成。
            # # 标记已经完成本地epoch
            l_epoch_check = True                # to evaluate_server function - to check local epoch has completed or not 
            # We store the state of the net_glob_server()
            # w_server 是全局模型中最新的训练参数，w_locals_server 是用于存储每个参与者的最后一轮训练参数的列表。因此，当本地epoch完成时，将 w_server 添加到 w_locals_server 中，以便之后将其发送到联邦平均服务器。
            w_locals_server.append(copy.deepcopy(w_server))
            
            # we store the last accuracy in the last batch of the epoch and it is not the average of all local epochs
            # this is because we work on the last trained model and its accuracy (not earlier cases)
            
            #print("accuracy = ", acc_avg_train)
            acc_avg_train_all = acc_avg_train   # 记录最后一个batch的准确率和损失，作为本地epoch的结果
            loss_avg_train_all = loss_avg_train #
                        
            # accumulate accuracy and loss for each new user
            loss_train_collect_user.append(loss_avg_train_all)   # 将本地epoch的损失添加到损失列表中
            acc_train_collect_user.append(acc_avg_train_all)    # # 将本地epoch的准确率添加到准确率列表中
            
            # collect the id of each new user                        
            if idx not in idx_collect:
                idx_collect.append(idx) 
                #print(idx_collect)
                print(idx_collect)
        
        # This is for federation process--------------------
        if len(idx_collect) == num_users:
            # 如果客户端编号列表的长度等于客户端总数，说明所有客户端的训练结果都已经到达服务器了。
            fed_check = True                                                  # to evaluate_server function  - to check fed check has hitted
            # Federation process at Server-Side------------------------- output print and update is done in evaluate_server()
            # for nicer display 
                                   
            w_glob_server = FedAvg(w_locals_server)  # 使用联邦平均算法更新全局模型，将所有客户端的本地模型参数传入该函数中。
            
            # server-side global model update and distribute that model to all clients ------------------------------
            net_glob_server.load_state_dict(w_glob_server)      # 将更新后的全局模型参数加载到服务器端的模型中。
            net_model_server = [net_glob_server for i in range(num_users)]  # 创建一个长度为客户端数量的列表，每个元素都是更新后的全局模型。这个列表用于向每个客户端分发全局模型参数。
            
            w_locals_server = []    #  # 清空本地模型参数列表
            idx_collect = []    # 清空客户端编号列表
            
            acc_avg_all_user_train = sum(acc_train_collect_user)/len(acc_train_collect_user)    # 计算所有客户端训练结果的平均准确率和损失
            loss_avg_all_user_train = sum(loss_train_collect_user)/len(loss_train_collect_user)
            
            loss_train_collect.append(loss_avg_all_user_train)
            acc_train_collect.append(acc_avg_all_user_train)
            
            acc_train_collect_user = []
            loss_train_collect_user = []
            
    # send gradients to the client               
    return dfx_client

### eval_server

In [14]:
def evaluate_server(fx_client, y, idx, len_batch, ell):
    global net_model_server, criterion, batch_acc_test, batch_loss_test, check_fed, net_server, net_glob_server 
    global loss_test_collect, acc_test_collect, count2, num_users, acc_avg_train_all, loss_avg_train_all, w_glob_server, l_epoch_check, fed_check
    global loss_test_collect_user, acc_test_collect_user, acc_avg_all_user_train, loss_avg_all_user_train
    
    net = copy.deepcopy(net_model_server[idx]).to(device)
    net.eval()
  
    with torch.no_grad():
        # with torch.no_grad()是一个上下文管理器，它可以暂时关闭所有的requires_grad标志，从而不计算梯度1。这样可以节省内存，提高推理速度，也可以避免不必要的梯度累积2。通常在验证或部署模型时使用这个方法3。
        fx_client = fx_client.to(device)
        y = y.to(device) 
        #---------forward prop-------------
        fx_server = net(fx_client)
        
        # calculate loss
        loss = criterion(fx_server, y)
        # calculate accuracy
        acc = calculate_accuracy(fx_server, y)
        
        
        batch_loss_test.append(loss.item())
        batch_acc_test.append(acc.item())
        
               
        count2 += 1
        if count2 == len_batch:
            acc_avg_test = sum(batch_acc_test)/len(batch_acc_test)
            loss_avg_test = sum(batch_loss_test)/len(batch_loss_test)
            
            batch_acc_test = []
            batch_loss_test = []
            count2 = 0
            
            prGreen('Client{} Test =>                   \tAcc: {:.3f} \tLoss: {:.4f}'.format(idx, acc_avg_test, loss_avg_test))
            
            # if a local epoch is completed   
            if l_epoch_check:
                l_epoch_check = False
                
                # Store the last accuracy and loss
                acc_avg_test_all = acc_avg_test
                loss_avg_test_all = loss_avg_test
                        
                loss_test_collect_user.append(loss_avg_test_all)
                acc_test_collect_user.append(acc_avg_test_all)
                
            # if federation is happened----------                    
            if fed_check:
                fed_check = False
                print("------------------------------------------------")
                print("------ Federation process at Server-Side ------- ")
                print("------------------------------------------------")
                
                acc_avg_all_user = sum(acc_test_collect_user)/len(acc_test_collect_user)
                loss_avg_all_user = sum(loss_test_collect_user)/len(loss_test_collect_user)
            
                loss_test_collect.append(loss_avg_all_user)
                acc_test_collect.append(acc_avg_all_user)
                acc_test_collect_user = []
                loss_test_collect_user= []
                              
                print("====================== SERVER V1==========================")
                print(' Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(ell, acc_avg_all_user_train, loss_avg_all_user_train))
                print(' Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(ell, acc_avg_all_user, loss_avg_all_user))
                print("==========================================================")
         
    return 

## DataSplit

In [16]:
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

def dataset_iid(dataset, num_users):
    # 该函数接受一个数据集dataset和一个整数num_users作为输入。它的作用是将数据集分割成num_users份，以便每个客户端都有一份相同分布的数据集。
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace = False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users   

## Client

In [17]:
class Client(object):
    def __init__(self, net_client_model, idx, lr, device, dataset_train = None, dataset_test = None, idxs = None, idxs_test = None):
        # net_client_model:一个与客户端实例相关的神经网络模型。
        self.idx = idx  # 一个整数，表示客户端的索引
        self.device = device    # 一个字符串，表示执行客户端计算的设备。
        self.lr = lr
        self.local_ep = 1
        #self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset_train, idxs), batch_size = 256, shuffle = True)    # 一个PyTorch数据集，表示客户端可用于训练的数据
        self.ldr_test = DataLoader(DatasetSplit(dataset_test, idxs_test), batch_size = 256, shuffle = True)
        

    def train(self, net):
        net.train()
        optimizer_client = torch.optim.Adam(net.parameters(), lr = self.lr) 
        
        for iter in range(self.local_ep):
            # 外层循环是客户端的本地训练轮数self.local_ep
            len_batch = len(self.ldr_train)
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                # 内层循环是数据加载器self.ldr_train中每个批次的训练。在每个批次中，将图像和标签加载到设备上，然后将优化器的梯度清零。
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer_client.zero_grad()
                #---------forward prop-------------
                fx = net(images)
                # 生成一个可求导的副本client_fx
                client_fx = fx.clone().detach().requires_grad_(True)
                
                # Sending activations to server and receiving gradients from server
                dfx = train_server(client_fx, labels, iter, self.local_ep, self.idx, len_batch)
                
                #--------backward prop -------------
                fx.backward(dfx)
                optimizer_client.step()
                            
            
            #prRed('Client{} Train => Epoch: {}'.format(self.idx, ell))
           
        return net.state_dict() 
    
    def evaluate(self, net, ell):
        net.eval()
           
        with torch.no_grad():
            len_batch = len(self.ldr_test)
            for batch_idx, (images, labels) in enumerate(self.ldr_test):
                images, labels = images.to(self.device), labels.to(self.device)
                #---------forward prop-------------
                fx = net(images)
                
                # Sending activations to server 
                evaluate_server(fx, labels, self.idx, len_batch, ell)
            
            #prRed('Client{} Test => Epoch: {}'.format(self.idx, ell))
            
        return   

# 程序运行

## 导入数据并预处理

In [23]:
df = pd.read_csv('D:\codes\DeepLearning\SplitFed-When-Federated-Learning-Meets-Split-Learning\data\HAM10000_metadata.csv')
"""
lesion_id：病变ID，标识一组图像（有些病变ID具有多个图像）.一个患者可能有多个损伤，该ID对应一个患者的所有损伤都相同。
image_id：图像ID，唯一标识每个图像
dx：病变诊断，一个分类标签，代表着皮肤病的类型，如良性病变（如色素性痣或良性肿瘤）或恶性病变（如黑色素瘤或基底细胞癌）。
dx_type：病变诊断类型，指诊断方法，包括临床（通过肉眼观察）、镜下（组织活检）、或是历史（过去的诊断）。
age：患者年龄，以年为单位，有一些缺失值。
sex：患者性别，分为男性或女性，有一些缺失值。
localization：病变位置，即皮肤上的具体位置，如头皮、脸部、手臂等。
"""
print(df.head())

# python字典
# 在皮肤病分类任务中，通常需要将原始的标签进行映射，从而将缩写转换为完整的病种名称。
lesion_type = {
    'nv': 'Melanocytic nevi',  # 黑素瘤痣
    'mel': 'Melanoma',  # 黑色素瘤
    'bkl': 'Benign keratosis-like lesions ',  # 良性鳞状细胞痣
    'bcc': 'Basal cell carcinoma',  # 基底细胞癌
    'akiec': 'Actinic keratoses',  # 日光性角化病
    'vasc': 'Vascular lesions',  # 血管病变
    'df': 'Dermatofibroma'  # 皮脂纤维瘤
}

# merging both folders of HAM1000 dataset -- part1 and part2 -- into a single directory
# os.path.join是路径拼接，glob是自带的文件操作，获得制定的文件
"""
glob(os.path.join("data", '*', '*.jpg'))获取了data目录下的所有.jpg图片的路径，其中*是通配符，可以匹配任何文件夹名。
os.path.basename(x)获取了路径x的文件名，如ISIC_0024433.jpg。

os.path.splitext(os.path.basename(x))[0]去掉了文件名的后缀.jpg，如ISIC_0024433。

最终的结果是一个字典，其中key为图片id，value为图片路径，如{'ISIC_0024433': 'data\HAM10000_images_part_1\ISIC_0024433.jpg', ...}。
"""
imageid_path = {os.path.splitext(os.path.basename(x))[0]: x
                for x in glob(os.path.join("D:\codes\DeepLearning\SplitFed-When-Federated-Learning-Meets-Split-Learning\data", '*', '*.jpg'))}

# print("path---------------------------------------", imageid_path.get)
# 将图像id映射为图像文件的路径，并将其存储在数据集中的path列中。
df['path'] = df['image_id'].map(imageid_path.get)
# 将诊断编码映射为对应的分类名称，并将其存储在数据集中的cell_type列中。
df['cell_type'] = df['dx'].map(lesion_type.get)
# 将分类名称转换为数字编码，并将其存储在数据集中的target列中。这里使用了.
# 可以将字符串类型的分类变量转换为数字编码，其中不同的分类名称对应不同的数字编码。
df['target'] = pd.Categorical(df['cell_type']).codes
print(df['cell_type'].value_counts())
print(df['target'].value_counts())
print(df.head())

     lesion_id      image_id   dx dx_type   age   sex localization
0  HAM_0000118  ISIC_0027419  bkl   histo  80.0  male        scalp
1  HAM_0000118  ISIC_0025030  bkl   histo  80.0  male        scalp
2  HAM_0002730  ISIC_0026769  bkl   histo  80.0  male        scalp
3  HAM_0002730  ISIC_0025661  bkl   histo  80.0  male        scalp
4  HAM_0001466  ISIC_0031633  bkl   histo  75.0  male          ear
Melanocytic nevi                  6705
Melanoma                          1113
Benign keratosis-like lesions     1099
Basal cell carcinoma               514
Actinic keratoses                  327
Vascular lesions                   142
Dermatofibroma                     115
Name: cell_type, dtype: int64
4    6705
5    1113
2    1099
1     514
0     327
6     142
3     115
Name: target, dtype: int64
     lesion_id      image_id   dx dx_type   age   sex localization  \
0  HAM_0000118  ISIC_0027419  bkl   histo  80.0  male        scalp   
1  HAM_0000118  ISIC_0025030  bkl   histo  80.0  male     

### Custom dataset prepration in Pytorch format

In [24]:
class SkinData(Dataset):
    # 其作用是将数据集转化为可以在PyTorch中使用的形式
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        """
        __getitem__通过给定索引index返回一个样本数据。该方法首先打开一个图片，
        然后读取该图片的路径并将其转换为PyTorch中的Tensor对象。同时还会返回该样本的标签（即病变类型所对应的数字编码）。
        """
        X = Image.open(self.df['path'][index]).resize((64, 64))
        y = torch.tensor(int(self.df['target'][index]))

        if self.transform:
            """
            在构造SkinData对象时，可以选择是否使用变换（transform）来对样本数据进行预处理。
            如果使用变换，将对图像进行缩放，同时可以应用一些常用的数据增强操作，如随机旋转、随机翻转、随机裁剪等。
            """
            X = self.transform(X)

        return X, y

In [25]:
# =============================================================================
# Train-test split  
train, test = train_test_split(df, test_size=0.2)
# reset_index()函数用于重置索引，以便在后续处理中更容易使用。
train = train.reset_index()
test = test.reset_index()

# =============================================================================
#                         Data preprocessing
# =============================================================================
# Data preprocessing: Transformation 
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
"""
函数将它们组合成一个数据预处理的 pipeline。在 train_transforms 中，首先进行了一个 50% 的概率的水平翻转，之后是一个 50% 的概率的竖直翻转，然后进行了一个 3 像素的 padding，紧接着进行了一个 10 度的随机旋转，最后对图像中心区域进行 64 像素的裁剪，最终将图像转换为 Tensor 格式，并进行归一化处理。而在 """
# torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),  # 以0。5概率水平翻转给定PIL图像
                                       transforms.RandomVerticalFlip(),  # 竖直
                                       transforms.Pad(3),  #
                                       transforms.RandomRotation(10),
                                       transforms.CenterCrop(64),  # 图片中间区域进行裁剪
                                       transforms.ToTensor(),  # 转化为torch tensor
                                       transforms.Normalize(mean=mean, std=std)
                                       ])
"""
只进行了一个 3 像素的 padding，之后进行了一个图像中心区域 64 像素的裁剪，最终将图像转换为 Tensor 格式，并进行归一化处理。这样做的目的是对训练和测试数据集进行相同的处理方式，以便在模型训练和测试时有相同的数据输入。"""
test_transforms = transforms.Compose([
    transforms.Pad(3),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# With augmentation
dataset_train = SkinData(train, transform=train_transforms)
dataset_test = SkinData(test, transform=test_transforms)

# -----------------------------------------------
dict_users = dataset_iid(dataset_train, num_users)
dict_users_test = dataset_iid(dataset_test, num_users)

## Train and Testing

In [None]:
#------------ Training And Testing  -----------------
net_glob_client.train()
#copy weights
w_glob_client = net_glob_client.state_dict()
# Federation takes place after certain local epochs in train() client-side
# this epoch is global epoch, also known as rounds
for iter in range(epochs):
    m = max(int(frac * num_users), 1)
    idxs_users = np.random.choice(range(num_users), m, replace = False)
    w_locals_client = []
      
    for idx in idxs_users:
        local = Client(net_glob_client, idx, lr, device, dataset_train = dataset_train, dataset_test = dataset_test, idxs = dict_users[idx], idxs_test = dict_users_test[idx])
        # Training ------------------
        w_client = local.train(net = copy.deepcopy(net_glob_client).to(device))
        w_locals_client.append(copy.deepcopy(w_client))
        
        # Testing -------------------
        local.evaluate(net = copy.deepcopy(net_glob_client).to(device), ell= iter)
        
            
    # Ater serving all clients for its local epochs------------
    # Fed  Server: Federation process at Client-Side-----------
    print("-----------------------------------------------------------")
    print("------ FedServer: Federation process at Client-Side ------- ")
    print("-----------------------------------------------------------")
    w_glob_client = FedAvg(w_locals_client)   
    
    # Update client-side global model 
    net_glob_client.load_state_dict(w_glob_client)    
    
#===================================================================================      

print("Training and Evaluation completed!")    



[91m Client4 Train => Local Epoch: 0 	Acc: 41.946 	Loss: 1.6383[00m
[4]
[92m Client4 Test =>                   	Acc: 67.339 	Loss: 1.6486[00m
[91m Client3 Train => Local Epoch: 0 	Acc: 40.299 	Loss: 1.6437[00m
[4, 3]
[92m Client3 Test =>                   	Acc: 64.453 	Loss: 1.7233[00m
[91m Client2 Train => Local Epoch: 0 	Acc: 42.734 	Loss: 1.6302[00m
[4, 3, 2]
[92m Client2 Test =>                   	Acc: 65.234 	Loss: 1.6943[00m
[91m Client1 Train => Local Epoch: 0 	Acc: 40.885 	Loss: 1.6761[00m
[4, 3, 2, 1]
[92m Client1 Test =>                   	Acc: 66.254 	Loss: 1.7079[00m
[91m Client0 Train => Local Epoch: 0 	Acc: 43.048 	Loss: 1.6146[00m
[4, 3, 2, 1, 0]
[92m Client0 Test =>                   	Acc: 66.059 	Loss: 1.6946[00m
------------------------------------------------
------ Federation process at Server-Side ------- 
------------------------------------------------
 Train: Round   0, Avg Accuracy 41.782 | Avg Loss 1.641
 Test: Round   0, Avg Accuracy 65.86

[91m Client3 Train => Local Epoch: 0 	Acc: 70.832 	Loss: 0.7975[00m
[2, 0, 3]
[92m Client3 Test =>                   	Acc: 71.680 	Loss: 0.7916[00m
[91m Client1 Train => Local Epoch: 0 	Acc: 71.151 	Loss: 0.7772[00m
[2, 0, 3, 1]
[92m Client1 Test =>                   	Acc: 68.772 	Loss: 0.8564[00m
[91m Client4 Train => Local Epoch: 0 	Acc: 71.809 	Loss: 0.7933[00m
[2, 0, 3, 1, 4]
[92m Client4 Test =>                   	Acc: 72.656 	Loss: 0.7759[00m
------------------------------------------------
------ Federation process at Server-Side ------- 
------------------------------------------------
 Train: Round   6, Avg Accuracy 71.854 | Avg Loss 0.774
 Test: Round   6, Avg Accuracy 71.011 | Avg Loss 0.818
-----------------------------------------------------------
------ FedServer: Federation process at Client-Side ------- 
-----------------------------------------------------------
[91m Client0 Train => Local Epoch: 0 	Acc: 74.530 	Loss: 0.7229[00m
[0]
[92m Client0 Test =>

In [None]:
#===============================================================================
# Save output data to .excel file (we use for comparision plots)
round_process = [i for i in range(1, len(acc_train_collect)+1)]
df = DataFrame({'round': round_process,'acc_train':acc_train_collect, 'acc_test':acc_test_collect})     
file_name = program+".xlsx"    
df.to_excel(file_name, sheet_name= "v1_test", index = False)     

#=============================================================================
#                         Program Completed
#=============================================================================