In [1]:
import numpy as np
import wandb
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import shutil
import os                              # Import the 'os' module for changing directories
os.chdir('/content/drive/MyDrive/FL')  # Change the directory

Mounted at /content/drive


In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset, DataLoader, random_split

from FederatedLearningProject.data.cifar100_loader import get_cifar100
import FederatedLearningProject.checkpoints.checkpointing as checkpointing
# from FederatedLearningProject.training.FL_training import train_server_model_editing
from FederatedLearningProject.experiments import models

In [3]:
### TRAIN SPLIT WITH EVALUATION ###

valid_split_perc = 0.2
# train_set, val_set, test_set = get_cifar100(valid_split_perc=valid_split_perc)

val_set = torch.load('FederatedLearningProject/masks/val_set.pth', weights_only=False)
train_set = torch.load('FederatedLearningProject/masks/train_set.pth', weights_only=False)
test_set = torch.load('FederatedLearningProject/masks/test_set.pth', weights_only=False)


train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

In [4]:
### CARICAMENTO DEL MODELLO ###

### ---- IID ---- ###
model = models.LinearFlexibleDino(num_layers_to_freeze=12)
model_checkpoint = torch.load("FederatedLearningProject/checkpoints/FL_IID_300round/dino_vits_16_iid_local_steps_4_checkpoint.pth")
model.load_state_dict(model_checkpoint['model_state_dict'])
model.to_cuda()


### ---- NON_IID_1 ---- ###
#model = models.LinearFlexibleDino(num_layers_to_freeze=12)
#model_checkpoint = torch.load("FederatedLearningProject/checkpoints/FL_NON_IID(1)_bs128/dino_vits_16_non_iid(1)_local_steps_4_bs128_checkpoint.pth")
#model.load_state_dict(model_checkpoint['model_state_dict'])
#model.to_cuda()

### ---- NON_IID_5 ---- ###
#model = models.LinearFlexibleDino(num_layers_to_freeze=12)
#model_checkpoint = torch.load("FederatedLearningProject/checkpoints/FL_NON_IID(5)_bs128/dino_vits_16_non_iid(5)_local_steps_4_bs128_checkpoint.pth")
#model.load_state_dict(model_checkpoint['model_state_dict'])
#model.to_cuda()

### ---- NON_IID_10 ---- ###
#model = models.LinearFlexibleDino(num_layers_to_freeze=12)
#model_checkpoint = torch.load("FederatedLearningProject/checkpoints/FL_NON_IID(10)_bs128/dino_vits_16_non_iid(10)_local_steps_4_bs128_checkpoint.pth")
#model.load_state_dict(model_checkpoint['model_state_dict'])
#model.to_cuda()

### ---- NON_IID_50 ---- ###
#model = models.LinearFlexibleDino(num_layers_to_freeze=12)
#model_checkpoint = torch.load("FederatedLearningProject/checkpoints/FL_NON_IID(50)_bs128/dino_vits_16_non_iid(50)_local_steps_4_bs128_checkpoint.pth")
#model.load_state_dict(model_checkpoint['model_state_dict'])
#model.to_cuda()

Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 277MB/s]


moving model to cuda


In [5]:
### CARICAMENTO DEI CLIENTS ###

num_clients = 100
from FederatedLearningProject.data.cifar100_loader import create_non_iid_splits, create_iid_splits
client_dataset_iid = create_iid_splits(train_set, num_clients = num_clients)

#client_dataset_non_iid_1 = create_non_iid_splits(train_set, num_clients = num_clients, classes_per_client = 1)

#client_dataset_non_iid_5 = create_non_iid_splits(train_set, num_clients = num_clients, classes_per_client = 5)

#client_dataset_non_iid_10 = create_non_iid_splits(train_set, num_clients = num_clients, classes_per_client = 10)

#client_dataset_non_iid_50 = create_non_iid_splits(train_set, num_clients = num_clients, classes_per_client = 50)



