Skip to content

mahdibeit/FedMuscle

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

📄 Paper  |  📑 Supplemental  |  📖 arXiv

If you find this work useful, please give us a star ⭐ on GitHub for the latest updates.

ICLR 2026 arXiv GitHub Issues

Unofficial PyTorch implementation. The code is adapted from the supplementary material provided with the paper submission.


📰 News

  • [2025.05] 🎉 Unofficial implementation released. Welcome to watch 👀 this repository for the latest updates.
  • [2025.01] 🏆 FedMuscle accepted at ICLR 2026!

🌟 Highlights

💡 Federated multi-task learning with heterogeneous models

FedMuscle tackles the hard setting where clients have different model architectures and different tasks (image classification, multi-label classification, semantic segmentation, text classification). Unlike FedAvg-style methods, there is no shared model to aggregate — instead, clients learn a shared representation space.

The Muscle loss is a contrastive objective that aligns client representations on a small unlabeled public dataset, enabling knowledge transfer without sharing raw data or model weights:

Local Task Training ──► Generate Reps on Public Data ──► Server Aggregates ──► Muscle Loss Update
       ▲                                                                               │
       └───────────────────────────── next round ◄────────────────────────────────────┘

🔀 Model and task heterogeneity

Each client independently fine-tunes a different pretrained model (ViT variants, SegFormer, BERT, DistilBERT) using LoRA adapters. The server never sees model weights — only representations of the public dataset.

📡 Communication-efficient alignment

The Muscle loss uses cosine similarity between multi-client representations to compute per-sample weighting coefficients (α). Only representations of a small public dataset batch are communicated per round.


✅ Takeaway: Core Muscle Loss

The key function that computes and applies the Muscle loss on the client side:

from itertools import combinations
import torch
from torch.nn import functional as F

def muscle_loss(rep_matrices, neg_expr, tau, device):
    """
    Compute the Muscle contrastive loss.

    Args:
        rep_matrices: list of M representation tensors, shape (B, d).
                      rep_matrices[0] is the anchor (current client, requires grad).
        neg_expr:     server-pre-computed α tensor (see compute_alpha), shape (B,)^M.
        tau:          τ^(M+1), the anchor temperature (--tau in args).
        device:       torch device.

    Returns:
        scalar loss
    """
    M = len(rep_matrices)
    B = rep_matrices[0].shape[0]

    # Add anchor-vs-others similarity to neg_expr
    for i, j in combinations(range(M), 2):
        if i == 0:
            sim = F.cosine_similarity(
                rep_matrices[i][:, None, :],
                rep_matrices[j][None, :, :],
                dim=-1,
            )
            for dim in range(M):
                if dim != i and dim != j:
                    sim = sim.unsqueeze(dim)
            neg_expr = neg_expr + (1.0 / tau) * sim

    mask = torch.zeros((B,) * M, device=device)
    mask.fill_diagonal_(1)

    log_prob = neg_expr - torch.log(torch.exp(neg_expr).sum(tuple(range(1, M)), keepdim=True))
    loss = -(mask * log_prob).sum(tuple(range(1, M))).mean()
    return loss

🪐 How the server-side weight (α) is computed

from itertools import combinations
import torch
from torch.nn import functional as F

def compute_alpha(rep_matrices, tau, tau_prime, device):
    """
    Pre-compute the server-side α tensor for the Muscle loss.

    τ^(M+1) = tau  (anchor temperature, --tau)
    τ^M     = tau_prime  (non-anchor temperature, --tau_prime)
    γ = 1/τ^M − 1/τ^(M+1) controls how strongly non-anchor pairs are down-weighted.
    """
    M = len(rep_matrices)
    B = rep_matrices[0].shape[0]
    gamma = 1.0 / tau_prime - 1.0 / tau

    neg_expr = torch.zeros((B,) * M, device=device)
    for i, j in combinations(range(M), 2):
        if i != 0:
            sim = F.cosine_similarity(
                rep_matrices[i][:, None, :],
                rep_matrices[j][None, :, :],
                dim=-1,
            )
            for dim in range(M):
                if dim != i and dim != j:
                    sim = sim.unsqueeze(dim)
            neg_expr += (-gamma) * sim
    return neg_expr

🧩 3-Modality Example (M = 3)

The example below shows one alignment step with 3 encoders on the same machine — a ViT (image classification), SegFormer (segmentation), and BERT (text).

