In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch 
import random 
import numpy as np

def set_seed(seed=42):
    torch.manual_seed(seed)                # CPU
    torch.cuda.manual_seed(seed)           # Current GPU
    torch.cuda.manual_seed_all(seed)       # All GPUs
    np.random.seed(seed)                   # NumPy
    random.seed(seed)                      # Python random
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

In [None]:
import torch.nn as nn
class IGV1Base(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten
        return x

class IGV1ClassifierHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(500, 50),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        return self.classifier(x)

class InformationGateV1(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(500, 50),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(50, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.classifier(x)

class SplitServerCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(2000, 50),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(50, 10)
        )

    def forward(self, x_concat):
        return self.classifier(x_concat)
        

In [None]:
import os
from dotenv import load_dotenv
from federated_inference.dataset import MNISTDataset, FMNISTDataset, CIFAR10Dataset
from federated_inference.common.environment import DataSetEnum

class DataConfiguration():
    
    #MNIST_FASHION_DATASET Configurations
    FMNIST_NAME = 'FMNIST'
    FMNIST_DATASET_PATH = os.path.join('./data/fmnist')
    FMNIST_LABELS = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']
    
    #MNIST_DATASET Configurations
    MNIST_NAME = 'MNIST'
    MNIST_DATASET_PATH = os.path.join('./data/mnist')
    
    #CIFAR_DATASET Configurations
    CIFAR10_NAME = 'CIFAR10'
    CIFAR10_DATASET_PATH = os.path.join('./data/cifar10')
    CIFAR10_LABELS = ['Plane', 'Car', 'Bird', 'Cat','Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

    def __init__(self, dataset_name : str = None):
        load_dotenv(override=True)
        self.DATASET_NAME = dataset_name if dataset_name else os.getenv('DATASET_NAME', 'MNIST')
        if self.DATASET_NAME == DataSetEnum.MNIST.value: 
            self.INSTANCE_SIZE = (1, 28, 28)
            self.DATASET = DataSetEnum.MNIST
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.MNIST_DATASET_PATH))
            self.LABELS = range(10)
        if self.DATASET_NAME == DataSetEnum.FMNIST.value: 
            self.INSTANCE_SIZE = (1, 28, 28)
            self.DATASET = DataSetEnum.FMNIST
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.FMNIST_DATASET_PATH))
            self.LABELS = self.FMNIST_LABELS
        if self.DATASET_NAME == DataSetEnum.CIFAR10.value:
            self.INSTANCE_SIZE = (1, 32, 32)
            self.DATASET = DataSetEnum.CIFAR10
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.CIFAR10_DATASET_PATH))
            self.LABELS =  self.CIFAR10_LABELS

    def __dict__(self):
        return {
            "dataset_name": self.DATASET_NAME, 
            "labels": list(self.LABELS),
            "size": self.INSTANCE_SIZE
        }

import torch
import torch.optim as optim
import torch.nn as nn

class ModelConfiguration():
    DEVICE = torch.device("cuda" if torch.cuda.is_available()
                          else "mps" if torch.backends.mps.is_available()
                          else "cpu")

    def __init__(self,  
        server_model: nn.Module, 
        client_base_model: nn.Module, 
        client_classifier_head: nn.Module, 
        client_ig_head: nn.Module,          
        learning_rate: float = 0.001, 
        batch_size_train: int = 64, 
        batch_size_val: int = 64, 
        batch_size_test: int = 64, 
        train_shuffle: bool = False, 
        val_shuffle: bool = False, 
        test_suffle: bool = False, 
        n_epochs: int = 30,
        train_ratio: float = 0.8):
        self.SERVER_MODEL = server_model
        self.CLIENT_BASE_MODEL = client_base_model
        self.CLIENT_CLASSIFIER_MODEL = client_classifier_head
        self.CLIENT_IG_MODEL = client_ig_head
        self.LEARNING_RATE = learning_rate
        self.TRAIN_RATIO = train_ratio
        self.BATCH_SIZE_TRAIN = batch_size_train
        self.BATCH_SIZE_VAL = batch_size_val
        self.BATCH_SIZE_TEST = batch_size_test
        self.TRAIN_SHUFFLE = train_shuffle
        self.VAL_SHUFFLE = val_shuffle
        self.TEST_SHUFFLE = test_suffle
        self.N_EPOCH = n_epochs

    def __dict__(self):
        return {
            "learning_rate": self.LEARNING_RATE, 
            "train_ratio": self.TRAIN_RATIO,
            "batch_size_train": self.BATCH_SIZE_TRAIN,
            "batch_size_val": self.BATCH_SIZE_VAL,
            "batch_size_test": self.BATCH_SIZE_TEST, 
            "train_shuffle": self.TRAIN_SHUFFLE,
            "val_shuffle": self.VAL_SHUFFLE,
            "test_shuffle": self.TEST_SHUFFLE,
            "n_epoch": self.N_EPOCH, 
        }



