In [5]:
import os
import numpy as np
import torch

In [11]:
from Config import parse_arguments
from Models.model_utils import FEDMD_digest_revisit, save_checkpoints, select_model, get_logits, create_model_architectures_noblip
from Common.dataset_fedmd import FEDMD_load_dataloaders_noblip

In [12]:
args = parse_arguments()

In [13]:
args

Namespace(sample_fraction=1.0, min_num_clients=2, KD=False, algo_type='FEDMD_homogen', available_architectures_noblip=['NetS', 'DeepNN_Ram', 'DeepNN_Hanu', 'DeepNN_Lax'], available_architectures_blip=['MLP1', 'MLP2', 'MLP3', 'MLP4'], name='cifar100', partitioning='dirichlet', num_clients=2, num_rounds=2, BLIP=False, proximal_mu=0.02, scheduler_=False, val_ratio=0.2, seed=42, alpha=1.0, labels_per_client=2, similarity=0.5, batch_size=32, batch_size_ratio=0.01, server_epochs=5, num_cpus=1, num_gpus=1, device='cuda')

In [14]:
def fit_config(server_round: int):
    """Return training configuration dict for each round."""

    config = {
        "server_round": server_round,
        "epochs": 2,
        "proximal_mu":args.proximal_mu,
        "client_kd_alpha": 0.3,
        "client_kd_temperature": 3,
        "server_kd_alpha": 0.7,
        "server_kd_temperature": 7,
        "server_lr": 0.001 
            }
    return config

In [18]:
    
def trasfer_learning_init(args, model_architectures, serverloader, trainloaders):
    """
    Initialize the model for transfer learning.
    This function trains the model on the server data and then initializes client models.
    """

    client_logits = {}
    config = fit_config(server_round=0)
    num_classes=100

    for i, trainloader in enumerate(trainloaders):
        print(f"\n Transfer learning on client {i + 1}\n")

        net = select_model(model_architectures[i], args.device, args.BLIP, num_classes=num_classes)
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
        # Training the model on public data
        FEDMD_digest_revisit(serverloader, net, optimizer, args.device, config, aggregated_logits=None, mode="revisit")
        # Training the model on private data
        FEDMD_digest_revisit(trainloader, net, optimizer, args.device, config, aggregated_logits=None, mode="revisit")
        # logit Communication for round 1
        client_logits[i] = get_logits(serverloader, net, args.device)
        # Save the model and optimizer state for each client
        checkpoint_dir = f"client_checkpoints/{args.algo_type}/{args.alpha}/client_{i}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        save_checkpoints(net, optimizer, 0, checkpoint_dir)

    logits = np.vstack([client_logits[i] for i in sorted(client_logits.keys())])

    return logits

In [19]:
trainloaders, valloaders, testloader, server_loader = FEDMD_load_dataloaders_noblip(args)
results_save_path = f"./results/NO_BLIP/{args.algo_type}/{args.alpha}"
model_architectures = create_model_architectures_noblip(args.available_architectures_noblip, args.algo_type, args.num_clients)

In [20]:
model_architectures

{0: 'NetS', 1: 'NetS'}

In [21]:
logits = trasfer_learning_init(args, model_architectures, server_loader, trainloaders)


 Transfer learning on client 1

 Round 0 --> Epoch 1/2 	 Loss: 0.0656 	 Accuracy: 27.96%
 Round 0 --> Epoch 2/2 	 Loss: 0.0546 	 Accuracy: 40.04%
 Round 0 --> Epoch 1/2 	 Loss: 0.1243 	 Accuracy: 11.63%
 Round 0 --> Epoch 2/2 	 Loss: 0.1074 	 Accuracy: 21.39%

 Transfer learning on client 2

 Round 0 --> Epoch 1/2 	 Loss: 0.0669 	 Accuracy: 27.04%
 Round 0 --> Epoch 2/2 	 Loss: 0.0549 	 Accuracy: 39.44%
 Round 0 --> Epoch 1/2 	 Loss: 0.1253 	 Accuracy: 11.01%
 Round 0 --> Epoch 2/2 	 Loss: 0.1095 	 Accuracy: 19.60%


In [26]:
def average_client_logits(stacked_array: np.ndarray, weighted: bool = False, weights: np.ndarray = None, num_clients: int = 2, num_samples: int = 5000):
    """
    Averages logits from multiple clients.

    Args:
        stacked_array (np.ndarray): Stacked logits of shape (num_clients * num_samples, num_classes).
        weighted (bool): Whether to perform weighted averaging. Default is False.
        weights (np.ndarray): Array of shape (num_clients,) specifying client weights. Required if weighted=True.

    Returns:
        np.ndarray: Averaged logits of shape (num_samples, num_classes).
    """

    num_classes = stacked_array.shape[1]
    
    # Reshape to (num_clients, 5000, num_classes)
    reshaped = stacked_array.reshape(num_clients, num_samples, num_classes)

    if weighted:
        if weights is None or len(weights) != num_clients:
            raise ValueError("Weights must be provided and match number of clients.")
        weights = np.array(weights).reshape(num_clients, 1, 1)
        averaged = np.sum(reshaped * weights, axis=0) / np.sum(weights)
    else:
        averaged = reshaped.mean(axis=0)
    
    return averaged

In [27]:
logits.shape

(10000, 100)

In [28]:
averaged = average_client_logits(logits)

In [29]:
averaged.shape

(5000, 100)