📄 Paper | 📑 Supplemental | 📖 arXiv
If you find this work useful, please give us a star ⭐ on GitHub for the latest updates.
Unofficial PyTorch implementation. The code is adapted from the supplementary material provided with the paper submission.
- [2025.05] 🎉 Unofficial implementation released. Welcome to watch 👀 this repository for the latest updates.
- [2025.01] 🏆 FedMuscle accepted at ICLR 2026!
FedMuscle tackles the hard setting where clients have different model architectures and different tasks (image classification, multi-label classification, semantic segmentation, text classification). Unlike FedAvg-style methods, there is no shared model to aggregate — instead, clients learn a shared representation space.
The Muscle loss is a contrastive objective that aligns client representations on a small unlabeled public dataset, enabling knowledge transfer without sharing raw data or model weights:
Local Task Training ──► Generate Reps on Public Data ──► Server Aggregates ──► Muscle Loss Update
▲ │
└───────────────────────────── next round ◄────────────────────────────────────┘
Each client independently fine-tunes a different pretrained model (ViT variants, SegFormer, BERT, DistilBERT) using LoRA adapters. The server never sees model weights — only representations of the public dataset.
The Muscle loss uses cosine similarity between multi-client representations to compute per-sample weighting coefficients (α). Only representations of a small public dataset batch are communicated per round.
The key function that computes and applies the Muscle loss on the client side:
from itertools import combinations
import torch
from torch.nn import functional as F
def muscle_loss(rep_matrices, neg_expr, tau, device):
"""
Compute the Muscle contrastive loss.
Args:
rep_matrices: list of M representation tensors, shape (B, d).
rep_matrices[0] is the anchor (current client, requires grad).
neg_expr: server-pre-computed α tensor (see compute_alpha), shape (B,)^M.
tau: τ^(M+1), the anchor temperature (--tau in args).
device: torch device.
Returns:
scalar loss
"""
M = len(rep_matrices)
B = rep_matrices[0].shape[0]
# Add anchor-vs-others similarity to neg_expr
for i, j in combinations(range(M), 2):
if i == 0:
sim = F.cosine_similarity(
rep_matrices[i][:, None, :],
rep_matrices[j][None, :, :],
dim=-1,
)
for dim in range(M):
if dim != i and dim != j:
sim = sim.unsqueeze(dim)
neg_expr = neg_expr + (1.0 / tau) * sim
mask = torch.zeros((B,) * M, device=device)
mask.fill_diagonal_(1)
log_prob = neg_expr - torch.log(torch.exp(neg_expr).sum(tuple(range(1, M)), keepdim=True))
loss = -(mask * log_prob).sum(tuple(range(1, M))).mean()
return lossfrom itertools import combinations
import torch
from torch.nn import functional as F
def compute_alpha(rep_matrices, tau, tau_prime, device):
"""
Pre-compute the server-side α tensor for the Muscle loss.
τ^(M+1) = tau (anchor temperature, --tau)
τ^M = tau_prime (non-anchor temperature, --tau_prime)
γ = 1/τ^M − 1/τ^(M+1) controls how strongly non-anchor pairs are down-weighted.
"""
M = len(rep_matrices)
B = rep_matrices[0].shape[0]
gamma = 1.0 / tau_prime - 1.0 / tau
neg_expr = torch.zeros((B,) * M, device=device)
for i, j in combinations(range(M), 2):
if i != 0:
sim = F.cosine_similarity(
rep_matrices[i][:, None, :],
rep_matrices[j][None, :, :],
dim=-1,
)
for dim in range(M):
if dim != i and dim != j:
sim = sim.unsqueeze(dim)
neg_expr += (-gamma) * sim
return neg_exprThe example below shows one alignment step with 3 encoders on the same machine — a ViT (image classification), SegFormer (segmentation), and BERT (text).
import torch
import torch.nn.functional as F
B, d = 32, 256
tau = 0.20 # τ^(M+1): anchor temperature
tau_prime = 0.15 # τ^M: non-anchor temperature
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Step 1 · Freeze all encoders and collect representations ─────────────────
with torch.no_grad():
rep_A = F.normalize(model_A(batch, projection=True), dim=-1) # (B, d) — ViT
rep_B = F.normalize(model_B(batch, projection=True), dim=-1) # (B, d) — SegFormer
rep_C = F.normalize(model_C(batch, projection=True), dim=-1) # (B, d) — BERT
# ── Step 2 · Compute α for each encoder (always put the anchor at index 0) ───
# For M=3, compute_alpha only sees pairs among the *non-anchor* encoders:
# alpha_A encodes the B↔C similarity
# alpha_B encodes the A↔C similarity
# alpha_C encodes the A↔B similarity
alpha_A = compute_alpha([rep_A, rep_B, rep_C], tau, tau_prime, device) # (B, B, B)
alpha_B = compute_alpha([rep_B, rep_A, rep_C], tau, tau_prime, device)
alpha_C = compute_alpha([rep_C, rep_A, rep_B], tau, tau_prime, device)
# ── Step 3 · Update each encoder with its Muscle loss ────────────────────────
# Encoder A: re-encode with grad; frozen reps of B and C stay from Step 1
rep_A_live = F.normalize(model_A(batch, projection=True), dim=-1)
loss_A = muscle_loss([rep_A_live, rep_B, rep_C], alpha_A, tau, device)
optimizer_A.zero_grad(); loss_A.backward(); optimizer_A.step()
# Encoder B
rep_B_live = F.normalize(model_B(batch, projection=True), dim=-1)
loss_B = muscle_loss([rep_B_live, rep_A, rep_C], alpha_B, tau, device)
optimizer_B.zero_grad(); loss_B.backward(); optimizer_B.step()
# Encoder C
rep_C_live = F.normalize(model_C(batch, projection=True), dim=-1)
loss_C = muscle_loss([rep_C_live, rep_A, rep_B], alpha_C, tau, device)
optimizer_C.zero_grad(); loss_C.backward(); optimizer_C.step()Key invariants:
rep_matrices[0]is the anchor encoder and must carry gradients; all other entries come from the frozen Step 1 tensors.compute_alphaandmuscle_lossmust use the same ordering — whichever encoder is at index 0 incompute_alphamust also be at index 0 inmuscle_loss.- When M = 2,
compute_alphareturns all zeros (no non-anchor pairs) andmuscle_lossreduces to standard InfoNCE.
Requirements: Python 3.10+, PyTorch 2.1+, CUDA GPU
git clone https://github.com/mahdibeit/FedMuscle
cd FedMuscle
python -m venv .venv
source .venv/bin/activate # or .venv\Scripts\Activate.ps1 on Windows
pip install -r requirements.txtconda create -n fedmuscle python=3.10 -y
conda activate fedmuscle
pip install -r requirements.txtCOCO / segmentation tasks only: Install
pycocotoolsseparately:pip install pycocotools
FedMuscle expects pre-partitioned client data. The default structure under --data_dir:
data_dir/
├── MLC_setup/
│ ├── train/ client_0.json, client_1.json, client_2.json (COCO annotations)
│ ├── test/ client_0.json, ...
│ └── public_Pascal/ unlabeled Pascal VOC images (public dataset)
├── semseg_setup/
│ ├── train/ client_0.json, ... (COCO-format segmentation annotations)
│ └── test/ client_0.json, ...
├── IC/
│ ├── CIFAR100/
│ │ ├── train/ client_0.pkl, client_1.pkl (dict with "data" and "labels")
│ │ ├── test/ client_0.pkl, ...
│ │ └── public/ unlabeled CIFAR-100 images (if using CIFAR100 public dataset)
│ └── CIFAR10/
│ ├── train/ client_0.pkl
│ ├── test/ client_0.pkl
│ └── public/ unlabeled CIFAR-10 images (if using CIFAR10 public dataset)
└── TC/
├── yahoo_qa/
│ ├── train/ client_0.pkl (dict with "texts" and "labels")
│ └── test/ client_0.pkl
└── flicker_public/
├── images/ unlabeled Flickr images
└── image_file_name_to_caption.pkl
Run the provided helper script to create this directory skeleton in one step:
python scripts/setup_data_dirs.py --data_dir /path/to/data/After the directories exist, download and place each dataset's files according to the naming convention above.
Download models to --models_path with the following directory layout:
models_path/
├── google/
│ ├── vit-base-patch32-224-in21k/
│ ├── vit-base-patch16-224-in21k/
│ └── vit-large-patch16-224-in21k/
├── WinKawaks/
│ ├── vit-small-patch16-224/
│ └── vit-tiny-patch16-224/
├── nvidia/
│ ├── segformer-b0-finetuned-ade-512-512/
│ └── segformer-b1-finetuned-ade-512-512/
├── googlebert/
│ └── bert-base-uncased/
└── distilbert/
└── distilbert-base-uncased/
Download with transformers:
from transformers import AutoModel
AutoModel.from_pretrained("google/vit-base-patch32-224-in21k", cache_dir="/path/to/models_path/google/")FedMuscle/
├── fedmuscle/ # Main package
│ ├── clients/ # Task-specific client implementations
│ │ ├── base.py # Client dataclass
│ │ ├── image_classification.py
│ │ ├── multi_label_classification.py
│ │ ├── semantic_segmentation.py
│ │ ├── text_classification.py
│ │ └── text_generation.py
│ ├── methods/ # Federated learning methods
│ │ ├── fedmuscle.py # Muscle loss, aggregation, rep alignment
│ │ └── local_training.py # Standard local SGD
│ └── utils/
│ ├── args.py # Argument parsing
│ ├── data.py # Dataset loading and public dataset setup
│ ├── experiment.py # Client factory / experiment configuration
│ ├── log.py # CSV metric logging
│ └── model_ops.py # Trainable parameter counting
├── main.py # Training entry point
└── scripts/
├── run_fedmuscle.sh # Example end-to-end run
└── setup_data_dirs.py # Create the data directory skeleton
python main.py \
--aggregation_method fedmuscle \
--adapter_method lora \
--rank 16 \
--proj_dim 256 \
--tau 0.2 \
--tau_prime 0.15 \
--public_dataset_name pascal \
--data_dir /path/to/data/ \
--root_train_image_folder /path/to/train/images/ \
--root_val_image_folder /path/to/val/images/ \
--models_path /path/to/hf_models/ \
--num_MLC_clients 3 \
--num_IC100_clients 2 \
--num_IC10_clients 1 \
--num_epochs 150 \
--lr 0.001 \
--device 0 \
--seed 1Or use the provided script:
bash scripts/run_fedmuscle.shpython main.py \
--public_dataset_name flicker_multi_modal \
--num_MLC_clients 3 \
--num_IC100_clients 2 \
--num_IC10_clients 1 \
--num_yahoo_topic_classification_clients 2 \
--num_epochs 150 \
--lr 0.001 \
--device 0 \
--seed 1
# ... (other args as above)Results are saved under results/ with per-client CSV logs for each experiment seed.
If you use this implementation, please cite the original paper:
@inproceedings{setayesh2026fedmuscle,
title={Toward Enhancing Representation Learning in Federated Multi-Task Settings},
author={Setayesh, Mehdi and Beitollahi, Mahdi and Khalil, Yasser H. and Li, Hongliang},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=nIOIfHHYzk}
}This is an unofficial implementation provided for the research community. The code is adapted from the supplementary material provided with the paper submission.