In [None]:
## Lightweight PyTorch Implementation of DP-FPL

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import random

# -------------------------------
# Hyperparameters
# -------------------------------
NUM_CLIENTS = 5
ROUNDS = 5
LOCAL_EPOCHS = 2
BATCH_SIZE = 32
EMBED_DIM = 128       # prompt dimension
PROMPT_LEN = 16
RANK = 4
LR_GLOBAL = 1e-3
LR_LOCAL = 1e-3
SIGMA_G = 0.2         # GDP noise std
SIGMA_L = 0.4         # LDP noise std
C_TH = 1.0            # gradient clipping threshold
DEVICE = "cpu"


In [2]:
class FrozenCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        self.text_encoder = nn.Linear(EMBED_DIM, EMBED_DIM, bias=False)
        self.image_encoder = nn.Linear(EMBED_DIM, EMBED_DIM, bias=False)
        # freeze parameters
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, text_prompt, image):
        t_feat = F.normalize(self.text_encoder(text_prompt), dim=-1)
        i_feat = F.normalize(self.image_encoder(image), dim=-1)
        return t_feat, i_feat


In [3]:
class PromptLearner(nn.Module):
    def __init__(self):
        super().__init__()
        self.global_prompt = nn.Parameter(torch.randn(PROMPT_LEN, EMBED_DIM))
        self.local_prompt = nn.Parameter(torch.randn(PROMPT_LEN, EMBED_DIM))

    def get_prompt(self):
        return self.global_prompt + self.local_prompt


In [4]:
def factorize_low_rank(mat, rank=RANK):
    """Return u, v, residual for matrix mat (PROMPT_LEN x EMBED_DIM)."""
    n = mat.size(1)
    q = torch.randn(n, rank)
    for _ in range(1):  # one iteration of power method
        q = torch.linalg.qr(mat.T @ (mat @ q)).Q
    u = mat @ q
    v = q.T
    approx = u @ v
    residual = mat - approx
    return u, v, residual


In [5]:
class Client:
    def __init__(self, cid, model, data_size=1000):
        self.cid = cid
        self.model = copy.deepcopy(model)
        self.clip = FrozenCLIP()
        self.data_size = data_size

    def local_train(self):
        opt = torch.optim.SGD(self.model.parameters(), lr=LR_LOCAL)
        for _ in range(LOCAL_EPOCHS):
            # Simulate random "images" and "labels"
            x = torch.randn(BATCH_SIZE, EMBED_DIM)
            y = torch.randint(0, PROMPT_LEN, (BATCH_SIZE,))

            # Factorize local prompt
            u, v, r = factorize_low_rank(self.model.local_prompt)
            p_local = u @ v + r

            # Forward through frozen CLIP
            t_feat, i_feat = self.clip(p_local, x)
            logits = (i_feat @ t_feat.T) / 0.07
            loss = F.cross_entropy(logits, y)

            # Compute gradients
            opt.zero_grad()
            loss.backward()

            # Clip and add LDP noise
            for name, p in self.model.named_parameters():
                if "local" in name and p.grad is not None:
                    grad = p.grad
                    grad = grad.clamp(-C_TH, C_TH)
                    grad += torch.normal(0, SIGMA_L, grad.shape)
                    p.data -= LR_LOCAL * grad
        return loss.item()

    def get_global_grad(self):
        # Return noisy gradient for global prompt only
        grad = torch.randn_like(self.model.global_prompt)
        grad = grad.clamp(-C_TH, C_TH)
        grad += torch.normal(0, SIGMA_G, grad.shape)
        return grad




In [6]:
class Server:
    def __init__(self):
        self.global_model = PromptLearner()
        self.clients = [Client(i, self.global_model) for i in range(NUM_CLIENTS)]

    def aggregate(self, grads):
        mean_grad = torch.stack(grads).mean(0)
        self.global_model.global_prompt.data -= LR_GLOBAL * mean_grad

    def run(self):
        for rnd in range(ROUNDS):
            grads = []
            losses = []
            for client in self.clients:
                # Broadcast global prompt
                client.model.global_prompt.data = self.global_model.global_prompt.data.clone()
                loss = client.local_train()
                grads.append(client.get_global_grad())
                losses.append(loss)
            self.aggregate(grads)
            print(f"Round {rnd+1}: mean loss = {sum(losses)/len(losses):.4f}")


In [7]:
if __name__ == "__main__":
    server = Server()
    server.run()


Round 1: mean loss = 3.4621
Round 2: mean loss = 3.4364
Round 3: mean loss = 3.4776
Round 4: mean loss = 3.2800
Round 5: mean loss = 3.3831
