In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import logging
import os
import random
import sys
import pdb
import numpy as np
import torch
import torch.nn as nn
#from fedml_api.model.cv.lenet5 import LeNet5
# sys.path.append('/data/users2/bthapaliya/DistributedFLExperiments/DistributedFL')
# sys.path.append('/data/users2/bthapaliya/DistributedFLExperiments/DistributedFL/fedml_api')

sys.path.insert(0, os.path.abspath("/data/users2/bthapaliya/DistributedFLExperiments/DistributedFL/data/"))
from fedml_api.model.cv.salient_models import AlexNet3D_Dropout, ResNet_l3

from fedml_api.data_preprocessing.cifar100.data_loader import load_partition_data_cifar100
from fedml_api.model.cv.vgg import vgg16, vgg11
from fedml_api.model.cv.cnn_cifar10 import cnn_cifar10, cnn_cifar100
from fedml_api.standalone.DisPFL.dispfl_api import dispflAPI
from fedml_api.standalone.sailentgrads.sailentgrads_api import SailentGradsAPI
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from fedml_api.data_preprocessing.cifar10.data_loader import load_partition_data_cifar10
from fedml_api.data_preprocessing.ABCD.data_loader import load_partition_data_abcd
from fedml_api.data_preprocessing.tiny_imagenet.data_loader import load_partition_data_tiny
from fedml_api.model.cv.resnet import  customized_resnet18, original_resnet18, tiny_resnet18
from fedml_api.standalone.sailentgrads.my_model_trainer import MyModelTrainer

from fedml_api.standalone.sailentgrads.client import Client

In [4]:
class Args:
    model ='3DCNN'
    dataset = 'ABCD'
    data_dir ='/data/users2/bthapaliya/NeuroimageDistributedFL/DistributedFL'
    partition_method='dir'
    partition_alpha=0.3
    batch_size=16
    client_optimizer='sgd'
    lr=0.001
    lr_decay=0.998
    wd=5e-4
    epochs=2
    client_num_in_total = 6
    frac = 0.5
    momentum=0
    comm_round=200
    frequency_of_the_test=1
    gpu=0
    ci=0
    dense_ratio=0.5
    anneal_factor=0.5
    seed=1024
    cs='v0'
    itersnip_iteration = 1
    stratified_sampling ='store_true'
    active=1.0
    public_portion=0
    erk_power_scale=1
    dis_gradient_check=False
    strict_avg= False
    static = False
    uniform = False
    save_masks = False
    different_initial = False
    record_mask_diff = False
    diff_spa = False
    global_test = False
    tag="test"
    snip_mask=True

In [5]:
# python main_sailentgrads.py --model 'resnet18' \
# --dataset 'cifar10' \
# --partition_method 'dir' \
# --partition_alpha 0.3 \
# --batch_size 16 \
# --lr 0.1 \
# --lr_decay 0.998 \
# --epochs 5 \
# --dense_ratio 0.1 \
# --client_num_in_total 100 --frac 0.1 \
# --comm_round 500 \
# --seed 2022

In [6]:
args = Args()

In [7]:
def create_model(args, model_name,class_num):
    model = None
    if model_name == "3DCNN":
        model = AlexNet3D_Dropout(num_classes=class_num)
    if model_name == "cnn_cifar10":
        model = cnn_cifar10()
    elif model_name == "cnn_cifar100":
        model = cnn_cifar100()
    elif model_name == "resnet18" and args.dataset != 'tiny':
        model = customized_resnet18(class_num=class_num)
    elif model_name == "resnet18" and args.dataset == 'tiny':
        model = tiny_resnet18(class_num=class_num)
    elif model_name == "vgg11":
        model = vgg11(class_num)
    return model


def custom_model_trainer(args, model, logger):
    return MyModelTrainer(model, args, logger)

