# Personalized Federated Lifelong Learning (PeFLL)
- https://arxiv.org/pdf/2306.05515.pdf
## Algorithm 1
> We start by describing the PeFLL-predict routine (Algorithm 1), which can predict a personalized model for any target client. First, the server sends the current embedding network, ηv, to the client 
(line 1), who evaluates it on all or a subset of its data to compute the client descriptor (line 3). Next, the client sends its descriptor to the server (line 4), who evaluates the hypernetwork on it (line 5). The resulting personalized model is sent back to the client (line 6), where it is ready for use. Overall, only two server-to-client and one client-to-server communication steps are required before the client has obtained a functioning personalized model (see Figure 1).

In [None]:
import torch

# Define the embedding network
class EmbeddingNetwork(torch.nn.Module):
    def __init__(self):
        super(EmbeddingNetwork, self).__init__()
        self.fc1 = torch.nn.Linear(10, 5)
        self.fc2 = torch.nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the descriptor network
class DescriptorNetwork(torch.nn.Module):
    def __init__(self):
        super(DescriptorNetwork, self).__init__()
        self.fc1 = torch.nn.Linear(2, 5)
        self.fc2 = torch.nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the personalized model
class PersonalizedModel(torch.nn.Module):
    def __init__(self):
        super(PersonalizedModel, self).__init__()
        self.fc1 = torch.nn.Linear(2, 5)
        self.fc2 = torch.nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the target client with private dataset S
target_client = torch.randn(10, 2)

# Server sends embedding network ηv to client
embedding_network = EmbeddingNetwork()

# Client selects a data batch B ⊆ S
batch_size = 5
batch_indices = torch.randint(0, 10, (batch_size,))
batch = target_client[batch_indices]

# Client computes v = v(B; ηv)
v = embedding_network(batch)

# Client sends descriptor v to server
descriptor_network = DescriptorNetwork()
descriptor = descriptor_network(v)

# Server computes θ = h(v; ηh)
personalized_model = PersonalizedModel()
theta = personalized_model(descriptor)

# Server sends personalized model θ to client

## Algorithm 2
> The training routine for PeFLL (Algorithm 2) mostly adopts a standard stochastic optimization pattern in a federated setting. In each iteration the server selects a batch of available clients (line 2) and broadcasts the embedding model, ηv, to all of them (line 3). Then, each client in parallel evaluates its descriptor, vi, (6), sends it to the server (line 7), and receives a personalized model from the server in return (line 8, 9). At this point the forward pass is over and backpropagation starts. To this end, each client performs local SGD for k steps on its personalized model and personal data (line 10). It sends the resulting update vector, ∆θi, to the server (line 11), where it acts as a proxy for ∂fi/∂θi. According to the chain rule, ... . The server can evaluate both expressions using backpropagation (line 12), because all required expressions are available to it now.
Thereby, it obtains update vectors ∆η^(i)_h and ∆vi, the latter of which it sends to the client (line 13) as a proxy for ∂fi/∂vi. Again based on the chain rule, ..., the client computes an update vector for the embedding network, ∆η^(i)_v (line 14), and sends it back to the server (line 15). Finally, the server updates all network parameters from the average of the per-client contributions as well as the contributions from the server objective (lines 17, 18).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the embedding network
class EmbeddingNetwork(nn.Module):
    def __init__(self):
        super(EmbeddingNetwork, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the descriptor network
class DescriptorNetwork(nn.Module):
    def __init__(self):
        super(DescriptorNetwork, self).__init__()
        self.fc1 = nn.Linear(2, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the personalized model
class PersonalizedModel(nn.Module):
    def __init__(self):
        super(PersonalizedModel, self).__init__()
        self.fc1 = nn.Linear(2, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the number of training steps and clients to select per step
T = 10
c = 5

# Define the hyperparameters
beta = 0.01
lambda_h = 0.01
lambda_v = 0.01

# Define the target clients with private datasets S
target_clients = [torch.randn(10, 2) for i in range(10)]

# Define the optimizer
optimizer = optim.Adam([{'params': embedding_network.parameters(), 'lr': 0.001},
                        {'params': descriptor_network.parameters(), 'lr': 0.001},
                        {'params': personalized_model.parameters(), 'lr': 0.001}])

# Train the model
for t in range(T):
    # Server randomly samples c clients
    P = torch.randperm(len(target_clients))[:c]

    # Server broadcasts embedding network η_v to P
    embedding_network = EmbeddingNetwork()

    # Train each client in parallel
    for i in P:
        # Client selects a data batch B ⊆ Si
        batch_size = 5
        batch_indices = torch.randint(0, 10, (batch_size,))
        batch = target_clients[i][batch_indices]

        # Client computes v = v(B; η_v), client computes descriptor
        v = embedding_network(batch)
        descriptor_network = DescriptorNetwork()
        descriptor = descriptor_network(v)

        # Client sends vi to server
        personalized_model = PersonalizedModel()
        theta_i = personalized_model(descriptor)

        # Server computes personalized model θi
        personalized_model = PersonalizedModel()
        theta_i = personalized_model(descriptor)

        # Server sends θi to client
        # Client runs k steps of local SGD on fi(θi)
        k = 5
        for j in range(k):
            # Compute the loss
            loss_fn = nn.MSELoss()
            loss = loss_fn(theta_i, torch.randn(1))

            # Compute the gradients
            optimizer.zero_grad()
            loss.backward()

            # Update the model
            optimizer.step()

        # Client sends ∆θi := θ^{new}_i − θi to server
        delta_theta_i = theta_i - personalized_model(descriptor)

        # Server runs backprop with error vector ∆θi
        optimizer.zero_grad()
        delta_theta_i.backward(retain_graph=True)

        # Server sends ∆vi to client
        delta_v_i = embedding_network.weight.grad

        # Client runs backprop with error vector ∆vi
        optimizer.zero_grad()
        delta_v_i.backward(retain_graph=True)

        # Client sends ∆η^(i)_v to server
        delta_eta_v_i = embedding_network.weight.grad

    # Update the embedding network and descriptor network
    embedding_network.weight.data -= (1 - 2 * beta * lambda_v) * embedding_network.weight.grad
    descriptor_network.weight.data -= (1 - 2 * beta * lambda_h) * descriptor_network.weight.grad

# Output the network parameters (η_h, η_v)
print(embedding_network.state_dict())
print(descriptor_network.state_dict())
print(personalized_model.state_dict())