In [1]:
import pandas as pd
import numpy as np
import copy
from argparse import Namespace
from copy import deepcopy

import os
import torch
import torch.nn as nn
from sklearn import preprocessing
import torch.optim as optim
from sklearn.preprocessing import LabelEncoder
from folktables import ACSDataSource, ACSEmployment,ACSIncome
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import tqdm
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
from typing import Dict, List, OrderedDict, Tuple
import pickle
from collections import OrderedDict
from typing import OrderedDict
from typing import Dict, List, OrderedDict, Tuple, Union


In [2]:
from collections import OrderedDict
import os
# from argparse import Namespace
import random

import torch
from path import Path
from rich.console import Console
from rich.progress import track
import pickle
from tqdm import tqdm

import sys

from utils.models_folktable import  DeepNet
from utils.util import (
    LOG_DIR,
    TEMP_DIR,
    clone_parameters,
    fix_random_seed,
    get_client_id_indices,
)
# from client.base import ClientBase


In [3]:

class ClientBase:
    def __init__(
        self,
        backbone: torch.nn.Module,
        dataset: str,
        batch_size: int,
        valset_ratio: float,
        testset_ratio: float,
        local_epochs: int,
        local_lr: float,
        logger: Console,
        gpu: int,
    ):
        self.device = torch.device(
            "cuda" if gpu and torch.cuda.is_available() else "cpu"
        )
        self.client_id: int = None
        self.valset: DataLoader = None
        self.trainset: DataLoader = None
        self.testset: DataLoader = None
            
#         name_of_model = '../../WW_WM_BW.pth'
#         init_model = DeepNet()
#         init_model.load_state_dict(torch.load(name_of_model))
             
        # need to change
#         self.model: torch.nn.Module = init_model
        self.model: torch.nn.Module = deepcopy(backbone)
            
        self.optimizer: torch.optim.Optimizer = torch.optim.SGD(
            self.model.parameters(), lr=local_lr
        )
        self.batch_size = batch_size
        self.valset_ratio = valset_ratio
        self.testset_ratio = testset_ratio
        self.local_epochs = local_epochs
        self.local_lr = local_lr
        self.criterion = torch.nn.CrossEntropyLoss()
        self.logger = logger
        
    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        size = 0
        loss = 0
        correct = 0
        for x, y in self.testset:
            x, y = x.to(self.device), y.to(self.device)

            logits = self.model(x)

            loss += self.criterion(logits, y)

            pred = torch.softmax(logits, -1).argmax(-1)

            correct += (pred == y).int().sum()

            size += y.size(-1)

        acc = correct / size * 100.0
        loss = loss / len(self.testset)
        return loss, acc

    def train(self):
        pass

    def _train(self):
        pass

    def get_client_local_dataset(self):
        datasets = get_dataloader(
            self.dataset,
            self.client_id,
            self.batch_size,
            self.valset_ratio,
            self.testset_ratio,
        )
        self.trainset = datasets["train"]
        self.valset = datasets["val"]
        self.testset = datasets["test"]

    def _log_while_training(self, evaluate=True, verbose=False):
        def _log_and_train(*args, **kwargs):
            loss_before = 0
            loss_after = 0
            acc_before = 0
            acc_after = 0
            if evaluate:
                loss_before, acc_before = self.evaluate()

            res = self._train(*args, **kwargs)

            if evaluate:
                loss_after, acc_after = self.evaluate()

            if verbose:
                self.logger.log(
                    "client [{}]   [bold red]loss: {:.4f} -> {:.4f}    [bold blue]accuracy: {:.2f}% -> {:.2f}%".format(
                        self.client_id, loss_before, loss_after, acc_before, acc_after
                    )
                )

            eval_stats = {
                "loss_before": loss_before,
                "loss_after": loss_after,
                "acc_before": acc_before,
                "acc_after": acc_after,
            }
            return res, eval_stats

        return _log_and_train

    def set_parameters(self, model_params: OrderedDict):
        self.model.load_state_dict(model_params, strict=True)


In [4]:
from collections import OrderedDict
from typing import OrderedDict