In [None]:
from torch.utils.data import Dataset as TorchDataset
from federated_inference.common.cost_calculator import CostCalculator
from federated_inference.common.environment import Member
from collections.abc import Iterable
import logging 

class IGV1Client():

    def __init__(self, 
            idx, 
            data_config: DataConfiguration,
            dataset: TorchDataset,
            labels,
            log: bool = True, 
            log_interval: int = 100, 
            save_interval: int = 20
        ):
        self.idx = idx
        self.seed = seed
        self.data = dataset
        self.data_config = data_config
        self.labels = labels
        self.numerical_labels = range(len(labels))
        self.member_type = Member.CLIENT
        self.model = None
        self.costs = []

        self.log = log
        self.log_interval = log_interval
        self.save_interval = save_interval

    def select_subset(self, ids: Iterable[int], set_type: str = "train"):
        if set_type == "test":
            return Subset(self.data.test_dataset, ids)
        else: 
            return Subset(self.data.train_dataset, ids)

    def send_all(self):
        cost = CostCalculator.calculate_communication_cost_by(self.data.train_dataset)
        cost.set_cost_reason("send_all_training")
        self.costs.append(cost)
        return self.data.train_dataset

    def request_pred(self, idx: int|None = None, set_type: str = "test", pred_all: bool= False, keep_label: bool = False): 
        if idx != None:
            if set_type == "test":
                data = self.data.test_dataset[idx]
                cost = CostCalculator.calculate_communication_cost_by(data)
                cost.set_cost_reason("send_testing_pred_request")
                self.costs.append(cost)
                
                return self.data.test_dataset[idx] if keep_label else self.data.test_dataset[idx][0] 
        elif pred_all:
            cost = CostCalculator.calculate_communication_cost_by(self.data.test_dataset)
            cost.set_cost_reason("send_all_testing_pred_request")
            self.costs.append(cost)
            return self.data.test_dataset if keep_label else [img for img, label in self.data.test_dataset]

    def check(self, predicted_labels,  pred_all: bool= True):
        from sklearn.metrics import accuracy_score, precision_score, recall_score,  confusion_matrix
        import pandas as pd
        if pred_all:
            true_labels = self.data.test_dataset.targets
            accuracy = accuracy_score(true_labels, predicted_labels)
            precision = precision_score(true_labels, predicted_labels, average='macro')  # or 'weighted'
            recall = recall_score(true_labels, predicted_labels, average='macro')

            print("\n=== Metrics ===")
            print(f"Accuracy : {accuracy:.4f}")
            print(f"Precision: {precision:.4f}")
            print(f"Recall   : {recall:.4f}")
            cm = confusion_matrix(self.data.test_dataset.targets, predicted_labels, labels=self.numerical_labels)
            self.cm = pd.DataFrame(cm, index=[f'True {l}' for l in self.labels],
                                    columns=[f'Pred {l}' for l in self.labels])
            self.accuracy = accuracy 
            self.precision = precision 
            self.recall = recall


from torch.utils.data import Dataset as TorchDataset
from federated_inference.common.cost_calculator import CostCalculator
from federated_inference.common.environment import Member
from collections.abc import Iterable
import logging 
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Subset
from torch.utils.data import Dataset as TorchDataset
from collections.abc import Iterable
import os

class EarlyStopper():
    def __init__(self, patience = 5, min_delta = 0.00001):
        self.patience = patience 
        self.min_delta = min_delta
        self.counter = 0 
        self.best_loss = None 
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None: 
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else: 
            self.best_loss = val_loss
            self.counter = 0
        

