In [18]:
# IMPORT
import numpy as np
import wandb
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import shutil
import os                              # Import the 'os' module for changing directories
os.chdir('/content/drive/MyDrive/FL')  # Change the directory

Mounted at /content/drive


In [16]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset, DataLoader, random_split

from FederatedLearningProject.data.cifar100_loader import get_cifar100, create_iid_splits, create_non_iid_splits
import FederatedLearningProject.checkpoints.checkpointing as checkpointing
from FederatedLearningProject.training.FedMETA import aggregate_with_task_arithmetic, aggregate_masks, distribution_function, train_server
from FederatedLearningProject.training.model_editing import plot_all_layers_mask_sparsity

from FederatedLearningProject.experiments import models

In [35]:
import importlib

# Importa i moduli del tuo progetto
from FederatedLearningProject.data import cifar100_loader
from FederatedLearningProject import checkpoints
from FederatedLearningProject.training import FedMETA, model_editing
from FederatedLearningProject import experiments

# Ricarica solo i moduli custom (NO torch)
importlib.reload(cifar100_loader)
importlib.reload(checkpoints.checkpointing)
importlib.reload(FedMETA)
importlib.reload(model_editing)
importlib.reload(experiments.models)

# Re-bind: importa di nuovo funzioni/classi/alias che usi nel codice
from FederatedLearningProject.data.cifar100_loader import (
    get_cifar100, create_iid_splits, create_non_iid_splits
)

import FederatedLearningProject.checkpoints.checkpointing as checkpointing

from FederatedLearningProject.training.FedMETA import (
    aggregate_with_task_arithmetic,
    aggregate_masks,
    distribution_function,
    train_server
)

from FederatedLearningProject.training.model_editing import (
    plot_all_layers_mask_sparsity
)

from FederatedLearningProject.experiments import models


In [3]:
# print the content of the folder FederatedLearningProject.data.masks
print(os.listdir('FederatedLearningProject/data/masks'))

val_set = torch.load('FederatedLearningProject/data/masks/val_set.pth', weights_only=False)
train_set = torch.load('FederatedLearningProject/data/masks/train_set.pth', weights_only=False)
test_set = torch.load('FederatedLearningProject/data/masks/test_set.pth', weights_only=False)

['train_set.pth', 'val_set.pth', 'test_set.pth', 'client_masks_iid.pth', 'client_masks_non_iid_1.pth']


In [4]:
o_model = models.LinearFlexibleDino(num_layers_to_freeze=12)
local_masks = torch.load('FederatedLearningProject/data/masks/client_masks_non_iid_1.pth')


Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 253MB/s]


In [9]:
final_mask = aggregate_masks(local_masks)

In [23]:
partition_masks = distribution_function(final_mask, number_clients=100)

Total parameters: 21293568
Masked parameters (zeros): 19902014
Unmasked parameters (ones): 1391554


In [12]:
client_dataset = create_non_iid_splits(train_set, num_clients=100, classes_per_client=1)

Dataset has 40000 samples across 100 classes.
Creating 100 non IID splits with 1 classes each.


Each of the 100 classes split into 1 shards.

Checking unique classes that each client sees:
Client 0 has samples from classes: {np.int64(0)}
Total: 1
Client 1 has samples from classes: {np.int64(1)}
Total: 1
Client 2 has samples from classes: {np.int64(2)}
Total: 1
Client 3 has samples from classes: {np.int64(3)}
Total: 1
Client 4 has samples from classes: {np.int64(4)}
Total: 1
Client 5 has samples from classes: {np.int64(5)}
Total: 1
Client 6 has samples from classes: {np.int64(6)}
Total: 1
Client 7 has samples from classes: {np.int64(7)}
Total: 1
Client 8 has samples from classes: {np.int64(8)}
Total: 1
Client 9 has samples from classes: {np.int64(9)}
Total: 1
Client 10 has samples from classes: {np.int64(10)}
Total: 1
Client 11 has samples from classes: {np.int64(11)}
Total: 1
Client 12 has samples from classes: {np.int64(12)}
Total: 1
Client 13 has samples from classes: {np.int64(13)}

In [None]:
optimizer_config = {
    'lr': 0.01,
    'momentum': 0.9,
    'weight_decay': 0.0001
}

checkpoint_path = 'FederatedLearningProject/checkpoints/'
val_loader = DataLoader(val_set, batch_size=128, shuffle=True)

train_server(o_model, num_rounds=1, client_dataset=client_dataset, client_masks=partition_masks, optimizer_config=optimizer_config, device='cuda', frac=0.1, batch_size=128, val_loader=val_loader, checkpoint_path=checkpoint_path)