Dataset has 40000 samples across 100 classes.
Creating 100 IID splits with 100 classes each.


Each of the 100 classes split into 100 shards.

Checking unique classes that each client sees:
Client 0 has samples from classes: {np.int64(0), np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46), np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51), np.int64(52), np.int64(53), np.int64(54), np.int64(55), 

In [6]:
from FederatedLearningProject.training.model_editing import compute_mask_clients, plot_all_layers_mask_sparsity
masks_to_save = compute_mask_clients(model, client_dataset_iid, num_examples=25, num_classes=100, n_per_class=1, final_sparsity=0.8)

KeyboardInterrupt: 

In [None]:
from FederatedLearningProject.training.model_editing import convert_float_masks_to_bool
bool_masks = convert_float_masks_to_bool(masks_to_save)
torch.save(bool_masks, "FederatedLearningProject/masks/client_masks_iid_sparsity_08.pth")

In [None]:
### MASK COMPUTATIONS ###
# We computed them once and saved in google drive
# loading them

from FederatedLearningProject.training.model_editing import compute_mask_clients, plot_all_layers_mask_sparsity
# client_mask_iid = compute_mask_clients(model, client_dataset_iid, num_examples=100, num_classes=100, n_per_class=1)
# torch.save(client_masks_iid, "FederatedLearningProject/masks/client_masks_iid.pth")
#client_masks_iid = torch.load("FederatedLearningProject/masks/client_masks_iid.pth")

#client_masks_non_iid_1 = torch.load("FederatedLearningProject/masks/client_masks_non_iid_1.pth")

#client_masks_non_iid_5 = torch.load("FederatedLearningProject/masks/client_masks_non_iid_5.pth")

#client_masks_non_iid_10 = torch.load("FederatedLearningProject/masks/client_masks_non_iid_10.pth")

#client_masks_non_iid_50 = torch.load("FederatedLearningProject/masks/client_masks_non_iid_50.pth")


In [None]:
# --- OPTIMIZER AND LOSS FUNCTION ---

num_rounds = 300
optimizer_config = {
  "lr" : 0.01,  # best hyperparameter of the centralized
  "momentum" : 0.9,
  "weight_decay" : 0.0001 # best hyperparameter of the centralized
}

num_clients = 100

# Default hyperparameters for FedAvg
num_local_steps = 4 # Fixed number of local steps
fraction = 0.1
criterion = nn.CrossEntropyLoss()
model_name = "dino_vits16_J4"

checkpoint_dir = "/content/drive/MyDrive/FL/FederatedLearningProject/checkpoints/FL/"
os.makedirs(checkpoint_dir, exist_ok=True)
# Make checkpoint path unique to the run if you want to store separate checkpoints
#checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_model_editing_iid.pth")

#checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_model_editing_non_iid_1.pth")

#checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_model_editing_non_iid_5.pth")

#checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_model_editing_non_iid_10.pth")

checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_model_editing_non_iid_50.pth")