from torch.utils.data import Dataset as TorchDataset
from federated_inference.common.cost_calculator import CostCalculator
from federated_inference.common.environment import Member
from collections.abc import Iterable
import logging 
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Subset
from torch.utils.data import Dataset as TorchDataset
from collections.abc import Iterable
import os
class IGV1Server():

    def __init__(self, 
            idx, 
            seed,
            model_config: ModelConfiguration,
            data_config: DataConfiguration,
            log: bool = True, 
            log_interval: int = 100,
            save_interval: int = 20
        ):
        self.idx = idx
        self.seed = seed
        self.model_config = model_config
        self.data_config = data_config
        self.seed = seed
        self.number_of_clients = 4
        self.n_epoch = model_config.N_EPOCH
        self.device = model_config.DEVICE
        self.member_type = Member.SERVER
        self.server_model = model_config.SERVER_MODEL().to(self.device)
        self.client_base_models = [model_config.CLIENT_BASE_MODEL().to(self.device) for c in range(self.number_of_clients)]
        self.client_classifier_models = [model_config.CLIENT_CLASSIFIER_MODEL().to(self.device) for c in range(self.number_of_clients)]
        self.client_ig_models = [model_config.CLIENT_IG_MODEL().to(self.device) for c in range(self.number_of_clients)]
        self.CRITERION = nn.CrossEntropyLoss()
        self.CRITERION_NAME = "CrossEntropyLoss"
        self.SERVER_OPTIMIZER = optim.Adam(self.server_model.parameters() , lr=model_config.LEARNING_RATE)
        self.CLIENT_BASE_OPTIMIZERS =  [optim.Adam(self.client_base_models[c].parameters() , lr=model_config.LEARNING_RATE) for c in range(self.number_of_clients)]
        self.CLIENT_CLASSIFIER_OPTIMIZERS =  [optim.Adam(self.client_classifier_models[c].parameters() , lr=model_config.LEARNING_RATE) for c in range(self.number_of_clients)]
        self.CLIENT_IG_OPTIMIZERS =  [optim.Adam(self.client_ig_models[c].parameters() , lr=model_config.LEARNING_RATE) for c in range(self.number_of_clients)]
        self.OPTIMIZER_NAME = "Adam"
        self.LOCAL_CLASSIFIER_CRITERION = nn.CrossEntropyLoss()
        self.LOCAL_IG_CRITERION = nn.BCELoss()
        self.costs = []

        self.log = log
        self.log_interval = log_interval
        self.save_interval = save_interval

        if self.log: 
            self.train_losses = []
            self.test_losses = []
            self.accuracies = []


    def _to_loader(self, trainsets, testsets, batch_size_train, batch_size_val, batch_size_test, train_shuffle, val_shuffle, test_shuffle, train_ratio):
        # TODO refactoing to use self
        if train_shuffle:
            dataset_length = len(trainsets[0])
            assert all(len(trainset) == dataset_length for trainset in trainsets), "All trainsets must be the same length"

            indices = np.arange(dataset_length)

            if train_shuffle:
                np.random.shuffle(indices)

            train_end = round(train_ratio * dataset_length)
            train_indices = indices[:train_end]
            val_indices = indices[train_end:]

            traindatas = [Subset(trainset, train_indices) for trainset in trainsets]
            valdatas = [Subset(trainset, val_indices) for trainset in trainsets]
        else:
            traindatas = [Subset(trainset, range(round(train_ratio*len(trainset)))) for trainset in trainsets]
            valdatas = [Subset(trainset, range(round(train_ratio*len(trainset)), len(trainset))) for trainset in trainsets]
        self.trainloader = [DataLoader(traindata, batch_size=batch_size_train, shuffle=False)  for traindata in traindatas]
        self.valloader = [DataLoader(valdata, batch_size=batch_size_val, shuffle=False) for valdata in valdatas]
        self.testloader = [DataLoader(testdata, batch_size=batch_size_test, shuffle=False) for testdata in testsets]


    def _pred_loader(self, testsets, batch_size_test, test_shuffle):
        return  [DataLoader(testdata, batch_size=batch_size_test, shuffle=False) for testdata in testsets]


    def train(self, epoch):
        for batch_idx, batches in enumerate(zip(*self.trainloader)):
            # batches is tuple of batch from each loader
            data_slices = [batch[0].to(self.device).float() for batch in batches]
            target = batches[0][1].to(self.device).long()

            # Client FeedForward
            client_activations = []
            for data, client_model in zip(data_slices, self.client_base_models):
                client_model.train()
                activation = client_model(data)
                activation.requires_grad_()
                client_activations.append(activation)

            concat_activations = torch.cat(client_activations, dim=1)
            
            #Server FeedForward
            self.server_model.train()
            self.SERVER_OPTIMIZER.zero_grad()
            output = self.server_model(concat_activations)
            loss =  self.CRITERION(output, target)
            if self.log: 
                self.train_losses.append(loss.item())
            # Get gradients for both activation and model parameters
            grads = torch.autograd.grad(loss, [concat_activations] + list(self.server_model.parameters()))

            #Server Backprop
            concat_activation_grad = grads[0]
            model_grads = grads[1:]
            # Apply gradients manually to model parameters
            for param, grad in zip(self.server_model.parameters(), model_grads):
                param.grad = grad  # Set .grad for optimizer

            self.SERVER_OPTIMIZER.step()

            # Client Backprop  
            activation_sizes = [act.shape[1] for act in client_activations]
            activation_grads = torch.split(concat_activation_grad, activation_sizes, dim=1)

            for i, activation_grad in enumerate(activation_grads):
                self.client_base_models[i].train()
                self.CLIENT_BASE_OPTIMIZERS[i].zero_grad()
                activation =  self.client_base_models[i](data)
                activation.backward(activation_grad)
                self.CLIENT_BASE_OPTIMIZERS[i].step()
                
            if self.log and batch_idx % self.log_interval == 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(self.trainloader[0].dataset)} '
                    f'({100. * batch_idx / len(self.trainloader[0]):.0f}%)]\tLoss: {loss.item():.6f}')

            if batch_idx % self.save_interval == 0:
                self.train_losses.append(loss.item())
                
        val_loss = 0
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.valloader)):
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                for data, client_model in zip(data_slices, self.client_base_models):
                    #Client FeedForward
                    client_model.eval()
                    activation = client_model(data)
                    client_activations.append(activation)
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                val_loss +=  self.CRITERION(output, target).item()
        val_loss /= len(self.valloader[0].dataset)
        if self.early_stopper.best_loss is None or val_loss < self.early_stopper.best_loss:
            print("Validation loss improved. Saving model...")
            self.save()
        self.early_stopper(val_loss)
    
    def test(self):
        self.server_model.eval()
        
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloader)):
                # batches is tuple of batch from each loader
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                for data, client_base_model in zip(data_slices, self.client_base_models):
                    #Client FeedForward
                    client_base_model.eval()
                    activation = client_base_model(data)
                    client_activations.append(activation)
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                test_loss +=  self.CRITERION(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.testloader[0].dataset)
        accuracy = 100. * correct / len(self.testloader[0].dataset)
        if self.log: 
            self.test_losses.append(test_loss)
            self.accuracies.append(accuracy)
        print(f'\nTest set: Average loss per Sample: {test_loss:.4f}, Accuracy: {correct}/{len(self.testloader[0].dataset)} '
            f'({accuracy:.0f}%)\n')
            
    def run_training(self, trainset, testset):
        self.early_stopper = EarlyStopper()
        self._to_loader(trainset, testset, 
            self.model_config.BATCH_SIZE_TRAIN,
            self.model_config.BATCH_SIZE_VAL, 
            self.model_config.BATCH_SIZE_TEST, 
            self.model_config.TRAIN_SHUFFLE,
            self.model_config.VAL_SHUFFLE, 
            self.model_config.TEST_SHUFFLE,
            self.model_config.TRAIN_RATIO)
        self.test()
        for epoch in range(1, self.model_config.N_EPOCH + 1):
            self.train(epoch)
            self.test()
            #if self.early_stopper.early_stop:
            #    self.early_stop_epoch = epoch
            #    print("early_stop_triggered")
            #   break
        self.run_head_training()
            
    def pred(self, testset,  pred=True):
        predictions = []
        outputs = []
        loader = self._pred_loader(testset, self.model_config.BATCH_SIZE_TEST, self.model_config.TEST_SHUFFLE)
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*loader)):
                # batches is tuple of batch from each loader
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                client_activations = []
                for data, client_base_model in zip(data_slices, self.client_base_models):
                    #Client FeedForward
                    client_base_model.eval()
                    activation = client_base_model(data)
                    client_activations.append(activation)
                concat_activations = torch.cat(client_activations, dim=1)
                #Server FeedForward
                self.server_model.eval()
                output = self.server_model(concat_activations)
                pred = output.argmax(dim=1, keepdim=True)
                predictions = predictions + pred.squeeze().tolist()
        return predictions

    def save(self):
        result_path = f"./results/split/{self.data_config.DATASET_NAME}/{self.seed}"
        os.makedirs(result_path, exist_ok=True)
        model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
        torch.save(self.server_model.state_dict(), model_path)
        torch.save(self.SERVER_OPTIMIZER.state_dict(), optimizer_path)

        client_result_path =os.path.join(result_path, "clients")
        os.makedirs(client_result_path, exist_ok=True)
        for i, client_base_model in enumerate(self.client_base_models): 
            model_path = os.path.join(client_result_path, f'model_client_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_client_{i}.pth').replace("\\", "/")
            torch.save(client_base_model.state_dict(), model_path)
            torch.save(self.CLIENT_BASE_OPTIMIZERS[i].state_dict(), optimizer_path)
                    
                   
    def load(self):
        result_path = f"./results/split/{self.data_config.DATASET_NAME}/{self.seed}"
        model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.server_model.load_state_dict(network_state_dict)
        optimizer_state_dict = torch.load(optimizer_path)
        self.SERVER_OPTIMIZER.load_state_dict(optimizer_state_dict)

        client_result_path =os.path.join(result_path, "clients")
        for i, client_base_model in enumerate(self.client_base_models): 
            model_path = os.path.join(client_result_path, f'model_client_{i}.pth').replace("\\", "/")
            optimizer_path = os.path.join(client_result_path, f'optimizer_client_{i}.pth').replace("\\", "/")
            network_state_dict = torch.load(model_path)
            torch.save(clien_base_model.state_dict(), model_path)
            torch.save(self.CLIENT_BASE_OPTIMIZERS[i].state_dict(), optimizer_path)
                            
    def classifier_test(self):
        
        test_loss_values = [0 for c in self.client_base_models]
        correct_values = [0 for c in self.client_base_models]
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloader)):
                data_slices = [batch[0].to(self.device).float() for batch in batches]
                target = batches[0][1].to(self.device).long()
                for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_classifier_models)):
                    data, client_base_model, classifier_model = client 
                    client_base_model.train()
                    activation = client_base_model(data)
                    classifier_output = classifier_model(activation)
                    test_loss_values[client_idx] +=  self.LOCAL_CLASSIFIER_CRITERION(classifier_output, target).item()
                    pred = classifier_output.argmax(dim=1, keepdim=True)
                    correct_values[client_idx]+= pred.eq(target.view_as(pred)).sum().item()
        test_loss_values = [t/ len(self.testloader[0].dataset) for t in test_loss_values]
        accuracy_values = [100. * c / len(self.testloader[0].dataset) for c in correct_values]

        if self.log: 
            self.test_losses.append(test_loss_values)
            self.accuracies.append(accuracy_values)
        print(f'\nTest set: Average loss per Sample: {test_loss_values}, Accuracy: {correct_values}/{len(self.testloader[0].dataset)}'
            f'({accuracy_values}%)\n')
        
    def classifier_train(self, epoch):
        for batch_idx, batches in enumerate(zip(*self.trainloader)):
            # batches is tuple of batch from each loader
            data_slices = [batch[0].to(self.device).float() for batch in batches]
            target = batches[0][1].to(self.device).long()
            client_activations = []
            for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_classifier_models, self.client_ig_models, self.CLIENT_CLASSIFIER_OPTIMIZERS, self.CLIENT_IG_OPTIMIZERS)):
                data, client_base_model, classifier_model, client_ig_model, classifier_optimizer, ig_optimizer = client
                #Client FeedForward
                client_base_model.train()
                activation = client_base_model(data)
                activation.requires_grad_()
                client_activations.append(activation)
                #Classifier
                classifier_optimizer.zero_grad()
                classifier_output = classifier_model(activation)
                loss = self.LOCAL_CLASSIFIER_CRITERION(classifier_output, target)
                grads = torch.autograd.grad(loss, [activation] + list(classifier_model.parameters()))
                concat_activation_grad = grads[0]
                model_grads = grads[1:]
                # Apply gradients manually to model parameters
                for param, grad in zip(classifier_model.parameters(), model_grads):
                    param.grad = grad  # Set .grad for optimizer
                classifier_optimizer.step()
                if self.log and batch_idx % self.log_interval == 0:
                    print(f'Train Epoch: {epoch} Client {client_idx} [{batch_idx * len(data)}/{len(self.trainloader[0].dataset)} '
                        f'({100. * batch_idx / len(self.trainloader[0]):.0f}%)]\tLoss: {loss.item():.6f}')

            concat_activations = torch.cat(client_activations, dim=1)
            
    def gate_train(self, epoch):  
        for batch_idx, batches in enumerate(zip(*self.trainloader)):
            data_slices = [batch[0].to(self.device).float() for batch in batches]
            target = batches[0][1].to(self.device).long()
            client_info = []
            for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_ig_models, self.CLIENT_IG_OPTIMIZERS)):
                data, client_base_model, classifier_model, client_ig_model, ig_optimizer = client
                #Client FeedForward
                client_base_model.train()
                activation = client_base_model(data)
                activation.requires_grad_()
                classifier_output = classifier_model(activation)
                classifier_pred = classifier_output.argmax(dim=1)
                # Save everything for later use
                client_info.append({
                    'activation': activation,
                    'classifier_output': classifier_output,
                    'classifier_pred': classifier_pred,
                    'ig_model': ig_model,
                    'ig_optimizer': ig_optimizer
                })
                
            concat_activations = torch.cat(client_activations, dim=1)
            #Server FeedForward
            self.server_model.train()
            self.SERVER_OPTIMIZER.zero_grad()
            server_output = self.server_model(concat_activations)
            server_pred = server_output.argmax(dim=1)
            
            for idx, info in enumerate(client_info):
                activation = info['activation']
                classifier_pred = info['classifier_pred']
                ig_model = info['ig_model']
                ig_optimizer = info['ig_optimizer']
                ig_optimizer.train()
                ig_optimizer.zero_grad()
    
                client_wrong = (classifier_pred != target)
                server_right = (server_pred == target)
                ig_target = (client_wrong & server_right).float().unsqueeze(1)  # Shape: (batch_size, 1)
    
                ig_output = ig_model(activation) 
    
                ig_loss = self.LOCAL_IG_CRITERION(ig_output, ig_target)
                grads = torch.autograd.grad(loss, [activation] + list(ig_model.parameters()))
                concat_activation_grad = grads[0]
                model_grads = grads[1:]
                for param, grad in zip(classifier_model.parameters(), model_grads):
                    param.grad = grad  # Set .grad for optimizer
                ig_optimizer.step()

    def gate_pred(self): 
        for batch_idx, batches in enumerate(zip(*self.testloader)):
            data_slices = [batch[0].to(self.device).float() for batch in batches]
            target = batches[0][1].to(self.device).long()
            ig_outputs = []
            for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_ig_models)):
                data, client_base_model, classifier_model, client_ig_model = client
                #Client FeedForward
                client_base_model.eval()
                activation = client_base_model(data)
                ig_output = ig_model(activation) 
                ig_outputs.append(ig_output)
            return ig_outputs, target
                
                 
    def run_head_training(self):
        self.classifier_test()
        for epoch in range(1, self.model_config.N_EPOCH + 1):
            self.classifier_train(epoch)
            self.classifier_test()

    def run_gate_training(self):
        for epoch in range(1, self.model_config.N_EPOCH + 1):
            self.gate_train(epoch)

            

