In [None]:
import numpy as np
import wandb
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import shutil
import os                              # Import the 'os' module for changing directories
os.chdir('/content/drive/MyDrive/FL')  # Change the directory
import datetime as datetime
import copy
import json

In [None]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset, DataLoader, random_split

from FederatedLearningProject.data.cifar100_loader import get_cifar100
from FederatedLearningProject.checkpoints.checkpointing import save_checkpoint, load_checkpoint
from FederatedLearningProject.training.centralized_training import train_and_validate, train_epoch, validate_epoch, log_to_wandb, generate_configs
from FederatedLearningProject.training.model_editing import compute_mask, SparseSGDM


import FederatedLearningProject.experiments.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Import CIFAR100 dataset: train_set, val_set, test_set
# The transforms are applied before returning the dataset (in the module)
valid_split_perc = 0.2    # of the 50000 training data
train_set, val_set, test_set = get_cifar100(valid_split_perc)

In [None]:
# Create DataLoaders for training, validation, and test sets
# batch_size è in hyperparameter (64, 128, ..), anche num_workers (consigliato per colab 2 o 4)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

In [None]:
o_model = models.LinearFlexibleDino()     # original model
o_model.freeze(12)
o_model.to_cuda()
o_model.debug()

In [None]:
# prendo il path dello state_dict del miglio modello che ho salvato su Drive per non dover ri-trainare sempre
checkpoint_dir = "/content/drive/MyDrive/FL/FederatedLearningProject/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
best_model_path = os.path.join(checkpoint_dir, "best_model.pth")

In [None]:
# copio il modello di base
model = copy.deepcopy(o_model)
# aggiorno i pesi del modello con quelli trainati
model.load_state_dict(torch.load(best_model_path))

In [None]:
model.freeze(0)

In [None]:
# numero totale di parametri del modello:
total_params = sum(p.numel() for p in model.parameters())
print(f"Numero totale di parametri: {total_params}")

# numero totale di parametri del modello attualmente allenabili:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parametri attualmente allenabili (trainable): {trainable_params}")

In [None]:
# Visualizzarli
for name, param in model.named_parameters():
    if 'embed' in name or 'cls_token' in name or 'backbone.norm' in name or 'head' in name:
        print(f"{name} - shape: {param.shape}")

In [None]:
# Freezzarli
print("param.requires_grad = False: ")
for name, param in model.named_parameters():
    if 'embed' in name or 'cls_token' in name or 'backbone.norm' in name or 'head' in name:
        param.requires_grad = False
        print(f"FROZEN: {name}")


In [None]:
# numero totale di parametri del modello:
total_params = sum(p.numel() for p in model.parameters())
print(f"Numero totale di parametri: {total_params}")

# numero totale di parametri del modello attualmente allenabili:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parametri attualmente allenabili (trainable): {trainable_params}")

In [None]:
model.eval()

In [None]:
mascherina=compute_mask(model, train_loader, sparsity_target=0.9, R=5, num_examples=200, device='cuda')

fede

In [None]:
# conta i parameteri totali - masked - non masked
def count_masked_params(mask):
    total_params = 0
    masked_params = 0

    for key, mask_tensor in mask.items():
        total_params += mask_tensor.numel()          # total number of elements in this parameter
        masked_params += (mask_tensor == 0).sum().item()  # count how many elements are zero (masked)

    unmasked_params = total_params - masked_params

    print(f"Total parameters: {total_params}")
    print(f"Masked parameters (zeros): {masked_params}")
    print(f"Unmasked parameters (ones): {unmasked_params}")

    return total_params, masked_params, unmasked_params
count_masked_params(mascherina)

In [None]:
import random

# prendo mascherina e applico un pò di rumore per ottenere 100 maschere simili

def crea_local_masks(mascherina, num_clients=100, noise_level=0.1, seed=42):
    torch.manual_seed(seed)
    random.seed(seed)

    local_masks = []

    for _ in range(num_clients):
        client_mask = {}
        for name, mask in mascherina.items():
            # Forza il tipo booleano per sicurezza
            mask = mask.bool()

            # Genera rumore booleano con probabilità `noise_level`
            noise = (torch.rand_like(mask, dtype=torch.float32) < noise_level)

            # XOR: inverte solo dove c'è rumore
            noisy_mask = mask ^ noise.bool()

            client_mask[name] = noisy_mask
        local_masks.append(client_mask)

    return local_masks

local_masks = crea_local_masks(mascherina, num_clients=100, noise_level=0.1, seed=42)

In [None]:
## output è la nostra global mask

