In [1]:
import copy
import os

import torch
import numpy as np
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from tqdm import tqdm
import cv2

from utils.options import args_parser
from utils.utils import exp_details, get_train_test, average_weights, get_model
from utils.update import LocalUpdate, test_inference
from utils.sampling import dominant_label_sampling, dirichlet_sampling

In [2]:
args = args_parser(default=True)
args.supervision = True
exp_details(args)


Experimental details:

Reinforcement Arguments:
    Steps Before PPO Update : 256
    PPO Learning Rate       : 0.0003
    PPO Discount Factor     : 0.9
    PPO Batch Size          : 16
    PPO Total Timesteps     : 15000
    Target Accuracy         : 0.95

Federated Arguments:
    Number of Users         : 100
    Fraction of Users       : 0.1
    Local Epochs            : 1
    Local Batch Size        : 10
    Learning Rate           : 0.001
    Momentum                : 0.5
    Optimizer               : adam

Model Arguments:
    Supervision             : True
    Architecture            : cnn

Misc. Arguments:
    Dataset                 : mnist
    Number of GPUs          : 1
    IID                     : 0
    Random Seed             : 1
    Test Fraction           : 1
    Save Path               : ../../save
    Data Path               : ../../data



In [3]:
def train_models(dataset, arch, sampling_type, optimal_epochs, client_epochs):
    train_dataset, test_dataset = get_train_test(dataset)

    base_model = get_model(arch, dataset, "cuda")
    optimal_model = copy.deepcopy(base_model)
    global_model = copy.deepcopy(base_model)
    local_models = [copy.deepcopy(base_model) for _ in range(args.num_users)]

    if sampling_type == "dominant_label":
        dict_users = dominant_label_sampling(train_dataset, num_users=args.num_users, num_samples=50_000, gamma=0.8, print_labels=False)
    if sampling_type == "dirichlet":
        dict_users = dirichlet_sampling(train_dataset, num_users=args.num_users, num_samples=30_000, alpha=0.2, print_labels=False)

    print("Training optimal model...")
    epoch_optimal_params = []
    for i in range(optimal_epochs + 1):
        # Evaluate accuracy and save parameters
        acc, _ = test_inference(supervision=True, device="cuda", model=optimal_model, test_dataset=test_dataset, test_fraction=1)
        print(f"Epoch {i}/{optimal_epochs} | Accuracy of optimal model: {acc}")
        optimal_params = torch.cat([p.flatten() for p in optimal_model.parameters()]).detach().cpu().numpy()
        epoch_optimal_params.append(optimal_params)
        
        if (i == optimal_epochs):
            break
        
        # Perform training
        local_update = LocalUpdate(
            train_dataset,
            range(50_000),
            args.local_ep,
            args.local_bs,
            args.lr,
            args.optimizer,
            args.supervision,
            "cuda",
        )
        local_update.update_weights(optimal_model)
    print()
    
    print("Training client fitting models...")
    epoch_client_params = []
    for i in tqdm(range(client_epochs)):
        
        # Save parameters
        client_params = []
        for i in dict_users.keys():
            local_model = local_models[i]
            params = torch.cat([p.flatten() for p in local_model.parameters()]).detach().cpu().numpy()
            client_params.append(params)

        # Train client models
        curr_users = np.random.choice(args.num_users, int(args.num_users * args.frac))
        for usr in curr_users:
            local_model = local_models[usr]
            local_update = LocalUpdate(
                train_dataset,
                dict_users[usr],
                args.local_ep,
                args.local_bs,
                args.lr,
                args.optimizer,
                args.supervision,
                "cuda",
            )
            local_update.update_weights(local_model)

        # Aggregate client weights
        local_model_weights = [local_models[usr].state_dict() for usr in curr_users]
        avg_weights = average_weights(local_model_weights)
        global_model.load_state_dict(avg_weights)
        for model in local_models:
            model.load_state_dict(avg_weights)
            
        # Save aggregated client params
        for i in dict_users.keys():
            local_model = local_models[i]
            params = torch.cat([p.flatten() for p in local_model.parameters()]).detach().cpu().numpy()
            client_params.append(params)
            
        epoch_client_params.append(client_params)
    
    return epoch_optimal_params, epoch_client_params

In [4]:
def plot_pca(epoch_optimal_params, epoch_client_params, dataset, image_folder):
    print("Plotting principal components...")
    pca = PCA(n_components=2)
    pca.fit(epoch_optimal_params)
    pca.fit([params for client_params in epoch_client_params for params in client_params])

    epoch_optimal_params_pca = pca.transform(epoch_optimal_params)
    epoch_client_params_pca = np.array([pca.transform(client_params) for client_params in epoch_client_params])

    all_x = [*epoch_optimal_params_pca[:, 0], *epoch_client_params_pca[:, :, 0].flatten()]
    all_y = [*epoch_optimal_params_pca[:, 1], *epoch_client_params_pca[:, :, 1].flatten()]

    xlim = (min(all_x), max(all_x))
    ylim = (min(all_y), max(all_y))

    if dataset == "mnist":
        dataset_name = "MNIST"
    elif dataset == "cifar":
        dataset_name = "CIFAR-10"

    for epoch in tqdm(range(len(epoch_client_params))):
        fig, ax = plt.subplots()
        pth = ax.scatter(epoch_optimal_params_pca[:, 0], epoch_optimal_params_pca[:, 1], c=range(len(epoch_optimal_params)), cmap="plasma")
        fig.colorbar(pth)

        ax.scatter(epoch_client_params_pca[epoch, :, 0], epoch_client_params_pca[epoch, :, 1], c="g")
        ax.set_title(f"Principal Components of {dataset_name} Models")
        ax.set_xlabel("Principal Component 1")
        ax.set_ylabel("Principal Component 2")
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        
        image_path = os.path.join(image_folder, f"pca_models_{dataset}_epoch_{epoch}.png")
        plt.savefig(image_path)
        
        plt.close()

In [5]:
def make_video(dataset, image_folder, video_path):
    print("Generating PCA visualization video...")
    images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
    images.sort(key=lambda a: int(a.split("_")[-1].split(".")[0]))
    frame = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, layers = frame.shape

    video = cv2.VideoWriter(video_path, 0, 1, (width, height))

    for image in tqdm(images):
        video.write(cv2.imread(os.path.join(image_folder, image)))

    cv2.destroyAllWindows()
    video.release()

In [7]:
image_folder = os.path.join(args.save_path, "pca_visualization/images")
video_path = os.path.join(args.save_path, "pca_visualization/videos/pca_video_mnist.avi")

epoch_optimal_params, epoch_client_params = train_models("mnist", "mlp", "dominant_label", 10, 80)

plot_pca(epoch_optimal_params, epoch_client_params, "mnist", image_folder)
make_video("mnist", image_folder, video_path)

Plotting principal components...


100%|██████████| 80/80 [00:11<00:00,  6.99it/s]


Generating PCA visualization video...


100%|██████████| 80/80 [00:00<00:00, 347.30it/s]
