In [None]:

import torch 
import random 
import numpy as np

    
def set_seed(seed=42):
    torch.manual_seed(seed)                # CPU
    torch.cuda.manual_seed(seed)           # Current GPU
    torch.cuda.manual_seed_all(seed)       # All GPUs
    np.random.seed(seed)                   # NumPy
    random.seed(seed)                      # Python random
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    



In [None]:
import os
from dotenv import load_dotenv
from federated_inference.dataset import MNISTDataset, FMNISTDataset, CIFAR10Dataset
from federated_inference.common.environment import DataSetEnum

class DataConfiguration():
    
    #MNIST_FASHION_DATASET Configurations
    FMNIST_NAME = 'FMNIST'
    FMNIST_DATASET_PATH = os.path.join('./data/fmnist')
    FMNIST_LABELS = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']
    
    #MNIST_DATASET Configurations
    MNIST_NAME = 'MNIST'
    MNIST_DATASET_PATH = os.path.join('./data/mnist')
    
    #CIFAR_DATASET Configurations
    CIFAR10_NAME = 'CIFAR10'
    CIFAR10_DATASET_PATH = os.path.join('./data/cifar10')
    CIFAR10_LABELS = ['Plane', 'Car', 'Bird', 'Cat','Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

    def __init__(self, dataset_name : str = None):
        load_dotenv(override=True)
        self.DATASET_NAME = dataset_name if dataset_name else os.getenv('DATASET_NAME', 'MNIST')
        if self.DATASET_NAME == DataSetEnum.MNIST.value: 
            self.INSTANCE_SIZE = (1, 28, 28)
            self.DATASET = DataSetEnum.MNIST
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.MNIST_DATASET_PATH))
            self.LABELS = range(10)
        if self.DATASET_NAME == DataSetEnum.FMNIST.value: 
            self.INSTANCE_SIZE = (1, 28, 28)
            self.DATASET = DataSetEnum.FMNIST
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.FMNIST_DATASET_PATH))
            self.LABELS = self.FMNIST_LABELS
        if self.DATASET_NAME == DataSetEnum.CIFAR10.value:
            self.INSTANCE_SIZE = (1, 32, 32)
            self.DATASET = DataSetEnum.CIFAR10
            self.DATASET_PATH = os.path.join(os.getenv('DATASET_PATH', self.CIFAR10_DATASET_PATH))
            self.LABELS =  self.CIFAR10_LABELS

    def __dict__(self):
        return {
            "dataset_name": self.DATASET_NAME, 
            "labels": list(self.LABELS),
            "size": self.INSTANCE_SIZE
        }

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn

class OnDeviceModelConfiguration():
    DEVICE = torch.device("cuda" if torch.cuda.is_available()
                          else "mps" if torch.backends.mps.is_available()
                          else "cpu")

    def __init__(self,  
        model: nn.Module, 
        learning_rate: float = 0.001, 
        batch_size_train: int = 64, 
        batch_size_val: int = 64, 
        batch_size_test: int = 64, 
        train_shuffle: bool = True, 
        val_shuffle: bool = False, 
        test_suffle: bool = False, 
        n_epochs: int = 60,
        train_ratio: float = 0.8):
        self.MODEL = model().to(self.DEVICE)
        self.LEARNING_RATE = learning_rate
        self.TRAIN_RATIO = train_ratio
        self.BATCH_SIZE_TRAIN = batch_size_train
        self.BATCH_SIZE_VAL = batch_size_val
        self.BATCH_SIZE_TEST = batch_size_test
        self.TRAIN_SHUFFLE = train_shuffle
        self.VAL_SHUFFLE = val_shuffle
        self.TEST_SHUFFLE = test_suffle
        self.N_EPOCH = n_epochs
        self.CRITERION = nn.CrossEntropyLoss()
        self.CRITERION_NAME = "CrossEntropyLoss"
        self.OPTIMIZER = optim.Adam(self.MODEL.parameters() , lr=learning_rate, weight_decay=1e-4) 
        self.OPTIMIZER_NAME = "Adam"

    def set_optimizer(self, optimizer, name):
        self.OPTIMIZER = optimizer
        self.OPTIMIZER_NAME = name

    def set_criterion(self, criterion, name): 
        self.CRITERION = criterion
        self.CRITERION_NAME = name


    def __dict__(self):
        return {
            "model_name": self.MODEL.name,
            "learning_rate": self.LEARNING_RATE, 
            "train_ratio": self.TRAIN_RATIO,
            "batch_size_train": self.BATCH_SIZE_TRAIN,
            "batch_size_val": self.BATCH_SIZE_VAL,
            "batch_size_test": self.BATCH_SIZE_TEST, 
            "train_shuffle": self.TRAIN_SHUFFLE,
            "val_shuffle": self.VAL_SHUFFLE,
            "test_shuffle": self.TEST_SHUFFLE,
            "n_epoch": self.N_EPOCH, 
            "optimizer":  self.OPTIMIZER_NAME, 
            "criterion": self.CRITERION_NAME
        }


