In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from multiprocessing import Pool
from concurrent.futures import ProcessPoolExecutor
import tensorflow as tf
from tensorflow import keras
import numpy as np
from torch.utils.data import DataLoader, RandomSampler
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

In [3]:
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

In [4]:
train_dataset = torch.utils.data.TensorDataset(torch.tensor(x_train, dtype=torch.float32).permute(0, 3, 1, 2), torch.tensor(y_train, dtype=torch.long))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [5]:
# NaiveFourierKANLayer definition (as previously discussed)
class NaiveFourierKANLayer(nn.Module):
    def __init__(self, inputdim, outdim, initial_gridsize, addbias=True):
        super(NaiveFourierKANLayer, self).__init__()
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim
        self.gridsize_param = nn.Parameter(torch.tensor(initial_gridsize, dtype=torch.float32))
        self.fouriercoeffs = nn.Parameter(torch.empty(2, outdim, inputdim, initial_gridsize))
        nn.init.xavier_uniform_(self.fouriercoeffs)
        if self.addbias:
            self.bias = nn.Parameter(torch.zeros(1, outdim))

    def forward(self, x):
        gridsize = torch.clamp(self.gridsize_param, min=1).round().int()
        outshape = x.shape[:-1] + (self.outdim,)
        x = torch.reshape(x, (-1, self.inputdim))
        k = torch.reshape(torch.arange(1, gridsize + 1, device=x.device), (1, 1, 1, gridsize))
        xrshp = torch.reshape(x, (x.shape[0], 1, x.shape[1], 1))
        c = torch.cos(k * xrshp)
        s = torch.sin(k * xrshp)
        y = torch.sum(c * self.fouriercoeffs[0:1, :, :, :gridsize], (-2, -1))
        y += torch.sum(s * self.fouriercoeffs[1:2, :, :, :gridsize], (-2, -1))
        if self.addbias:
            y += self.bias
        y = torch.reshape(y, outshape)
        return y

# CNNFourierKAN model definition
class CNNFourierKAN(nn.Module):
    def __init__(self):
        super(CNNFourierKAN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2)
        self.fourierkan1 = NaiveFourierKANLayer(32*7*7, 128, initial_gridsize=100)
        self.fourierkan2 = NaiveFourierKANLayer(128, 10, initial_gridsize=100)

    def forward(self, x):
        x = F.selu(self.conv1(x))
        x = self.pool1(x)
        x = F.selu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fourierkan1(x)
        x = self.fourierkan2(x)
        return x

In [6]:
def split_data(x_data, y_data, num_clients=5):
    client_data = []
    num_items_per_client = x_data.shape[0] // num_clients
    for i in range(num_clients):
        start = i * num_items_per_client
        end = start + num_items_per_client
        client_x = x_data[start:end]
        client_y = y_data[start:end]
        if client_x.size > 0 and client_y.size > 0:
            client_data.append((client_x, client_y))
    return client_data

In [7]:
def federated_sgd(clients, model_constructor, x_test, y_test, rounds=10, epochs=1):
    global_model = model_constructor()  # 实例化全局 PyTorch 模型
    loss_function = torch.nn.CrossEntropyLoss()  # 定义损失函数
    accuracy_history = []

    for _ in range(rounds):
        for client_x, client_y in clients:
            local_model = model_constructor()  # 实例化一个新的本地模型
            local_model.load_state_dict(global_model.state_dict())  # 将全局模型权重复制到本地模型
            local_optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)

            # 训练本地模型
            local_model.train()
            client_x = torch.tensor(client_x, dtype=torch.float32).permute(0, 3, 1, 2)  # 调整维度
            client_y = torch.tensor(client_y, dtype=torch.long)
            for _ in range(epochs):
                outputs = local_model(client_x)
                loss = loss_function(outputs, client_y)
                local_optimizer.zero_grad()
                loss.backward()
                local_optimizer.step()

            # 更新全局模型权重
            global_state_dict = global_model.state_dict()
            local_state_dict = local_model.state_dict()
            for key in global_state_dict:
                global_state_dict[key] = global_state_dict[key] - (global_state_dict[key] - local_state_dict[key])
            global_model.load_state_dict(global_state_dict)

            del local_model  # 删除模型
            gc.collect()  # 清理内存
        # 评估全局模型
        global_model.eval()
        x_test_permuted = torch.tensor(x_test, dtype=torch.float32).permute(0, 3, 1, 2)  # 调整测试数据维度
        y_test = torch.tensor(y_test, dtype=torch.long)
        with torch.no_grad():
            outputs = global_model(x_test_permuted)
            _, predicted = torch.max(outputs.data, 1)
            total = y_test.size(0)
            correct = (predicted == y_test).sum().item()
            accuracy = correct / total
            accuracy_history.append(accuracy)

    return global_model, accuracy_history

In [None]:
clients = split_data(x_train, y_train)
global_model, accuracy_history = federated_sgd(clients, CNNFourierKAN, x_test, y_test)

In [None]:
plt.plot(accuracy_history)
plt.title('Test Accuracy Over Rounds of Training (KAN-CNN)')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.show()
print(f"Final test accuracy: {accuracy_history[-1]}")