def aggregate_masks(local_masks, threshold_ratio=0.8):
    """
    Aggrega n maschere locali in una maschera globale.

    Args:
        local_masks (list of dict): Lista di maschere locali (bool tensor per parametro).
        threshold_ratio (float): Soglia minima (es. 0.5 significa 50% dei client).

    Returns:
        dict: Maschera globale con stessi nomi dei parametri.
    """
    if not local_masks:
        raise ValueError("La lista delle maschere locali è vuota!")

    num_clients = len(local_masks)
    threshold = int(num_clients * threshold_ratio)

    # Inizializza l'accumulatore a zeri interi
    agg_mask = {
        name: torch.zeros_like(mask, dtype=torch.int32)
        for name, mask in local_masks[0].items()
    }

    # Somma le maschere locali
    for cm in local_masks:
        for name in agg_mask:
            agg_mask[name] += cm[name].int()

    # Crea la maschera finale con soglia
    final_mask = {
        name: (agg_mask[name] >= threshold)
        for name in agg_mask
    }

    return final_mask


In [None]:
global_mask = aggregate_masks(local_masks)
total_params, masked_params, unmasked_params = count_masked_params(global_mask)
print( (unmasked_params / total_params)*100 ) # % parametri masked nella global mask



# Explanation of the `distribution_function`

The `distribution_function` distributes the **unmasked** (i.e., active, True) parameters from a global mask `final_mask` exclusively among a given number of
clients.

## Inputs:
- `final_mask`: a dictionary where keys correspond to model parameter names and values are boolean tensors indicating whether a parameter is active (True) or masked (False).
- `unmasked_params`: the total number of active (True) parameters in `final_mask`.
- `number_clients`: the number of clients to split the active parameters among.

## Output:
- `client_masks`: a list of dictionaries (one per client), each with the same keys as `final_mask`. Each dictionary contains a mask where only a subset of parameters is active. These active positions are unique and do not overlap across clients.

## How it works:
1. Computes how many active parameters to assign to each client (equally dividing the parameters with any remainder given to the last client).
2. Creates a list of all active positions in the global mask, recording for each position the parameter name and the flattened index.
3. Randomly shuffles this list to ensure random distribution.
4. Assigns a mutually exclusive subset of active parameters to each client, creating local masks with True only where assigned.
5. Ensures each client mask has all keys from the global mask, filling missing keys with all-False masks.

## Result:
- Active parameters are partitioned exclusively and evenly (up to remainder) among clients.
- Summing all client masks recreates the original global mask.
- The last client receives any remained active parameters to handle uneven splits.

---





In [None]:
def distribution_function(final_mask, unmasked_params, number_clients):
    '''
    final_mask: dict of tensors with 1s and 0s (global mask)
    unmasked_params: total number of 1s in final_mask
    number_clients: number of clients to partition the unmasked parameters

    Returns:
        client_masks: list of length = number_clients
                      each element is a dict (same keys as final_mask)
                      with 1s in unique positions (disjoint among clients)
    '''
    total_params = sum(m.numel() for m in final_mask.values())
    #print(f"Totale parametri in final_mask: {total_params}")

    base_params_per_client = unmasked_params // number_clients
    remainder = unmasked_params % number_clients

    client_masks = [dict() for _ in range(number_clients)]

    # Costruiamo la lista di tutte le posizioni degli 1 nella maschera globale
    all_1_positions = []
    for key, mask_tensor in final_mask.items():
        ones_indices = torch.nonzero(mask_tensor.flatten(), as_tuple=False).squeeze()
        # ones_indices può essere un tensor 1D o 0D se solo 1 elemento
        if ones_indices.ndim == 0:
            ones_indices = ones_indices.unsqueeze(0)
        for idx in ones_indices.tolist():
            all_1_positions.append((key, idx))

    # Shuffle
    torch.manual_seed(0)
    perm = torch.randperm(len(all_1_positions)).tolist()
    all_1_positions = [all_1_positions[i] for i in perm]

    start_idx = 0
    for client_id in range(number_clients):
        count = base_params_per_client + (1 if client_id == number_clients - 1 else 0)
        subset = all_1_positions[start_idx:start_idx+count]
        start_idx += count

        for key, flat_idx in subset:
            shape = final_mask[key].shape
            if key not in client_masks[client_id]:
                client_masks[client_id][key] = torch.zeros_like(final_mask[key], dtype=torch.bool)
            idx_unravel = torch.unravel_index(torch.tensor(flat_idx), shape)
            client_masks[client_id][key][idx_unravel] = True

    # Per sicurezza, per ogni client assicuriamo che tutte le chiavi siano presenti:
    for client_mask in client_masks:
        for key in final_mask.keys():
            if key not in client_mask:
                client_mask[key] = torch.zeros_like(final_mask[key], dtype=torch.bool)

    return client_masks


In [None]:
local_masks_parititioned = distribution_function(global_mask, unmasked_params, 100)

In [None]:
total_params, masked_params, unmasked_params = count_masked_params(local_masks_parititioned[50])
print( (unmasked_params / total_params)*100 ) # # % parametri masked in una local mask -> se funziona sara circa 0.1%

manca aggregazione task vector