from federated_inference.common.environment import  DataMode, DataSetEnum, TransformType
from federated_inference.simulations.simulation import Simulation
import os
import pandas as pd
from federated_inference.simulations.simulation import Simulation
from federated_inference.simulations.utils import *
class IGSimulation(Simulation): 
    def __init__(self, seed, data_config, transform_config, server_model, client_base_model, client_classifier_model, client_ig_model, transform_type: TransformType = TransformType.FULL_STRIDE_PARTITION, exist=False):
        self.seed = seed
        self.data_config = data_config
        self.transform_config = transform_config
        self.server_model_config = ModelConfiguration
        self.data_mode = DataMode.VERTICAL
        self.transform_type = transform_type
        self.dataset =  self.load_data(data_config)
        self.client_datasets, self.transformation = self.transform_data(self.dataset, data_mode = self.data_mode, transform_config = transform_config, transform_type = self.transform_type)
        self.clients = [IGV1Client(idx, data_config, dataset, data_config.LABELS) for idx, dataset in enumerate(self.client_datasets)]
        self.server = IGV1Server(0, seed, ModelConfiguration(server_model, client_base_model, client_classifier_model, client_ig_model), self.data_config)
        
        self.cost_summary = {
            'overall_cost' : 0, 
            'reasons': []
        }

    def train(self): 
        datasets = [client.send_all() for client in self.clients]
        testsets = [client.request_pred(pred_all = True, keep_label = True) for client in self.clients]
        self.server.run_training(datasets, testsets)

    def test_inference(self):
        testsets = [client.request_pred(pred_all = True, keep_label = True) for client in self.clients]
        predictions = self.server.pred(testsets)
        self.clients[0].check(predictions)
        self.collect_results(self.seed, save = True)

    def collect_results(self, name: str, save: bool = True, figures: bool = False):
        from IPython.display import display
        import json
        import os

        self.results = {
            'seed': name,
            'client': [],
            'server': {}
        }

        if figures:
            fig = create_simulation_image_subplots(simulation)
            display(fig)

        # Collect server results
        self.results['server'] = self._gather_server_results(self.server)

        # Collect results for the first (and only) client
        client = self.clients[0]
        client_result = self._gather_client_result(client, figures)
        self.results['client'].append(client_result)

        # Save results
        if save:
            self._save_results(str(name), base_dir="naive")
        return self.results

    def _gather_server_results(self, server):
        return {
            'training_losses': server.train_losses,
            'test_losses': server.test_losses
        }

    def _gather_client_result(self, client, figures):
        from IPython.display import display

        result = {
            'idx': client.idx,
            'cm': client.cm.to_json(orient='split')
        }

        analysis, df_cm_per = cm_analysis(client)
        result['cm_analysis'] = analysis

        # Extract relevant class indices
        indices = [
            analysis['correct']['most_correct_class'],
            analysis['wrong']['most_misclassified_class']
        ]
        indices += [i for i in [analysis['wrong']['wrong_from'], analysis['wrong']['wrong_to']] if i not in indices]

        # Add performance metrics
        result.update({
            'accuracy': analysis["performance"]["accuracy"],
            'precision': analysis["performance"]["precision"],
            'recall': analysis["performance"]["recall"]
        })

        # Generate figures if needed
        if figures:
            display(plot_test_loss(client.test_losses, 1, client.idx, "Test"))
            fig, subplot_indices = create_client_image_subplots(self, [client.idx], 8, keys=indices)
            display(fig)
            result['client_image_subplots_ids'] = subplot_indices
            fig = print_cm_heat(df_cm_per, client.idx)
            display(fig)

        return result

    def _save_results(self, name, base_dir="ig"):
        import os
        import json

        result_path = os.path.join("./results", base_dir , self.data_config.DATASET_NAME, name)
        os.makedirs(result_path, exist_ok=True)
        file_path = os.path.join(result_path, "simulation.json")

        with open(file_path, "w") as f:
            json.dump(self.results, f, indent=4)
        print("Results saved to JSON.")