def logger_config(log_path, logging_name):
    logger = logging.getLogger(logging_name)
    logger.setLevel(level=logging.DEBUG)
    handler = logging.FileHandler(log_path, mode='w+',encoding='UTF-8')
    handler.setLevel(level=logging.DEBUG)
    formatter = logging.Formatter('%(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger


In [8]:
def load_data(args, dataset_name):
    if dataset_name == "ABCD":
        args.data_dir += "ABCD"
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_abcd(args.data_dir, args.partition_method,
                                args.partition_alpha, args.client_num_in_total, args.batch_size, logger)

    if dataset_name == "cifar10":
        args.data_dir += "cifar10"
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_cifar10(args.data_dir, args.partition_method,
                                args.partition_alpha, args.client_num_in_total, args.batch_size, logger)
    elif dataset_name == "cifar100":
        args.data_dir += "cifar100"
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_cifar100(args.data_dir, args.partition_method,
                                                args.partition_alpha, args.client_num_in_total, args.batch_size, logger)
    elif dataset_name == "tiny":
        args.data_dir += "tiny_imagenet"
        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = load_partition_data_tiny(args.data_dir, args.partition_method,
                                             args.partition_alpha, args.client_num_in_total,
                                                 args.batch_size, logger)

    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset

In [9]:
print("torch version{}".format(torch.__version__))
device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")

data_partition=args.partition_method
if data_partition!="homo":
    data_partition+=str(args.partition_alpha)
args.identity = "SailentGrads" + "-" + args.dataset + "-" + data_partition
args.identity+="-mdl" + args.model + "customized" +"lowbatch"
args.identity+="-cs"+args.cs

args.identity += "-cm" + str(args.comm_round) + "-total_clnt" + str(args.client_num_in_total)
args. client_num_per_round = int(args.client_num_in_total* args.frac)
args.identity += "-neighbor" + str(args.client_num_per_round)
args.identity += "-dr" + str(args.dense_ratio)
args.identity += "-active" + str(args.active)
args.identity += '-seed' + str(args.seed)
args.identity += '-batchsize' + str(args.batch_size)
args.identity += '-iteration' + str(args.itersnip_iteration)
args.identity += '-stratified' + str(args.stratified_sampling)


torch version1.12.1+cu116


In [10]:
log_path = os.path.join('./', 'LOG/test' + '.log')
main_log_path = os.path.join('LOG/' + args.dataset)
if not os.path.exists(main_log_path):
    os.makedirs(main_log_path)
logger = logger_config(log_path='LOG/' + args.dataset + '.log', logging_name=args.identity)

logger.info(args)
logger.info(device)

In [11]:
# dataset = load_data(args, "ABCD")

In [12]:
# with open('abcd.pickle', 'wb') as handle:
#     pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [13]:
with open('abcd.pickle', 'rb') as handle:
    dataset = pickle.load(handle)

In [14]:
# create model.
model = create_model(args, model_name=args.model,class_num=1)
model = model.to(device)
model_trainer = custom_model_trainer(args, model, logger)
logger.info(model)

In [15]:
def setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer):
    for client_idx in range(args.client_num_in_total):
        c = Client(client_idx, train_data_local_dict[client_idx], test_data_local_dict[client_idx],
                    train_data_local_num_dict[client_idx], args, device, model_trainer, logger)
        client_list.append(c)

In [16]:
[train_data_num, test_data_num, train_data_global, test_data_global,
  train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_counts] = dataset

client_list = []
setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer)

In [17]:
# Get client model and client data
clnt = client_list[1]
clnt_data = clnt.local_training_data
clnt_testdata = clnt.local_test_data
model = clnt.model_trainer.model

In [18]:
x,labels,z = next(iter(clnt_data))

In [19]:
x = x.to(device)  # Convert to tensor
x = x.unsqueeze(1)

x, labels = x.to(device), labels.to(device)
model.zero_grad()

In [20]:
output = model(x)

In [21]:
args.epochs=40

In [22]:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=args.momentum,weight_decay=args.wd)

In [23]:
def train(model, optimizer, train_data,  device,  args, round, epochs):
    # torch.manual_seed(0)
    model = model
    model.to(device)
    model.train()
    # train and update
    criterion = nn.BCEWithLogitsLoss().to(device)
    for epoch in range(epochs):
        epoch_loss = []
        #for batch_idx, (x, labels) in enumerate(train_data):
        for x, labels, _ in train_data:
            #For 3DConv Network
            #x = torch.tensor(x, dtype=torch.float32)  # Convert to tensor
            x = x.to(device)  # Convert to tensor
            x = x.unsqueeze(1)

            x, labels = x.to(device), labels.to(device)
            model.zero_grad()
            log_probs = model.forward(x)
            loss = criterion(log_probs, labels.unsqueeze(1).float())
            loss.backward()
            # to avoid nan loss
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
            optimizer.step()
            epoch_loss.append(loss.item())
            # for name, param in model.named_parameters():
            #     if name in masks:
            #         param.data *= masks[name].to(device)
        print('Epoch: {}\tLoss: {:.6f}\t LR: {:.4f}'.format(epoch, sum(epoch_loss) / len(epoch_loss), optimizer.param_groups[0]['lr']))

