In [2]:
import numpy as np
import wandb
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import shutil
import os                              # Import the 'os' module for changing directories
os.chdir('/content/drive/MyDrive/FL')  # Change the directory

Mounted at /content/drive


In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset, DataLoader, random_split

from FederatedLearningProject.data.cifar100_loader import get_cifar100
import FederatedLearningProject.checkpoints.checkpointing as checkpointing
from FederatedLearningProject.training.FL_training_ME import train_server
from FederatedLearningProject.experiments import models

In [3]:
# wandb.login() # Ask for your APIw key for logging in to the wandb library.

In [4]:

# Import CIFAR100 dataset: train_set, val_set, test_set
# The transforms are applied before returning the dataset (in the module)

valid_split_perc = 0.2  # of the 50000 training data
train_set, val_set, test_set = get_cifar100(valid_split_perc=valid_split_perc)
# batch_size è in hyperparameter (64, 128, ..), anche num_workers (consigliato per colab 2 o 4)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)


Number of images in Training Set:   40000
Number of images in Validation Set: 10000
Number of images in Test Set:       10000
✅ Datasets loaded successfully


In [5]:
model = models.LinearFlexibleDino(num_layers_to_freeze=12)
model_checkpoint = torch.load("FederatedLearningProject/checkpoints/FL_IID_300round/dino_vits_16_iid_local_steps_4_checkpoint.pth")
model.load_state_dict(model_checkpoint['model_state_dict'])


Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 237MB/s]


<All keys matched successfully>

In [6]:
model

LinearFlexibleDino(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm

In [7]:
model.debug()


--- Debugging Model ---
Model is primarily on device: cpu
Model overall mode: Train

Parameter Details (Name | Device | Requires Grad? | Inferred Block | Module Mode):
- backbone.cls_token                                 | cpu        | False           | N/A             | Train
- backbone.pos_embed                                 | cpu        | False           | N/A             | Train
- backbone.patch_embed.proj.weight                   | cpu        | False           | N/A             | Train
- backbone.patch_embed.proj.bias                     | cpu        | False           | N/A             | Train
- backbone.blocks.0.norm1.weight                     | cpu        | False           | Block 0         | Eval
- backbone.blocks.0.norm1.bias                       | cpu        | False           | Block 0         | Eval
- backbone.blocks.0.attn.qkv.weight                  | cpu        | False           | Block 0         | Eval
- backbone.blocks.0.attn.qkv.bias                    | cpu      

In [8]:
# --- OPTIMIZER AND LOSS FUNCTION ---

lr = 0.01  # best hyperparameter of the centralized
momentum = 0.9
weight_decay = 0.0001 # best hyperparameter of the centralized

num_clients = 100

# Default hyperparameters for FedAvg
num_local_steps = 4 #
# num_local_steps = 8 #
# num_local_steps = 16 #
fraction = 0.1
criterion = nn.CrossEntropyLoss()

model_name = "prova_model_editing"

checkpoint_dir = "/content/drive/MyDrive/FL/FederatedLearningProject/checkpoints/FL/"
os.makedirs(checkpoint_dir, exist_ok=True)
# Make checkpoint path unique to the run if you want to store separate checkpoints
checkpoint_path = os.path.join(checkpoint_dir, "prova_checkpoint.pth")

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [10]:
model.to_cuda()

moving model to cuda


In [11]:
from FederatedLearningProject.data.cifar100_loader import create_iid_splits
client_dataset = create_iid_splits(train_set, num_clients = num_clients)

Dataset has 40000 samples across 100 classes.
Creating 100 IID splits with 100 classes each.


Each of the 100 classes split into 100 shards.

Checking unique classes that each client sees:
Client 0 has samples from classes: {np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46), np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51), np.int64(52), np.int64(53), np.int64(54), np.int64(55), 

In [13]:
train_server(model=model,
             num_clients = 10,
             num_rounds=1,
             client_dataset = client_dataset,
             frac=fraction,
             lr = lr,
             val_loader = val_loader,
             criterion = criterion,
             num_client_steps = 4,
             model_name = model_name,
             weight_decay = weight_decay,
             momentum = momentum,
             checkpoint_path = checkpoint_path,
             device = device)

--- Starting Round 1/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.3690
Round Target: 0.3690 (7858234 weights)
Total considered params: 21293568
Already pruned: 0. Active: 21293568
Need to prune 7858234 more weights from the active set.
threshold:2.3732052056857356e-07
Achieved cumulative sparsity in final mask: 0.3690
-------------------------

--- Starting Round 2/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.6019
Round Target: 0.6019 (12816445 weights)
Total considered params: 21293568
Already pruned: 7858235. Active: 13435333
Need to prune 4958210 more weights from the active set.
threshold:9.539007532882419e-19
Achieved cumulative sparsity in final mask: 0.6019
-------------------------

--- Starting Round 3/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.7488
Round Target: 0.7488 (15944865 weights)
Total considered params: 21293568
Already pruned: 12816446. Active: 8477122
Need to pru

