In [1]:
import numpy as np
import torch
import evaluate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, AutoModelForSequenceClassification, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorWithPadding, DataCollatorForLanguageModeling, ViTImageProcessor, ViTForImageClassification
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
from datasets import load_dataset
from torch.nn.functional import cross_entropy  # Assuming classification task
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from datasets import ClassLabel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# General configuration
USERS = 10
ROUNDS = 500
LOCAL_STEPS = 5
EVAL_INTERVAL = 10
LORA_RANK = 8
ONLY_MERGE_ADAPTERS = False  # Whether we only merge LoRA adapters or all trainable parameters

# text classification:  "SetFit/20_newsgroups", "imdb", "ag_news", "emotion" or "yelp_review_full"
# image classification: "uoft-cs/cifar10"
# next-word prediction: "wikitext" (+"wikitext-2-raw-v1" as config)
# summarization:        "xsum"
DATASET = "ag_news"
DATASET_CONFIG = "wikitext-2-raw-v1"
TASK = "txt_classification"  # "img_classification", "txt_classification", "prediction", "summarization"
DATASET_DISTRIBUTION = "uniform"  # "uniform" or "dirichlet"
ALPHA = 0.1
FT_ALGORITHM = "lora"  # "lora" or "head" (the latter just fine-tunes the classification head)

# text classification:   "roberta-base", "distilbert-base-uncased"
# image classification:  "google/vit-base-patch16-224"
# next-token prediction: "gpt2", "distilgpt2"
# summarization:         "t5-small"
BASE_MODEL = 'roberta-base'

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device: %s" % DEVICE)

if FT_ALGORITHM == "head" and USERS > 1:
    raise RuntimeError("Head-only FT can only be done with a single user for now!")

Using device: mps


In [3]:
if TASK in ["txt_classification", "img_classification", "summarization"]:
    dataset = load_dataset(DATASET, cache_dir="datasets")
elif TASK == "prediction":
    dataset = load_dataset(DATASET, DATASET_CONFIG, cache_dir="datasets")
else:
    raise RuntimeError("Unknown task %s" % TASK)

In [4]:
if DATASET == "SetFit/20_newsgroups":
    # We need to do some small transformations
    unique_classes = sorted(set(dataset['train']['label']))
    label_feature = ClassLabel(names=unique_classes)
    dataset = dataset.cast_column('label', label_feature)
    dataset = dataset.remove_columns('label_text')
elif DATASET == "uoft-cs/cifar10":
    feature_extractor = ViTImageProcessor.from_pretrained(BASE_MODEL)

    def transform(examples):
        inputs = feature_extractor(examples['img'], return_tensors='pt')
        inputs['labels'] = torch.tensor(examples['label'])
        return inputs
    
    dataset = dataset.with_transform(transform)

In [5]:
print(dataset["train"])

Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})


In [6]:
if TASK == "txt_classification":
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
    
    def preprocess(examples):
        tokenized = tokenizer(examples['text'], truncation=True, padding=True)
        return tokenized
    
    processed_dataset = dataset.map(preprocess, batched=True,  remove_columns=["text"])
elif TASK == "prediction":
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
    
    def preprocess(examples):
        return tokenizer(examples["text"])

    tokenized_dataset = dataset.map(preprocess, batched=True, num_proc=4, remove_columns=["text"])
    
    block_size = 128
    
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result
    
    processed_dataset = tokenized_dataset.map(
        group_texts,
        batched=True,
        batch_size=1000,
        num_proc=4,
    )
elif TASK == "summarization":
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)

    max_input_length = 1024
    max_target_length = 128
    if BASE_MODEL in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
        prefix = "summarize: "
    else:
        prefix = ""
    
    def preprocess(examples):
        inputs = [prefix + doc for doc in examples["document"]]
        model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    
        # Setup the tokenizer for targets
        labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)
    
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    processed_dataset = dataset.map(preprocess, num_proc=8, batched=True, remove_columns=["document", "summary", "id"])
else:
    processed_dataset = dataset

train_dataset = processed_dataset['train']

In [7]:
def split_dataset_uniform(dataset, n_users):
    # Shuffle the dataset to ensure randomness
    shuffled_dataset = dataset.shuffle(seed=42)

    # Calculate the number of samples per user
    num_samples = len(shuffled_dataset) // n_users

    # Create a dictionary to hold the split datasets
    split_datasets = []

    for i in range(n_users):
        start_idx = i * num_samples
        end_idx = start_idx + num_samples if i < n_users - 1 else len(shuffled_dataset)

        # Create a subset for the current user
        user_subset = shuffled_dataset.select(range(start_idx, end_idx))
        split_datasets.append(user_subset)

    return split_datasets


