In [None]:
!pip install transformers
!pip install torchvision
!pip install tqdm
!pip install numpy


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms

from transformers import CLIPModel, CLIPProcessor

import numpy as np
from tqdm import tqdm
import math
import copy
import random

DEVICE = "cpu"          # CPU-only mode (as requested)
torch.set_num_threads(4)

# ========== Federated Learning Parameters ============
NUM_CLIENTS = 10
ROUNDS = 5
LOCAL_EPOCHS = 1
BATCH_SIZE = 16

# ========== Prompt Tuning Settings ===================
PROMPT_LEN = 16
EMBED_DIM = 512           # CLIP ViT-B/16 text embedding dim
RANK = 8                  # As used in the paper
LR_LOCAL = 1e-3
LR_GLOBAL = 1e-3

# ========== Differential Privacy Parameters ==========
EPSILONS = [0.4, 0.2, 0.1, 0.05, 0.01]   # Paper settings
DELTA = 1e-5                              # Paper setting
CLIP_THRESHOLD = 10                       # Paper setting

print("Config loaded.")


In [None]:
hf_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

hf_clip.eval()
for p in hf_clip.parameters():
    p.requires_grad = False

print("Loaded HuggingFace CLIP ViT-B/16 (frozen).")


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

trainset = torchvision.datasets.Flowers102(
    root="./data",
    split="train",
    download=True,
    transform=transform
)

testset = torchvision.datasets.Flowers102(
    root="./data",
    split="test",
    download=True,
    transform=transform
)

labels = np.array(trainset._labels)
print("Train samples:", len(trainset))
print("Test samples:", len(testset))


In [None]:
idx_by_class = {i: np.where(labels == i)[0] for i in range(102)}

classes_per_client = 102 // NUM_CLIENTS
client_indices = []

start = 0
for cid in range(NUM_CLIENTS):
    selected_classes = list(range(start, start + classes_per_client))
    start += classes_per_client

    indices = []
    for c in selected_classes:
        indices.extend(idx_by_class[c])
    
    client_indices.append(indices)

print("Example: Client 0 samples =", len(client_indices[0]))


In [None]:
class PromptLearner(nn.Module):
    def __init__(self):
        super().__init__()
        self.global_prompt = nn.Parameter(torch.randn(PROMPT_LEN, EMBED_DIM))
        self.local_prompt  = nn.Parameter(torch.randn(PROMPT_LEN, EMBED_DIM))

    def full_prompt(self):
        return self.global_prompt + self.local_prompt


In [None]:
def low_rank_factorization(mat, rank=RANK):
    Q = torch.randn(mat.size(1), rank)
    Q = torch.linalg.qr(Q).Q

    # 1-step power iteration
    Z = mat.T @ (mat @ Q)
    Q = torch.linalg.qr(Z).Q

    u = mat @ Q
    v = Q.T
    approx = u @ v
    residual = mat - approx
    return u, v, residual


In [None]:
def compute_sigma(sensitivity, epsilon, delta):
    """
    Gaussian mechanism noise:
    σ >= sqrt(2 * log(1.25/delta)) * sensitivity / epsilon
    """
    return math.sqrt(2 * math.log(1.25/delta)) * sensitivity / epsilon


In [None]:
class Client:
    def __init__(self, cid, indices, global_model):
        self.cid = cid
        self.indices = indices
        self.model = copy.deepcopy(global_model)

    def get_loader(self):
        ds = Subset(trainset, self.indices)
        return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

    def local_train(self, sigma_l):
        loader = self.get_loader()
        opt = torch.optim.SGD(self.model.parameters(), lr=LR_LOCAL)

        losses = []

        for _ in range(LOCAL_EPOCHS):
            for img, label in loader:
                batch = processor(images=img, return_tensors="pt")
                pixel_values = batch["pixel_values"]

                opt.zero_grad()

                # low-rank factorization
                u, v, r = low_rank_factorization(self.model.local_prompt)
                p_local = u @ v + r

                # encode text
                text_features = hf_clip.text_model(
                    p_local.unsqueeze(0)
                ).last_hidden_state.mean(1)

                # encode images
                img_features = hf_clip.vision_model(pixel_values).pooler_output

                logits = (text_features @ img_features.T) / 0.07
                loss = F.cross_entropy(logits, label)
                loss.backward()

                # TRUE LDP clipping + noise
                for name, p in self.model.named_parameters():
                    if "local_prompt" in name and p.grad is not None:
                        p.grad = torch.clamp(p.grad, -CLIP_THRESHOLD, CLIP_THRESHOLD)
                        p.grad += torch.randn_like(p.grad) * sigma_l

                opt.step()
                losses.append(loss.item())

        return np.mean(losses)

    def get_global_grad(self, sigma_g):
        # Global DP: return Gaussian noise gradient
        return torch.randn_like(self.model.global_prompt) * sigma_g


In [None]:
class Server:
    def __init__(self):
        self.global_model = PromptLearner()
        self.clients = [
            Client(cid, client_indices[cid], self.global_model)
            for cid in range(NUM_CLIENTS)
        ]

    def aggregate(self, grads, sigma_g):
        mean_grad = torch.stack(grads).mean(dim=0)
        # add global DP noise
        noisy_grad = mean_grad + torch.randn_like(mean_grad) * sigma_g

        with torch.no_grad():
            self.global_model.global_prompt -= LR_GLOBAL * noisy_grad

    def run(self, epsilon):
        print(f"\n=== Training with ε = {epsilon} ===")

        # compute sensitivities
        SL = CLIP_THRESHOLD / BATCH_SIZE
        SG = CLIP_THRESHOLD / (NUM_CLIENTS * BATCH_SIZE)

        sigma_l = compute_sigma(SL, epsilon, DELTA)
        sigma_g = compute_sigma(SG, epsilon, DELTA)

        print(f"sigma_L={sigma_l:.4f}, sigma_G={sigma_g:.4f}")

        for r in range(ROUNDS):
            print(f"\n--- Round {r+1}/{ROUNDS} ---")

            grads = []
            losses = []

            for client in tqdm(self.clients):
                client.model.global_prompt.data = self.global_model.global_prompt.data.clone()

                losses.append(client.local_train(sigma_l))
                grads.append(client.get_global_grad(sigma_g))

            self.aggregate(grads, sigma_g)
            print("Avg local loss:", np.mean(losses))


In [None]:
server = Server()

for eps in EPSILONS:
    server.run(eps)


In [None]:
def evaluate_client(client):
    client.model.eval()
    loader = DataLoader(testset, batch_size=32)

    correct = 0
    total = 0

    prompt = client.model.full_prompt()
    text_features = hf_clip.text_model(
        prompt.unsqueeze(0)
    ).last_hidden_state.mean(1)

    with torch.no_grad():
        for img, label in loader:
            pixel_values = processor(images=img, return_tensors="pt")["pixel_values"]
            img_features = hf_clip.vision_model(pixel_values).pooler_output

            logits = (text_features @ img_features.T) / 0.07
            pred = logits.argmax(dim=0)

            correct += (pred == label).sum().item()
            total += len(label)

    return correct / total

for cid, client in enumerate(server.clients):
    print(f"Client {cid} Accuracy = {evaluate_client(client):.4f}")
