In [None]:
import numpy as np
import wandb
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import shutil
import os                              # Import the 'os' module for changing directories
os.chdir('/content/drive/MyDrive/FL')  # Change the directory
import datetime as datetime
import copy
import json

Mounted at /content/drive
Mounted at /content/drive


In [None]:
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import time

In [None]:
import wandb
wandb.login(relogin = True)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Change working directory
import os
os.chdir('/content/drive/MyDrive/FL')

# Core Python modules
import os
import shutil
import json
import time
import datetime
import copy

# Data & Plotting
import numpy as np
import matplotlib.pyplot as plt

# Torch and related libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR

# TorchVision datasets and transforms
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100

# wandb (for experiment tracking)
import wandb
wandb.login(relogin=True)

# Project-specific imports
from FederatedLearningProject.data.cifar100_loader import get_cifar100
from FederatedLearningProject.checkpoints.checkpointing import (
    save_checkpoint, save_checkpoint_test, load_checkpoint
)

from FederatedLearningProject.training.centralized_training import (
    train_and_validate, train_and_test, train_epoch, validate_epoch,
    test_epoch, log_to_wandb, log_to_wandb_test, generate_configs
)
from FederatedLearningProject.training.model_editing import (
    compute_mask, SparseSGDM, plot_changed_weights_percentage,
    plot_all_layers_mask_sparsity, plot_qkv_weight_bias_sparsity,
    get_n_examples_per_class_loader, print_info_dataloader
)
import FederatedLearningProject.experiments.models as models

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Import CIFAR100 dataset: train_set, val_set, test_set
# The transforms are applied before returning the dataset (in the module)
valid_split_perc = 0.2    # of the 50000 training data§
train_set, val_set, test_set = get_cifar100(valid_split_perc)

Number of images in Training Set:   40000
Number of images in Validation Set: 10000
Number of images in Test Set:       10000
✅ Datasets loaded successfully


In [3]:
# Create DataLoaders for training, validation, and test sets
# batch_size è in hyperparameter (64, 128, ..), anche num_workers (consigliato per colab 2 o 4)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

In [None]:
o_model = models.LinearFlexibleDino()     # original model
o_model.freeze(12)
o_model.to_cuda()

Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 112MB/s]


cuda not available

--- Debugging Model ---
Model is primarily on device: cpu
Model overall mode: Train

Parameter Details (Name | Device | Requires Grad? | Inferred Block | Module Mode):
- backbone.cls_token                                 | cpu        | False           | N/A             | Train
- backbone.pos_embed                                 | cpu        | False           | N/A             | Train
- backbone.patch_embed.proj.weight                   | cpu        | False           | N/A             | Train
- backbone.patch_embed.proj.bias                     | cpu        | False           | N/A             | Train
- backbone.blocks.0.norm1.weight                     | cpu        | False           | Block 0         | Eval
- backbone.blocks.0.norm1.bias                       | cpu        | False           | Block 0         | Eval
- backbone.blocks.0.attn.qkv.weight                  | cpu        | False           | Block 0         | Eval
- backbone.blocks.0.attn.qkv.bias            

In [None]:
# prendo il path dello state_dict del miglio modello che ho salvato su Drive per non dover ri-trainare sempre
checkpoint_dir = "/content/drive/MyDrive/FL/FederatedLearningProject/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
best_model_path = os.path.join(checkpoint_dir, "best_model.pth")

In [None]:
# copio il modello di base
model = copy.deepcopy(o_model)

In [None]:
# aggiorno i pesi del modello con quelli trainati
model.load_state_dict(torch.load(best_model_path))

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/FL/best_model_locale.pth'

In [None]:
# 1. Prima del training:
before_weights = {name: p.data.clone().detach() for name, p in model.named_parameters()}

In [None]:
# ..............CENTRALIZZATO..................
# importiamo il modello migliore: best_model
debugging = False
# 1. Definiamo il dataloader stratificato per allenare la maschera
stratified_loader = get_n_examples_per_class_loader(train_loader, num_classes=100, n_per_class=1) # 100 classi, un esempio per classe
if debugging:
  print_info_dataloader(stratified_loader)

# 1.2. definiamo hyper-parametri per la maschera
final_sparsity = 0.9    # hyperParam
tot_rounds = 10
soft_zero = 0.01
num_examples = None

# 2. Alleniamo la maschera su 100 esempi
start = time.time()
mask = compute_mask(model, stratified_loader, sparsity_target=final_sparsity, R=tot_rounds, num_examples=num_examples, soft_zero=soft_zero, device='cuda', enable_plot=0, debug=False) # num_examples=None : sfoglia tutto il dataset passato in ingresso (già ridotto)
end = time.time()
if debugging:
  plot_all_layers_mask_sparsity(mask)
  plot_qkv_weight_bias_sparsity(mask)
  print(f"Mask computation time: {end-start}")
  print(f"Sparsity target: {final_sparsity}")
  print(f"Soft Zero Value: {soft_zero}")
  print(f"Rounds: {tot_rounds}")
  print(f"Num_examples: {num_examples}")