In [None]:
import torch
import torch.optim as optim
import torch.nn as nn

class OnDeviceModelConfiguration():
    DEVICE = torch.device("cuda" if torch.cuda.is_available()
                          else "mps" if torch.backends.mps.is_available()
                          else "cpu")

    def __init__(self,  
        model: nn.Module, 
        learning_rate: float = 0.001, 
        batch_size_train: int = 64, 
        batch_size_val: int = 64, 
        batch_size_test: int = 64, 
        train_shuffle: bool = True, 
        val_shuffle: bool = False, 
        test_suffle: bool = False, 
        n_epochs: int = 30,
        train_ratio: float = 0.8):
        self.MODEL = model().to(self.DEVICE)
        self.LEARNING_RATE = learning_rate
        self.TRAIN_RATIO = train_ratio
        self.BATCH_SIZE_TRAIN = batch_size_train
        self.BATCH_SIZE_VAL = batch_size_val
        self.BATCH_SIZE_TEST = batch_size_test
        self.TRAIN_SHUFFLE = train_shuffle
        self.VAL_SHUFFLE = val_shuffle
        self.TEST_SHUFFLE = test_suffle
        self.N_EPOCH = n_epochs
        self.CRITERION = nn.CrossEntropyLoss()
        self.CRITERION_NAME = "CrossEntropyLoss"
        self.OPTIMIZER = optim.Adam(self.MODEL.parameters() , lr=2*learning_rate, weight_decay=0) 
        self.OPTIMIZER_NAME = "Adam"

    def set_optimizer(self, optimizer, name):
        self.OPTIMIZER = optimizer
        self.OPTIMIZER_NAME = name

    def set_criterion(self, criterion, name): 
        self.CRITERION = criterion
        self.CRITERION_NAME = name


    def __dict__(self):
        return {
            "model_name": self.MODEL.name,
            "learning_rate": self.LEARNING_RATE, 
            "train_ratio": self.TRAIN_RATIO,
            "batch_size_train": self.BATCH_SIZE_TRAIN,
            "batch_size_val": self.BATCH_SIZE_VAL,
            "batch_size_test": self.BATCH_SIZE_TEST, 
            "train_shuffle": self.TRAIN_SHUFFLE,
            "val_shuffle": self.VAL_SHUFFLE,
            "test_shuffle": self.TEST_SHUFFLE,
            "n_epoch": self.N_EPOCH, 
            "optimizer":  self.OPTIMIZER_NAME, 
            "criterion": self.CRITERION_NAME
        }


In [None]:
class DataTransformConfiguration(): 

    def __init__(self, tensor_size = [1,28,28], method_name = 'full', mask_size = [14,14], dimensions = [1,2], stride = 14, n_position = None, drop_p = None):
        self.TENSOR_SIZE = tensor_size
        self.METHOD_NAME = method_name
        self.MASK_SIZE = mask_size 
        self.DIMENSIONS = dimensions
        self.STRIDE = stride
        self.N_POSITION = n_position
        self.DROP_P = drop_p

    def __dict__(self):
        return {
            "tensor_size": self.TENSOR_SIZE, 
            "method_name": self.METHOD_NAME,
            "mask_size": self.MASK_SIZE,
            "dimensions": self.DIMENSIONS,
            "stride": self.STRIDE,
            "n_position": self.N_POSITION,
            "drop_p": self.DROP_P,

        }



In [None]:
import torch.nn as nn

class OnDeviceMNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = "OnDeviceMNISTModel"

        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # 16x14x14
            nn.BatchNorm2d(16),
            nn.ReLU(),

            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # 32x14x14
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.MaxPool2d(2),  # 32x7x7

        )

        self.classifier = nn.Sequential(
            nn.Linear(32 * 7 * 7, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten dynamically
        x = self.classifier(x)
        return x

In [None]:
import torch
from torch.utils.data import DataLoader, Subset
from federated_inference.common.environment import Member
from torch.utils.data import Dataset as TorchDataset
from collections.abc import Iterable
import numpy as np
import os
import pandas as pd

class EarlyStopper():
    def __init__(self, patience = 20, min_delta =  0.00001):
        self.patience = patience 
        self.min_delta = min_delta
        self.counter = 0 
        self.best_loss = None 
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None: 
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else: 
            self.best_loss = val_loss
            self.counter = 0
        
class OnDeviceVerticalClient():

    def __init__(self, 
            idx, 
            seed,
            model_config: OnDeviceModelConfiguration,
            data_config: DataConfiguration,
            dataset: TorchDataset,
            labels,
            log: bool = True, 
            log_interval: int = 100,
            save_interval: int = 20
        ):
        self.idx = idx
        self.seed = seed
        self.data = dataset
        self.data_config = data_config
        self.labels = labels
        self.numerical_labels = range(len(labels))
        self.member_type = Member.CLIENT
        self.model_config = model_config
        self.n_epoch = model_config.N_EPOCH
        self.device = model_config.DEVICE
        self.model = model_config.MODEL
        self.optimizer = model_config.OPTIMIZER
        self.criterion = model_config.CRITERION
        self.costs = []

        self.log = log
        self.log_interval = log_interval
        self.save_interval = save_interval

        if self.log: 
            self.train_losses = []
            self.test_losses = []
            self.accuracies = []

    def select_subset(self, ids: Iterable[int], set_type: str = "train"):
        if set_type == "test":
            return Subset(self.data.test_dataset, ids)
        else: 
            return Subset(self.data.train_dataset, ids)


    def _to_loader(self, trainset, testset, batch_size_train, batch_size_val, batch_size_test, train_shuffle, val_shuffle, test_shuffle, train_ratio):
        self.train_set_indices = np.arange(len(trainset))
        traindata = Subset(trainset, range(round(train_ratio*len(trainset))))
        valdata = Subset(trainset, range(round(train_ratio*len(trainset)), len(trainset)))
        self.trainloader = DataLoader(traindata, batch_size=batch_size_train, shuffle= train_shuffle) 
        self.valloader = DataLoader(valdata, batch_size=batch_size_val, shuffle= val_shuffle) 
        self.testloader = DataLoader(testset, batch_size=batch_size_test, shuffle=test_shuffle)

    def shuffle_loader(self, trainset, batch_size_train, batch_size_val, train_ratio):
        np.random.shuffle(self.train_set_indices)
        dataset_length = len(self.train_set_indices)
        train_end = round(train_ratio * dataset_length)
        train_indices = self.train_set_indices[:train_end]
        val_indices = self.train_set_indices[train_end:]
        traindata = Subset(trainset, train_indices)
        valdata = Subset(trainset, val_indices)
        self.valloader = DataLoader(valdata, batch_size=batch_size_val, shuffle=False) 
        self.trainloader = DataLoader(traindata, batch_size=batch_size_train, shuffle=False) 

    def _pred_loader(self, testset, batch_size_test, test_shuffle):
        return DataLoader(testset, batch_size=batch_size_test, shuffle=test_shuffle)

    def train(self, epoch):
        self.model.train()
        for batch_idx, (data, target) in enumerate(self.trainloader):
            data = data.to(self.model_config.DEVICE).float()
            target = target.to(self.model_config.DEVICE).long()

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()

            if batch_idx % self.save_interval == 0: 
                self.train_losses.append(loss.item())
            if self.log and batch_idx % self.log_interval == 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(self.trainloader.dataset)} '
                    f'({100. * batch_idx / len(self.trainloader):.0f}%)]\tLoss: {loss.item():.6f}')


        val_loss = 0
        self.model.eval()
        with torch.no_grad():
            for data, target in self.valloader:
                data = data.to(self.model_config.DEVICE).float()
                target = target.to(self.model_config.DEVICE).long()
                output = self.model(data)
                val_loss += self.criterion(output, target).item()
        val_loss /= len(self.valloader.dataset)

        if self.early_stopper.best_loss is None or val_loss < self.early_stopper.best_loss:
            print("Validation loss improved. Saving model...")
            self.save()
        self.early_stopper(val_loss)

    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for data, target in self.testloader:
                data = data.to(self.model_config.DEVICE).float()
                target = target.to(self.model_config.DEVICE).long()
                output = self.model(data)
                test_loss += self.criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.testloader.dataset)
        accuracy = 100. * correct / len(self.testloader.dataset)
        if self.log: 
            self.test_losses.append(test_loss)
            self.accuracies.append(accuracy)
        print(f'\nTest set: Average loss per sample: {test_loss:.4f}, Accuracy: {correct}/{len(self.testloader.dataset)} '
            f'({accuracy:.0f}%)\n')

    def run_training(self):
        trainset = self.data.train_dataset
        testset = self.data.test_dataset
        self.early_stopper = EarlyStopper()
        self._to_loader(trainset, testset, 
            self.model_config.BATCH_SIZE_TRAIN,
            self.model_config.BATCH_SIZE_VAL, 
            self.model_config.BATCH_SIZE_TEST, 
            self.model_config.TRAIN_SHUFFLE,
            self.model_config.VAL_SHUFFLE, 
            self.model_config.TEST_SHUFFLE,
            self.model_config.TRAIN_RATIO)
        self.test()
        for epoch in range(1, self.model_config.N_EPOCH + 1):
            self.train(epoch)
            self.test()
            if self.early_stopper.early_stop:
                self.early_stop_epoch = epoch
                print("early_stop_triggered")
                break
            self.shuffle_loader(trainset, 
            self.model_config.BATCH_SIZE_TRAIN,
            self.model_config.BATCH_SIZE_VAL, 
            self.model_config.TRAIN_RATIO)
                
            
    def save(self):
        result_path = f"./results/ondevice/{self.data_config.DATASET_NAME}/{self.seed}"
        os.makedirs(result_path, exist_ok=True)
        model_path = os.path.join(result_path, f'model_client_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_client_{self.idx}.pth').replace("\\", "/")
        torch.save(self.model.state_dict(), model_path)
        torch.save(self.optimizer.state_dict(), optimizer_path)


    def load(self):
        result_path = f"./results/ondevice/{self.data_config.DATASET_NAME}/{self.seed}"
        model_path = os.path.join(result_path, f'model_client_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_client_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.model.load_state_dict(network_state_dict)
        optimizer_state_dict = torch.load(optimizer_path)
        self.optimizer.load_state_dict(optimizer_state_dict)

    def pred(self):
        predictions = []
        testset = self.data.test_dataset
        testloader = self._pred_loader(testset, self.model_config.BATCH_SIZE_TEST, self.model_config.TEST_SHUFFLE)
        self.model.eval()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for data, target in testloader :
                data = data.to(self.model_config.DEVICE).float()
                target = target.to(self.model_config.DEVICE).long()
                output = self.model(data)
                test_loss += self.criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                predictions = predictions + pred.squeeze().tolist()

        return predictions

    def check(self, predicted_labels,  pred_all: bool= True):
        from sklearn.metrics import accuracy_score, precision_score, recall_score,  confusion_matrix
        import pandas as pd
        if pred_all:
            true_labels = self.data.test_dataset.targets
            accuracy = accuracy_score(true_labels, predicted_labels)
            precision = precision_score(true_labels, predicted_labels, average='macro')  # or 'weighted'
            recall = recall_score(true_labels, predicted_labels, average='macro')

            print("\n=== Metrics ===")
            print(f"Accuracy : {accuracy:.4f}")
            print(f"Precision: {precision:.4f}")
            print(f"Recall   : {recall:.4f}")
            cm = confusion_matrix(self.data.test_dataset.targets, predicted_labels, labels=self.numerical_labels)
            self.cm = pd.DataFrame(cm, index=[f'True {l}' for l in self.labels],
                                    columns=[f'Pred {l}' for l in self.labels])
            self.accuracy = accuracy 
            self.precision = precision 
            self.recall = recall

In [None]:

import numpy as np
from federated_inference.simulations.simulation import Simulation
from federated_inference.common.environment import TransformType, DataMode  
import torch.nn as nn
import os
import json
import pandas as pd
from federated_inference.simulations.utils import *

class OnDeviceVerticalSimulation(Simulation): 
    def __init__(self, 
                 data_config: DataConfiguration, 
                 transform_config: DataTransformConfiguration, 
                 seed: int, 
                 model: nn.Module,
                 transform_type:  TransformType = TransformType.FULL_STRIDE_PARTITION, 
                 exist = False):

        self.data_config = data_config
        self.transform_config = transform_config
        self.seed = seed
        self.data_mode = DataMode.VERTICAL
        self.transform_type = transform_type
        self.dataset =  self.load_data(data_config)
        self.client_datasets, self.transformation = self.transform_data(self.dataset, data_mode = self.data_mode, transform_config = transform_config, transform_type = self.transform_type)
        self.clients = [OnDeviceVerticalClient(idx, 
                                               seed, 
                                               OnDeviceModelConfiguration(model), 
                                               data_config,
                                               dataset, 
                                               data_config.LABELS) for idx, dataset in enumerate(self.client_datasets)]

        if exist:
            for client in self.clients:
                client.load()
            self.load()
                
    def train(self): 
        for client in self.clients:
            client.run_training()

    def test(self):
        for client in self.clients:
            predictions = client.pred()
            client.check(predictions)
            
    def load(self):
        from io import StringIO
        result_path = f"./results/ondevice/{self.data_config.DATASET_NAME}/{self.seed}/simulation.json"
        if os.path.isfile(result_path):
            with open(result_path, 'r') as f:
                results = json.load(f)
    
            self.results = results 
            for client_data in results['clients']:
                idx = client_data['idx']
                client = next((c for c in self.clients if c.idx == idx), None)
                if client:
                    client.cm = pd.read_json(StringIO(client_data['cm']), orient='split')
                    client.train_losses = client_data['training_losses']
                    client.test_losses = client_data['test_losses']
                else:
                    print(f"[Warning] No client with idx={idx} found in self.clients.")
        
    def to_json(self):
        import json
        simulation_data = {
            "configs" : {
                "data": self.data_config.__dict__(),
                "data_mode" : self.data_mode.value,
                "transformation": {
                    "transformation_type": self.transform_type.value,
                    "transoformation_config": self.transform_config.__dict__()
                }
            }

        }
        return json.dumps(simulation_data)

    def collect_results(self, name: str, save: bool = True, figures: bool = False):
        import json
        import os
        from IPython.display import display

        self.results = {
            'seed': name,
            'clients': []
        }

        if figures:
            fig = create_simulation_image_subplots(self)
            display(fig)

        for client in self.clients:
            client_result = self._gather_client_results(client, figures)
            self.results['clients'].append(client_result)

        if save:
            self._save_results(name)
        return self.results

    def _gather_client_results(self, client, figures):
        result = {
            'idx': client.idx,
            'cm': client.cm.to_json(orient='split'),
            'training_losses': client.train_losses,
            'test_losses': client.test_losses
        }

        analysis, df_cm_per = cm_analysis(client)
        result['cm_analysis'] = analysis

        indices = [
            analysis['correct']['most_correct_class'],
            analysis['wrong']['most_misclassified_class']
        ]
        indices += [i for i in [analysis['wrong']['wrong_from'], analysis['wrong']['wrong_to']] if i not in indices]

        # Performance metrics
        result.update({
            'accuracy': analysis["performance"]["accuracy"],
            'precision': analysis["performance"]["precision"],
            'recall': analysis["performance"]["recall"]
        })

        # Generate and display figures if needed
        if figures:
            display(plot_test_loss(client.test_losses, 1, client.idx, "Test"))
            fig, selected_indices = create_client_image_subplots(self, [client.idx], 8, keys=indices)
            display(fig)
            result['client_image_subplots_ids'] = selected_indices
            display(print_cm_heat(df_cm_per, client.idx))

        return result

    def _save_results(self, name):
        import os
        import json

        result_path = f"./results/ondevice/{self.data_config.DATASET_NAME}/{name}"
        os.makedirs(result_path, exist_ok=True)
        with open(os.path.join(result_path, "simulation.json"), "w") as f:
            json.dump(self.results, f, indent=4)
        print("Results saved to JSON.")



In [None]:
for seed in [1]:
    data_config = DataConfiguration('MNIST')
    transform_config = DataTransformConfiguration()
    simulation = OnDeviceVerticalSimulation(data_config, transform_config, seed, OnDeviceMNISTModel, exist=False)
    simulation.train()
    #simulation.test()
    #on_device_result = simulation.collect_results(seed, save = True, figures=False)
