In [1]:
import os
from functools import wraps
from collections import defaultdict
from tqdm import tqdm

import numpy as np
# import matplotlib
# import matplotlib.pyplot as plt
import copy
import random
import time
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils, datasets
from argparse import ArgumentParser
from torchvision import transforms as tt
from PIL import Image
from utils import AverageMeter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils.utils import *
# from utils.training import *
# from utils.training_batch import *
from utils.model import *
from utils.BYOL_models import *

In [3]:
# set manual seed for reproducibility
seed = 1234

In [4]:
# general reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f2185d0f190>

In [5]:
# gpu training specific
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [6]:
# ('mnist', 'femnist', 'fmnist', 'cifar10', 'cifar100', 'svhn')
data_path = "./data"
dataset = "cifar10"
# ('noniid-labeldir', 'iid', 'default') default only for femnist
partition = "iid"
test_batch = 128

In [7]:
# Hyperparameters_List (H) = [rounds, client_fraction, number_of_clients, number_of_training_rounds_local, local_batch_size, lr_client]

stop_gradient = True
has_predictor = True
OneLayer = "1_layer"
TwoLayer = "2_layer"
predictor_network=TwoLayer
global_epochs = 1000
# global_epochs = 3
client_fraction = 1.0
client_num = 5
local_epoch = 5
batch_size = 32
lr = 3e-4
partition = 'iid'
norm = 'bn'
alpha_partition = 0.5
sch_flag = False
iid = False
# 8 batch_size avg_freq = 40(25)
# 16 batch_size avg_freq = 20(25)
# 32 batch_size avg_freq = 2(125)
# 32 batch_size avg_freq = 5(50)
# 32 batch_size avg_freq = 10(25)
# 32 batch_size avg_freq = 25(10)
# 32 batch_size avg_freq = 50(5)
# 64 batch_size avg_freq = 5(25)
# 128 batch_size avg_freq = 3(21)
avg_freq = 10


data_portion = 1.0
noniid_ratio = 1.0
# noniid_ratio = 0.55

# save_path = f"./model/SplitFSSLMaxpool_resnet18/resnet18Maxpooling_cifar10_{batch_size}_{noniid_ratio}_{client_num}"
save_path = f"./model/SplitFSSL_BYOL_Avg25times/resnet18Maxpooling_cifar10_{batch_size}_{avg_freq}_{partition}_{client_num}"
# save_path = f"./model/SplitFSSL_BYOL32_DifAvgtimes/resnet18Maxpooling_cifar10_{batch_size}_{avg_freq}_{noniid_ratio}_{client_num}"
H = [global_epochs, client_fraction, client_num, local_epoch, batch_size, lr]