In [None]:
from federated_inference.simulations.naive.configs.data_config import DataConfiguration
from federated_inference.simulations.naive.configs.transform_config import DataTransformConfiguration
seed = 1
set_seed(seed)
data_config = DataConfiguration()
transform_config = DataTransformConfiguration()
simulation = IGSimulation(seed, data_config, transform_config, SplitServerCNN, IGV1Base, IGV1ClassifierHead, InformationGateV1)

In [None]:
simulation.train()

In [None]:
def gate_train(self, epoch):  
    for batch_idx, batches in enumerate(zip(*self.trainloader)):
        data_slices = [batch[0].to(self.device).float() for batch in batches]
        target = batches[0][1].to(self.device).long()
        client_info = []
        client_activations = []
        for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_classifier_models, self.client_ig_models, self.CLIENT_IG_OPTIMIZERS)):
            data, client_base_model, classifier_model, ig_model, ig_optimizer = client
            #Client FeedForward
            client_base_model.train()
            activation = client_base_model(data)
            activation.requires_grad_()
            classifier_output = classifier_model(activation)
            classifier_pred = classifier_output.argmax(dim=1)
            client_activations.append(activation)
            client_info.append({
                'activation': activation,
                'classifier_output': classifier_output,
                'classifier_pred': classifier_pred,
                'ig_model': ig_model,
                'ig_optimizer': ig_optimizer
            })
            
        concat_activations = torch.cat(client_activations, dim=1)
        #Server FeedForward
        self.server_model.train()
        self.SERVER_OPTIMIZER.zero_grad()
        server_output = self.server_model(concat_activations)
        server_pred = server_output.argmax(dim=1)
        
        for idx, info in enumerate(client_info):
            activation = info['activation']
            classifier_pred = info['classifier_pred']
            ig_model = info['ig_model']
            ig_optimizer = info['ig_optimizer']
            ig_model.train()
            ig_optimizer.zero_grad()

            
            client_wrong = (classifier_pred != target)
            server_right = (server_pred == target)
            ig_target = (client_wrong & server_right).float().unsqueeze(1)  # Shape: (batch_size, 1)

            ig_output = ig_model(activation) 

            ig_loss = self.LOCAL_IG_CRITERION(ig_output, ig_target)
            grads = torch.autograd.grad(ig_loss, [activation] + list(ig_model.parameters()))
            concat_activation_grad = grads[0]
            model_grads = grads[1:]
            for param, grad in zip(ig_model.parameters(), model_grads):
                param.grad = grad
            ig_optimizer.step()