import torch
from rich.console import Console
from utils.util import clone_parameters

# from client.base import ClientBase

class pFedLAClient(ClientBase):
    def __init__(
        self,
        backbone: torch.nn.Module,        
        dataset: str,
        batch_size: int,
        valset_ratio: float,
        testset_ratio: float,
        local_epochs: int,
        local_lr: float,
        logger: Console,
        gpu: int,
    ):
        super(pFedLAClient, self).__init__(
            backbone,
            dataset,
            batch_size,
            valset_ratio,
            testset_ratio,
            local_epochs,
            logger,
            local_lr,
            gpu,
        )
        
    def train(
        self,
        client_id: int,
        model_params: OrderedDict[str, torch.Tensor],
        verbose=True,
    ):
        self.client_id = client_id
        self.set_parameters(model_params)
        self.get_client_local_dataset()
        self.model.to(self.device)
        
        res, stats = self._log_while_training(evaluate=True, verbose=verbose)()
        self.model.cpu()
        return res, stats
        
    def _train(self):
        self.model.train()
        frz_model_params = clone_parameters(self.model)

        for _ in range(self.local_epochs):
            for x, y in self.trainset:
                x, y = x.to(self.device), y.to(self.device)

                logits = self.model(x)

                loss = self.criterion(logits, y)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        delta = OrderedDict(
            {
                k: p1 - p0 
                for (k, p1), p0 in zip(
                    self.model.state_dict(keep_vars=True).items(),
                    frz_model_params.values(),
                )
            }
        )
        
        return delta
    
    def test(
        self, client_id: int, model_params: OrderedDict[str, torch.Tensor],
    ):
        self.client_id = client_id
        self.set_parameters(model_params)
        self.get_client_local_dataset()
        self.model.to(self.device)
        loss, acc = self.evaluate()
        dummy_diff = OrderedDict(
            {
                name: torch.zeros_like(param)
                for name, param in self.model.state_dict().items()
            }
        )
        self.model.cpu()
        stats = {"loss": loss, "acc": acc}
        return dummy_diff, stats


In [5]:
from collections import OrderedDict
import os
from argparse import Namespace
import random

import torch
from path import Path
from rich.console import Console
from rich.progress import track
import pickle
from tqdm import tqdm

from utils.util import (
    LOG_DIR,
    TEMP_DIR,
    clone_parameters,
    fix_random_seed,
    get_client_id_indices,
)

class ServerBase:
    def __init__(self, args: Namespace, algo: str):
        self.algo = algo
        self.args = args
        
        # default log file format
        self.log_name = "{}_{}_{}_{}.html".format(
            self.algo,
            self.args.dataset,
            self.args.global_epochs,
            self.args.local_epochs,
        )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        fix_random_seed(5)
        self.global_epochs=5
        
#         self.backbone= (DeepNet)   
        
        self.backbone = (
            CNNWithBatchNorm
            if self.args.dataset in ["cifar10", "cifar100"]
            else DeepNet
        )
        
        self.logger = Console(record=True, log_path=False, log_time=False,)
        self.client_id_indices, self.client_num_in_total = [0,1,2,3], 3
        
        self.temp_dir ="D:/Download/pythonProject/HiWi/pFedLA_Folktable/temp/my_model"


#         name_of_model = '../../WW_WM_BW.pth'
        #  self.temp_dir = TEMP_DIR / self.algo
        if not os.path.isdir(self.temp_dir):
            os.makedirs(self.temp_dir)
            
            
#         init_model = DeepNet()
#         init_model.load_state_dict(torch.load(name_of_model))
        
#         _dummy_model = init_model
        
#         _dummy_model = self.backbone(DeepNet)
        _dummy_model = DeepNet()
        print("_dummy_model:: chriag\n",_dummy_model)
#         print(type(_dummy_model))
        
        passed_epoch = 0
        self.global_params_dict: OrderedDict[str : torch.Tensor] = None
        
#         if os.listdir(self.temp_dir) != []:
#             if os.path.exists(os.path.join(self.temp_dir, "global_model.pt")):
# #             if os.path.exists(self.temp_dir / "global_model.pt"):
                