import torch
import torch.nn.functional as F

B, d      = 32, 256
tau       = 0.20   # τ^(M+1): anchor temperature
tau_prime = 0.15   # τ^M:     non-anchor temperature
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── Step 1 · Freeze all encoders and collect representations ─────────────────
with torch.no_grad():
    rep_A = F.normalize(model_A(batch, projection=True), dim=-1)  # (B, d) — ViT
    rep_B = F.normalize(model_B(batch, projection=True), dim=-1)  # (B, d) — SegFormer
    rep_C = F.normalize(model_C(batch, projection=True), dim=-1)  # (B, d) — BERT

# ── Step 2 · Compute α for each encoder (always put the anchor at index 0) ───
# For M=3, compute_alpha only sees pairs among the *non-anchor* encoders:
#   alpha_A encodes the B↔C similarity
#   alpha_B encodes the A↔C similarity
#   alpha_C encodes the A↔B similarity
alpha_A = compute_alpha([rep_A, rep_B, rep_C], tau, tau_prime, device)  # (B, B, B)
alpha_B = compute_alpha([rep_B, rep_A, rep_C], tau, tau_prime, device)
alpha_C = compute_alpha([rep_C, rep_A, rep_B], tau, tau_prime, device)

# ── Step 3 · Update each encoder with its Muscle loss ────────────────────────
# Encoder A: re-encode with grad; frozen reps of B and C stay from Step 1
rep_A_live = F.normalize(model_A(batch, projection=True), dim=-1)
loss_A = muscle_loss([rep_A_live, rep_B, rep_C], alpha_A, tau, device)
optimizer_A.zero_grad(); loss_A.backward(); optimizer_A.step()

# Encoder B
rep_B_live = F.normalize(model_B(batch, projection=True), dim=-1)
loss_B = muscle_loss([rep_B_live, rep_A, rep_C], alpha_B, tau, device)
optimizer_B.zero_grad(); loss_B.backward(); optimizer_B.step()

# Encoder C
rep_C_live = F.normalize(model_C(batch, projection=True), dim=-1)
loss_C = muscle_loss([rep_C_live, rep_A, rep_B], alpha_C, tau, device)
optimizer_C.zero_grad(); loss_C.backward(); optimizer_C.step()

Key invariants:

  • rep_matrices[0] is the anchor encoder and must carry gradients; all other entries come from the frozen Step 1 tensors.
  • compute_alpha and muscle_loss must use the same ordering — whichever encoder is at index 0 in compute_alpha must also be at index 0 in muscle_loss.
  • When M = 2, compute_alpha returns all zeros (no non-anchor pairs) and muscle_loss reduces to standard InfoNCE.

🔧 Environment Setup

Requirements: Python 3.10+, PyTorch 2.1+, CUDA GPU

Option A — pip + venv

git clone https://github.com/mahdibeit/FedMuscle
cd FedMuscle

python -m venv .venv
source .venv/bin/activate        # or .venv\Scripts\Activate.ps1 on Windows

pip install -r requirements.txt

Option B — conda

conda create -n fedmuscle python=3.10 -y
conda activate fedmuscle
pip install -r requirements.txt

COCO / segmentation tasks only: Install pycocotools separately:

pip install pycocotools

📂 Data & Model Setup

Data layout

FedMuscle expects pre-partitioned client data. The default structure under --data_dir:

data_dir/
├── MLC_setup/
│   ├── train/          client_0.json, client_1.json, client_2.json  (COCO annotations)
│   ├── test/           client_0.json, ...
│   └── public_Pascal/  unlabeled Pascal VOC images (public dataset)
├── semseg_setup/
│   ├── train/          client_0.json, ...  (COCO-format segmentation annotations)
│   └── test/           client_0.json, ...
├── IC/
│   ├── CIFAR100/
│   │   ├── train/      client_0.pkl, client_1.pkl  (dict with "data" and "labels")
│   │   ├── test/       client_0.pkl, ...
│   │   └── public/     unlabeled CIFAR-100 images (if using CIFAR100 public dataset)
│   └── CIFAR10/
│       ├── train/      client_0.pkl
│       ├── test/       client_0.pkl
│       └── public/     unlabeled CIFAR-10 images (if using CIFAR10 public dataset)
└── TC/
    ├── yahoo_qa/
    │   ├── train/      client_0.pkl  (dict with "texts" and "labels")
    │   └── test/       client_0.pkl
    └── flicker_public/
        ├── images/     unlabeled Flickr images
        └── image_file_name_to_caption.pkl