In [8]:
def save_checkpoint(state, checkpoint, filename= 'checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    os.makedirs(checkpoint, exist_ok=True)
    torch.save(state, filepath)
    print(f'global epoch {state["glepoch"]} saved')

In [9]:

# partition
net_dataidx_map, net_dataidx_map_test, traindata_cls_counts, testdata_cls_counts = partition_data(dataset, data_path, partition, client_num)

# get dataloader
train_loader_list = []
test_loader_list = []
for idx in range(client_num):
    
    dataidxs = net_dataidx_map[idx]
    if net_dataidx_map_test is None:
        dataidx_test = None 
    else:
        dataidxs_test = net_dataidx_map_test[idx]

    train_dl_local, test_dl_local, train_ds_local, test_ds_local = get_dataloader(dataset, 
                                                                   data_path, batch_size, test_batch, 
                                                                   dataidxs, dataidxs_test)
    train_loader_list.append(train_dl_local)
    test_loader_list.append(test_dl_local)

Files already downloaded and verified
Files already downloaded and verified
partition: iid
Data statistics Train: {0: {0: 987, 1: 1003, 2: 982, 3: 1057, 4: 966, 5: 950, 6: 1009, 7: 1015, 8: 1027, 9: 1004}, 1: {0: 1002, 1: 999, 2: 982, 3: 1020, 4: 1035, 5: 992, 6: 1067, 7: 923, 8: 997, 9: 983}, 2: {0: 1009, 1: 1012, 2: 1037, 3: 974, 4: 998, 5: 1004, 6: 960, 7: 997, 8: 1003, 9: 1006}, 3: {0: 1030, 1: 938, 2: 978, 3: 993, 4: 1050, 5: 1032, 6: 969, 7: 1021, 8: 964, 9: 1025}, 4: {0: 972, 1: 1048, 2: 1021, 3: 956, 4: 951, 5: 1022, 6: 995, 7: 1044, 8: 1009, 9: 982}}
Data statistics Test:
 {0: {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000}, 1: {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000}, 2: {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000}, 3: {0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000}, 4: {0: 1000, 1: 1000, 2: 1000,

In [10]:
class BasicBlock(nn.Module):
    # feature expansion
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    """ResNet
    Note two main differences from official pytorch version:
    1. conv1 kernel size: pytorch version uses kernel_size=7
    2. average pooling: pytorch version uses AdaptiveAvgPool
    """

    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.feature_dim = 512 * block.expansion

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # after conv1 do max pooling
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.layer1 = self._make_layer(block, 64 , num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128 , num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256 , num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512 , num_blocks[3], stride=2)
        self.avgpool = nn.AvgPool2d((4, 4))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

In [11]:
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()

In [12]:
# augmentation utils
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

In [13]:
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

In [14]:
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(
            current_model.parameters(), ma_model.parameters()
    ):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)


def byol_loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

In [15]:
class MLP(nn.Module):
    def __init__(self, dim, projection_size, hidden_size=4096, num_layer=TwoLayer):
        super().__init__()
        self.in_features = dim
        if num_layer == OneLayer:
            self.net = nn.Sequential(
                nn.Linear(dim, projection_size),
            )
        elif num_layer == TwoLayer:
            self.net = nn.Sequential(
                nn.Linear(dim, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_size, projection_size),
            )
        else:
            raise NotImplementedError(f"Not defined MLP: {num_layer}")

    def forward(self, x):
        return self.net(x)

In [16]:
#client model
class BYOL_Client(nn.Module):
    def __init__(
        self,
        net=ResNet18(),
        image_size=32,
        projection_size=2048,
        projection_hidden_size=4096,
        moving_average_decay=0.99,
        stop_gradient=True,
        has_predictor=True,
        predictor_network=TwoLayer,
    ):
        super().__init__()

        self.online_encoder = net
        if not hasattr(net, 'feature_dim'):
            feature_dim = list(net.children())[-1].in_features
        else:
            feature_dim = net.feature_dim
        self.online_encoder.fc = MLP(feature_dim, projection_size, projection_hidden_size)  # projector

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.stop_gradient = stop_gradient
        self.has_predictor = has_predictor
        
        # debug purpose
        # self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
        # self.reset_moving_average()
        
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert (
                self.target_encoder is not None
        ), "target encoder has not been created yet"
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def forward(self, image_one, image_two):
        online_proj_one = self.online_encoder(image_one)
        online_proj_two = self.online_encoder(image_two)

        # online_pred_one = self.online_predictor(online_proj_one)
        # online_pred_two = self.online_predictor(online_proj_two)

        if self.stop_gradient:
            with torch.no_grad():
                if self.target_encoder is None:
                    self.target_encoder = self._get_target_encoder()
                target_proj_one = self.target_encoder(image_one)
                target_proj_two = self.target_encoder(image_two)

                target_proj_one = target_proj_one.detach()
                target_proj_two = target_proj_two.detach()


        # loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        # loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        # loss = loss_one + loss_two
        return online_proj_one, online_proj_two, target_proj_one, target_proj_two

In [17]:
# server model
class BYOL_Server(nn.Module):
    def __init__(
        self,
        projection_size=2048,
        projection_hidden_size=4096,
        moving_average_decay=0.99,
        predictor_network=TwoLayer,
    ):
        super().__init__()

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size, predictor_network)

    def forward(self, online_proj_one, online_proj_two, target_proj_one, target_proj_two):

        online_pred_one = self.online_predictor(online_proj_one)
        online_pred_two = self.online_predictor(online_proj_two)


        loss_one = byol_loss_fn(online_pred_one, target_proj_two)
        loss_two = byol_loss_fn(online_pred_two, target_proj_one)
        loss = loss_one + loss_two
        
        return loss.mean()

In [18]:
class TransformsSimCLR:
    """
    A stochastic data augmentation module that transforms any given data example randomly 
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self, size):
        s = 1
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
        CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
            ]
        )

        self.test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(size=size),
                torchvision.transforms.ToTensor(),
                # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                # torchvision.transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
            ]
        )

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)


In [19]:
net = ResNet18()
client_model = BYOL_Client(net=net, stop_gradient=stop_gradient, has_predictor=has_predictor, predictor_network=predictor_network)
server_model = BYOL_Server()
server_model.cuda()

client_weights = [1/5 for i in range(client_num)]
client_models = [copy.deepcopy(client_model).cuda() for idx in range(client_num)]
# server_models = [copy.deepcopy(server_model).cuda() for idx in range(client_num)]

optimizer_server = torch.optim.Adam(server_model.parameters(), lr = H[5]) 
optimizer_clients = [torch.optim.Adam(client_models[i].parameters(), lr = H[5]) for i in range(len(client_models))]

In [20]:
# if using checkpoint to train
epoch = 0
# checkpath = save_path + "/checkpoint.pth.tar" 
# checkpoint = torch.load(checkpath)
# epoch = checkpoint['glepoch']
# print(epoch)
# optimizer_server.load_state_dict(checkpoint['optimizer'][0])
# for localmodel in client_models:
#     localmodel.online_encoder.load_state_dict(checkpoint['state_dict'])
# for clientidx in range(client_num):
#     optimizer_clients[clientidx].load_state_dict(checkpoint['optimizer'][clientidx+1])

In [21]:
client_model

BYOL_Client(
  (online_encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [22]:
class create_iterator():
    def __init__(self, iterator) -> None:
        self.iterator = iterator

    def __next__(self):
        return next(self.iterator)

In [23]:
client_iterator_list = []

for client_id in range(client_num):
    client_iterator_list.append(create_iterator(iter((train_loader_list[client_id]))))

In [24]:
def train_server(online_proj_one, online_proj_two, target_proj_one, target_proj_two, server_model):
    
    # print("shape", online_proj_one.shape)
    # print("online_proj_one", online_proj_one)
    
    
    online_proj_one.requires_grad = True
    online_proj_two.requires_grad = True
    
    online_proj_one.retain_grad()
    online_proj_two.retain_grad()

    server_model.train()
    
    # forward prop
    loss = server_model(online_proj_one, online_proj_two, target_proj_one, target_proj_two)
    
    if online_proj_one.grad is not None:
        online_proj_one.grad.zero_()
        
    if online_proj_two.grad is not None:
        online_proj_two.grad.zero_()
            
    # backward prop
    loss.backward()
    online_proj_one_grad, online_proj_two_grad = online_proj_one.grad.detach().clone(), online_proj_two.grad.detach().clone()
    # print("online_proj_one_grad", online_proj_one_grad.shape)
    # print("online_proj_two_grad", online_proj_two_grad.shape)
    
    return online_proj_one_grad, online_proj_two_grad, loss

In [25]:
def optimizer_zero_grads(optimizer_server, optimizer_clients):  # This needs to be called
    optimizer_server.zero_grad()
    for i in range(client_num):
        optimizer_clients[i].zero_grad()

In [26]:
# read one batch of image pair
def next_data_batch(client_id):
    try:
        img1, img2 = next(client_iterator_list[client_id])
        if img1.size(0) != batch_size:
            try:
                next(client_iterator_list[client_id])
            except StopIteration:
                pass
            client_iterator_list[client_id] = create_iterator(iter((train_loader_list[client_id])))
            img1, img2 = next(client_iterator_list[client_id])
    except StopIteration:
        client_iterator_list[client_id] = create_iterator(iter((train_loader_list[client_id])))
        img1, img2 = next(client_iterator_list[client_id])
    return img1, img2

In [27]:
def training(client_models, server_model, optimizer_server, optimizer_clients, rounds, batch_size, lr, C, K, local_epochs, plt_title, plt_color, cifar_data_test=None,
             test_batch_size=None, criterion=None, num_classes=None, classes_test=None, sch_flag=None):
   
    # training loss
    train_loss = []
    test_loss = []
    test_accuracy = []
    best_accuracy = 0
    avg_times = 0
    # measure time
    start = time.time()
    
    
    num_batch = len(train_loader_list[0])
    
    # writer = SummaryWriter(f'logs/SplitFSSL_BYOL32_DifAvgtimes/resnet18Maxpooling_cifar10_{batch_size}_{avg_freq}_{noniid_ratio}_{client_num}')
    writer = SummaryWriter(f'logs/SplitFSSL_BYOL_Avg25times/resnet18Maxpooling_cifar10_{batch_size}_{avg_freq}_{noniid_ratio}_{client_num}')
    global_step = 0
    for curr_round in range(epoch, rounds + 1):
        metrics = defaultdict(list)
        print(f"Global Round:", curr_round)
        w, local_loss = [], []
        
        batch_time = AverageMeter()
        data_time = AverageMeter()
        p_bar = tqdm(range(num_batch))
        
        for batch in range(num_batch):
            # print("0>", time.time() - start)
            optimizer_zero_grads(optimizer_server, optimizer_clients)
            
            online_proj_one_list = [None for _ in range(5)]
            online_proj_two_list = [None for _ in range(5)]
            target_proj_one_list = [None for _ in range(5)]
            target_proj_two_list = [None for _ in range(5)]

            # client forward
            # select 5 client to join training
            s_clients = []
            s_clients = random.sample(range(client_num), 5)
            # print("1>", time.time() - start)
            for i, client_id in enumerate(s_clients):
                # print("Client: ",i)
                # Compute a local update
                # print(i, "0>", time.time() - start)
                img1, img2 = next_data_batch(client_id)
                
                img1 = img1.cuda()
                img2 = img2.cuda()
                
                data_time.update(time.time() - start)
                # print(i, "1>", time.time() - start)
                # pass to client model
                # print("pass to client model")
                client_models[client_id].train()
                # print(i, "2>", time.time() - start)
                online_proj_one, online_proj_two, target_proj_one, target_proj_two = client_models[client_id](img1, img2)
                # print(i, "3>", time.time() - start)
                
                # store representations
                online_proj_one_list[i] = online_proj_one
                online_proj_two_list[i] = online_proj_two
                target_proj_one_list[i] = target_proj_one
                target_proj_two_list[i] = target_proj_two
                  

            # stack representations
            stack_online_proj_one = torch.cat(online_proj_one_list, dim = 0)
            stack_online_proj_two = torch.cat(online_proj_two_list, dim = 0)
            stack_target_proj_one = torch.cat(target_proj_one_list, dim = 0)
            stack_target_proj_two = torch.cat(target_proj_two_list, dim = 0)

            # print(">", time.time() - start)
            stack_online_proj_one, stack_online_proj_two, stack_target_proj_one, stack_target_proj_two = stack_online_proj_one.cuda(), stack_online_proj_two.cuda(), stack_target_proj_one.cuda(), stack_target_proj_two.cuda()
            
            # server computes
            # print("server computes")
            online_proj_one_grad, online_proj_two_grad, loss = train_server(stack_online_proj_one.detach(), stack_online_proj_two.detach(), stack_target_proj_one, stack_target_proj_two, server_model)
            local_loss.append((loss.item()))
            optimizer_server.step()
            
            # print(time.time() - start)
            # distribute gradients to clients
            # online_proj_one_grad, online_proj_two_grad = online_proj_one_grad.cpu(), online_proj_two_grad.cpu()
            gradient_dict_one = {key: [] for key in range(client_num)}
            gradient_dict_two = {key: [] for key in range(client_num)}
            
            for j in range(5):
                gradient_dict_one[j] = online_proj_one_grad[j*batch_size:(j+1)*batch_size, :]
                gradient_dict_two[j] = online_proj_two_grad[j*batch_size:(j+1)*batch_size, :]
                
            
            for i, client_id in enumerate(s_clients):
                online_proj_one_list[i].backward(gradient_dict_one[i])
                online_proj_two_list[i].backward(gradient_dict_two[i])
                optimizer_clients[client_id].step()
                client_models[client_id].update_moving_average()
            
            # if (batch+1)%10 == 0:
            #     print(f"Step [{batch}/{num_batch}]:\tLoss: {loss.item()}")
            
            del img1, img2
            writer.add_scalar("Loss/train_step", loss, global_step)
            metrics["Loss/train"].append(loss.item())
            global_step += 1
            
            batch_time.update(time.time() - start)
            start = time.time()
            #=======================================set p_bar description=======================================================
            p_bar.set_description("Train Epoch: {epoch}/{epochs:4}. Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}.".format(
                    epoch=curr_round,
                    epochs=rounds+1,
                    batch=batch + 1,
                    iter=num_batch,
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=loss.item()))
            p_bar.update()
            #=======================================set p_bar description=======================================================
            # in 32 batch size will have 250 batches, if aggregate per 10 batches will have 25 aggerations in one epoch
            # in 64 batch size will have 125 batches, if aggregate per 5 batches will have 25 aggerations in one epoch
            if batch == num_batch - 1 or ((batch+1) % avg_freq == 0):
                print("aggregate batch", batch)
                avg_times += 1
                with torch.no_grad():
                    # aggregate client models
                    for key in client_model.state_dict().keys():
                        # num_batches_tracked is a non trainable LongTensor and
                        # num_batches_tracked are the same for all clients for the given datasets
                        if "running" in key or "num_batches" in key:
                            continue
                        # elif 'target' in key:
                        #     continue
                        else:
                            temp = torch.zeros_like(client_model.state_dict()[key]).to('cuda')
                            for client_idx in s_clients:
                                temp += client_weights[client_idx] * client_models[client_idx].state_dict()[key]                        
                            client_model.state_dict()[key].data.copy_(temp)
                            for client_idx in range(len(client_models)):
                                client_models[client_idx].state_dict()[key].data.copy_(client_model.state_dict()[key])
        
        
        p_bar.close()
        # scheduler_server.step()
        for k, v in metrics.items():
            writer.add_scalar(k, np.array(v).mean(), curr_round)


        # loss
        loss_avg = sum(local_loss) / len(local_loss)
        train_loss.append(loss_avg)
        if curr_round % 5 == 0:
            optimizer_dict = []
            optimizer_dict.append(optimizer_server.state_dict())
            for client_idx in range(client_num):
                optimizer_dict.append(optimizer_clients[client_idx].state_dict())
            state_dict = client_model.online_encoder.cpu().state_dict()
            save_checkpoint({
                'glepoch': curr_round+1,
                'state_dict': state_dict,
                'optimizer': optimizer_dict,
            }, save_path)
        if curr_round % 100 == 0:
            torch.save(client_model.online_encoder.cpu().state_dict(), save_path + f"_{curr_round}_epoch.pt")
        
        
        print(f"Global round: {curr_round} | Average loss: {loss_avg}")
        # print('best_accuracy:', best_accuracy, '---Round:', curr_round, '---lr', lr, '----localEpocs--', E)

    end = time.time()
   
    print("Training Done!")
    print("Total time taken to Train: {}".format(end - start))
    print(f"Total average times : {avg_times}")

    return client_model, train_loss


In [28]:
plot_str = partition + '_' + norm + '_' + 'comm_rounds_' + str(global_epochs) + '_clientfr_' + str(
        client_fraction) + '_numclients_' + str(client_num) + '_clientepochs_' + str(
        local_epoch) + '_clientbs_' + str(batch_size) + '_clientLR_' + str(lr)
print(plot_str)

iid_bn_comm_rounds_1000_clientfr_1.0_numclients_5_clientepochs_5_clientbs_32_clientLR_0.0003


In [29]:
trained_model, train_loss = training(client_models, server_model, optimizer_server, optimizer_clients, H[0], H[4], H[5], H[1], H[2], H[3], plot_str,
                             "green")

Global Round: 0


Train Epoch: 0/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0533.: 100%|█| 313/313 [01:55<00:00,  2.71it/s]


global epoch 1 saved
Global round: 0 | Average loss: 0.4955434874414255
Global Round: 1


Train Epoch: 1/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.0218.: 100%|█| 313/313 [01:54<00:00,  2.73it/s]


Global round: 1 | Average loss: 0.030354390051751473
Global Round: 2


Train Epoch: 2/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0206.: 100%|█| 313/313 [01:54<00:00,  2.72it/s]


Global round: 2 | Average loss: 0.02068751166089655
Global Round: 3


Train Epoch: 3/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.0192.: 100%|█| 313/313 [01:54<00:00,  2.73it/s]


Global round: 3 | Average loss: 0.02040001550040687
Global Round: 4


Train Epoch: 4/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.0210.: 100%|█| 313/313 [01:54<00:00,  2.73it/s]


Global round: 4 | Average loss: 0.02113972971447931
Global Round: 5


Train Epoch: 5/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.363s. Loss: 0.0226.: 100%|█| 313/313 [01:53<00:00,  2.76it/s]


global epoch 6 saved
Global round: 5 | Average loss: 0.022188148363091693
Global Round: 6


Train Epoch: 6/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.0246.: 100%|█| 313/313 [01:55<00:00,  2.72it/s]


Global round: 6 | Average loss: 0.024013914733220593
Global Round: 7


Train Epoch: 7/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0275.: 100%|█| 313/313 [01:55<00:00,  2.72it/s]


Global round: 7 | Average loss: 0.0258891683357497
Global Round: 8


Train Epoch: 8/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0320.: 100%|█| 313/313 [01:55<00:00,  2.72it/s]


Global round: 8 | Average loss: 0.02772150679042164
Global Round: 9


Train Epoch: 9/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0312.: 100%|█| 313/313 [01:54<00:00,  2.73it/s]


Global round: 9 | Average loss: 0.02979683962921365
Global Round: 10


Train Epoch: 10/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0332.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


global epoch 11 saved
Global round: 10 | Average loss: 0.03180097802854574
Global Round: 11


Train Epoch: 11/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.0381.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 11 | Average loss: 0.0371008581115891
Global Round: 12


Train Epoch: 12/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.364s. Loss: 0.0437.: 100%|█| 313/313 [01:53<00:00,  2.75it/s


Global round: 12 | Average loss: 0.041350888201413444
Global Round: 13


Train Epoch: 13/1001. Iter:  313/ 313. Data: 0.140s. Batch: 0.361s. Loss: 0.0529.: 100%|█| 313/313 [01:53<00:00,  2.77it/s


Global round: 13 | Average loss: 0.048228031018385874
Global Round: 14


Train Epoch: 14/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.0724.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 14 | Average loss: 0.060982056176319674
Global Round: 15


Train Epoch: 15/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.0748.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


global epoch 16 saved
Global round: 15 | Average loss: 0.07264801507559829
Global Round: 16


Train Epoch: 16/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.369s. Loss: 0.0450.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 16 | Average loss: 0.06335882046304572
Global Round: 17


Train Epoch: 17/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0365.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 17 | Average loss: 0.042775724815151184
Global Round: 18


Train Epoch: 18/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.0406.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 18 | Average loss: 0.04019675153893785
Global Round: 19


Train Epoch: 19/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0333.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 19 | Average loss: 0.040213515452397895
Global Round: 20


Train Epoch: 20/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.0299.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


global epoch 21 saved
Global round: 20 | Average loss: 0.03711519934260807
Global Round: 21


Train Epoch: 21/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.0256.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 21 | Average loss: 0.030320856065605396
Global Round: 22


Train Epoch: 22/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0247.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 22 | Average loss: 0.02794617185363183
Global Round: 23


Train Epoch: 23/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0306.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 23 | Average loss: 0.032948506049835645
Global Round: 24


Train Epoch: 24/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0494.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 24 | Average loss: 0.052476764332276944
Global Round: 25


Train Epoch: 25/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0540.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


global epoch 26 saved
Global round: 25 | Average loss: 0.058863625097008175
Global Round: 26


Train Epoch: 26/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.0557.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 26 | Average loss: 0.05274144882639757
Global Round: 27


Train Epoch: 27/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.0531.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 27 | Average loss: 0.056249092883481004
Global Round: 28


Train Epoch: 28/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0488.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 28 | Average loss: 0.048628322042215365
Global Round: 29


Train Epoch: 29/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0608.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 29 | Average loss: 0.053716746239235606
Global Round: 30


Train Epoch: 30/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0668.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


global epoch 31 saved
Global round: 30 | Average loss: 0.06461565626172212
Global Round: 31


Train Epoch: 31/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.369s. Loss: 0.0681.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 31 | Average loss: 0.06989932134033391
Global Round: 32


Train Epoch: 32/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.0877.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 32 | Average loss: 0.08145216057380548
Global Round: 33


Train Epoch: 33/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.0985.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 33 | Average loss: 0.09075157256267322
Global Round: 34


Train Epoch: 34/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0931.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 34 | Average loss: 0.09180952770451006
Global Round: 35


Train Epoch: 35/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0962.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


global epoch 36 saved
Global round: 35 | Average loss: 0.08890249989569758
Global Round: 36


Train Epoch: 36/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1014.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 36 | Average loss: 0.0950369490697361
Global Round: 37


Train Epoch: 37/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1036.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 37 | Average loss: 0.10529869118818459
Global Round: 38


Train Epoch: 38/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0913.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 38 | Average loss: 0.10509801489381364
Global Round: 39


Train Epoch: 39/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0972.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 39 | Average loss: 0.10649693397858653
Global Round: 40


Train Epoch: 40/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.0842.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


global epoch 41 saved
Global round: 40 | Average loss: 0.10394819135578295
Global Round: 41


Train Epoch: 41/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.0808.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 41 | Average loss: 0.10011187505226928
Global Round: 42


Train Epoch: 42/1001. Iter:  313/ 313. Data: 0.141s. Batch: 0.361s. Loss: 0.0873.: 100%|█| 313/313 [01:52<00:00,  2.77it/s


Global round: 42 | Average loss: 0.09850095074397687
Global Round: 43


Train Epoch: 43/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1041.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 43 | Average loss: 0.09894818303208001
Global Round: 44


Train Epoch: 44/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0915.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 44 | Average loss: 0.09959874713954073
Global Round: 45


Train Epoch: 45/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1068.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


global epoch 46 saved
Global round: 45 | Average loss: 0.10240142302105602
Global Round: 46


Train Epoch: 46/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.369s. Loss: 0.1119.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 46 | Average loss: 0.10946399830400753
Global Round: 47


Train Epoch: 47/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1377.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 47 | Average loss: 0.12161170024746143
Global Round: 48


Train Epoch: 48/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1395.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 48 | Average loss: 0.1319999461785292
Global Round: 49


Train Epoch: 49/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1276.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 49 | Average loss: 0.1328471545784618
Global Round: 50


Train Epoch: 50/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.1366.: 100%|█| 313/313 [01:55<00:00,  2.71it/s


global epoch 51 saved
Global round: 50 | Average loss: 0.13211414379814562
Global Round: 51


Train Epoch: 51/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1123.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 51 | Average loss: 0.1294277618392207
Global Round: 52


Train Epoch: 52/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.0964.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 52 | Average loss: 0.12492209908585199
Global Round: 53


Train Epoch: 53/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.363s. Loss: 0.0955.: 100%|█| 313/313 [01:53<00:00,  2.75it/s


Global round: 53 | Average loss: 0.1190369943507944
Global Round: 54


Train Epoch: 54/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1114.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 54 | Average loss: 0.11567059102149817
Global Round: 55


Train Epoch: 55/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.0910.: 100%|█| 313/313 [01:55<00:00,  2.71it/s


global epoch 56 saved
Global round: 55 | Average loss: 0.11062311204953697
Global Round: 56


Train Epoch: 56/1001. Iter:  313/ 313. Data: 0.151s. Batch: 0.371s. Loss: 0.0601.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 56 | Average loss: 0.10085447971670392
Global Round: 57


Train Epoch: 57/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.0740.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 57 | Average loss: 0.0876329104407146
Global Round: 58


Train Epoch: 58/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0850.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 58 | Average loss: 0.07919709886700962
Global Round: 59


Train Epoch: 59/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.0883.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 59 | Average loss: 0.0804178919821692
Global Round: 60


Train Epoch: 60/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.0866.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


global epoch 61 saved
Global round: 60 | Average loss: 0.07860702698746809
Global Round: 61


Train Epoch: 61/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0716.: 100%|█| 313/313 [01:53<00:00,  2.75it/s


Global round: 61 | Average loss: 0.07759445428419799
Global Round: 62


Train Epoch: 62/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.365s. Loss: 0.0718.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 62 | Average loss: 0.07788342424332144
Global Round: 63


Train Epoch: 63/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0637.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 63 | Average loss: 0.07539525426948032
Global Round: 64


Train Epoch: 64/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0579.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 64 | Average loss: 0.06942357652532026
Global Round: 65


Train Epoch: 65/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0682.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


global epoch 66 saved
Global round: 65 | Average loss: 0.06104446979709707
Global Round: 66


Train Epoch: 66/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.370s. Loss: 0.0502.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 66 | Average loss: 0.055289459030944316
Global Round: 67


Train Epoch: 67/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0510.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 67 | Average loss: 0.054796022562363655
Global Round: 68


Train Epoch: 68/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.0554.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 68 | Average loss: 0.05621924745246244
Global Round: 69


Train Epoch: 69/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0503.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 69 | Average loss: 0.05567222511092314
Global Round: 70


Train Epoch: 70/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0566.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


global epoch 71 saved
Global round: 70 | Average loss: 0.055891902885212306
Global Round: 71


Train Epoch: 71/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.0495.: 100%|█| 313/313 [01:55<00:00,  2.71it/s


Global round: 71 | Average loss: 0.056180649672072536
Global Round: 72


Train Epoch: 72/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.0421.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 72 | Average loss: 0.05732399971483234
Global Round: 73


Train Epoch: 73/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0687.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 73 | Average loss: 0.058848859272159326
Global Round: 74


Train Epoch: 74/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.0561.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 74 | Average loss: 0.061070779546762044
Global Round: 75


Train Epoch: 75/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0656.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


global epoch 76 saved
Global round: 75 | Average loss: 0.06419579958477721
Global Round: 76


Train Epoch: 76/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.0710.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 76 | Average loss: 0.0668249249148864
Global Round: 77


Train Epoch: 77/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0729.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 77 | Average loss: 0.07072902899294996
Global Round: 78


Train Epoch: 78/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0734.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 78 | Average loss: 0.0776711470974139
Global Round: 79


Train Epoch: 79/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0808.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 79 | Average loss: 0.0820811312205304
Global Round: 80


Train Epoch: 80/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.0850.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


global epoch 81 saved
Global round: 80 | Average loss: 0.08686635188591747
Global Round: 81


Train Epoch: 81/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0797.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 81 | Average loss: 0.09130640001818775
Global Round: 82


Train Epoch: 82/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.0957.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 82 | Average loss: 0.09116897326165115
Global Round: 83


Train Epoch: 83/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.364s. Loss: 0.1324.: 100%|█| 313/313 [01:53<00:00,  2.75it/s


Global round: 83 | Average loss: 0.09494050549337277
Global Round: 84


Train Epoch: 84/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1064.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 84 | Average loss: 0.09841030609969514
Global Round: 85


Train Epoch: 85/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.0870.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


global epoch 86 saved
Global round: 85 | Average loss: 0.10044533485612168
Global Round: 86


Train Epoch: 86/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.370s. Loss: 0.1088.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 86 | Average loss: 0.10172320971378503
Global Round: 87


Train Epoch: 87/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0919.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 87 | Average loss: 0.1044745417163014
Global Round: 88


Train Epoch: 88/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1318.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 88 | Average loss: 0.1056202746951542
Global Round: 89


Train Epoch: 89/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0918.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 89 | Average loss: 0.10863206140435161
Global Round: 90


Train Epoch: 90/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.0918.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


global epoch 91 saved
Global round: 90 | Average loss: 0.11110914758028695
Global Round: 91


Train Epoch: 91/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1155.: 100%|█| 313/313 [01:54<00:00,  2.74it/s


Global round: 91 | Average loss: 0.1118335048563945
Global Round: 92


Train Epoch: 92/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1323.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 92 | Average loss: 0.11275476838548343
Global Round: 93


Train Epoch: 93/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0853.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 93 | Average loss: 0.11575361976798731
Global Round: 94


Train Epoch: 94/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1079.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 94 | Average loss: 0.11537897658233826
Global Round: 95


Train Epoch: 95/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1330.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


global epoch 96 saved
Global round: 95 | Average loss: 0.11957604723711746
Global Round: 96


Train Epoch: 96/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1441.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 96 | Average loss: 0.12362343299027068
Global Round: 97


Train Epoch: 97/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1373.: 100%|█| 313/313 [01:54<00:00,  2.73it/s


Global round: 97 | Average loss: 0.12390720451506562
Global Round: 98


Train Epoch: 98/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1436.: 100%|█| 313/313 [01:55<00:00,  2.72it/s


Global round: 98 | Average loss: 0.1286648199343072
Global Round: 99


Train Epoch: 99/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1040.: 100%|█| 313/313 [01:54<00:00,  2.72it/s


Global round: 99 | Average loss: 0.12982953156526097
Global Round: 100


Train Epoch: 100/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1116.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 101 saved
Global round: 100 | Average loss: 0.13031324729942284
Global Round: 101


Train Epoch: 101/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.369s. Loss: 0.1146.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 101 | Average loss: 0.13386736222254203
Global Round: 102


Train Epoch: 102/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1614.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 102 | Average loss: 0.13213895620724644
Global Round: 103


Train Epoch: 103/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1559.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 103 | Average loss: 0.13479068082647203
Global Round: 104


Train Epoch: 104/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1582.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 104 | Average loss: 0.13560814998401238
Global Round: 105


Train Epoch: 105/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1390.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 106 saved
Global round: 105 | Average loss: 0.1376175798309116
Global Round: 106


Train Epoch: 106/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1392.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 106 | Average loss: 0.13633622498081896
Global Round: 107


Train Epoch: 107/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1549.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 107 | Average loss: 0.1391324886260703
Global Round: 108


Train Epoch: 108/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1573.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 108 | Average loss: 0.14166162174921065
Global Round: 109


Train Epoch: 109/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1510.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 109 | Average loss: 0.13943702075332878
Global Round: 110


Train Epoch: 110/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1307.: 100%|█| 313/313 [01:54<00:00,  2.73it/


global epoch 111 saved
Global round: 110 | Average loss: 0.1450707249986097
Global Round: 111


Train Epoch: 111/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1109.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 111 | Average loss: 0.1454539433978617
Global Round: 112


Train Epoch: 112/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1855.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 112 | Average loss: 0.14485665129872557
Global Round: 113


Train Epoch: 113/1001. Iter:  313/ 313. Data: 0.142s. Batch: 0.362s. Loss: 0.2094.: 100%|█| 313/313 [01:53<00:00,  2.76it/


Global round: 113 | Average loss: 0.1430902193767575
Global Round: 114


Train Epoch: 114/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1143.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 114 | Average loss: 0.1461367686859335
Global Round: 115


Train Epoch: 115/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1316.: 100%|█| 313/313 [01:54<00:00,  2.73it/


global epoch 116 saved
Global round: 115 | Average loss: 0.14339757803529976
Global Round: 116


Train Epoch: 116/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1408.: 100%|█| 313/313 [01:53<00:00,  2.75it/


Global round: 116 | Average loss: 0.146780922985115
Global Round: 117


Train Epoch: 117/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1211.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 117 | Average loss: 0.1439843864296191
Global Round: 118


Train Epoch: 118/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1226.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 118 | Average loss: 0.14610459348454644
Global Round: 119


Train Epoch: 119/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1393.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 119 | Average loss: 0.1472057714915504
Global Round: 120


Train Epoch: 120/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.363s. Loss: 0.1467.: 100%|█| 313/313 [01:53<00:00,  2.75it/


global epoch 121 saved
Global round: 120 | Average loss: 0.14679158667025094
Global Round: 121


Train Epoch: 121/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1871.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 121 | Average loss: 0.1459521619370951
Global Round: 122


Train Epoch: 122/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1686.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 122 | Average loss: 0.14862660892283955
Global Round: 123


Train Epoch: 123/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1388.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 123 | Average loss: 0.14819821386862866
Global Round: 124


Train Epoch: 124/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1793.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 124 | Average loss: 0.1450671628593637
Global Round: 125


Train Epoch: 125/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1316.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 126 saved
Global round: 125 | Average loss: 0.14677090400133652
Global Round: 126


Train Epoch: 126/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1296.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 126 | Average loss: 0.14586893628580502
Global Round: 127


Train Epoch: 127/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1672.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 127 | Average loss: 0.1449612300997725
Global Round: 128


Train Epoch: 128/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1981.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 128 | Average loss: 0.14821751841817038
Global Round: 129


Train Epoch: 129/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1280.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 129 | Average loss: 0.1460356321007299
Global Round: 130


Train Epoch: 130/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1589.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 131 saved
Global round: 130 | Average loss: 0.14533943004501512
Global Round: 131


Train Epoch: 131/1001. Iter:  313/ 313. Data: 0.151s. Batch: 0.371s. Loss: 0.1543.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 131 | Average loss: 0.14613098113205486
Global Round: 132


Train Epoch: 132/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1309.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 132 | Average loss: 0.14476585326293787
Global Round: 133


Train Epoch: 133/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1459.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 133 | Average loss: 0.14843407818398918
Global Round: 134


Train Epoch: 134/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1484.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 134 | Average loss: 0.1472091527935415
Global Round: 135


Train Epoch: 135/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.363s. Loss: 0.1844.: 100%|█| 313/313 [01:53<00:00,  2.76it/


global epoch 136 saved
Global round: 135 | Average loss: 0.14511992599065313
Global Round: 136


Train Epoch: 136/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.1570.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 136 | Average loss: 0.14698540959685755
Global Round: 137


Train Epoch: 137/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1207.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 137 | Average loss: 0.14764523277648342
Global Round: 138


Train Epoch: 138/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.363s. Loss: 0.1633.: 100%|█| 313/313 [01:53<00:00,  2.75it/


Global round: 138 | Average loss: 0.146488520641106
Global Round: 139


Train Epoch: 139/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1354.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 139 | Average loss: 0.1448368269462174
Global Round: 140


Train Epoch: 140/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1369.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 141 saved
Global round: 140 | Average loss: 0.14654252413934032
Global Round: 141


Train Epoch: 141/1001. Iter:  313/ 313. Data: 0.151s. Batch: 0.371s. Loss: 0.1668.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 141 | Average loss: 0.14640364393639488
Global Round: 142


Train Epoch: 142/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1449.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 142 | Average loss: 0.14519063154824627
Global Round: 143


Train Epoch: 143/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1333.: 100%|█| 313/313 [01:55<00:00,  2.71it/


Global round: 143 | Average loss: 0.14602175926248107
Global Round: 144


Train Epoch: 144/1001. Iter:  313/ 313. Data: 0.141s. Batch: 0.358s. Loss: 0.1376.: 100%|█| 313/313 [01:52<00:00,  2.79it/


Global round: 144 | Average loss: 0.14764932185982743
Global Round: 145


Train Epoch: 145/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1284.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 146 saved
Global round: 145 | Average loss: 0.14639687866639026
Global Round: 146


Train Epoch: 146/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.368s. Loss: 0.1748.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 146 | Average loss: 0.14443591379890808
Global Round: 147


Train Epoch: 147/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1454.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 147 | Average loss: 0.14526340651055114
Global Round: 148


Train Epoch: 148/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1605.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 148 | Average loss: 0.14706560593253126
Global Round: 149


Train Epoch: 149/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1486.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 149 | Average loss: 0.14569939025484335
Global Round: 150


Train Epoch: 150/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1735.: 100%|█| 313/313 [01:54<00:00,  2.73it/


global epoch 151 saved
Global round: 150 | Average loss: 0.14753698307675675
Global Round: 151


Train Epoch: 151/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1430.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 151 | Average loss: 0.14621449169069053
Global Round: 152


Train Epoch: 152/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1354.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 152 | Average loss: 0.14782395192418998
Global Round: 153


Train Epoch: 153/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1429.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 153 | Average loss: 0.14591246681472364
Global Round: 154


Train Epoch: 154/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.365s. Loss: 0.1408.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 154 | Average loss: 0.14637578402559598
Global Round: 155


Train Epoch: 155/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1487.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 156 saved
Global round: 155 | Average loss: 0.1458265387688201
Global Round: 156


Train Epoch: 156/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1394.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 156 | Average loss: 0.14391405918537237
Global Round: 157


Train Epoch: 157/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.365s. Loss: 0.1511.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 157 | Average loss: 0.14484273456632138
Global Round: 158


Train Epoch: 158/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1471.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 158 | Average loss: 0.14451690280018523
Global Round: 159


Train Epoch: 159/1001. Iter:  313/ 313. Data: 0.142s. Batch: 0.363s. Loss: 0.1367.: 100%|█| 313/313 [01:53<00:00,  2.75it/


Global round: 159 | Average loss: 0.14358722701811563
Global Round: 160


Train Epoch: 160/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.1121.: 100%|█| 313/313 [01:55<00:00,  2.71it/


global epoch 161 saved
Global round: 160 | Average loss: 0.1438170889505563
Global Round: 161


Train Epoch: 161/1001. Iter:  313/ 313. Data: 0.151s. Batch: 0.371s. Loss: 0.1473.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 161 | Average loss: 0.14452227185995054
Global Round: 162


Train Epoch: 162/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1397.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 162 | Average loss: 0.14637216537619552
Global Round: 163


Train Epoch: 163/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.1561.: 100%|█| 313/313 [01:55<00:00,  2.71it/


Global round: 163 | Average loss: 0.14564128115344732
Global Round: 164


Train Epoch: 164/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1588.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 164 | Average loss: 0.14367276784806207
Global Round: 165


Train Epoch: 165/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1296.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 166 saved
Global round: 165 | Average loss: 0.1440596900428065
Global Round: 166


Train Epoch: 166/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1442.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 166 | Average loss: 0.1438609165743517
Global Round: 167


Train Epoch: 167/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1621.: 100%|█| 313/313 [01:55<00:00,  2.71it/


Global round: 167 | Average loss: 0.14577690626200016
Global Round: 168


Train Epoch: 168/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1423.: 100%|█| 313/313 [01:55<00:00,  2.71it/


Global round: 168 | Average loss: 0.1459478373630359
Global Round: 169


Train Epoch: 169/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1828.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 169 | Average loss: 0.14590952750116873
Global Round: 170


Train Epoch: 170/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.364s. Loss: 0.1398.: 100%|█| 313/313 [01:53<00:00,  2.75it/


global epoch 171 saved
Global round: 170 | Average loss: 0.14448218726026366
Global Round: 171


Train Epoch: 171/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.364s. Loss: 0.1439.: 100%|█| 313/313 [01:52<00:00,  2.77it/


Global round: 171 | Average loss: 0.14592088075777213
Global Round: 172


Train Epoch: 172/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.364s. Loss: 0.1382.: 100%|█| 313/313 [01:53<00:00,  2.75it/


Global round: 172 | Average loss: 0.1473901109478344
Global Round: 173


Train Epoch: 173/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.1430.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 173 | Average loss: 0.1462501203909088
Global Round: 174


Train Epoch: 174/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.1405.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 174 | Average loss: 0.14496604312723055
Global Round: 175


Train Epoch: 175/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1648.: 100%|█| 313/313 [01:54<00:00,  2.74it/


global epoch 176 saved
Global round: 175 | Average loss: 0.14613479249679243
Global Round: 176


Train Epoch: 176/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.370s. Loss: 0.1631.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 176 | Average loss: 0.14637474075388224
Global Round: 177


Train Epoch: 177/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1237.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 177 | Average loss: 0.14721219686749645
Global Round: 178


Train Epoch: 178/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1384.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 178 | Average loss: 0.14633123533794293
Global Round: 179


Train Epoch: 179/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1048.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 179 | Average loss: 0.1444119464475126
Global Round: 180


Train Epoch: 180/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1269.: 100%|█| 313/313 [01:54<00:00,  2.72it/


global epoch 181 saved
Global round: 180 | Average loss: 0.14464467697250197
Global Round: 181


Train Epoch: 181/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1518.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 181 | Average loss: 0.1436201801029638
Global Round: 182


Train Epoch: 182/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1415.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 182 | Average loss: 0.1426376876549218
Global Round: 183


Train Epoch: 183/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1437.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 183 | Average loss: 0.14252939656043587
Global Round: 184


Train Epoch: 184/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.369s. Loss: 0.1479.: 100%|█| 313/313 [01:55<00:00,  2.71it/


Global round: 184 | Average loss: 0.1365563375548052
Global Round: 185


Train Epoch: 185/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1247.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 186 saved
Global round: 185 | Average loss: 0.13711748048425101
Global Round: 186


Train Epoch: 186/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1242.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 186 | Average loss: 0.13467476714533358
Global Round: 187


Train Epoch: 187/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1247.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 187 | Average loss: 0.13533117710211026
Global Round: 188


Train Epoch: 188/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1154.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 188 | Average loss: 0.1338854574452574
Global Round: 189


Train Epoch: 189/1001. Iter:  313/ 313. Data: 0.142s. Batch: 0.360s. Loss: 0.1590.: 100%|█| 313/313 [01:52<00:00,  2.78it/


Global round: 189 | Average loss: 0.13515577184411284
Global Round: 190


Train Epoch: 190/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1532.: 100%|█| 313/313 [01:54<00:00,  2.73it/


global epoch 191 saved
Global round: 190 | Average loss: 0.1352988506515567
Global Round: 191


Train Epoch: 191/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1303.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 191 | Average loss: 0.13446270217434667
Global Round: 192


Train Epoch: 192/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.361s. Loss: 0.1289.: 100%|█| 313/313 [01:53<00:00,  2.77it/


Global round: 192 | Average loss: 0.13654597441609295
Global Round: 193


Train Epoch: 193/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1075.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 193 | Average loss: 0.13493647181187957
Global Round: 194


Train Epoch: 194/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1220.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 194 | Average loss: 0.134960181225603
Global Round: 195


Train Epoch: 195/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1244.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 196 saved
Global round: 195 | Average loss: 0.13408566892337495
Global Round: 196


Train Epoch: 196/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1330.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 196 | Average loss: 0.13397828051552604
Global Round: 197


Train Epoch: 197/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1359.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 197 | Average loss: 0.13629318016786546
Global Round: 198


Train Epoch: 198/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.1414.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 198 | Average loss: 0.1371325848106378
Global Round: 199


Train Epoch: 199/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.365s. Loss: 0.1347.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 199 | Average loss: 0.13560214472083618
Global Round: 200


Train Epoch: 200/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1213.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 201 saved
Global round: 200 | Average loss: 0.1371625479513083
Global Round: 201


Train Epoch: 201/1001. Iter:  313/ 313. Data: 0.149s. Batch: 0.369s. Loss: 0.1174.: 100%|█| 313/313 [01:54<00:00,  2.74it/


Global round: 201 | Average loss: 0.13696037849393516
Global Round: 202


Train Epoch: 202/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.0941.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 202 | Average loss: 0.13704318196152726
Global Round: 203


Train Epoch: 203/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1404.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 203 | Average loss: 0.1371554778025935
Global Round: 204


Train Epoch: 204/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1129.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 204 | Average loss: 0.13588168012638824
Global Round: 205


Train Epoch: 205/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1341.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 206 saved
Global round: 205 | Average loss: 0.13870454160645365
Global Round: 206


Train Epoch: 206/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.371s. Loss: 0.1651.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 206 | Average loss: 0.13723496595225015
Global Round: 207


Train Epoch: 207/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1511.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 207 | Average loss: 0.140440371922982
Global Round: 208


Train Epoch: 208/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1400.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 208 | Average loss: 0.1395604042009043
Global Round: 209


Train Epoch: 209/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.367s. Loss: 0.1748.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 209 | Average loss: 0.14166628309903434
Global Round: 210


Train Epoch: 210/1001. Iter:  313/ 313. Data: 0.139s. Batch: 0.360s. Loss: 0.1291.: 100%|█| 313/313 [01:52<00:00,  2.78it/


global epoch 211 saved
Global round: 210 | Average loss: 0.13976060320584538
Global Round: 211


Train Epoch: 211/1001. Iter:  313/ 313. Data: 0.142s. Batch: 0.363s. Loss: 0.1305.: 100%|█| 313/313 [01:52<00:00,  2.78it/


Global round: 211 | Average loss: 0.14154623846371714
Global Round: 212


Train Epoch: 212/1001. Iter:  313/ 313. Data: 0.144s. Batch: 0.364s. Loss: 0.1270.: 100%|█| 313/313 [01:54<00:00,  2.75it/


Global round: 212 | Average loss: 0.14187937434584189
Global Round: 213


Train Epoch: 213/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1213.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 213 | Average loss: 0.14075300992487338
Global Round: 214


Train Epoch: 214/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1478.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 214 | Average loss: 0.1439375103757785
Global Round: 215


Train Epoch: 215/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1516.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 216 saved
Global round: 215 | Average loss: 0.1404930204867174
Global Round: 216


Train Epoch: 216/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1242.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 216 | Average loss: 0.1425914754168675
Global Round: 217


Train Epoch: 217/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1279.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 217 | Average loss: 0.1411907997089453
Global Round: 218


Train Epoch: 218/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1315.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 218 | Average loss: 0.1403715743567235
Global Round: 219


Train Epoch: 219/1001. Iter:  313/ 313. Data: 0.145s. Batch: 0.363s. Loss: 0.1135.: 100%|█| 313/313 [01:53<00:00,  2.75it/


Global round: 219 | Average loss: 0.14043804408071903
Global Round: 220


Train Epoch: 220/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1505.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 221 saved
Global round: 220 | Average loss: 0.14036320549801898
Global Round: 221


Train Epoch: 221/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.366s. Loss: 0.1297.: 100%|█| 313/313 [01:53<00:00,  2.76it/


Global round: 221 | Average loss: 0.13632562687507452
Global Round: 222


Train Epoch: 222/1001. Iter:  313/ 313. Data: 0.142s. Batch: 0.359s. Loss: 0.1164.: 100%|█| 313/313 [01:52<00:00,  2.78it/


Global round: 222 | Average loss: 0.13488881585125725
Global Round: 223


Train Epoch: 223/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1430.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 223 | Average loss: 0.13362513423061217
Global Round: 224


Train Epoch: 224/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1359.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 224 | Average loss: 0.13298352860128537
Global Round: 225


Train Epoch: 225/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1200.: 100%|█| 313/313 [01:54<00:00,  2.73it/


global epoch 226 saved
Global round: 225 | Average loss: 0.13390570210096553
Global Round: 226


Train Epoch: 226/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.363s. Loss: 0.1287.: 100%|█| 313/313 [01:52<00:00,  2.78it/


Global round: 226 | Average loss: 0.1335067911841237
Global Round: 227


Train Epoch: 227/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1251.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 227 | Average loss: 0.13385269514764078
Global Round: 228


Train Epoch: 228/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1252.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 228 | Average loss: 0.13414703259548058
Global Round: 229


Train Epoch: 229/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1354.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 229 | Average loss: 0.13571715454895275
Global Round: 230


Train Epoch: 230/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1421.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 231 saved
Global round: 230 | Average loss: 0.13640823355688456
Global Round: 231


Train Epoch: 231/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1165.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 231 | Average loss: 0.13647571027564545
Global Round: 232


Train Epoch: 232/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1214.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 232 | Average loss: 0.13820319613233542
Global Round: 233


Train Epoch: 233/1001. Iter:  313/ 313. Data: 0.140s. Batch: 0.359s. Loss: 0.1246.: 100%|█| 313/313 [01:52<00:00,  2.79it/


Global round: 233 | Average loss: 0.13657185809014324
Global Round: 234


Train Epoch: 234/1001. Iter:  313/ 313. Data: 0.146s. Batch: 0.366s. Loss: 0.1426.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 234 | Average loss: 0.1373657630369686
Global Round: 235


Train Epoch: 235/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.368s. Loss: 0.1328.: 100%|█| 313/313 [01:55<00:00,  2.72it/


global epoch 236 saved
Global round: 235 | Average loss: 0.13740767783726365
Global Round: 236


Train Epoch: 236/1001. Iter:  313/ 313. Data: 0.151s. Batch: 0.371s. Loss: 0.1405.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 236 | Average loss: 0.1367306513622546
Global Round: 237


Train Epoch: 237/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1386.: 100%|█| 313/313 [01:55<00:00,  2.72it/


Global round: 237 | Average loss: 0.13581218165806688
Global Round: 238


Train Epoch: 238/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.368s. Loss: 0.1219.: 100%|█| 313/313 [01:55<00:00,  2.71it/


Global round: 238 | Average loss: 0.1391177409515975
Global Round: 239


Train Epoch: 239/1001. Iter:  313/ 313. Data: 0.143s. Batch: 0.364s. Loss: 0.1631.: 100%|█| 313/313 [01:53<00:00,  2.75it/


Global round: 239 | Average loss: 0.1382516888193429
Global Round: 240


Train Epoch: 240/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1398.: 100%|█| 313/313 [01:54<00:00,  2.73it/


global epoch 241 saved
Global round: 240 | Average loss: 0.1378928375796388
Global Round: 241


Train Epoch: 241/1001. Iter:  313/ 313. Data: 0.150s. Batch: 0.370s. Loss: 0.1563.: 100%|█| 313/313 [01:54<00:00,  2.72it/


Global round: 241 | Average loss: 0.13951686379342035
Global Round: 242


Train Epoch: 242/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1176.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 242 | Average loss: 0.1402803946274538
Global Round: 243


Train Epoch: 243/1001. Iter:  313/ 313. Data: 0.147s. Batch: 0.367s. Loss: 0.1181.: 100%|█| 313/313 [01:54<00:00,  2.73it/


Global round: 243 | Average loss: 0.13862319885732266
Global Round: 244


Train Epoch: 244/1001. Iter:  313/ 313. Data: 0.148s. Batch: 0.374s. Loss: 0.1524.: 100%|█| 313/313 [01:56<00:00,  2.68it/


Global round: 244 | Average loss: 0.1397564835346545
Global Round: 245


Train Epoch: 245/1001. Iter:  138/ 313. Data: 0.144s. Batch: 0.365s. Loss: 0.1674.:  44%|▍| 138/313 [00:49<00:59,  2.97it/

KeyboardInterrupt: 

In [None]:
torch.save(client_model.online_encoder.cpu().state_dict(), save_path + "_final.pt")

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss)
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()