#                 self.global_params_dict = torch.load(self.temp_dir / "global_model.pt")
#                 self.logger.log("Find existed global model...")

# #           if os.path.exists(self.temp_dir / "epoch.pkl"):  
# #               with open(self.temp_dir / "epoch.pkl", "rb") as f:
#             if os.path.exists(os.path.join(self.temp_dir, "epoch.pkl")):  
#                 with open(os.path.join(self.temp_dir, "epoch.pkl"), "rb") as f:
# #               
#                     passed_epoch = pickle.load(f)
#                 self.logger.log(f"Have run {passed_epoch} epochs already.",)
#         else:
#             self.global_params_dict = OrderedDict(_dummy_model.state_dict())
            
        self.global_params_dict = OrderedDict(_dummy_model.state_dict())

#         self.global_epochs = self.args.global_epochs - passed_epoch
        self.global_epochs = self.global_epochs - passed_epoch
    
        self.logger.log("Backbone:", _dummy_model)

        self.trainer: ClientBase = None
        self.all_clients_stats = {i: {} for i in self.client_id_indices}
           
            
            
    def train(self):

        print("In server class \n")
        self.logger.log("=" * 30, "TRAINING", "=" * 30, style="bold green")
        progress_bar = (
            track(
                range(self.global_epochs),
                "[bold green]Training...",
                console=self.logger,
            )
            if not self.args.log
            else tqdm(range(self.global_epochs), "Training...")
        )
        for E in progress_bar:

#             if E % self.args.verbose_gap == 0:
#                 self.logger.log("=" * 30, f"ROUND: {E}", "=" * 30)

            selected_clients = random.sample(
                self.client_id_indices, self.args.client_num_per_round
            )
            
            updated_params_cache = []
            weights_cache = []

            for client_id in selected_clients:
                client_local_params = clone_parameters(self.global_params_dict)
                (updated_params, weight), stats = self.trainer.train(
                    client_id=client_id,
                    model_params=client_local_params,
                    verbose=(E % self.args.verbose_gap) == 0,
                )

                updated_params_cache.append(updated_params)
                weights_cache.append(weight)
                self.all_clients_stats[client_id][f"ROUND: {E}"] = (
                    f"{stats['loss_before']:.4f} -> {stats['loss_after']:.4f}",
                )

            self.aggregate_parameters(updated_params_cache, weights_cache)

            if E % self.args.save_period == 0:
                torch.save(
                    self.global_params_dict, self.temp_dir / "global_model.pt",
                )
                with open(self.temp_dir / "epoch.pkl", "wb") as f:
                    pickle.dump(E, f)
        self.logger.log(self.all_clients_stats)
    
    @torch.no_grad()
    def aggregate_parameters(self, updated_params_cache, weights_cache):
        weight_sum = sum(weights_cache)
        weights = torch.tensor(weights_cache, device=self.device) / weight_sum

        aggregated_params = []

        for params in zip(*updated_params_cache):
            aggregated_params.append(
                torch.sum(weights * torch.stack(params, dim=-1), dim=-1)
            )

        self.global_params_dict = OrderedDict(
            zip(self.global_params_dict.keys(), aggregated_params)
        )

        
        
    def test(self) -> None:
        self.logger.log("=" * 30, "TESTING", "=" * 30, style="bold blue")
        all_loss = []
        all_acc = []
        for client_id in track(
            self.client_id_indices,
            "[bold blue]Testing...",
            console=self.logger,
            disable=self.args.log,
        ):
            client_local_params = clone_parameters(self.global_params_dict)
            stats = self.trainer.test(
                client_id=client_id, model_params=client_local_params,
            )

            self.logger.log(
                f"client [{client_id}]  [red]loss: {stats['loss']:.4f}    [magenta]accuracy: {stats['acc']:.2f}%"
            )
            all_loss.append(stats["loss"])
            all_acc.append(stats["acc"])

        self.logger.log("=" * 20, "RESULTS", "=" * 20, style="bold green")
        self.logger.log(
            "loss: {:.4f}    accuracy: {:.2f}%".format(
                sum(all_loss) / len(all_loss), sum(all_acc) / len(all_acc),
            )
        )
        
    
    def run(self):
#         self.logger.log("Arguments:", dict(self.args._get_kwargs()))
        self.train()
        self.test()
        if self.args.log:
            if not os.path.isdir(LOG_DIR):
                os.mkdir(LOG_DIR)
            self.logger.save_html(LOG_DIR / self.log_name)

        # delete all temporary files
        if os.listdir(self.temp_dir) != []:
            os.system(f"rm -rf {self.temp_dir}")


In [6]:

class Linear(nn.Module):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.empty(out_features))

        nn.init.uniform_(self.weight)
        nn.init.constant_(self.bias, 0.0)

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

    

class HyperNetwork(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        hidden_dim: int,
        backbone: nn.Module,
        client_num: int=4,
        K: int=2,
        gpu=True,
    ):
    
        super(HyperNetwork, self).__init__()
        self.device = torch.device(
                "cuda" if gpu and torch.cuda.is_available() else "cpu"
            )
        
        self.K = K
        self.client_num = client_num
        self.embedding = nn.Embedding(client_num, embedding_dim, device=self.device)
        self.blocks_name = set(n.split(".")[0] for n, _ in backbone.named_parameters())
        self.cache_dir =  "pkl_files/hn"  # put dir here
        
        if os.listdir(self.cache_dir) != client_num:
            
            for client_id in range(client_num):
#                 with open(self.cache_dir / f"{client_id}.pkl", "wb") as f:
                with open(os.path.join(self.cache_dir, f"{client_id}.pkl"), "wb") as f:

                    pickle.dump(
                        {
                            "mlp": nn.Sequential(
                                nn.Linear(embedding_dim, hidden_dim),
                                nn.ReLU(),
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(),
                                nn.Linear(hidden_dim, hidden_dim),
                                nn.ReLU(),
                            ),
                            
                            "fc": {
                                name: Linear(hidden_dim, client_num)
                                for name in self.blocks_name
                            },
                        },
                        f,
                    )

        # for tracking the current client's hn parameters
        self.current_client_id: int = None
        self.mlp: nn.Sequential = None
        self.fc_layers: Dict[str, Linear] = {}
        self.retain_blocks: List[str] = []
            
        print("HypterNetwork")
        
    def mlp_parameters(self) -> List[nn.Parameter]:
            print("self.mlp.parameters():: ", self.mlp)            
            return list(filter(lambda p: p.requires_grad, self.mlp.parameters()))   
        
        
    def fc_layer_parameters(self) -> List[nn.Parameter]:
        params_list = []
        for block, fc in self.fc_layers.items():
            if block not in self.retain_blocks:
                params_list += list(filter(lambda p: p.requires_grad, fc.parameters()))

        return params_list

    def emd_parameters(self) -> List[nn.Parameter]:
        return list(self.embedding.parameters())
        
    
    def forward(self, client_id: int) -> Tuple[Dict[str, torch.Tensor], List[str]]:
        self.current_client_id = client_id

        print(" self.current_client_id : ",self.current_client_id,"\n")

        self.retain_blocks = []
        emd = self.embedding(
            torch.tensor(client_id, dtype=torch.long, device=self.device)
        )
        self.load_hn()

        feature = self.mlp(emd)

        print("TEMP_DIR:: ",TEMP_DIR )

        # print(" features : ",feature,"\n")

        alpha = {
            block: F.relu(self.fc_layers[block](feature)) for block in self.blocks_name
        }

        # print("  alpha: ",alpha,"\n")

        default_weight = torch.tensor(
            [i == client_id for i in range(self.client_num)],
            dtype=torch.float,
            device=self.device,
        )

        print(" self.k ", self.K,"\n")
        # i set K =2 in arg.py

        if self.K > 0:  # HeurpFedLA
            
            blocks_name = []
            self_weights = []
            
            with torch.no_grad():
                for name, weight in alpha.items():

                    # print("-: NAME AND WEIGHT : ", name, weight,"\n")

                    blocks_name.append(name)
                    self_weights.append(weight[client_id])

                # not in the Loop
                _, topk_weights_idx = torch.topk(torch.tensor(self_weights), self.K)

                print("  topk_weights_idx ",topk_weights_idx ,"\n")
                
            for i in topk_weights_idx:
                # print(" topk_weights_idx I  ",i,"\n")
                print(" blocks_name[i] Topk ",blocks_name[i],"\n")
                # print(" default_weight[i] Topk ",default_weight,"\n")

                alpha[blocks_name[i]] = default_weight
                self.retain_blocks.append(blocks_name[i])

        return alpha, self.retain_blocks
    
    
    
    def save_hn(self):
            for block, param in self.fc_layers.items():
                self.fc_layers[block] = param.cpu()
            with open(self.cache_dir / f"{self.current_client_id}.pkl", "wb") as f:
                pickle.dump(
                    {"mlp": self.mlp.cpu(), "fc": self.fc_layers}, f,
                )
            self.mlp = None
            self.fc_layers = {}
            self.current_client_id = None

    def load_hn(self) -> Tuple[nn.Sequential, OrderedDict[str, Linear]]:
        with open(os.path.join(self.cache_dir, f"{self.current_client_id}.pkl"), "rb") as f:
        
#         with open(self.cache_dir / f"{self.current_client_id}.pkl", "rb") as f:
            parameters = pickle.load(f)
        self.mlp = parameters["mlp"].to(self.device)
        for block, param in parameters["fc"].items():
            self.fc_layers[block] = param.to(self.device)

    def clean_models(self):
        if os.path.isdir(self.cache_dir):
            os.system(f"rm -rf {self.cache_dir}")
            
            
            
class DeepNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 14 : input shape
        self.layer1 = nn.Linear(14, 512)
        self.act1 = nn.ReLU()
        self.dropout1 = nn.Dropout(p=0.5)
        self.layer2 = nn.Linear(512, 256)
        self.act2 = nn.ReLU()
        self.layer3 = nn.Linear(256, 60)
        self.act3 = nn.ReLU()
        self.output = nn.Linear(60, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.act1(self.layer1(x))
        x = self.dropout1(x)
        x = self.act2(self.layer2(x))
        x = self.act3(self.layer3(x))
        x = self.sigmoid(self.output(x))
        return x
    

In [7]:
args1 = Namespace(
    k=2,
    global_epochs=5,
    local_epochs=5,
    local_lr=1e-2,
    hn_lr=5e-3,
    verbose_gap=20,
    embedding_dim=10,
    hidden_dim=10,
    dataset="no",
    batch_size=32,
    valset_ratio=0.0,
    testset_ratio=0.3,
    gpu=1,
    log=0,
    seed=5,
    client_num_per_round=4,
    save_period=5,
)


In [8]:
import os
import pickle
import random
from collections import OrderedDict
from typing import List, OrderedDict, Tuple
from rich.console import Console
from rich.progress import track

# from server.base import ServerBase
# from client.pFedLA import pFedLAClient
from tqdm import tqdm
# from utils.args import get_pFedLA_args


class pFedLAServer(ServerBase):
    def __init__(self):
        super(pFedLAServer, self).__init__(args1, "pFedLA")
        
        self.log_name = "{}_{}_{}_{}_{}.html".format(
            self.algo,
            self.args.dataset,
            self.args.global_epochs,
            self.args.local_epochs,
            self.args.k,
        )
        
        bm="..\\1_BM_WM_WW.pth"
        wm="..\\2_WM_BM_BW.pth"
        ww="..\\3_WW_BW_BM.pth"
        bw="..\\4_WW_WM_BW.pth"

        bm_state_dict_values = torch.load(bm).values()
        wm_state_dict_values = torch.load(wm).values()
        ww_state_dict_values = torch.load(ww).values()
        bw_state_dict_values = torch.load(bw).values()


        bm_values_list = list(bm_state_dict_values)
        wm_values_list = list(wm_state_dict_values)
        ww_values_list = list(ww_state_dict_values)
        bw_values_list = list(bw_state_dict_values)

        self.client_model_params_list = [bm_values_list, wm_values_list, ww_values_list, bw_values_list]
        
#         if self.global_params_dict is not None:
#             del self.global_params_dict  # pFedLA don't have global model

#         if os.listdir(self.temp_dir) != []:
                        
# #             if os.path.exists(self.temp_dir / "clients_model.pt"):
#             if os.path.exists(os.path.join(self.temp_dir, "clients_model.pt")):
#                 self.client_model_params_list = torch.load(
#                     self.temp_dir / "clients_model.pt"
#                 )
#                 self.logger.log("Find existed clients model...")
#         else:
#             self.logger.log("Initializing clients model...")
            
#             self.client_model_params_list = [
#                 list(self.backbone(self.args.dataset).state_dict().values())
#                 for _ in range(self.client_num_in_total)
#             ]
            
#         _dummy_model = self.backbone(DeepNet)


        _dummy_model = DeepNet()
        print(type(_dummy_model))
        
#         _dummy_model = self.backbone(self.args.dataset)

        
#         name_of_model = '../../WW_WM_BW.pth'
#         init_model = DeepNet()
#         init_model.load_state_dict(torch.load(name_of_model))
        
        #need to change
#         _dummy_model = init_model
#         self.client_model_params_list = init_model.load_state_dict(torch.load(name_of_model))
        
        self.hypernet = HyperNetwork(
            client_num=4,
            backbone=_dummy_model,
            embedding_dim=10,
            hidden_dim=10,
            K=2,
            gpu=True,
        )
        
        self.trainer = pFedLAClient(
            backbone=_dummy_model,
            dataset=self.args.dataset,
            batch_size=32,
            valset_ratio=0.0,
            testset_ratio=0.3,
            local_epochs=5,
            local_lr=1e-2,
            logger=0,
            gpu=True,
        )
        
        self.all_params_name = [name for name in _dummy_model.state_dict().keys()]
        
        self.trainable_params_name = [
                name
                for name, param in _dummy_model.state_dict(keep_vars=True).items()
                if param.requires_grad
            ]

    def train(self) -> None:
        self.logger.log("=" * 30, "TRAINING", "=" * 30, style="bold green")
        progress_bar = (
            track(
                range(self.global_epochs),
                "[bold green]Training...",
                console=self.logger,
            )
            if not self.args.log
            else tqdm(range(self.global_epochs), "Training...")
        )        
        
        for E in progress_bar:

            if E % self.args.verbose_gap == 0:
                self.logger.log("=" * 30, f"ROUND: {E}", "=" * 30)

            selected_clients = random.sample(
                self.client_id_indices, self.args.client_num_per_round
            )
            for client_id in selected_clients:
                (
                    client_local_params,
                    retain_blocks,
                ) = self.generate_client_model_parameters(client_id)

                diff, stats = self.trainer.train(
                    client_id=client_id,
                    model_params=client_local_params,
                    verbose=(E % self.args.verbose_gap) == 1,
                )
                
                self.all_clients_stats[client_id][f"ROUND: {E}"] = (
                    f"retain {retain_blocks}, {stats['loss_before']:.4f} -> {stats['loss_after']:.4f}",
                )

                self.update_hypernetwork(client_id, diff, retain_blocks)

                self.update_client_model_parameters(client_id, diff)

            if E % self.args.save_period == 0:
                torch.save(
                    self.client_model_params_list, self.temp_dir / "clients_model.pt",
                )
                with open(self.temp_dir / "epoch.pkl", "wb") as f:
                    pickle.dump(E, f)
        self.logger.log(self.all_clients_stats)
    
    def generate_client_model_parameters(
            self, client_id: int
        ) -> Tuple[OrderedDict[str, torch.Tensor], List[str]]:
        
        
            layer_params_dict = dict(
                zip(self.all_params_name, list(zip(*self.client_model_params_list)))
            )

            alpha, retain_blocks = self.hypernet(client_id)

            aggregated_parameters = {}
            default_weight = torch.tensor(
                [i == client_id for i in range(self.client_num_in_total)],
                dtype=torch.float,
                device=self.device,
            )


            for name in self.all_params_name:

                if name in self.trainable_params_name:
                    a = alpha[name.split(".")[0]]

                else:
                    a = default_weight

#                 print("From hypernet:  alpha ", alpha,"\n")

                if a.sum() == 0:
                    self.logger.log(self.all_clients_stats)
                    raise RuntimeError(
                        f"client [{client_id}]'s {name.split('.')[0]} alpha is a all 0 vector"
                    )


                # print("layer_params_dict[name]", name," :: ",layer_params_dict[name][0],"\n")

#                 aggregated_parameters[name] = torch.sum(
#                     a
#                     / a.sum()
#                     * torch.stack(layer_params_dict[name], dim=-1).to(self.device),
#                     dim=-1,
#                 )
                
                aggregated_parameters[name] = torch.sum(a/ a.sum(),dim=-1)

            # in the begining, its  client_model_params_list.value 
            ##putting values to  particular client's weights
            self.client_model_params_list[client_id] = list(aggregated_parameters.values())

            print("retain_blocks : " , retain_blocks,"\n")
            
            return aggregated_parameters, retain_blocks
        
    def update_hypernetwork(
        self,
        client_id: int,
        diff: OrderedDict[str, torch.Tensor],
        retain_blocks: List[str] = [],
    ) -> None:
        # calculate gradients
        print( "self.client_model_params_list[client_id] ",client_id,"  \n")
        # print(self.client_model_params_list[client_id])

        hn_grads = torch.autograd.grad(
            outputs=list(
                filter(
                    lambda param: param.requires_grad,
                    self.client_model_params_list[client_id],
                )
            ),
            inputs=self.hypernet.mlp_parameters()
            + self.hypernet.fc_layer_parameters()
            + self.hypernet.emd_parameters(),

            grad_outputs=list(
                map(
                    lambda tup: tup[1],
                    filter(
                        lambda tup: tup[1].requires_grad
                        and tup[0].split(".")[0] not in retain_blocks,
                        diff.items(),
                    ),
                )
            ),
            allow_unused=True,
        )
        
        mlp_grads = hn_grads[: len(self.hypernet.mlp_parameters())]
        fc_grads = hn_grads[
            len(self.hypernet.mlp_parameters()) : len(
                self.hypernet.mlp_parameters() + self.hypernet.fc_layer_parameters()
            )
        ]
        emd_grads = hn_grads[
            len(self.hypernet.mlp_parameters() + self.hypernet.fc_layer_parameters()) :
        ]

        for param, grad in zip(self.hypernet.fc_layer_parameters(), fc_grads):
            if grad is not None:
                param.data -= self.args.hn_lr * grad

        for param, grad in zip(self.hypernet.mlp_parameters(), mlp_grads):
            param.data -= self.args.hn_lr * grad

        for param, grad in zip(self.hypernet.emd_parameters(), emd_grads):
            param.data -= self.args.hn_lr * grad

        self.hypernet.save_hn()
        
        
    def run(self):
        super().run()
        # clean out all HNs
        self.hypernet.clean_models()

if __name__ == "__main__":
    server = pFedLAServer()
    server.run()


_dummy_model:: chriag
 DeepNet(
  (layer1): Linear(in_features=14, out_features=512, bias=True)
  (act1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (layer2): Linear(in_features=512, out_features=256, bias=True)
  (act2): ReLU()
  (layer3): Linear(in_features=256, out_features=60, bias=True)
  (act3): ReLU()
  (output): Linear(in_features=60, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)


<class '__main__.DeepNet'>
HypterNetwork


Output()

RuntimeError: Error(s) in loading state_dict for DeepNet:
	size mismatch for layer1.weight: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([512, 14]).
	size mismatch for layer1.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for layer2.weight: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for layer2.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for layer3.weight: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([60, 256]).
	size mismatch for layer3.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([60]).
	size mismatch for output.weight: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([1, 60]).
	size mismatch for output.bias: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([1]).