for epoch in range(1, 30):
    gate_train(simulation.server, epoch)
    

In [None]:
def gate_train(self, epoch):  
    for batch_idx, batches in enumerate(zip(*self.trainloader)):
        data_slices = [batch[0].to(self.device).float() for batch in batches]
        target = batches[0][1].to(self.device).long()

        client_info = []
        client_activations = []

        for data, base_model, classifier_model, ig_model, ig_optimizer in zip(
            data_slices, self.client_base_models, self.client_classifier_models,
            self.client_ig_models, self.CLIENT_IG_OPTIMIZERS
        ):
            base_model.eval()
            classifier_model.eval()
            ig_model.train()

            # Client base model forward
            activation = base_model(data)
            activation.requires_grad_()
            classifier_output = classifier_model(activation)
            classifier_loss = self.LOCAL_CLASSIFIER_CRITERION(classifier_output, target)

            # Save for server input
            client_info.append({
                'activation': activation,
                'ig_model': ig_model,
                'ig_optimizer': ig_optimizer,
                'classifier_loss': classifier_loss
            })
            client_activations.append(activation)

        # Server model forward
        concat_activations = torch.cat(client_activations, dim=1)
        self.server_model.train()
        server_output = self.server_model(concat_activations)
        server_loss = self.CRITERION(server_output, target)

        for info in client_info:
            activation = info['activation']
            ig_model = info['ig_model']
            ig_optimizer = info['ig_optimizer']
            classifier_loss = info['classifier_loss']
            ig_optimizer.zero_grad()
            # Forward through IG
            ig_output = ig_model(activation).squeeze(1)  # shape: [batch_size, 1] with sigmoid output

            # Soft-gating loss: IG(x) * local_loss + (1 - IG(x)) * server_loss
            loss = (1-ig_output) * 0.002 * classifier_loss + (ig_output) * server_loss
            loss = loss.mean()

            # Backprop
            grads = torch.autograd.grad(loss, [activation] + list(ig_model.parameters()), retain_graph=True)
            concat_activation_grad = grads[0]
            model_grads = grads[1:]
            for param, grad in zip(ig_model.parameters(), model_grads):
                param.grad = grad
            ig_optimizer.step()
            
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, IG output mean: {ig_output.mean().item():.4f}")