In [24]:
train(model, optimizer, clnt_data, device, args, round=0, epochs=100)

Epoch: 0	Loss: 0.828396	 LR: 0.0100
Epoch: 1	Loss: 0.645710	 LR: 0.0100
Epoch: 2	Loss: 0.506960	 LR: 0.0100
Epoch: 3	Loss: 0.911434	 LR: 0.0100
Epoch: 4	Loss: 0.714987	 LR: 0.0100
Epoch: 5	Loss: 0.841519	 LR: 0.0100
Epoch: 6	Loss: 0.645194	 LR: 0.0100
Epoch: 7	Loss: 0.714814	 LR: 0.0100
Epoch: 8	Loss: 0.612401	 LR: 0.0100
Epoch: 9	Loss: 0.623211	 LR: 0.0100
Epoch: 10	Loss: 0.759901	 LR: 0.0100
Epoch: 11	Loss: 0.676069	 LR: 0.0100
Epoch: 12	Loss: 0.605578	 LR: 0.0100
Epoch: 13	Loss: 0.586054	 LR: 0.0100
Epoch: 14	Loss: 0.625682	 LR: 0.0100
Epoch: 15	Loss: 0.524836	 LR: 0.0100
Epoch: 16	Loss: 0.664885	 LR: 0.0100
Epoch: 17	Loss: 0.815500	 LR: 0.0100
Epoch: 18	Loss: 0.539739	 LR: 0.0100
Epoch: 19	Loss: 0.593757	 LR: 0.0100
Epoch: 20	Loss: 0.740088	 LR: 0.0100
Epoch: 21	Loss: 0.673250	 LR: 0.0100
Epoch: 22	Loss: 0.556431	 LR: 0.0100
Epoch: 23	Loss: 0.505106	 LR: 0.0100
Epoch: 24	Loss: 0.545667	 LR: 0.0100
Epoch: 25	Loss: 0.669968	 LR: 0.0100
Epoch: 26	Loss: 0.453069	 LR: 0.0100
Epoch: 27	L

In [25]:
optimizer.param_groups[0]['lr'] = 0.001
optimizer.param_groups[0]['lr']

0.001

In [26]:
def test(model, test_data, device):
    model.to(device)
    model.eval()

    metrics = {
        'test_correct': 0,
        'test_acc':0.0,
        'test_loss': 0,
        'test_total': 0
    }

    criterion = nn.BCEWithLogitsLoss().to(device)

    with torch.no_grad():
        #for batch_idx, (x, target) in enumerate(test_data):
        for x, target, _ in test_data:
            #For 3DConv Network
            #x = torch.tensor(x, dtype=torch.float32)  # Convert to tensor
            x = x.to(device)  # Convert to tensor
            x = x.unsqueeze(1)

            #x = x.to(device)
            target = target.to(device)
            pred = model(x)
            loss = criterion(pred, target.unsqueeze(1).float())

            _, predicted = torch.max(pred, -1)
            correct = predicted.eq(target).sum()

            metrics['test_correct'] += correct.item()
            metrics['test_loss'] += loss.item() * target.size(0)
            metrics['test_total'] += target.size(0)
            metrics['test_acc'] = metrics['test_correct'] / metrics['test_total']
    return metrics

In [27]:
test(model, clnt_testdata, device)

{'test_correct': 5,
 'test_acc': 0.625,
 'test_loss': 19.62701416015625,
 'test_total': 8}

In [28]:
for i, data in enumerate(clnt_testdata):
    print(i)
    data

0


In [29]:
data[0].shape

torch.Size([8, 121, 145, 121])

In [30]:
client_list

[<fedml_api.standalone.sailentgrads.client.Client at 0x7f661ea59450>,
 <fedml_api.standalone.sailentgrads.client.Client at 0x7f661ea5b640>,
 <fedml_api.standalone.sailentgrads.client.Client at 0x7f661ea58910>,
 <fedml_api.standalone.sailentgrads.client.Client at 0x7f661ea5ae60>,
 <fedml_api.standalone.sailentgrads.client.Client at 0x7f661ea59360>,
 <fedml_api.standalone.sailentgrads.client.Client at 0x7f661ea5a0e0>]