def split_dataset_dirichlet(dataset, n_users, alpha):
    # Get the number of classes
    labels = dataset['label']
    num_classes = len(set(labels))
    
    # Initialize a list to hold indices for each user
    user_indices = [[] for _ in range(n_users)]
    
    # Seed for reproducibility
    np.random.seed(42)
    
    # Generate the Dirichlet distribution for each class
    for cls in range(num_classes):
        # Get indices for all samples of this class
        cls_indices = np.where(np.array(labels) == cls)[0]
        
        # Get the number of samples for this class
        np.random.shuffle(cls_indices)
        num_samples = len(cls_indices)
        
        # Split the samples according to the Dirichlet distribution
        proportions = np.random.dirichlet([alpha] * n_users)
        
        # Ensure the proportions sum to 1
        proportions = proportions / proportions.sum()
        
        # Assign samples to each user based on the proportions
        split = (np.cumsum(proportions) * num_samples).astype(int)[:-1]
        cls_indices_split = np.split(cls_indices, split)
        
        for user, indices in enumerate(cls_indices_split):
            user_indices[user].extend(indices)
    
    # Create datasets for each user
    split_datasets = []
    for indices in user_indices:
        split_datasets.append(dataset.select(indices))
    
    return split_datasets


# Split the dataset
if DATASET_DISTRIBUTION == "uniform":
    split_datasets = split_dataset_uniform(train_dataset, USERS)
elif DATASET_DISTRIBUTION == "dirichlet":
    split_datasets = split_dataset_dirichlet(train_dataset, USERS, ALPHA)
else:
    raise RuntimeError("Unknown dataset distribution")

In [8]:
for idx, data in enumerate(split_datasets):
    print(f'{idx}: {len(data)} samples')

print(split_datasets[0])

0: 12000 samples
1: 12000 samples
2: 12000 samples
3: 12000 samples
4: 12000 samples
5: 12000 samples
6: 12000 samples
7: 12000 samples
8: 12000 samples
9: 12000 samples
Dataset({
    features: ['label', 'input_ids', 'attention_mask'],
    num_rows: 12000
})


In [17]:
def get_pretrained_model():
    if TASK == "txt_classification":
        # Extract the number of classess and their names
        num_labels = dataset['train'].features['label'].num_classes
        class_names = dataset["train"].features["label"].names
        print(f"number of labels: {num_labels}")
        print(f"the labels: {class_names}")
        
        # Create an id2label mapping
        # We will need this for our classifier.
        id2label = {i: label for i, label in enumerate(class_names)}
        
        pretrained_model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, id2label=id2label, cache_dir="models")
    elif TASK == "prediction":
        pretrained_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, cache_dir="models")
    elif TASK == "img_classification":
        pretrained_model = ViTForImageClassification.from_pretrained(BASE_MODEL, num_labels=10, ignore_mismatched_sizes=True, cache_dir="models")
    elif TASK == "summarization":
        pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL, cache_dir="models")
    return pretrained_model.to(DEVICE)

In [18]:
eval_dataset=processed_dataset['test'].shard(num_shards=2, index=0)
test_dataset=processed_dataset['test'].shard(num_shards=2, index=1)

if TASK == "txt_classification":
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
elif TASK == "prediction":
    def text_collate(examples):
        input_ids = torch.stack([torch.tensor(d["input_ids"]) for d in examples])
        labels = torch.stack([torch.tensor(d["labels"]) for d in examples])
        attention_mask = torch.stack([torch.tensor(d["attention_mask"]) for d in examples])
        return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}

    data_collator = text_collate
elif TASK == "img_classification":
    data_collator = None  # Use the default one
elif TASK == "summarization":
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
else:
    raise RuntimeError("Unknown task %s" % TASK)

In [19]:
def print_trainable_parameters(model):
    total_params = 0
    trainable_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable parameter: {name}, shape: {param.shape}")
            trainable_params += param.numel()
        total_params += param.numel()
    print(f"Total parameters: {total_params}")
    print(f"Trainable parameters: {trainable_params}")
    print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%")

