In [1]:
"""Script to run the baselines."""
import importlib
from pprint import pprint
import inspect
import numpy as np
import os

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ['CUDA_VISIBLE_DEVICES'] = "3"
import random
import torch
import wandb
from datetime import datetime

import metrics.writer as metrics_writer
from baseline_constants import (
    MAIN_PARAMS,
    MODEL_PARAMS,
)
from utils.args import parse_args, check_args
from utils.cutout import Cutout
from utils.main_utils import *
from utils.model_utils import read_data, read_public_data

os.environ["WANDB_API_KEY"] = ""
os.environ["WANDB_MODE"] = "offline"

  from .autonotebook import tqdm as notebook_tqdm


In [49]:
from unittest.mock import Mock
args = Mock()
args.seed = 0
args.alpha = 0.0
args.device = "cuda:0"
args.clients_per_round = 5
args.num_rounds = 10
args.model = "destillation"
args.lr = 0.01
args.algorithm = "fedmd"
args.wandb_run_id = None
args.server_opt = "sgd"
args.server_momentum = 0
args.client_algorithm = None
args.dataset = "cifar100"
args.publicdataset = "cifar10"
args.t = 'large'
args.load = False
args.num_workers = 0
args.batch_size = 64

In [13]:
from main import 

SyntaxError: invalid syntax (3444482996.py, line 1)

In [50]:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# CIFAR: obtain info on parameter alpha (Dirichlet's distribution)
alpha = args.alpha
if alpha is not None:
    alpha = "alpha_{:.2f}".format(alpha)
    print("Alpha:", alpha)

# Setup GPU
device = torch.device(args.device if torch.cuda.is_available else "cpu")
print(
    "Using device:",
    torch.cuda.get_device_name(device) if device != "cpu" else "cpu",
)

run, job_name = init_wandb(args, alpha, run_id=args.wandb_run_id)

# Obtain the path to client's model (e.g. cifar10/cnn.py), client class and servers class
model_path = "%s/%s.py" % (args.dataset, args.model)
dataset_path = "%s/%s.py" % (args.dataset, "dataloader")
server_path = "servers/%s.py" % (args.algorithm + "_server")
paths = [model_path, dataset_path, server_path]
client_sufix = (
    f"{args.client_algorithm}_client"
    if args.client_algorithm is not None
    else "client"
)
client_path = f"clients/{client_sufix}.py"
paths.append(client_path)
check_init_paths(paths)

# Experiment parameters (e.g. num rounds, clients per round, lr, etc)
tup = MAIN_PARAMS[args.dataset][args.t]
num_rounds = args.num_rounds if args.num_rounds != -1 else tup[0]
eval_every = args.eval_every if args.eval_every != -1 else tup[1]
clients_per_round = (
    args.clients_per_round if args.clients_per_round != -1 else tup[2]
)

model_path = "%s.%s" % (args.dataset, args.model)
dataset_path = "%s.%s" % (args.dataset, "dataloader")
server_path = "servers.%s" % (args.algorithm + "_server")
model_params = MODEL_PARAMS[model_path]
if args.lr != -1:
    model_params_list = list(model_params)
    model_params_list[0] = args.lr
    model_params = tuple(model_params_list)

# Load model and dataset
print(f"{'#' * 30} {model_path} {'#' * 30}")

checkpoint = {}
client_models = []
public_models = []
client_path = f"clients.{client_sufix}"
PublicDataset = None
mod = importlib.import_module(model_path)
dataset = importlib.import_module(dataset_path)
ClientDataset = getattr(dataset, "ClientDataset")

Alpha: alpha_0.00
Using device: NVIDIA GeForce GTX 1660 Ti




############################## cifar100.destillation ##############################


In [51]:
if args.model == "destillation":
    publicmodel_path = "%s.%s" % (args.publicdataset, args.model)
    publicdataset_path = "%s.%s" % (args.publicdataset, "dataloader")
    publicdataset = importlib.import_module(publicdataset_path)
    publicmod = importlib.import_module(publicmodel_path)
    PublicDataset = getattr(publicdataset, "ClientDataset")
    print("Running experiment with server", server_path, "and client", client_path)
    client_models = []
    Client, Server = get_clients_and_server(server_path, client_path)
    print("Verify client and server:", Client, Server)
    for model_number in range(5):
        ClientModel = getattr(mod, f"ClientModel{model_number}")
        PublicClientModel = getattr(publicmod, f"ClientModel{model_number}")
        client_model = ClientModel(*model_params, device)
        client_models.append(client_model)
        public_client_model = PublicClientModel(*model_params, device)
        public_models.append(public_client_model)
    assert not args.load, "Not implemented checkpoimws yet"
    client_models = [model.to(device) for model in client_models]
    public_models = [model.to(device) for model in public_models]
else:
    ClientModel = getattr(mod, "ClientModel")
    print("Running experiment with server", server_path, "and client", client_path)
    Client, Server = get_clients_and_server(server_path, client_path)
    # Load client and server
    print("Verify client and server:", Client, Server)
    client_model = ClientModel(*model_params, device)
    if args.load and wandb.run and wandb.run.resumed:  # load model from checkpoint
        [client_model], checkpoint, ckpt_path_resumed = resume_run(
            client_model, args, wandb.run
        )
        if args.restart:  # start new wandb run
            wandb.finish()
            print("Starting new run...")
            run = init_wandb(args, alpha, run_id=None)
    client_model = client_model.to(device)
    client_models.append(client_model)

Running experiment with server servers.fedmd_server and client clients.client
fedmd
Verify client and server: <class 'clients.client.Client'> <class 'servers.fedmd_server.FedMdServer'>


In [52]:
server_params = define_server_params(
    args,
    client_models,
    public_models,
    args.algorithm,
    opt_ckpt=args.load and checkpoint.get("opt_state_dict"),
    PublicDataset=PublicDataset,
)
server = Server(**server_params)

In [None]:
if args.model == "destillation":
    train_clients, _ = setup_clients(
        args,
        client_models,
        public_models,
        Client,
        ClientDataset,
        PublicDataset,
        run,
        device,
    )
    train_client_ids, train_client_num_samples = server.get_clients_info(train_clients)
    print("Clients in Total: %d" % len(train_clients))
    server.set_num_clients(len(train_clients))
else: 
    train_clients, test_clients = setup_clients(
        args,
        client_models,
        public_models,
        Client,
        ClientDataset,
        PublicDataset,
        run,
        device,
    )
    train_client_ids, train_client_num_samples = server.get_clients_info(train_clients)
    test_client_ids, test_client_num_samples = server.get_clients_info(test_clients)
    if set(train_client_ids) == set(test_client_ids):
        print("Clients in Total: %d" % len(train_clients))
    else:
        print(
            f"Clients in Total: {len(train_clients)} training clients and {len(test_clients)} test clients"
        )
    server.set_num_clients(len(train_clients))

In [56]:
args.num_workers = 0
args.batch_size = 64