{'model': LinearFlexibleDino(
   (backbone): VisionTransformer(
     (patch_embed): PatchEmbed(
       (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
     )
     (pos_drop): Dropout(p=0.0, inplace=False)
     (blocks): ModuleList(
       (0-11): 12 x Block(
         (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
         (attn): Attention(
           (qkv): Linear(in_features=384, out_features=1152, bias=True)
           (attn_drop): Dropout(p=0.0, inplace=False)
           (proj): Linear(in_features=384, out_features=384, bias=True)
           (proj_drop): Dropout(p=0.0, inplace=False)
         )
         (drop_path): Identity()
         (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
         (mlp): Mlp(
           (fc1): Linear(in_features=384, out_features=1536, bias=True)
           (act): GELU(approximate='none')
           (fc2): Linear(in_features=1536, out_features=384, bias=True)
           (drop): Dropout(p=0.0, inplace=False)
  

In [5]:
mask = torch.load("FederatedLearningProject/masks/client_masks.pth")

In [6]:
for i in range(len(mask)):
  if (mask[i] is not None):
    print(i)

5


In [7]:
from FederatedLearningProject.training.model_editing import plot_all_layers_mask_sparsity

In [None]:
plot_all_layers_mask_sparsity(mask[5])

In [10]:
def count_masked_params(mask):
    total_params = 0
    masked_params = 0

    for key, mask_tensor in mask.items():
        total_params += mask_tensor.numel()          # total number of elements in this parameter
        masked_params += (mask_tensor == 0).sum().item()  # count how many elements are zero (masked)

    unmasked_params = total_params - masked_params

    print(f"Total parameters: {total_params}")
    print(f"Masked parameters (zeros): {masked_params}")
    print(f"Unmasked parameters (ones): {unmasked_params}")

    return total_params, masked_params, unmasked_params
count_masked_params(mask[5])

Total parameters: 21293568
Masked parameters (zeros): 21293568
Unmasked parameters (ones): 0


(21293568, 21293568, 0)

In [14]:
from FederatedLearningProject.training.model_editing import compute_mask

In [15]:
model.debug()


--- Debugging Model ---
Model is primarily on device: cuda:0
Model overall mode: Train

Parameter Details (Name | Device | Requires Grad? | Inferred Block | Module Mode):
- backbone.cls_token                                 | cuda:0     | False           | N/A             | Train
- backbone.pos_embed                                 | cuda:0     | False           | N/A             | Train
- backbone.patch_embed.proj.weight                   | cuda:0     | False           | N/A             | Train
- backbone.patch_embed.proj.bias                     | cuda:0     | False           | N/A             | Train
- backbone.blocks.0.norm1.weight                     | cuda:0     | False           | Block 0         | Eval
- backbone.blocks.0.norm1.bias                       | cuda:0     | False           | Block 0         | Eval
- backbone.blocks.0.attn.qkv.weight                  | cuda:0     | False           | Block 0         | Eval
- backbone.blocks.0.attn.qkv.bias                    | cuda:0

In [12]:
def compute_mask_clients(model, client_dataset):
  client_mask = []
  for i in range(2): # range(len(client_dataset))
    client_loader = DataLoader(client_dataset[i], batch_size=128, shuffle=True, num_workers=2)
    new_mask = compute_mask(model, dataloader=client_loader)
    client_mask.append(new_mask)
  return client_mask

In [None]:
masks = compute_mask_clients(model, client_dataset)

--- Starting Round 1/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.3690
Round Target: 0.3690 (7858234 weights)
Total considered params: 21293568
Already pruned: 0. Active: 21293568
Need to prune 7858234 more weights from the active set.
threshold:0.008814557455480099
Achieved cumulative sparsity in final mask: 0.3690
-------------------------

--- Starting Round 2/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.6019
Round Target: 0.6019 (12816445 weights)
Total considered params: 21293568
Already pruned: 7858235. Active: 13435333
Need to prune 4958210 more weights from the active set.
threshold:1.6433132259408012e-05
Achieved cumulative sparsity in final mask: 0.6019
-------------------------

--- Starting Round 3/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.7488
Round Target: 0.7488 (15944865 weights)
Total considered params: 21293568
Already pruned: 12816446. Active: 8477122
Need to prun

In [41]:
a = compute_mask(model, dataloader=DataLoader(client_dataset[0], batch_size=128, shuffle=True, num_workers=2))

--- Starting Round 1/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.3690
Round Target: 0.3690 (7858234 weights)
Total considered params: 21293568
Already pruned: 0. Active: 21293568
Need to prune 7858234 more weights from the active set.
threshold:7.881801866460592e-05
Achieved cumulative sparsity in final mask: 0.3690
-------------------------

--- Starting Round 2/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.6019
Round Target: 0.6019 (12816445 weights)
Total considered params: 21293568
Already pruned: 7858235. Active: 13435333
Need to prune 4958210 more weights from the active set.
threshold:2.6095978000739706e-07
Achieved cumulative sparsity in final mask: 0.6019
-------------------------

--- Starting Round 3/5 ---
Computing Fisher diagonal with current mask...
Current sparsity level: 0.7488
Round Target: 0.7488 (15944865 weights)
Total considered params: 21293568
Already pruned: 12816446. Active: 8477122
Need to pru

In [54]:
count_masked_params(masks[0])

Total parameters: 21293568
Masked parameters (zeros): 19164212
Unmasked parameters (ones): 2129356


(21293568, 19164212, 2129356)

In [1]:
plot_all_layers_mask_sparsity(masks[1])

NameError: name 'plot_all_layers_mask_sparsity' is not defined

In [None]:
torch.save()