if TASK == "txt_classification":
    target_modules = None
    if BASE_MODEL == "distilbert-base-uncased":
        target_modules = {"q_lin", "v_lin"}
    peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=LORA_RANK, lora_alpha=16, lora_dropout=0.1, target_modules=target_modules)
elif TASK == "img_classification":
    peft_config = LoraConfig(inference_mode=False, r=LORA_RANK, lora_alpha=16, lora_dropout=0.1, target_modules=["attention.query", "attention.key"])
elif TASK == "prediction":
    peft_config = LoraConfig(task_type="CAUSAL_LM", inference_mode=False, r=LORA_RANK, lora_alpha=16, lora_dropout=0.1)
elif TASK == "summarization":
    peft_config = LoraConfig(task_type="SEQ_2_SEQ_LM", inference_mode=False, r=LORA_RANK, lora_alpha=16, lora_dropout=0.1)

if FT_ALGORITHM == "lora":
    base_model = get_pretrained_model()
    peft_model = get_peft_model(base_model, peft_config).to(DEVICE)

    # Create adapters for each user
    for adapter_name in ["client_%d" % i for i in range(USERS)]:
        if adapter_name not in peft_model.peft_config:
            peft_model.add_adapter(adapter_name, peft_config)
            print("Adding LoRA adapter %s" % adapter_name)
        peft_model.set_adapter(adapter_name)

    # Create a global adapter
    if "global" not in peft_model.peft_config:
        peft_model.add_adapter("global", peft_config)
        print("Adding LoRA adapter global")
elif FT_ALGORITHM == "head":
    # TODO fix this
    pass
else:
    raise RuntimeError("Unknown FT algorithm %s" % FT_ALGORITHM)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


number of labels: 4
the labels: ['World', 'Sports', 'Business', 'Sci/Tech']
Adding LoRA adapter client_0
Adding LoRA adapter client_1
Adding LoRA adapter client_2
Adding LoRA adapter client_3
Adding LoRA adapter client_4
Adding LoRA adapter client_5
Adding LoRA adapter client_6
Adding LoRA adapter client_7
Adding LoRA adapter client_8
Adding LoRA adapter client_9
Adding LoRA adapter global