for epoch in range(1, 30):
    gate_train(simulation.server, epoch)

In [None]:
def gate_pred(self): 
    ig_outputs = [ [] for i in range(self.number_of_clients)]
    targets = []
    for batch_idx, batches in enumerate(zip(*self.testloader)):
        data_slices = [batch[0].to(self.device).float() for batch in batches]
        target = batches[0][1].to(self.device).long()
        targets = targets + target.squeeze().tolist()
        for client_idx, client in enumerate(zip(data_slices, self.client_base_models, self.client_ig_models)):
            data, client_base_model, ig_model = client
            #Client FeedForward
            client_base_model.eval()
            ig_model.eval()
            activation = client_base_model(data)
            ig_output = ig_model(activation) 
            ig_outputs[client_idx] = ig_outputs[client_idx] + ig_output.squeeze().tolist()
    return ig_outputs, targets
ig_outputs, targets = gate_pred(simulation.server)
import numpy as np
for i in range(4):
    from collections import Counter
    filtered_targets = np.array(targets)[np.array(ig_outputs[i]) > 0.25]
    counts = Counter(filtered_targets)
    #print(ig_outputs)
    print(counts, len(filtered_targets)/len(ig_outputs[i]))

In [None]:
def test_inference(self):
    for i in range(self.number_of_clients):
        correct = 0
        client_correct = 0
        server_correct = 0
        client_instance = 0
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloader)):
                    # Step 1: Get data and target
                    data_slices = [batch[0].to(self.device).float() for batch in batches]
                    target = batches[0][1].to(self.device).long()
                    # Step 2: Run base model and IG for active client (client 0)
                    active_client = list(zip(data_slices, self.client_base_models, self.client_classifier_models, self.client_ig_models))[i]
                    active_data, base_model, classifier_model, ig_model = active_client
                    
                    base_model.eval()
                    activation = base_model(active_data)
                    
                    ig_model.eval()
                    ig_output = ig_model(activation).squeeze()  # Make sure it's shape [batch_size]
        
                    # Step 3: Split by IG threshold
                    low_mask = (ig_output < 0.25).to(self.device)
                    high_mask = (ig_output >= 0.25).to(self.device)
        
                    # Step 4: Local classifier inference for low IG samples
                    low_target = target[low_mask]
                    client_instance += len(low_target)
                    activation_low = activation[low_mask]
                    if len(low_target) > 0:
                        classifier_model.eval()
                        local_output = classifier_model(activation_low)
                        classifier_pred = local_output.argmax(dim=1)
                        correct += classifier_pred.eq(low_target).sum().item()
                        client_correct += classifier_pred.eq(low_target).sum().item()
            
                    # Step 5: Prepare activations from all clients for high IG samples
                    high_data_slices = [batch[0].to(self.device)[high_mask].to(self.device).float() for batch in batches]
                    high_target = target[high_mask]
        
                    if len(high_target) > 0:
                        client_activations = []
                        for data, base_model, classifier_model, ig_model in zip(high_data_slices, self.client_base_models, self.client_classifier_models, self.client_ig_models):
                            if len(data) > 0:
                                base_model.eval()
                                activation = base_model(data)
                                client_activations.append(activation)
                
                        if client_activations:  # Avoid torch.cat error on empty list
                            concat_activations = torch.cat(client_activations, dim=1)
            
                            # Step 6: Server-side inference
                            self.server_model.eval()
                            output = self.server_model(concat_activations)
                            pred = output.argmax(dim=1)
                            correct += pred.eq(high_target).sum().item()
                            server_correct += pred.eq(high_target).sum().item()
        
            print(correct / len(self.testloader[0].dataset))
            print(client_correct/client_instance, client_instance)
            print(server_correct/(len(self.testloader[0].dataset) - client_instance), (len(self.testloader[0].dataset) - client_instance))
