In [None]:
## Implements:

# Federated Prompt Learning (10 clients)

# OpenAI CLIP (ViT-B/32, frozen encoder)

# Local & Global prompts

# Low-rank factorization + residual

# Simulated DP noise (fast!)

# Personalization (local) vs Generalization (neighbor) evaluation

## --------------------------------

In [1]:
!pip install git+https://github.com/openai/CLIP.git
!pip install torchvision
!pip install tqdm


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /private/var/folders/bg/wsgmhpr97ms7nhlct5rkgzn40000gn/T/pip-req-build-y8vdpjpc
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /private/var/folders/bg/wsgmhpr97ms7nhlct5rkgzn40000gn/T/pip-req-build-y8vdpjpc
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
Building wheels for collected packages: clip
[33m  DEPRECATION: Building 'clip' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-iso

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import clip
import numpy as np
from tqdm import tqdm
import random
import copy

DEVICE = "cpu"   # CPU ONLY as requested
torch.set_num_threads(4)

# Federated setup
NUM_CLIENTS = 10
ROUNDS = 5
LOCAL_EPOCHS = 1
BATCH_SIZE = 16

# Prompt settings
PROMPT_LEN = 16
EMBED_DIM = 512   # CLIP ViT-B/32 text dim

# DP Simulation noise levels (fast)
SIGMA_L = 0.3    # "local" DP noise
SIGMA_G = 0.1    # "global" DP noise

# Low-rank factorization
RANK = 4
LR_LOCAL = 1e-3
LR_GLOBAL = 1e-3


In [None]:
clip_model, preprocess = clip.load("ViT-B/32", device=DEVICE)
for p in clip_model.parameters():
    p.requires_grad = False

print("Loaded CLIP on", DEVICE)


In [None]:
transform = preprocess

dataset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    download=True,
    transform=transform
)

testset = torchvision.datasets.Flowers102(
    root="./data",
    split="test",
    download=True,
    transform=transform
)

num_samples = len(dataset)
print("Training samples:", num_samples)


In [None]:
labels = np.array(dataset._labels)

# 102 classes → divide into 10 clients
classes_per_client = 102 // NUM_CLIENTS
client_classes = {}

idx_by_class = {i: np.where(labels == i)[0] for i in range(102)}

client_indices = []

start = 0
for cid in range(NUM_CLIENTS):
    cls_range = list(range(start, start + classes_per_client))
    start += classes_per_client
    
    indices = []
    for c in cls_range:
        indices.extend(idx_by_class[c])
    
    client_indices.append(indices)

print("Client 0 has", len(client_indices[0]), "samples")


In [None]:
class PromptLearner(nn.Module):
    def __init__(self):
        super().__init__()
        self.global_prompt = nn.Parameter(
            torch.randn(PROMPT_LEN, EMBED_DIM)
        )
        self.local_prompt = nn.Parameter(
            torch.randn(PROMPT_LEN, EMBED_DIM)
        )
    
    def full_prompt(self):
        return self.global_prompt + self.local_prompt


In [None]:
def low_rank_factorization(mat, rank=RANK):
    """
    mat: (PROMPT_LEN x EMBED_DIM)
    returns: u (PROMPT_LEN x rank), v (rank x EMBED_DIM), residual
    """
    # random projection for power iteration
    Q = torch.randn(mat.shape[1], rank)
    Q = torch.linalg.qr(Q).Q
    
    # one power iteration (fast)
    Z = mat.T @ (mat @ Q)
    Q = torch.linalg.qr(Z).Q
    
    u = mat @ Q
    v = Q.T
    approx = u @ v
    residual = mat - approx
    return u, v, residual



In [None]:
class Client:
    def __init__(self, cid, indices, global_model):
        self.cid = cid
        self.indices = indices
        self.model = copy.deepcopy(global_model)
    
    def get_loader(self):
        ds = Subset(dataset, self.indices)
        return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    
    def local_train(self):
        loader = self.get_loader()
        opt = torch.optim.SGD(self.model.parameters(), lr=LR_LOCAL)
        
        self.model.train()

        losses = []

        for _ in range(LOCAL_EPOCHS):
            for img, label in loader:
                img = img.to(DEVICE)
                
                opt.zero_grad()

                # Factorize local prompt into low rank + residual
                mat = self.model.local_prompt
                u, v, residual = low_rank_factorization(mat)

                # Reconstructed prompt
                p_local = u @ v + residual
                p_global = self.model.global_prompt

                # Encode prompt via CLIP text encoder
                text_embed = clip_model.encode_text(p_local)
                img_embed = clip_model.encode_image(img)

                logits = (text_embed @ img_embed.T) / 0.07
                loss = F.cross_entropy(logits, label.to(DEVICE))
                loss.backward()

                # Simulated LDP noise
                with torch.no_grad():
                    for name, p in self.model.named_parameters():
                        if "local_prompt" in name:
                            if p.grad is not None:
                                p.grad += torch.randn_like(p.grad) * SIGMA_L

                opt.step()
                losses.append(loss.item())

        return np.mean(losses)
    
    def get_global_gradient(self):
        # Simulated GDP global gradient (noise-only)
        grad = torch.randn_like(self.model.global_prompt) * SIGMA_G
        return grad


In [None]:
class Server:
    def __init__(self):
        self.global_model = PromptLearner().to(DEVICE)
        self.clients = [
            Client(cid, client_indices[cid], self.global_model)
            for cid in range(NUM_CLIENTS)
        ]
    
    def aggregate(self, grads):
        mean_grad = torch.stack(grads).mean(dim=0)
        with torch.no_grad():
            self.global_model.global_prompt -= LR_GLOBAL * mean_grad
    
    def run(self):
        for r in range(ROUNDS):
            print(f"\n=== Round {r+1}/{ROUNDS} ===")
            grads = []
            losses = []

            for client in tqdm(self.clients):
                # broadcast global prompt
                client.model.global_prompt.data = self.global_model.global_prompt.data.clone()

                loss = client.local_train()
                losses.append(loss)

                grads.append(client.get_global_gradient())
            
            self.aggregate(grads)
            print(f"Avg Local Loss: {np.mean(losses):.4f}")


In [None]:
server = Server()
server.run()


In [None]:
def evaluate_client(client, mode="local"):
    client.model.eval()
    test_loader = DataLoader(testset, batch_size=64, shuffle=False)

    correct = 0
    total = 0

    with torch.no_grad():
        prompt = client.model.full_prompt()

        text_embed = clip_model.encode_text(prompt)

        for img, label in test_loader:
            img = img.to(DEVICE)
            img_emb = clip_model.encode_image(img)

            logits = (text_embed @ img_emb.T) / 0.07
            pred = logits.argmax(dim=0)

            # local classes only
            if mode == "local":
                mask = torch.tensor(
                    [label in client.indices for label in range(len(testset))]
                )
            
            correct += (pred.cpu() == label).sum().item()
            total += len(label)

    return correct / total

# evaluate all clients
for cid, client in enumerate(server.clients):
    print(f"Client {cid} | Personalization Acc = {evaluate_client(client):.4f}")