In [None]:
wandb.login() # Ask for your APIw key for logging in to the wandb library.

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdepetrofabio[0m ([33mdepetrofabio-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
device = "cuda"
project_name = "FedAvg_ModelEditing_Corretto"

# Generate a unique run name for each iteration
#run_name = f"{model_name}_FedAvg_model_editing_iid"

#run_name = f"{model_name}_FedAvg_model_editing_non_iid_1"

#run_name = f"{model_name}_FedAvg_model_editing_non_iid_5"

#run_name = f"{model_name}_FedAvg_model_editing_non_iid_10"

run_name = f"{model_name}_FedAvg_model_editing_non_iid_50"

# INITIALIZE W&B for each new run
wandb.init(
    project=project_name,
    name=run_name,
    config={
        "model": model_name,
        "num_rounds": 300, # Use the current num_rounds_val
        "batch_size": 128, # Using test_loader's batch_size as a placeholder
    },
    reinit=True # Important: Allows re-initialization of wandb in a loop
)

# Copy your config
config = wandb.config


0,1
client_avg_accuracy,▁▂▅▆▆▆▃▅▆█▇▇▇█▆▆█▇▇█▆▄▇▇▅▄█▇█▇▇▇▅▆▅█▅▇▇▇
client_avg_loss,▆▆▅▄▆▅▃▆▃▂▇▃▇▁▆▄▄▃▅▅▄▄▅█▅▄▃▄▆▄▄▆▄▆▆▅▂▇▅▃
round,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
server_val_accuracy,▇▇▇▇▇▇██▇█▂▅▅▆▃▂▄▄▅▃▃▁▂▃▅▂▇▄▆▆▅▆▅▅▆▅▇▇▆▇
server_val_loss,▄▃▃▃▃▃▃▂▂▃▂▄█▅▅▆▃█▅▅▅▅█▂▅▃▃▄▁▃▃▃▂▃▃▃▁▁▃▁

0,1
client_avg_accuracy,62.09475
client_avg_loss,1.36901
round,299.0
server_val_accuracy,45.09
server_val_loss,2.10083


In [None]:
model.unfreeze(12)
model.to_cuda()

moving model to cuda


In [None]:
train_server_model_editing(model=model,
             num_clients = 100,
             num_rounds=num_rounds,
             client_dataset = client_dataset_non_iid_50,
             frac=0.1,
             batch_size=128,
             client_masks = client_masks_non_iid_50,
             optimizer_config=optimizer_config,
             val_loader = val_loader,
             criterion = criterion,
             num_client_steps = 4,
             model_name = model_name,
             checkpoint_path = checkpoint_path,
             device = device)

Checkpoint salvato su: /content/drive/MyDrive/FL/FederatedLearningProject/checkpoints/FL/dino_vits16_J4_model_editing_non_iid_50.pth

Round 5/300
Selected Clients: [35 48 65 23 70 24 40 12 80 43]
Avg Client Loss: 2.1441 | Avg Client Accuracy: 44.36%
Evaluation Loss: 2.3034 | Val Accuracy: 43.10%
--------------------------------------------------
Checkpoint salvato su: /content/drive/MyDrive/FL/FederatedLearningProject/checkpoints/FL/dino_vits16_J4_model_editing_non_iid_50.pth

Round 10/300
Selected Clients: [ 8 27 46 87 33 29 61 96 44 15]
Avg Client Loss: 2.0426 | Avg Client Accuracy: 45.67%
Evaluation Loss: 2.2556 | Val Accuracy: 44.11%
--------------------------------------------------
Checkpoint salvato su: /content/drive/MyDrive/FL/FederatedLearningProject/checkpoints/FL/dino_vits16_J4_model_editing_non_iid_50.pth

Round 15/300
Selected Clients: [78 98 43 10 39  1 52 92 53 71]
Avg Client Loss: 1.8988 | Avg Client Accuracy: 47.89%
Evaluation Loss: 2.2080 | Val Accuracy: 45.06%
-----

{'model': LinearFlexibleDino(
   (backbone): VisionTransformer(
     (patch_embed): PatchEmbed(
       (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
     )
     (pos_drop): Dropout(p=0.0, inplace=False)
     (blocks): ModuleList(
       (0-11): 12 x Block(
         (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
         (attn): Attention(
           (qkv): Linear(in_features=384, out_features=1152, bias=True)
           (attn_drop): Dropout(p=0.0, inplace=False)
           (proj): Linear(in_features=384, out_features=384, bias=True)
           (proj_drop): Dropout(p=0.0, inplace=False)
         )
         (drop_path): Identity()
         (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
         (mlp): Mlp(
           (fc1): Linear(in_features=384, out_features=1536, bias=True)
           (act): GELU(approximate='none')
           (fc2): Linear(in_features=1536, out_features=384, bias=True)
           (drop): Dropout(p=0.0, inplace=False)
  