Run the provided helper script to create this directory skeleton in one step:

python scripts/setup_data_dirs.py --data_dir /path/to/data/

After the directories exist, download and place each dataset's files according to the naming convention above.

HuggingFace models

Download models to --models_path with the following directory layout:

models_path/
├── google/
│   ├── vit-base-patch32-224-in21k/
│   ├── vit-base-patch16-224-in21k/
│   └── vit-large-patch16-224-in21k/
├── WinKawaks/
│   ├── vit-small-patch16-224/
│   └── vit-tiny-patch16-224/
├── nvidia/
│   ├── segformer-b0-finetuned-ade-512-512/
│   └── segformer-b1-finetuned-ade-512-512/
├── googlebert/
│   └── bert-base-uncased/
└── distilbert/
    └── distilbert-base-uncased/

Download with transformers:

from transformers import AutoModel
AutoModel.from_pretrained("google/vit-base-patch32-224-in21k", cache_dir="/path/to/models_path/google/")

📁 Repository Structure

FedMuscle/
├── fedmuscle/                    # Main package
│   ├── clients/                  # Task-specific client implementations
│   │   ├── base.py               # Client dataclass
│   │   ├── image_classification.py
│   │   ├── multi_label_classification.py
│   │   ├── semantic_segmentation.py
│   │   ├── text_classification.py
│   │   └── text_generation.py
│   ├── methods/                  # Federated learning methods
│   │   ├── fedmuscle.py          # Muscle loss, aggregation, rep alignment
│   │   └── local_training.py     # Standard local SGD
│   └── utils/
│       ├── args.py               # Argument parsing
│       ├── data.py               # Dataset loading and public dataset setup
│       ├── experiment.py         # Client factory / experiment configuration
│       ├── log.py                # CSV metric logging
│       └── model_ops.py          # Trainable parameter counting
├── main.py                       # Training entry point
└── scripts/
    ├── run_fedmuscle.sh          # Example end-to-end run
    └── setup_data_dirs.py        # Create the data directory skeleton

🚀 Running Experiments

Vision-only (3 MLC + 2 IC100 + 1 IC10 clients, Pascal public dataset)

python main.py \
  --aggregation_method fedmuscle \
  --adapter_method lora \
  --rank 16 \
  --proj_dim 256 \
  --tau 0.2 \
  --tau_prime 0.15 \
  --public_dataset_name pascal \
  --data_dir /path/to/data/ \
  --root_train_image_folder /path/to/train/images/ \
  --root_val_image_folder /path/to/val/images/ \
  --models_path /path/to/hf_models/ \
  --num_MLC_clients 3 \
  --num_IC100_clients 2 \
  --num_IC10_clients 1 \
  --num_epochs 150 \
  --lr 0.001 \
  --device 0 \
  --seed 1

Or use the provided script:

bash scripts/run_fedmuscle.sh

Mixed vision + language (add Yahoo topic classification)

python main.py \
  --public_dataset_name flicker_multi_modal \
  --num_MLC_clients 3 \
  --num_IC100_clients 2 \
  --num_IC10_clients 1 \
  --num_yahoo_topic_classification_clients 2 \
  --num_epochs 150 \
  --lr 0.001 \
  --device 0 \
  --seed 1
  # ... (other args as above)

Results are saved under results/ with per-client CSV logs for each experiment seed.


📝 Citation

If you use this implementation, please cite the original paper:

@inproceedings{setayesh2026fedmuscle,
  title={Toward Enhancing Representation Learning in Federated Multi-Task Settings},
  author={Setayesh, Mehdi and Beitollahi, Mahdi and Khalil, Yasser H. and Li, Hongliang},
  booktitle={The Fourteenth International Conference on Learning Representations},
  year={2026},
  url={https://openreview.net/forum?id=nIOIfHHYzk}
}

⭐ Star History

Star History Chart


Disclaimer

This is an unofficial implementation provided for the research community. The code is adapted from the supplementary material provided with the paper submission.

About

FedMuscle (ICLR 2026) — contrastive representation alignment for federated multi-task learning with heterogeneous models.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors