In [None]:
import torch 
import random 
import numpy as np

    
def set_seed(seed=42):
    torch.manual_seed(seed)                # CPU
    torch.cuda.manual_seed(seed)           # Current GPU
    torch.cuda.manual_seed_all(seed)       # All GPUs
    np.random.seed(seed)                   # NumPy
    random.seed(seed)                      # Python random
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    


In [None]:
32 * 7 *7

In [None]:
import os
from dotenv import load_dotenv
from federated_inference.common.environment import DataSetEnum

class DataConfiguration():
    
    #MNIST_FASHION_DATASET Configurations
    FMNIST_NAME = 'FMNIST'
    FMNIST_DATASET_PATH = os.path.join('./data/fmnist')
    FMNIST_LABELS = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']
    
    #MNIST_DATASET Configurations
    MNIST_NAME = 'MNIST'
    MNIST_DATASET_PATH = os.path.join('./data/mnist')
    
    #CIFAR_DATASET Configurations
    CIFAR10_NAME = 'CIFAR10'
    CIFAR10_DATASET_PATH = os.path.join('./data/cifar10')
    CIFAR10_LABELS = ['Plane', 'Car', 'Bird', 'Cat','Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

    def __init__(self, dataset_name : str = None):
        load_dotenv(override=True)
        self.DATASET_NAME = dataset_name if dataset_name else os.getenv('DATASET_NAME', 'MNIST')
        if self.DATASET_NAME == DataSetEnum.MNIST.value: 
            self.INSTANCE_SIZE = (1, 28, 28)
            self.DATASET = DataSetEnum.MNIST
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.MNIST_DATASET_PATH))
            self.LABELS = range(10)
        if self.DATASET_NAME == DataSetEnum.FMNIST.value: 
            self.INSTANCE_SIZE = (1, 28, 28)
            self.DATASET = DataSetEnum.FMNIST
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.FMNIST_DATASET_PATH))
            self.LABELS = self.FMNIST_LABELS
        if self.DATASET_NAME == DataSetEnum.CIFAR10.value:
            self.INSTANCE_SIZE = (1, 32, 32)
            self.DATASET = DataSetEnum.CIFAR10
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.CIFAR10_DATASET_PATH))
            self.LABELS =  self.CIFAR10_LABELS

    def __dict__(self):
        return {
            "dataset_name": self.DATASET_NAME, 
            "labels": list(self.LABELS),
            "size": self.INSTANCE_SIZE
        }

In [None]:

import torch
import torch.nn as nn

class HybridSplitModelConfiguration():
    DEVICE = torch.device("cuda" if torch.cuda.is_available()
                          else "mps" if torch.backends.mps.is_available()
                          else "cpu")

    def __init__(self,  
        server_model: nn.Module, 
        client_base_model: nn.Module, 
        client_classifier_head: nn.Module, 
        client_ig_head: nn.Module,          
        learning_rate: float = 0.001, 
        batch_size_train: int = 64, 
        batch_size_val: int = 64, 
        batch_size_test: int = 64, 
        train_shuffle: bool = True, 
        val_shuffle: bool = False, 
        test_suffle: bool = False, 
        n_epochs: int = 60,
        train_ratio: float = 0.8):
        self.SERVER_MODEL = server_model
        self.CLIENT_BASE_MODEL = client_base_model
        self.CLIENT_CLASSIFIER_MODEL = client_classifier_head
        self.CLIENT_IG_MODEL = client_ig_head
        self.LEARNING_RATE = learning_rate
        self.TRAIN_RATIO = train_ratio
        self.BATCH_SIZE_TRAIN = batch_size_train
        self.BATCH_SIZE_VAL = batch_size_val
        self.BATCH_SIZE_TEST = batch_size_test
        self.TRAIN_SHUFFLE = train_shuffle
        self.VAL_SHUFFLE = val_shuffle
        self.TEST_SHUFFLE = test_suffle
        self.N_EPOCH = n_epochs

    def __dict__(self):
        return {
            "learning_rate": self.LEARNING_RATE, 
            "train_ratio": self.TRAIN_RATIO,
            "batch_size_train": self.BATCH_SIZE_TRAIN,
            "batch_size_val": self.BATCH_SIZE_VAL,
            "batch_size_test": self.BATCH_SIZE_TEST, 
            "train_shuffle": self.TRAIN_SHUFFLE,
            "val_shuffle": self.VAL_SHUFFLE,
            "test_shuffle": self.TEST_SHUFFLE,
            "n_epoch": self.N_EPOCH, 
        }



In [None]:
class DataTransformConfiguration(): 

    def __init__(self, tensor_size = [1,28,28], method_name = 'full', mask_size = [14,14], dimensions = [1,2], stride = 14, n_position = None, drop_p = None):
        self.TENSOR_SIZE = tensor_size
        self.METHOD_NAME = method_name
        self.MASK_SIZE = mask_size 
        self.DIMENSIONS = dimensions
        self.STRIDE = stride
        self.N_POSITION = n_position
        self.DROP_P = drop_p

    def __dict__(self):
        return {
            "tensor_size": self.TENSOR_SIZE, 
            "method_name": self.METHOD_NAME,
            "mask_size": self.MASK_SIZE,
            "dimensions": self.DIMENSIONS,
            "stride": self.STRIDE,
            "n_position": self.N_POSITION,
            "drop_p": self.DROP_P,

        }



In [None]:
import torch.nn as nn
class HybridSplitBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten
        return x

class LocalHybridSplitClassifierHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(1568, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.classifier(x)

class RouterHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(1568, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.classifier(x)

class GlobalHybridSplitCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(6272, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x_concat):
        return self.classifier(x_concat)
        

In [None]:
from torch.utils.data import Subset
from torch.utils.data import Dataset as TorchDataset
from federated_inference.common.environment import Member
from collections.abc import Iterable
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import pandas as pd
from torch.utils.data import DataLoader

class HybridSplitClient():

    def __init__(self, 
            idx, 
            seed: int,
            data_config: DataConfiguration,
            model_config: HybridSplitModelConfiguration,
            dataset: TorchDataset,
            labels,
            log: bool = True, 
            log_interval: int = 100, 
            save_interval: int = 20
        ):
        self.idx = idx
        self.seed = seed
        self.data = dataset
        self.data_config = data_config
        self.model_config = model_config
        self.device = model_config.DEVICE
        self.labels = labels
        self.numerical_labels = range(len(labels))
        self.member_type = Member.CLIENT
        self.model = None
        self.log = log
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.base_model = model_config.CLIENT_BASE_MODEL().to(self.device)
        self.classifier_model = model_config.CLIENT_CLASSIFIER_MODEL().to(self.device)
        self.router_model = model_config.CLIENT_IG_MODEL().to(self.device)

    def select_subset(self, ids: Iterable[int], set_type: str = "train"):
        if set_type == "test":
            return Subset(self.data.test_dataset, ids)
        else: 
            return Subset(self.data.train_dataset, ids)

    def send_all(self):
        return self.data.train_dataset

    def request_pred(self, idx: int|None = None, set_type: str = "test", pred_all: bool= False, keep_label: bool = False): 
        if idx != None:
            if set_type == "test":
                return self.data.test_dataset[idx] if keep_label else self.data.test_dataset[idx][0] 
        elif pred_all:
            return self.data.test_dataset if keep_label else [img for img, label in self.data.test_dataset]

    def check(self, predicted_labels, pred_all: bool = True):
        if pred_all:
            true_labels = self.data.test_dataset.targets
            accuracy = accuracy_score(true_labels, predicted_labels)
            precision = precision_score(true_labels, predicted_labels, average='macro')
            recall = recall_score(true_labels, predicted_labels, average='macro')
            f1 = f1_score(true_labels, predicted_labels, average='macro') 

            print("\n=== Metrics ===")
            print(f"Accuracy : {accuracy:.4f}")
            print(f"Precision: {precision:.4f}")
            print(f"Recall   : {recall:.4f}")
            print(f"F1 Score : {f1:.4f}") 

            cm = confusion_matrix(true_labels, predicted_labels, labels=self.numerical_labels)
            self.cm = pd.DataFrame(cm, index=[f'True {l}' for l in self.labels],
                                    columns=[f'Pred {l}' for l in self.labels])
            self.accuracy = accuracy
            self.precision = precision
            self.recall = recall
            self.f1 = f1 


    def load(self):
        result_path = f"./results/hybrid/{self.data_config.DATASET_NAME}/{self.seed}"
        client_result_path = os.path.join(result_path, "clients") 

        # Base
        model_path = os.path.join(client_result_path, f'model_client_base_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.base = self.base_model.load_state_dict(network_state_dict)

        # Classifier Head
        model_path = os.path.join(client_result_path, f'model_client_classifier_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.classifier_model.load_state_dict(network_state_dict)

        # Router Head
        model_path = os.path.join(client_result_path, f'model_client_router_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.router_model.load_state_dict(network_state_dict)

    def to_loader(self, testdata):
        self.testloader = DataLoader(self.data.test_dataset, batch_size=self.model_config.BATCH_SIZE_TEST, shuffle=False) 

            


In [None]:
from federated_inference.common.environment import Member
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import os
from collections import defaultdict

class EarlyStopper():
    def __init__(self, patience = 20, min_delta = 0.00001):
        self.patience = patience 
        self.min_delta = min_delta
        self.counter = 0 
        self.best_loss = None 
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None: 
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else: 
            self.best_loss = val_loss
            self.counter = 0
        

class HybridSplitServer():

    def __init__(self, 
            idx, 
            seed,
            model_config: HybridSplitModelConfiguration,
            data_config: DataConfiguration,
            log: bool = True, 
            log_interval: int = 100,
            save_interval: int = 20
        ):
        self.idx = idx
        self.seed = seed
        self.model_config = model_config
        self.data_config = data_config
        self.seed = seed
        self.number_of_clients = 4
        self.n_epoch = model_config.N_EPOCH
        self.device = model_config.DEVICE
        self.member_type = Member.SERVER
        self.server_model = model_config.SERVER_MODEL().to(self.device)
        self.client_base_models = [model_config.CLIENT_BASE_MODEL().to(self.device) for c in range(self.number_of_clients)]
        self.client_classifier_models = [model_config.CLIENT_CLASSIFIER_MODEL().to(self.device) for c in range(self.number_of_clients)]
        self.client_ig_models = [model_config.CLIENT_IG_MODEL().to(self.device) for c in range(self.number_of_clients)]
        self.CRITERION = nn.CrossEntropyLoss()
        self.CRITERION_NAME = "CrossEntropyLoss"
        self.weight_decay = 0
        self.SERVER_OPTIMIZER = optim.Adam(self.server_model.parameters() , lr=model_config.LEARNING_RATE, weight_decay=self.weight_decay)
        self.CLIENT_BASE_OPTIMIZERS =  [optim.Adam(self.client_base_models[c].parameters() , lr=model_config.LEARNING_RATE, weight_decay=self.weight_decay ) for c in range(self.number_of_clients)]
        self.CLIENT_CLASSIFIER_OPTIMIZERS =  [optim.Adam(self.client_classifier_models[c].parameters() , lr=model_config.LEARNING_RATE, weight_decay=self.weight_decay) for c in range(self.number_of_clients)]
        self.CLIENT_IG_OPTIMIZERS =  [optim.Adam(self.client_ig_models[c].parameters() , lr=model_config.LEARNING_RATE, weight_decay=self.weight_decay ) for c in range(self.number_of_clients)]
        self.OPTIMIZER_NAME = "Adam"
        self.LOCAL_CLASSIFIER_CRITERION = nn.CrossEntropyLoss()
        self.LOCAL_IG_CRITERION = nn.BCELoss()

        self.log = log
        self.log_interval = log_interval
        self.save_interval = save_interval

        if self.log: 
            self.train_losses = []
            self.test_losses = []
            self.accuracies = []


    def _to_loader(self, trainsets, testsets, batch_size_train, batch_size_val, batch_size_test, train_shuffle, val_shuffle, test_shuffle, train_ratio):
        # TODO refactoing to use self
        if True:
            dataset_length = len(trainsets[0])
            assert all(len(trainset) == dataset_length for trainset in trainsets), "All trainsets must be the same length"

            indices = np.arange(dataset_length)
            self.train_set_indices = np.arange(dataset_length)

            np.random.shuffle(indices)

            train_end = round(train_ratio * dataset_length)
            train_indices = indices[:train_end]
            val_indices = indices[train_end:]

            traindatas = [Subset(trainset, train_indices) for trainset in trainsets]
            valdatas = [Subset(trainset, val_indices) for trainset in trainsets]
        else:
            traindatas = [Subset(trainset, range(round(train_ratio*len(trainset)))) for trainset in trainsets]
            valdatas = [Subset(trainset, range(round(train_ratio*len(trainset)), len(trainset))) for trainset in trainsets]
        self.trainloader = [DataLoader(traindata, batch_size=batch_size_train, shuffle=False)  for traindata in traindatas]
        self.valloader = [DataLoader(valdata, batch_size=batch_size_val, shuffle=False) for valdata in valdatas]
        self.testloader = [DataLoader(testdata, batch_size=batch_size_test, shuffle=False) for testdata in testsets]

    def shuffle_loader(self, trainsets, batch_size_train, batch_size_val, train_ratio):
        # TODO refactoing to use self
        if True:
            dataset_length = len(trainsets[0])
            assert all(len(trainset) == dataset_length for trainset in trainsets), "All trainsets must be the same length"

            np.random.shuffle(self.train_set_indices)

            train_end = round(train_ratio * dataset_length)
            train_indices = self.train_set_indices[:train_end]
            val_indices = self.train_set_indices[train_end:]

            traindatas = [Subset(trainset, train_indices) for trainset in trainsets]
            valdatas = [Subset(trainset, val_indices) for trainset in trainsets]
        else:
            traindatas = [Subset(trainset, range(round(train_ratio*len(trainset)))) for trainset in trainsets]
            valdatas = [Subset(trainset, range(round(train_ratio*len(trainset)), len(trainset))) for trainset in trainsets]
        self.trainloader = [DataLoader(traindata, batch_size=batch_size_train, shuffle=False)  for traindata in traindatas]
        self.valloader = [DataLoader(valdata, batch_size=batch_size_val, shuffle=False) for valdata in valdatas]


    def _pred_loader(self, testsets, batch_size_test, test_shuffle):
        return  [DataLoader(testdata, batch_size=batch_size_test, shuffle=False) for testdata in testsets]
        
    def select_best_ig_threshold(self, ig_outputs, client_preds, server_preds, targets):
        import numpy as np
        """
        Args:
            ig_outputs (Tensor): shape (N,), sigmoid output from IG model
            client_preds (Tensor): shape (N,), predictions from client classifiers
            server_preds (Tensor): shape (N,), predictions from server
            targets (Tensor): shape (N,), ground truth labels
        """
        thresholds = np.arange(0.0, 1.01, 0.01)
        best_acc = 0
        best_threshold = 0.5
    
        for thresh in thresholds:
            use_server = ig_outputs >= thresh
            final_preds = torch.where(use_server, server_preds, client_preds)
            acc = (final_preds == targets).float().mean().item()
    
            if acc > best_acc:
                best_acc = acc
                best_threshold = thresh
    
            return best_threshold, best_acc

    def evaluate_and_select_ig_thresholds(self):
        """
        Evaluate and select the best IG threshold for each client based on validation performance.
        Stores best thresholds in self.client_ig_thresholds[i].
        """
        self.client_ig_thresholds = []
    
        for i, ig_model in enumerate(self.client_ig_models):
            ig_model.eval()
    
            all_ig_outputs = []
            all_client_preds = []
            all_server_preds = []
            all_targets = []
    
            with torch.no_grad():
                for batches in zip(*self.valloader):
                    data_slices = [batch[0].to(self.device).float() for batch in batches]
                    target = batches[0][1].to(self.device).long()
    
                    # Forward pass
                    client_activations = [base(data) for base, data in zip(self.client_base_models, data_slices)]
                    concat_activations = torch.cat(client_activations, dim=1)
                    server_output = self.server_model(concat_activations)
                    server_pred = server_output.argmax(dim=1)
    
                    # This client's outputs
                    activation = client_activations[i]
                    classifier_pred = self.client_classifier_models[i](activation).argmax(dim=1)
                    ig_output = self.client_ig_models[i](activation).squeeze(1)
    
                    all_ig_outputs.append(ig_output)
                    all_client_preds.append(classifier_pred)
                    all_server_preds.append(server_pred)
                    all_targets.append(target)
    
            # Stack for evaluation
            ig_outputs = torch.cat(all_ig_outputs)
            client_preds = torch.cat(all_client_preds)
            server_preds = torch.cat(all_server_preds)
            targets = torch.cat(all_targets)
    
            # Select best threshold
            best_threshold, best_acc = self.select_best_ig_threshold(
                ig_outputs, client_preds, server_preds, targets
            )
            self.client_ig_thresholds.append(best_threshold)
    
            print(f"[Client {i}] Best IG Threshold: {best_threshold:.2f} | Accuracy: {best_acc:.4f}")
        

    
    def train(self, epoch):
        for batch_idx, batches in enumerate(zip(*self.trainloader)):
            data_slices = [batch[0].to(self.device).float() for batch in batches]
            target = batches[0][1].to(self.device).long()
    
            # ====== Base forward passes ======
            client_activations = []
            classifier_grads_per_client = []
            classifier_losses = []
            classifier_preds = []
    
            for data, base_model, classifier_model, optimizer, early_stopper in zip(data_slices, self.client_base_models,  self.client_classifier_models, self.CLIENT_CLASSIFIER_OPTIMIZERS, self.client_early_stopper):
                base_model.train()
                activation = base_model(data)
                activation.requires_grad_()
                client_activations.append(activation)
                classifier_model.train()
                optimizer.zero_grad()
                classifier_output = classifier_model(activation)
                classifier_pred = classifier_output.argmax(dim=1)
                loss = self.LOCAL_CLASSIFIER_CRITERION(classifier_output, target)
                classifier_losses.append(loss.item())
                grads = torch.autograd.grad(loss, [activation] + list(classifier_model.parameters()), retain_graph=True)
                classifier_grads_per_client.append(grads[0])

                for param, grad in zip(classifier_model.parameters(), grads[1:]):
                    param.grad = grad
                optimizer.step()
                classifier_preds.append(classifier_pred)
            concat_activations = torch.cat(client_activations, dim=1)
    
            # ====== Server forward and backprop======
            self.server_model.train()
            self.SERVER_OPTIMIZER.zero_grad()
            server_output = self.server_model(concat_activations)
            server_loss = self.CRITERION(server_output, target)
            server_pred = server_output.argmax(dim=1)
                server_grads = torch.autograd.grad(server_loss, [concat_activations] + list(self.server_model.parameters()), retain_graph=True)
                #Server Backprop
                server_concat_activation_grad = server_grads[0]
                server_model_grads = server_grads[1:]
                
                # Apply gradients manually to model parameters
                for param, grad in zip(self.server_model.parameters(), server_model_grads):
                    param.grad = grad  # Set .grad for optimizer
        
                self.SERVER_OPTIMIZER.step()
    

            # ====== IG forward and backprop ======
            ig_grads_per_client = []
            for i, (activation, ig_model, ig_optimizer) in enumerate(zip(client_activations, self.client_ig_models, self.CLIENT_IG_OPTIMIZERS)):
                ig_model.train()
                ig_optimizer.zero_grad()
            
                # Get IG prediction
                ig_output = ig_model(activation).squeeze(1)  # (B,)
            
                client_wrong = (classifier_pred != target)
                server_right = (server_pred == target)
                ig_target = (client_wrong & server_right).float() # Shape: (batch_size, 1)
                ig_loss = self.LOCAL_IG_CRITERION(ig_output, ig_target)

                grads = torch.autograd.grad(ig_loss, [activation] + list(ig_model.parameters()))
                ig_grads_per_client.append(grads[0])
            
                for param, grad in zip(ig_model.parameters(), grads[1:]):
                    param.grad = grad
                ig_optimizer.step()
    
            # ====== Base Backprop combined gradients ======
            if not self.server_early_stopper.early_stop:
                activation_sizes = [act.shape[1] for act in client_activations]
                activation_grads_server = torch.split(server_concat_activation_grad, activation_sizes, dim=1)
        
                for i, (base_model, optimizer, data) in enumerate(zip(self.client_base_models, self.CLIENT_BASE_OPTIMIZERS, data_slices)):
                    optimizer.zero_grad()
        
                    # Combine gradients: server + classifier + IG
                    combined_grad = activation_grads_server[i] + classifier_grads_per_client[i] + ig_grads_per_client[i]
        
                    # Backward through base
                    activation = base_model(data)
                    activation.backward(combined_grad)
                    optimizer.step()
                
            if batch_idx % self.log_interval == 0:
                print(f"Epoch {epoch} | Batch {batch_idx} | Server Loss: {server_loss.item():.4f} | "
                    f"Classifier Losses: {[round(x, 4) for x in classifier_losses]}")

    def validate(self):
        self.server_model.eval()
        
        val_loss = 0
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.valloader)):
                # batches is tuple of batch from each loader
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                client_val_loss = 0
                for data, client_base_model, classifier_model, early_stopper in zip(data_slices, self.client_base_models, self.client_classifier_models, self.client_early_stopper):
                    #Client FeedForward
                    client_base_model.eval()
                    activation = client_base_model(data)
                    client_activations.append(activation)
                    classifier_output = classifier_model(activation)
                    client_loss = self.LOCAL_CLASSIFIER_CRITERION(classifier_output, target).item()
                    client_val_loss += client_loss
                    client_loss /= len(self.valloader[0].dataset)
                    early_stopper(client_loss) 
                client_val_loss /= self.number_of_clients
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                server_val_loss = self.CRITERION(output, target).item()
                val_loss +=  0.5 * server_val_loss + 0.5 * client_val_loss
                server_val_loss /= len(self.valloader[0].dataset)
                self.server_early_stopper(server_val_loss)

        val_loss /= len(self.valloader[0].dataset)

        if self.group_early_stopper.best_loss is None or val_loss < self.group_early_stopper.best_loss:
            print("Validation loss improved. Saving model...")
            self.save()
        self.group_early_stopper(val_loss)

    def test(self):
        self.server_model.eval()
        
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloader)):
                # batches is tuple of batch from each loader
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                for data, client_base_model in zip(data_slices, self.client_base_models):
                    #Client FeedForward
                    client_base_model.eval()
                    activation = client_base_model(data)
                    client_activations.append(activation)
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                test_loss +=  self.CRITERION(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.testloader[0].dataset)
        accuracy = 100. * correct / len(self.testloader[0].dataset)
        if self.log: 
            self.test_losses.append(test_loss)
            self.accuracies.append(accuracy)
        print(f'\nTest set: Average loss per Sample: {test_loss:.4f}, Accuracy: {correct}/{len(self.testloader[0].dataset)} '
            f'({accuracy:.0f}%)\n')
            
        
    def test_inferences(self):
        self.server_model.eval()
        clients_preds = {i: [] for i in range(self.number_of_clients)}
        clients_router = {i: [] for i in range(self.number_of_clients)}
        test_loss = 0
        server_preds = []

        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloaders)):
                # batches is tuple of batch from each loader
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                for i, (data, client_base_model, classifier_model, router_model) in enumerate(zip(data_slices, self.client_base_models, self.client_classifier_models, self.client_ig_models)):
                    #Client FeedForward
                    client_base_model.eval()
                    classifier_model.eval()
                    router_model.eval()
                    activation = client_base_model(data)
                    client_class_logits = classifier_model(activation)
                    clients_preds[i] += client_class_logits.argmax(dim=1, keepdim=True)
                    client_router_logits = router_model(activation)
                    clients_router[i] += client_router_logits.view(-1).tolist()
                    client_activations.append(activation)
                    
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                test_loss +=  self.CRITERION(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                server_preds += pred.squeeze().tolist()

        return clients_preds, clients_router, server_preds
            
    def run_training(self, trainset, testset):
        self.server_early_stopper = EarlyStopper()
        self.client_early_stopper = [EarlyStopper() for _ in range(self.number_of_clients)]
        self.group_early_stopper = EarlyStopper()
        self._to_loader(trainset, testset, 
            self.model_config.BATCH_SIZE_TRAIN,
            self.model_config.BATCH_SIZE_VAL, 
            self.model_config.BATCH_SIZE_TEST, 
            self.model_config.TRAIN_SHUFFLE,
            self.model_config.VAL_SHUFFLE, 
            self.model_config.TEST_SHUFFLE,
            self.model_config.TRAIN_RATIO)
        self.test()
        
        self.classifier_test()
        for epoch in range(1, self.model_config.N_EPOCH + 1):
            self.train(epoch)
            self.test()
            self.classifier_test()
            self.validate()
            
            if self.group_early_stopper.early_stop:
                self.early_stop_epoch = epoch
                print("early_stop_triggered")
                break
            self.shuffle_loader(trainset, 
            self.model_config.BATCH_SIZE_TRAIN,
            self.model_config.BATCH_SIZE_VAL, 
            self.model_config.TRAIN_RATIO)
        
        print("loading best model to server...")
        self.load()

    def run_router_threshold_sweep(self, clients_router, thresholds=np.linspace(1.0, 0.0, num=50)):
        results = {}

        for threshold in thresholds:
            percentages = {}
            for client_id, logits in clients_router.items():
                # logits: list of floats (router outputs from test_inferences)
                preds = [1 if logit > threshold else 0 for logit in logits]
                num_total = len(preds)
                num_true = sum(preds)
                percent_true = (num_true / num_total) if num_total > 0 else 0.0
                percentages[client_id] = round(percent_true, 4)
            
            results[float(np.round(threshold, 4))] = percentages

        return results
    
    def hybrid_f1_per_client_individual_threshold(
        self,
        clients_preds,
        clients_router,
        server_preds,
        targets,
        thresholds=np.linspace(1.0, 0.0, num=50)
    ):
        """
        For each client and each threshold, compute their own hybrid F1-score
        (use client prediction if router > threshold, else use server prediction).

        Args:
            clients_preds (dict): {client_id: list of predicted labels (tensors or ints)}
            clients_router (dict): {client_id: list of router confidence scores (floats)}
            server_preds (list): list of server predicted labels
            targets (list or tensor): true labels
            thresholds (np.array): thresholds to sweep for each client

        Returns:
            dict: {client_id: {threshold: f1_score (macro)}}
        """
        number_of_clients = len(clients_preds)
        num_samples = len(server_preds)

        # Flatten predictions and targets
        flat_client_preds = {
            i: [p.item() if hasattr(p, 'item') else int(p) for p in preds]
            for i, preds in clients_preds.items()
        }
        flat_targets = [t.item() if hasattr(t, 'item') else int(t) for t in targets]
        flat_server_preds = [s.item() if hasattr(s, 'item') else int(s) for s in server_preds]

        results = {}

        for client_id in range(number_of_clients):
            client_f1 = {}

            for threshold in thresholds:
                preds = []
                for idx in range(num_samples):
                    if clients_router[client_id][idx] > threshold:
                        pred = flat_client_preds[client_id][idx]
                    else:
                        pred = flat_server_preds[idx]
                    preds.append(pred)

                f1 = f1_score(flat_targets, preds, average='macro')
                client_f1[float(np.round(threshold, 4))] = f1

            results[client_id] = client_f1

        return results


    def hybrid_accuracy_per_client_individual_threshold(
        self,
        clients_preds,
        clients_router,
        server_preds,
        targets,
        thresholds=np.linspace(1.0, 0.0, num=50 )
    ):
        """
        For each client and each threshold, compute their own hybrid accuracy
        (use client prediction if router > threshold, else use server prediction).

        Args:
            clients_preds (dict): {client_id: list of predicted labels (tensors or ints)}
            clients_router (dict): {client_id: list of router confidence scores (floats)}
            server_preds (list): list of server predicted labels
            targets (list or tensor): true labels
            thresholds (np.array): thresholds to sweep for each client

        Returns:
            dict: {client_id: {threshold: accuracy}}
        """
        number_of_clients = len(clients_preds)
        num_samples = len(server_preds)

        # Flatten predictions and targets
        flat_client_preds = {
            i: [p.item() if hasattr(p, 'item') else int(p) for p in preds]
            for i, preds in clients_preds.items()
        }
        flat_targets = [t.item() if hasattr(t, 'item') else int(t) for t in targets]
        flat_server_preds = [s.item() if hasattr(s, 'item') else int(s) for s in server_preds]

        results = {}

        for client_id in range(number_of_clients):
            client_acc = {}

            for threshold in thresholds:
                preds = []
                for idx in range(num_samples):
                    if clients_router[client_id][idx] > threshold:
                        pred = flat_client_preds[client_id][idx]
                    else:
                        pred = flat_server_preds[idx]
                    preds.append(pred)

                acc = accuracy_score(flat_targets, preds)
                client_acc[float(np.round(threshold, 4))] = acc

            results[client_id] = client_acc

        return results
        

    def run_tests(self, testsets): 
        self.testloaders = [DataLoader(testdata, batch_size=self.model_config.BATCH_SIZE_TEST, shuffle=False) for testdata in testsets]
        clients_preds, clients_router, server_preds = self.test_inferences()
        coverage_sweep = self.run_router_threshold_sweep(clients_router)
        targets =  self.testloaders[0].dataset.targets  
        hybrid_sweep = self.hybrid_accuracy_per_client_individual_threshold(
            clients_preds=clients_preds,
            clients_router=clients_router,
            server_preds=server_preds,
            targets=self.testloaders[0].dataset.targets
        )
        hybrid_F1_sweep = self.hybrid_f1_per_client_individual_threshold(
            clients_preds=clients_preds,
            clients_router=clients_router,
            server_preds=server_preds,
            targets=self.testloaders[0].dataset.targets
        )
        return coverage_sweep, hybrid_sweep, hybrid_F1_sweep 
        
    def pred(self, testset,  pred=True):
        predictions = []
        loader = self._pred_loader(testset, self.model_config.BATCH_SIZE_TEST, self.model_config.TEST_SHUFFLE)
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*loader)):
                # batches is tuple of batch from each loader
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                for data, client_base_model in zip(data_slices, self.client_base_models):
                    #Client FeedForward
                    client_base_model.eval()
                    activation = client_base_model(data)
                    client_activations.append(activation)
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                pred = output.argmax(dim=1, keepdim=True)
                predictions = predictions + pred.squeeze().tolist()
        return predictions

    def save(self):
        result_path = f"./results/hybrid/v1/{self.data_config.DATASET_NAME}/{self.seed}"
        os.makedirs(result_path, exist_ok=True)
        model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
        torch.save(self.server_model.state_dict(), model_path)
        torch.save(self.SERVER_OPTIMIZER.state_dict(), optimizer_path)

        client_result_path =os.path.join(result_path, "clients")
        os.makedirs(client_result_path, exist_ok=True)
        for i, client_base_model in enumerate(self.client_base_models): 
            model_path = os.path.join(client_result_path, f'model_client_base_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_client_base_{i}.pth').replace("\\", "/")
            torch.save(client_base_model.state_dict(), model_path)
            torch.save(self.CLIENT_BASE_OPTIMIZERS[i].state_dict(), optimizer_path)
                    
        for i, client_classifier_model in enumerate(self.client_classifier_models): 
            model_path = os.path.join(client_result_path, f'model_client_classifier_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_classifier_client_{i}.pth').replace("\\", "/")
            torch.save(client_classifier_model.state_dict(), model_path)
            torch.save(self.CLIENT_CLASSIFIER_OPTIMIZERS[i].state_dict(), optimizer_path)
        
        for i, client_router_model in enumerate(self.client_ig_models): 
            model_path = os.path.join(client_result_path, f'model_client_router_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_router_client_{i}.pth').replace("\\", "/")
            torch.save(client_router_model.state_dict(), model_path)
            torch.save(self.CLIENT_IG_OPTIMIZERS[i].state_dict(), optimizer_path)
                   
    def load(self):
        result_path = f"./results/hybrid/v1/{self.data_config.DATASET_NAME}/{self.seed}"
        model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.server_model.load_state_dict(network_state_dict)
        optimizer_state_dict = torch.load(optimizer_path)
        self.SERVER_OPTIMIZER.load_state_dict(optimizer_state_dict)

        client_result_path = os.path.join(result_path, "clients")
        for i, (client_base_model, model_optimizer) in enumerate(zip(self.client_base_models, self.CLIENT_BASE_OPTIMIZERS)): 
            model_path = os.path.join(client_result_path, f'model_client_base_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_client_base_{i}.pth').replace("\\", "/")
            network_state_dict = torch.load(model_path)
            client_base_model.load_state_dict(network_state_dict)
            optimizer_state_dict = torch.load(optimizer_path)
            model_optimizer.load_state_dict(optimizer_state_dict)
        
        for i, (client_classifier_model, model_optimizer) in enumerate(zip(self.client_classifier_models, self.CLIENT_CLASSIFIER_OPTIMIZERS)): 
            model_path = os.path.join(client_result_path, f'model_client_classifier_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_classifier_client_{i}.pth').replace("\\", "/")
            network_state_dict = torch.load(model_path)
            client_classifier_model.load_state_dict(network_state_dict)
            optimizer_state_dict = torch.load(optimizer_path)
            model_optimizer.load_state_dict(optimizer_state_dict)

        for i, (client_router_model, model_optimizer) in enumerate(zip(self.client_ig_models, self.CLIENT_IG_OPTIMIZERS)): 
            model_path = os.path.join(client_result_path, f'model_client_router_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_router_client_{i}.pth').replace("\\", "/")
            network_state_dict = torch.load(model_path)
            client_router_model.load_state_dict(network_state_dict)
            optimizer_state_dict = torch.load(optimizer_path)
            model_optimizer.load_state_dict(optimizer_state_dict)
                   
                            
    def classifier_test(self):
        test_loss_values = [0 for c in self.client_base_models]
        correct_values = [0 for c in self.client_base_models]
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloader)):
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_classifier_models)):
                    data, client_base_model, classifier_model = client 
                    client_base_model.train()
                    activation = client_base_model(data)
                    classifier_output = classifier_model(activation)
                    test_loss_values[client_idx] +=  self.LOCAL_CLASSIFIER_CRITERION(classifier_output, target).item()
                    pred = classifier_output.argmax(dim=1, keepdim=True)
                    correct_values[client_idx]+= pred.eq(target.view_as(pred)).sum().item()
        test_loss_values = [t/ len(self.testloader[0].dataset) for t in test_loss_values]
        accuracy_values = [100. * c / len(self.testloader[0].dataset) for c in correct_values]

        if self.log: 
            self.test_losses.append(test_loss_values)
            self.accuracies.append(accuracy_values)
        print(f'\nTest set: Average loss per Sample: {test_loss_values}, Accuracy: {correct_values}/{len(self.testloader[0].dataset)}'
            f'({accuracy_values}%)\n')

    def gate_pred(self): 
        for batch_idx, batches in enumerate(zip(*self.testloader)):
            data_slices = [batch[0].to(self.device).float() for batch in batches]
            target = batches[0][1].to(self.device).long()
            ig_outputs = []
            for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_ig_models)):
                data, client_base_model, classifier_model, client_ig_model = client
                #Client FeedForward
                client_base_model.eval()
                activation = client_base_model(data)
                ig_output = client_ig_model(activation) 
                ig_outputs.append(ig_output)
            return ig_outputs, target
                

In [None]:
            
from federated_inference.common.environment import  DataMode, TransformType
from federated_inference.simulations.simulation import Simulation
from federated_inference.simulations.utils import *

class HybridSplitSimulation(Simulation): 
    def __init__(self, seed, data_config, transform_config, server_model, client_base_model, client_classifier_model, client_ig_model, transform_type: TransformType = TransformType.FULL_STRIDE_PARTITION, exist=False):
        self.seed = seed
        self.data_config = data_config
        self.transform_config = transform_config
        self.server_model_config = HybridSplitModelConfiguration(server_model, client_base_model, client_classifier_model, client_ig_model)
        self.data_mode = DataMode.VERTICAL
        self.transform_type = transform_type
        self.dataset =  self.load_data(data_config)
        self.client_datasets, self.transformation = self.transform_data(self.dataset, data_mode = self.data_mode, transform_config = transform_config, transform_type = self.transform_type)
        self.clients = [HybridSplitClient(idx, seed, data_config, self.server_model_config, dataset, data_config.LABELS) for idx, dataset in enumerate(self.client_datasets)]
        self.server = HybridSplitServer(0, seed, self.server_model_config , self.data_config)

    def train(self): 
        datasets = [client.send_all() for client in self.clients]
        testsets = [client.request_pred(pred_all = True, keep_label = True) for client in self.clients]
        self.server.run_training(datasets, testsets)

    def test_inference(self):
        testsets = [client.request_pred(pred_all = True, keep_label = True) for client in self.clients]
        predictions = self.server.pred(testsets)
        self.clients[0].check(predictions)
        #self.collect_results(self.seed, save = True)

    def collect_results(self, name: str, save: bool = True, figures: bool = False):
        from IPython.display import display

        self.results = {
            'seed': name,
            'client': [],
            'server': {}
        }

        if figures:
            fig = create_simulation_image_subplots(self)
            display(fig)

        # Collect server results
        self.results['server'] = self._gather_server_results(self.server)

        # Collect results for the first (and only) client
        client = self.clients[0]
        client_result = self._gather_client_result(client, figures)
        self.results['client'].append(client_result)

        # Save results
        if save:
            self._save_results(str(name), base_dir="naive")
        return self.results

    def _gather_server_results(self, server):
        return {
            'training_losses': server.train_losses,
            'test_losses': server.test_losses
        }

    def _gather_client_result(self, client, figures):
        from IPython.display import display

        result = {
            'idx': client.idx,
            'cm': client.cm.to_json(orient='split')
        }

        analysis, df_cm_per = cm_analysis(client)
        result['cm_analysis'] = analysis

        # Extract relevant class indices
        indices = [
            analysis['correct']['most_correct_class'],
            analysis['wrong']['most_misclassified_class']
        ]
        indices += [i for i in [analysis['wrong']['wrong_from'], analysis['wrong']['wrong_to']] if i not in indices]

        # Add performance metrics
        result.update({
            'accuracy': analysis["performance"]["accuracy"],
            'precision': analysis["performance"]["precision"],
            'recall': analysis["performance"]["recall"]
        })

        # Generate figures if needed
        if figures:
            display(plot_test_loss(client.test_losses, 1, client.idx, "Test"))
            fig, subplot_indices = create_client_image_subplots(self, [client.idx], 8, keys=indices)
            display(fig)
            result['client_image_subplots_ids'] = subplot_indices
            fig = print_cm_heat(df_cm_per, client.idx)
            display(fig)

        return result

    def _save_results(self, name, base_dir="ig"):
        import os
        import json

        result_path = os.path.join("./results", base_dir , self.data_config.DATASET_NAME, name)
        
        os.makedirs(result_path, exist_ok=True)
        file_path = os.path.join(result_path, "simulation.json")

        with open(file_path, "w") as f:
            json.dump(self.results, f, indent=4)
        print("Results saved to JSON.")

In [None]:
import plotly.graph_objects as go
import numpy as np

client_mapping = {
    0: "left top (LT)",
    1: "right top (RT)",
    2: "left bottom (LB)",
    3: "right bottom (RB)"
}

def plot_accuracy_vs_coverage(coverage, hybrid_accuracy):
    """
    Plot hybrid accuracy vs. coverage per client.
    Coverage is flipped so that 1.0 coverage corresponds to threshold=0,
    and 0.0 coverage corresponds to threshold=1.

    Args:
        coverage (dict): {threshold: {client_id: coverage_value}}
        hybrid_accuracy (dict): {client_id: {threshold: accuracy_value}}
    """

    # Sort thresholds ascending (from 0.0 to 1.0) because coverage is reversed
    thresholds = sorted(coverage.keys())
    clients = sorted(hybrid_accuracy.keys())

    fig = go.Figure()

    # Collect accuracies at coverage = 1.0 for all clients
    best_acc_at_cov1 = []

    for client_id in clients:
        # Extract coverage and accuracy arrays matching thresholds
        covs = np.array([coverage[t][client_id] for t in thresholds])
        accs = np.array([hybrid_accuracy[client_id][t] for t in thresholds])
        flipped = (1-covs)

        fig.add_trace(
            go.Scatter(
                x=flipped,
                y=accs,
                mode='lines+markers',
                name=f'Client {client_mapping[client_id]}'
            )
        )

        # Find accuracy where coverage == 1.0 (or very close to it)
        # Use np.isclose for floating point safety
        mask = np.isclose(covs, 1.0)
        if np.any(mask):
            best_acc_at_cov1.append(accs[mask][0])

    # Add horizontal line at best accuracy at coverage=1.0 (max over clients)
    if best_acc_at_cov1:
        max_acc = max(best_acc_at_cov1)
        fig.add_hline(
            y=max_acc,
            line_dash="dash",
            line_color="red",
            annotation_text=f"Best local accuracy: {max_acc:.3f}",
            annotation_position="top right"
        )

    fig.update_layout(
        title='Hybrid Accuracy vs. Local Coverage per Client',
        xaxis=dict(
            title='Remote Coverage (Fraction of samples processed by the Cloud)',
            range=[0, 1],
            autorange=False
        ),
        yaxis=dict(
            title='Hybrid Accuracy',
            range=[0, 1],
            autorange=False
        ),
        legend_title='Clients',
        width=800,
        height=500,
        template='plotly_white'
    )

    fig.show()

In [None]:
for seed in [1,2,3]:
    set_seed(seed)
    data_config = DataConfiguration('MNIST')
    transform_config = DataTransformConfiguration()
    simulation2 = HybridSplitSimulation(seed, data_config, transform_config, GlobalHybridSplitCNN, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
    simulation2.train()

In [None]:
import torch.nn as nn

class OnDeviceMNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = "OnDeviceMNISTModel"

        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # 16x14x14
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # 32x14x14
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 32x7x7
        )

        self.classifier = nn.Sequential(
            nn.Linear(32 * 7 * 7, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten dynamically
        x = self.classifier(x)
        return x

In [None]:
_coverage