In [20]:
print(peft_model)

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): RobertaForSequenceClassification(
      (roberta): RobertaModel(
        (embeddings): RobertaEmbeddings(
          (word_embeddings): Embedding(50265, 768, padding_idx=1)
          (position_embeddings): Embedding(514, 768, padding_idx=1)
          (token_type_embeddings): Embedding(1, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): RobertaEncoder(
          (layer): ModuleList(
            (0-11): 12 x RobertaLayer(
              (attention): RobertaAttention(
                (self): RobertaSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                      (client_0): Dropout(p=0.1, inplac

In [32]:
def aggregate_adapters():
    # --- collect client adapter state dicts
    client_states = []
    for client_id in range(USERS):
        sd = get_peft_model_state_dict(peft_model, adapter_name=f"client_{client_id}")
        if not sd:
            raise ValueError(f"Adapter 'client_{client_id}' not found on the model.")
        client_states.append(sd)

    # --- current global (used for dtype/reference + to keep non-shared keys)
    global_state = get_peft_model_state_dict(peft_model, adapter_name="global")
    if not global_state:
        raise ValueError("Adapter 'global' not found on the model.")

    # --- only average keys that exist in ALL client adapters
    common_keys = set(global_state.keys())
    for sd in client_states:
        common_keys &= set(sd.keys())
    if not common_keys:
        raise ValueError("No common LoRA parameter keys across client adapters to aggregate.")

    # --- average on CPU in fp32, then cast back to the global adapter's dtype
    agg_state = {}
    for k in common_keys:
        stack = torch.stack([sd[k].detach().to("cpu", dtype=torch.float32) for sd in client_states], dim=0)
        avg = stack.mean(dim=0)
        agg_state[k] = avg.to(dtype=global_state[k].dtype)

    # --- keep any keys that aren't shared (e.g., modules_to_save) from global as-is
    for k in (set(global_state.keys()) - common_keys):
        agg_state[k] = global_state[k]

    # --- write back into the 'global' adapter and ensure device matches the model
    set_peft_model_state_dict(peft_model, agg_state, adapter_name="global")
    peft_model.to(next(peft_model.parameters()).device)

    # Serialize the global adapter to disk
    peft_model.save_pretrained(
        "adapters/global",
        selected_adapters=["global"],   # save just this adapter
        safe_serialization=True           # writes adapter_model.safetensors
    )

In [14]:
if TASK == "prediction":
    encodings = tokenizer("\n\n".join(dataset["test"]["text"]), return_tensors="pt")

def compute_perplexity(eval_model):
    max_length = 512 # eval_model.config.n_positions
    stride = 512
    seq_len = encodings.input_ids.size(1)
    
    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride), desc='Batch', leave=False):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        eval_model.eval()
    
        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
    
            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss
    
        nlls.append(neg_log_likelihood)
    
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
    
    return torch.exp(torch.stack(nlls).mean())

if TASK == "txt_classification" or TASK == "img_classification":
    metric = evaluate.load('accuracy')


def evaluate_classification_model(inference_model, dataset):
    eval_dataloader = DataLoader(dataset.rename_column("label", "labels") if TASK != "img_classification" else dataset, batch_size=512, collate_fn=data_collator)

    inference_model.to(DEVICE)
    inference_model.eval()
    for step, batch in enumerate(tqdm(eval_dataloader, desc='Batch', leave=False)):
        batch = {key: val.to(DEVICE) for key, val in batch.items() if isinstance(val, torch.Tensor)}
        with torch.no_grad():
            outputs = inference_model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        predictions, references = predictions, batch["labels"]
        metric.add_batch(
            predictions=predictions,
            references=references,
        )

    return metric.compute()


def compute_metrics(predictions, labels):
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}


def evaluate_summarization_model(inference_model, dataset):
    rouge = evaluate.load('rouge')
    eval_dataloader = DataLoader(dataset, batch_size=256, collate_fn=data_collator)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inference_model.to(device)
    inference_model.eval()
    for step, batch in enumerate(tqdm(eval_dataloader, desc='Batch', leave=False)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels']
        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask)
        prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        reference = tokenizer.batch_decode(labels, skip_special_tokens=True)
        rouge.add_batch(predictions=prediction, references=reference)

    rouge_score = rouge.compute()
    print(rouge_score)
    return rouge_score['rougeL']

In [15]:
for round_nr in tqdm(range(1, ROUNDS + 1), desc='Training Rounds'):
    for user_idx in tqdm(range(USERS), desc='Users', leave=False):
        #print("Training model for user %d" % user_idx)
        peft_model.set_adapter("client_%d" % user_idx)
        optimizer = torch.optim.AdamW(peft_model.parameters(), lr=5e-5)
        peft_model.train()  # Set the model to training mode
        train_dataloader = DataLoader(split_datasets[user_idx], batch_size=16, shuffle=True, collate_fn=data_collator)
        train_set_it = iter(train_dataloader)

        for local_step in range(LOCAL_STEPS): # tqdm(range(LOCAL_STEPS), desc='Local Steps', leave=False):
            batch = next(train_set_it)
            optimizer.zero_grad()
            if TASK == "txt_classification" or TASK == "img_classification":
                inputs = {k: v.to(DEVICE) for k, v in batch.items() if k != 'labels'}
                labels = batch['labels'].to(DEVICE)
                outputs = peft_model(**inputs)
                loss = cross_entropy(outputs.logits, labels)  # Calculate loss
            elif TASK == "prediction" or TASK == "summarization":
                input_ids = batch["input_ids"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)
                masks = batch["attention_mask"].to(DEVICE)
                outputs = peft_model(input_ids=input_ids, labels=labels, attention_mask=masks)
                loss = outputs[0]

            loss.backward()
            optimizer.step()

    # Aggregate models
    aggregate_adapters()

    if round_nr % EVAL_INTERVAL == 0:
        print("Evaluating model at round %d" % round_nr)
        peft_model.set_adapter("global")  # Switch to the global adapter for evaluation
        if TASK in ["txt_classification", "img_classification"]:
            eval_res = evaluate_classification_model(peft_model, test_dataset)
        elif TASK == "prediction":
            eval_res = compute_perplexity(peft_model)
        elif TASK == "summarization":
            eval_res = evaluate_summarization_model(peft_model, test_dataset)
        print("Round %d: %s" % (round_nr, eval_res))

  return forward_call(*args, **kwargs)
Training Rounds:   0%|          | 2/500 [02:42<11:13:48, 81.18s/it]


KeyboardInterrupt: 

In [None]:
compute_perplexity(peft_models[0])