test_inference(simulation.server)

In [None]:
from federated_inference.simulations.utils import *
simulation.test_inference()

In [None]:
def save(self):
    result_path = f"./results/split/{self.data_config.DATASET_NAME}/{self.seed}"
    os.makedirs(result_path, exist_ok=True)
    model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
    optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
    torch.save(self.server_model.state_dict(), model_path)
    torch.save(self.SERVER_OPTIMIZER.state_dict(), optimizer_path)

    client_base_result_path =os.path.join(result_path, "clients", "base")
    os.makedirs(client_base_result_path, exist_ok=True)
    for i, client_base_model in enumerate(self.client_base_models): 
        model_path = os.path.join(client_base_result_path, f'model_client_{i}.pth').replace("\\", "/")
        optimizer_path = os.path.join(client_base_result_path, f'optimizer_client_{i}.pth').replace("\\", "/")
        torch.save(client_base_model.state_dict(), model_path)
        torch.save(self.CLIENT_BASE_OPTIMIZERS[i].state_dict(), optimizer_path)

    clients_result_path =os.path.join(result_path, "clients", "ig")
    os.makedirs(clients_result_path, exist_ok=True)
    for i, client_base_model in enumerate(self.client_ig_models): 
        model_path = os.path.join(clients_result_path, f'model_client_{i}.pth').replace("\\", "/")
        optimizer_path = os.path.join(clients_result_path, f'optimizer_client_{i}.pth').replace("\\", "/")
        torch.save(client_base_model.state_dict(), model_path)
        torch.save(self.CLIENT_IG_OPTIMIZERS[i].state_dict(), optimizer_path)
        
    clients_result_path = os.path.join(result_path, "clients", "classifier")
    os.makedirs(clients_result_path, exist_ok=True)
    for i, client_base_model in enumerate(self.client_ig_models): 
        model_path = os.path.join(clients_result_path, f'model_client_{i}.pth').replace("\\", "/")
        optimizer_path = os.path.join(clients_result_path, f'optimizer_client_{i}.pth').replace("\\", "/")
        torch.save(client_base_model.state_dict(), model_path)
        torch.save(self.CLIENT_CLASSIFIER_OPTIMIZERS[i].state_dict(), optimizer_path)
                

                        