# 3. Salviamo la maschera
torch.save(mask, 'mask.pth')

In [None]:
c = {
    'epochs': {
        'values': [20, 30, 40] }}

configs = generate_configs(c)

In [None]:
# Percorsi dei file JSON
completed_combinations_path = os.path.join(checkpoint_dir, "completed_combinations2.json")
best_model_path_2 = os.path.join(checkpoint_dir, "best_model_2.pth")
best_combination_path_2 = os.path.join(checkpoint_dir, "best_combination_2.json")

# Carica le combinazioni completate
if os.path.exists(completed_combinations_path):
    with open(completed_combinations_path, "r") as f:
        completed_combinations2 = json.load(f)
    print(f"Completed combinations: {completed_combinations2}")
else:
    completed_combinations2 = []
    print("No completed combinations")

# Carica la migliore combinazione salvata, se esiste
if os.path.exists(best_combination_path_2):
    with open(best_combination_path_2, "r") as f:
        best_combination_info = json.load(f)
        best_val_accuracy = best_combination_info.get("best_val_accuracy", 0.0)
        best_index = best_combination_info.get("best_index", None)
else:
    best_val_accuracy = 0.0
    best_index = None


# Loop su tutte le configurazioni
for i in range(len(configs)):
    if str(i) in completed_combinations2:
        print(f"Skipping combination {i} (already completed)")
        continue

    config_i = configs[i]

    learning_rate = 0.01
    weight_decay = 0.0001
    momentum = 0.9
    epochs = config_i["epochs"]

    model = copy.deepcopy(o_model)
    model.load_state_dict(torch.load(best_model_path))
    model.freeze(0)
    model.train()

    params_to_optimize = model.named_parameters()
    optimizer = SparseSGDM(params_to_optimize, masks=mask, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    model_name = "dino_vit-s-16-edited"
    project_name = "BaselineCentralized_Editing_definitiva"
    run_name = f"{model_name}_run_{i}"

    wandb.init(
        project=project_name,
        name=run_name,
        id=run_name,
        config={
            "model": model_name,
            "epochs": epochs,
            "batch_size": train_loader.batch_size,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "momentum": momentum,
            "architecture": model.__class__.__name__,
        }
    )

    config = wandb.config

    checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_run_{i}_checkpoint.pth")
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    start_epoch, checkpoint_data = load_checkpoint(model, optimizer, scheduler, run_name)

    # Esegui training e ottieni la migliore validation accuracy del run
    val_accuracy = train_and_validate(
        start_epoch,
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        scheduler=scheduler,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        checkpoint_path=checkpoint_path,
        num_epochs=epochs,
        checkpoint_interval=5
    )

    wandb.finish()

    # Salva il miglior modello se la val_accuracy è migliorata
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_index = i
        torch.save(model.state_dict(), best_model_path_2)

        # Salva anche l'indice e la val_accuracy nel file JSON
        with open(best_combination_path_2, "w") as f:
            json.dump({"best_index": best_index, "best_val_accuracy": best_val_accuracy}, f)

        print(f" Best model updated! Combination: {best_index} | Accuracy: {best_val_accuracy:.4f}")
    else:
        print(f"Best combination is {best_index} with val accuracy {best_val_accuracy:.4f}")

    # Segna il run come completato
    completed_combinations2.append(str(i))
    with open(completed_combinations_path, "w") as f:
        json.dump(completed_combinations2, f)

    print(f" Finished combination {i}")

In [None]:
plot_all_layers_mask_sparsity(mask)

In [None]:
valid_split_perc = 0   # of the 50000 training data
train_set, test_set = get_cifar100(valid_split_perc)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

In [None]:
checkpoint_dir = "/content/drive/MyDrive/FL/FederatedLearningProject/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
best_model_path = os.path.join(checkpoint_dir, "best_model.pth")

In [None]:
learning_rate = 0.01
weight_decay = 0.0001
momentum = 0.9
epochs = 40

model = copy.deepcopy(o_model)
model.load_state_dict(torch.load(best_model_path))
model.freeze(0)
model.train()

params_to_optimize = model.named_parameters()
optimizer = SparseSGDM(params_to_optimize, masks=mask, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

model_name = "dino_vit-s-16_"
project_name = "BaselineCentralized_editing_Best_Model_Test"
run_name = f"{model_name}_run"

wandb.init(
    project=project_name,
    name=run_name,
    id=run_name,
    config={
        "model": model_name,
        "epochs": epochs,
        "batch_size": train_loader.batch_size,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "momentum": momentum,
        "architecture": model.__class__.__name__,
    }
)

config = wandb.config

checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_run_checkpoint_Test_Best.pth")
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

start_epoch, checkpoint_data = load_checkpoint(model, optimizer, scheduler, run_name)

test_accuracy = train_and_test(
    start_epoch,
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    scheduler=scheduler,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    checkpoint_path=checkpoint_path,
    num_epochs=epochs,
    checkpoint_interval=5
)

wandb.finish()