In [1]:
import random
import torch
import decimal
from decimal import Decimal
from typing import List, Tuple, Dict, Set
import numpy as np
import time
import torch.nn as nn
from sklearn.datasets import load_iris, load_diabetes
from torch.utils.data import Dataset, Subset, DataLoader
import torch.nn.functional as F
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import copy
import itertools

# Define a large prime number as the field size
FIELD_SIZE = 2147483647  # A large prime (here 2^31-1) can be used
scale_factor = 10**8  # Increase scale factor for better precision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def polynom(x, coefficients):
    """
    Generate a single point on the graph of a given polynomial.
    Args:
        x (int): The x-coordinate
        coefficients (List[int]): Coefficients of the polynomial
    Returns:
        int: The y-coordinate for the given x
    """
    point = 0
    for coefficient_index, coefficient_value in enumerate(coefficients[::-1]):
        point += x ** coefficient_index * coefficient_value
    return point


def coeff(t: int, secret: int) -> List[int]:
    """
    Randomly generate coefficients for a polynomial.

    Args:
        t (int): Threshold for reconstruction (degree + 1)
        secret (int): The secret value to be shared

    Returns:
        List[int]: Coefficients of the polynomial
    """
    coefficients = [random.randrange(0, FIELD_SIZE) for _ in range(t - 1)]
    coefficients.append(secret)  # Secret is the constant term
    return coefficients


def generate_shares(n: int, m: int, secret: int) -> List[Tuple[int, int]]:
    """
    Generate 'n' secret shares using a (m, n)-threshold scheme.

    Args:
        n (int): Total number of shares
        m (int): Minimum shares required for reconstruction
        secret (int): The secret value to be shared

    Returns:
        List[Tuple[int, int]]: A list of (x, y) share pairs
    """
    coefficients = coeff(m, secret)
    shares = []
    used_x_values = set()

    for _ in range(n):
        while True:
            x = random.randrange(1, FIELD_SIZE)
            if x not in used_x_values:  # Ensure unique x values
                used_x_values.add(x)
                break
        shares.append((x, polynom(x, coefficients)))

    return shares


# def reconstruct_secret(shares):
#     sums = Decimal(0)
#     for j, share_j in enumerate(shares):
#         xj, yj = share_j
#         prod = Decimal(1)
#         valid_share = True
#         for i, share_i in enumerate(shares):
#             xi, _ = share_i
#             if i != j:
#                 try:
#                     prod *= Decimal(xi) / Decimal(xi - xj)
#                 except (ZeroDivisionError, decimal.DivisionByZero):
#                     valid_share = False
#                     break
#         if valid_share:
#             prod *= yj
#             sums += prod
#     return float(sums)  # Use float to preserve the fractional part  
def reconstruct_secret(shares: List[Tuple[int, float]], device='cpu') -> float:
    """
    Reconstructs a secret from shares using Lagrange interpolation.
    Works with GPU tensors if device is 'cuda'.
    
    Args:
        shares (List[Tuple[int, float]]): List of (x, y) points.
        device (str): 'cpu' or 'cuda'
    
    Returns:
        float: Reconstructed secret.
    """
    x_vals = torch.tensor([s[0] for s in shares], dtype=torch.float32, device=device)
    y_vals = torch.tensor([s[1] for s in shares], dtype=torch.float32, device=device)

    secret = torch.tensor(0.0, device=device)
    print("Reconstructing on:", x_vals.device)
    for i in range(len(shares)):
        xi, yi = x_vals[i], y_vals[i]
        li = torch.tensor(1.0, device=device)
        
        for j in range(len(shares)):
            if i != j:
                xj = x_vals[j]
                li *= (0 - xj) / (xi - xj)
        
        secret += yi * li
    
    return secret.item()


def train_model(model, dataloader, epochs=10, lr=0.01):
    """
    Trains a regression model using MSE loss.

    Args:
        model      : PyTorch model to train.
        dataloader : DataLoader yielding (inputs, targets).
        epochs     : Number of epochs.
        lr         : Learning rate.
    Returns:
        Trained model.
    """
    model.train()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        running_loss = 0.0
        total_samples = 0

        for inputs, targets in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)              # shape (N,1)
            loss = criterion(outputs, targets)   # MSE loss
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)

        epoch_loss = running_loss / total_samples
        if epoch == epochs - 1:
            print(f"Epoch {epoch+1}/{epochs}  Loss: {epoch_loss:.4f}")

    return model

Using device: cpu


In [2]:
# Settings
num_clients = 8
threshold = 3
delay_constant = 1
# Attack configuration - specify which clients' shares should be tampered
attacked_clients = [4,6,9,1]  # Client IDs 0 and 1 are compromised
tampering_factor = 0.5  # How much to tamper (as a fraction of scale_factor)
tampering_probability = 1.0  # Probability of tampering with a specific share (1.0 = tamper with all shares)


# Function to partition the dataset into 'num_clients' subsets
def partition_dataset(dataset, num_clients):
    data_size = len(dataset)
    indices = np.random.permutation(data_size)  # shuffle indices for random partitioning
    partition_size = data_size // num_clients
    partitions = []
    for i in range(num_clients):
        start = i * partition_size
        # Ensure the last client gets any remaining data
        end = (i + 1) * partition_size if i != num_clients - 1 else data_size
        partitions.append(Subset(dataset, indices[start:end]))
    return partitions


class DiabetesNet(nn.Module):
    """
    A simple fully connected neural network for the Diabetes dataset (10 features, 1 output).
    """
    def __init__(self, input_dim=10):
        super(DiabetesNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)  # Regression output

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x


class DiabetesDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


class Client:
    def __init__(self, client_id: int, model: nn.Module) -> None:
        self.client_id = client_id
        self.model = model
        self.weights = []
        self.shares = {}  # Dictionary to store received shares from other clients
        self.own_shares = {}  # Store shares of its own model
        
        # Initialize weights from model
        for name, param in self.model.state_dict().items():
            if "weight" in name:
                self.weights.append(param.cpu().numpy())

    def get_model_shares(self) -> Dict[int, Dict[str, list]]:
        """
        Generate secret shares for all clients (including self) with proper recipient tracking
        """
        # Create shares for each recipient client
        shared_model = {recipient_id: {"weights": []} for recipient_id in range(num_clients)}

        for layer_idx, layer_weights in enumerate(self.weights):
            # Initialize empty structure for each recipient
            recipient_layers = [[] for _ in range(num_clients)]
            
            for row_idx, row in enumerate(layer_weights):
                # Initialize empty rows for each recipient
                recipient_rows = [[] for _ in range(num_clients)]
                
                for col_idx, weight in enumerate(row):
                    # Generate shares for this weight value for each recipient
                    shares = generate_shares(num_clients, threshold, int(weight * scale_factor))
                    
                    # Distribute shares to respective recipients
                    for recipient_id, share in enumerate(shares):
                        recipient_rows[recipient_id].append({
                            "owner_id": self.client_id,  # This is the owner of the original weight
                            "recipient_id": recipient_id,  # This is who receives the share
                            "layer": layer_idx,
                            "row": row_idx,
                            "column": col_idx,
                            "share": share
                        })
                
                # Add rows to respective recipient layers
                for recipient_id in range(num_clients):
                    recipient_layers[recipient_id].append(recipient_rows[recipient_id])
            
            # Add layers to respective recipients
            for recipient_id in range(num_clients):
                shared_model[recipient_id]["weights"].append(recipient_layers[recipient_id])
        
        return shared_model

    def share_model(self, clients: List["Client"]) -> None:
        """
        Share model updates with all clients including self
        """
        shared_model = self.get_model_shares()
        
        # If this client is compromised, tamper with the shares before distributing
        if self.client_id in attacked_clients:
            print(f"ATTACK: Tampering witth shares from Client {self.client_id}")
            shared_model = self.tamper_with_shares(shared_model)
        
        # Store own shares
        self.own_shares[self.client_id] = shared_model[self.client_id]
        # print(f"Client {self.client_id} stored its own shares")
        
        # Share with other clients
        for client in clients:
            recipient_id = client.client_id
            if recipient_id != self.client_id:  # Don't duplicate share with yourself
                if self.client_id not in client.shares:
                    client.shares[self.client_id] = {}
                client.shares[self.client_id] = shared_model[recipient_id]
                # print(f"Client {self.client_id} shared with Client {recipient_id}")


    def tamper_with_shares(self, shared_model):
        """
        Tamper with the shares using configurable parameters
        """
        tampering_amount = int(scale_factor * tampering_factor)
        tampered_count = 0
        total_count = 0

        for recipient_id in shared_model:
            for layer_idx, layer in enumerate(shared_model[recipient_id]["weights"]):
                for row_idx, row in enumerate(layer):
                    for col_idx, share_entry in enumerate(row):
                        total_count += 1
                        # Only tamper with probability defined by tampering_probability
                        if random.random() < tampering_probability:
                            x, y = share_entry["share"]
                            # Increment the y-coordinate by tampering_amount
                            share_entry["share"] = (x, y + tampering_amount)
                            tampered_count += 1

        print(f"ATTACK: Client {self.client_id} has tampered with {tampered_count}/{total_count} shares ({tampered_count/total_count:.2%})")
        return shared_model

    def upload_to_buffer(self, buffers: List["Buffer"]) -> None:
        """
        Upload shares to the corresponding buffer based on owner_id
        """
        for buffer in buffers:
            # Upload shares where owner_id matches buffer_id
            shares_for_buffer = {}
            
            # Add shares from other clients
            for owner_id, share_data in self.shares.items():
                if owner_id == buffer.buffer_id:
                    shares_for_buffer[self.client_id] = share_data
            
            # Also add own shares if this client is the owner
            if self.client_id == buffer.buffer_id and buffer.buffer_id in self.own_shares:
                shares_for_buffer[self.client_id] = self.own_shares[buffer.buffer_id]
                # print(f"Client {self.client_id} uploaded its own shares to Buffer {buffer.buffer_id}")
            
            if shares_for_buffer:
                buffer.add_shares(shares_for_buffer)
                # print(f"Client {self.client_id} uploaded shares from Owner {buffer.buffer_id} to Buffer {buffer.buffer_id}")


In [3]:
class Buffer:
    def __init__(self, buffer_id: int) -> None:
        self.buffer_id = buffer_id
        self.buffer_shares = {}  # Dictionary to store shares from clients

    def add_shares(self, shares_data: Dict) -> None:
        """
        Add shares to the buffer
        """
        for client_id, data in shares_data.items():
            self.buffer_shares[client_id] = data
            # print(f"Buffer {self.buffer_id} received shares from Client {client_id} for Owner {self.buffer_id}")

    def get_shares(self) -> Dict:
        """
        Retrieve stored shares from the buffer.
        """
        return self.buffer_shares


def reconstruct_weights_with_voting(buffers: List[Buffer], threshold: int) -> Dict[int, List[np.ndarray]]:
    """
    Reconstructs models buffer-wise using threshold-compliant shares from contributing clients
    with a resilience mechanism to handle compromised shares through majority voting.
    
    Returns dictionary mapping buffer IDs to reconstructed models.
    """
    reconstructed_models = {}
    
    for buffer in buffers:
        print(f"\n{'='*40}\nReconstructing Buffer {buffer.buffer_id}\n{'='*40}")
        buffer_shares = buffer.get_shares()
        
        if not buffer_shares:
            print(f"No shares available for Buffer {buffer.buffer_id}")
            continue
        
        # Get all clients who provided shares
        contributing_clients = list(buffer_shares.keys())
        # print(f"Contributing clients for Buffer {buffer.buffer_id}: {contributing_clients}")
        
        # Aggregate shares by position (layer, row, col)
        share_aggregate_by_client = {}
        
        for client_id, client_data in buffer_shares.items():
            for layer_idx, layer in enumerate(client_data["weights"]):
                for row_idx, row in enumerate(layer):
                    for col_idx, share_entry in enumerate(row):
                        key = (layer_idx, row_idx, col_idx)
                        if key not in share_aggregate_by_client:
                            share_aggregate_by_client[key] = {}
                        share_aggregate_by_client[key][client_id] = share_entry["share"]
        
        # For each parameter position, try different combinations of shares for voting
        layer_dimensions = {}
        layer_data = {}
        
        for position, client_shares in share_aggregate_by_client.items():
            layer_idx, row_idx, col_idx = position
            
            if len(client_shares) < threshold:
                # print(f" Position {position}: Insufficient shares ({len(client_shares)}/{threshold})")
                continue
            
            # Generate all possible combinations of clients that meet the threshold
            client_combinations = list(itertools.combinations(client_shares.keys(), threshold))
            # print(f"Position {position}: Testing {len(client_combinations)} different share combinations")
            
            # Store reconstruction results for voting
            reconstruction_votes = []
            
            for combo in client_combinations:
                shares_for_reconstruction = [client_shares[client_id] for client_id in combo]
                
                try:
                    rec_val = reconstruct_secret(shares_for_reconstruction) / scale_factor
                    reconstruction_votes.append(round(rec_val, 5))  # Round for voting
                    
                    # Print which clients were used and the result
                    used_clients_str = ', '.join([str(c) for c in combo])
                    # print(f"  - Using clients {used_clients_str}: got {rec_val:.5f}")
                    
                except Exception as e:
                    print(f"  - Reconstruction failed for combo {combo}: {str(e)}")
            
            if not reconstruction_votes:
                print(f" Failed to reconstruct for position {position}")
                continue
            
            # Majority voting to select the most frequent reconstruction value
            vote_counts = Counter(reconstruction_votes)
            winner, count = vote_counts.most_common(1)[0]
            
            # Check if this is a clear winner (more than 50% of votes)
            total_votes = len(reconstruction_votes)
            confidence = count / total_votes
            
            # print(f" Position {position}: Winner value = {winner} with {count}/{total_votes} votes ({confidence:.2%})")
            
            # Track dimensions for creating properly sized arrays
            if layer_idx not in layer_dimensions:
                layer_dimensions[layer_idx] = {"rows": 0, "cols": 0}
            layer_dimensions[layer_idx]["rows"] = max(layer_dimensions[layer_idx]["rows"], row_idx + 1)
            layer_dimensions[layer_idx]["cols"] = max(layer_dimensions[layer_idx]["cols"], col_idx + 1)
            
            # Store reconstructed value
            if layer_idx not in layer_data:
                layer_data[layer_idx] = {}
            if row_idx not in layer_data[layer_idx]:
                layer_data[layer_idx][row_idx] = {}
            layer_data[layer_idx][row_idx][col_idx] = winner
            
            # Analyze if attacked clients were detected
            attack_analysis = ""
            for combo in client_combinations:
                shares_for_reconstruction = [client_shares[client_id] for client_id in combo]
                try:
                    val = round(reconstruct_secret(shares_for_reconstruction) / scale_factor, 5)
                    # Check if this combination contains only attacked clients or only honest clients
                    contains_only_attacked = all(c in attacked_clients for c in combo)
                    contains_only_honest = all(c not in attacked_clients for c in combo)
                    
                    if contains_only_attacked and val != winner:
                        attack_analysis = f"⚠️ Attack detected: The purely attacked combination {combo} gave {val}, different from majority {winner}"
                    elif contains_only_honest and val == winner:
                        attack_analysis = f"✅ Honest combination {combo} agreed with majority {winner}"
                except Exception:
                    pass
                    
            # if attack_analysis:
                # print(f"  {attack_analysis}")
        
        # Convert reconstructed data to numpy arrays
        final_model = []
        
        for layer_idx in sorted(layer_dimensions.keys()):
            rows = layer_dimensions[layer_idx]["rows"]
            cols = layer_dimensions[layer_idx]["cols"]
            
            # Create a zero-filled array of the correct dimensions
            layer_array = np.zeros((rows, cols), dtype=np.float32)
            
            # Fill in the reconstructed values
            for row_idx in range(rows):
                for col_idx in range(cols):
                    if layer_idx in layer_data and row_idx in layer_data[layer_idx] and col_idx in layer_data[layer_idx][row_idx]:
                        layer_array[row_idx, col_idx] = layer_data[layer_idx][row_idx][col_idx]
            
            final_model.append(layer_array)
        
        if final_model:
            reconstructed_models[buffer.buffer_id] = final_model
            # print(f" Buffer {buffer.buffer_id} reconstruction complete with voting-based resilience")
        # else:
            # print(f"Buffer {buffer.buffer_id} failed reconstruction")
    
    return reconstructed_models

In [4]:
class Server:
    def __init__(self) -> None:
        # List to store the state_dicts (model weights) from clients.
        self.model_weights = []
        # Store the final aggregated model
        self.aggregated_model = None

    def fedSGD(self, learning_rate=0.01):
        """
        Applies FedSGD by averaging the gradients from all clients' model weights
        and updating the global model.

        Args:
            learning_rate (float): Learning rate to update the model.
        
        Returns:
            dict: The updated model weights (state dict).
        """
        if not self.model_weights:
            raise ValueError("No model weights received from clients.")
        
        num_clients = len(self.model_weights)
        
        # Initialize with zeroed state dict (representing gradients)
        gradient_state = copy.deepcopy(self.model_weights[0])
        for key in gradient_state.keys():
            gradient_state[key] = torch.zeros_like(gradient_state[key])
        
        # Sum the gradients from each client
        for state in self.model_weights:
            for key in gradient_state.keys():
                gradient_state[key] += state[key]  # Assuming the model weights are already gradients
        
        # Average the gradients
        for key in gradient_state.keys():
            gradient_state[key] /= num_clients
        
        # Update the global model by applying the averaged gradient
        if self.aggregated_model is None:
            # If there's no initial aggregated model, initialize it
            self.aggregated_model = copy.deepcopy(self.model_weights[0])

        # Apply the gradient update to the aggregated model
        for key in self.aggregated_model.keys():
            self.aggregated_model[key] -= learning_rate * gradient_state[key]  # SGD update
        
        return self.aggregated_model

    def evaluate_rmse(self, model, loader, device="cpu"):
        model.to(device).eval()
        preds, trues = [], []
        with torch.no_grad():
            for X, y in loader:
                X, y = X.to(device), y.to(device)
                out = model(X)
                preds.append(out.cpu().numpy().flatten())
                trues.append(y.cpu().numpy().flatten())
        preds = np.concatenate(preds)
        trues = np.concatenate(trues)
        return np.sqrt(np.mean((preds - trues)**2))


    def evaluate_accuracy(self, model, test_loader, device="cpu"):
        """
        Evaluates the aggregated model on the test dataset and returns the accuracy.
        
        Args:
            model: The PyTorch model to be evaluated.
            test_loader: DataLoader for the test set.
            device: Device to run the model on (CPU/GPU).
        
        Returns:
            float: Accuracy of the model on the test dataset.
        """
        model.to(device)
        model.eval()
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
        
        accuracy = correct / total
        return accuracy



def send_reconstructed_weights_to_server(reconstructed_model: Dict[int, List[np.ndarray]], server: Server):
    """
    Transfers the reconstructed model updates to the server and performs FedSGD.

    Args:
        reconstructed_model (dict): Mapping of buffer_id to a list of reconstructed layers (NumPy arrays).
        server (Server): The server instance that will aggregate the model updates.

    Returns:
        dict: Updated model state dictionary after applying FedSGD.
    """
    for client_id, layers in reconstructed_model.items():
        state_dict = {}

        # Map layer names to match the IrisNet model
        state_dict["fc1.weight"] = torch.tensor(layers[0], dtype=torch.float32)  # First layer weights
        state_dict["fc1.bias"] = torch.zeros(state_dict["fc1.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
        state_dict["fc2.weight"] = torch.tensor(layers[1], dtype=torch.float32)  # Second layer weights
        state_dict["fc2.bias"] = torch.zeros(state_dict["fc2.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
        state_dict["fc3.weight"] = torch.tensor(layers[2], dtype=torch.float32)  # First layer weights
        state_dict["fc3.bias"] = torch.zeros(state_dict["fc3.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
        state_dict["fc4.weight"] = torch.tensor(layers[3], dtype=torch.float32)  # First layer weights
        state_dict["fc4.bias"] = torch.zeros(state_dict["fc4.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
        
        # Add the state dict to the server's list
        server.model_weights.append(state_dict)
        print(f"Transferred reconstructed weights for Client {client_id}.")

    # Perform FedSGD aggregation
    updated_state = server.fedSGD()

    print("Updated Model Weights via FedSGD:")
    for key, value in updated_state.items():
        print(f"  {key}:\n{value}\n")

    return updated_state







# def send_reconstructed_weights_to_server(reconstructed_model: Dict[int, List[np.ndarray]], server: Server):
#     """
#     Transfers the reconstructed model updates to the server and performs FedAvg.

#     Args:
#         reconstructed_model (dict): Mapping of buffer_id to a list of reconstructed layers (NumPy arrays).
#         server (Server): The server instance that will aggregate the model updates.

#     Returns:
#         dict: Aggregated model state dictionary.
#     """
#     for client_id, layers in reconstructed_model.items():
#         state_dict = {}

#         # Map layer names to match the IrisNet model
#         state_dict["fc1.weight"] = torch.tensor(layers[0], dtype=torch.float32)  # First layer weights
#         state_dict["fc1.bias"] = torch.zeros(state_dict["fc1.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
#         state_dict["fc2.weight"] = torch.tensor(layers[1], dtype=torch.float32)  # Second layer weights
#         state_dict["fc2.bias"] = torch.zeros(state_dict["fc2.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
#         state_dict["fc3.weight"] = torch.tensor(layers[2], dtype=torch.float32)  # First layer weights
#         state_dict["fc3.bias"] = torch.zeros(state_dict["fc3.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
#         state_dict["fc4.weight"] = torch.tensor(layers[3], dtype=torch.float32)  # First layer weights
#         state_dict["fc4.bias"] = torch.zeros(state_dict["fc4.weight"].shape[0], dtype=torch.float32)  # Bias initialized to zeros
        
#         # Add the state dict to the server's list
#         server.model_weights.append(state_dict)
#         print(f"Transferred reconstructed weights for Client {client_id}.")

#     # Perform FedAvg aggregation
#     aggregated_state = server.fedSGD()

#     print("Aggregated Model Weights via FedAvg:")
#     for key, value in aggregated_state.items():
#         print(f"  {key}:\n{value}\n")

#     return aggregated_state

In [None]:
import numpy as np
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# ---------- Data Augmentation ----------
def augment_data(X: np.ndarray, y: np.ndarray, factor: int = 2, noise_std: float = 0.05):
    X_list, y_list = [X], [y]
    for _ in range(factor):
        noise = np.random.normal(0, noise_std, X.shape)
        X_list.append(X + noise)
        y_list.append(y)
    return np.vstack(X_list), np.hstack(y_list)

# ---------- Main Routine ----------
def main():
    print(f"Starting attack simulation with {num_clients} clients")
    print(f" Clients {attacked_clients} will have their shares tampered")
    print(f" Using threshold {threshold} for secret sharing")
    
    # Load the Diabetes dataset from sklearn.
    diabetes = load_diabetes()
    X, y = diabetes.data, diabetes.target
    
    # Split dataset into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42)

    # === AUGMENTATION STEP ===
    print(f"Original training samples: {X_train.shape[0]}")
    X_train_aug, y_train_aug = augment_data(
        X_train, y_train, factor=2, noise_std=0.1)
    print(f"Augmented training samples: {X_train_aug.shape[0]}")
    # =========================
    
    # Create PyTorch datasets
    scaler_X = StandardScaler().fit(X_train)
    scaler_y = StandardScaler().fit(y_train.reshape(-1,1))
    
    # 2) Transform both train and test
    X_train_scaled = scaler_X.transform(X_train)
    X_test_scaled  = scaler_X.transform(X_test)
    
    y_train_scaled = scaler_y.transform(y_train.reshape(-1,1)).flatten()
    y_test_scaled  = scaler_y.transform(y_test.reshape(-1,1)).flatten()
    
    # 3) Create your DiabetesDataset with scaled arrays:
    train_ds = DiabetesDataset(X_train_scaled, y_train_scaled)
    test_ds  = DiabetesDataset(X_test_scaled,  y_test_scaled)
    train_dataset = train_ds
    test_dataset = test_ds
    
    # Partition data for clients
    data_size = len(train_dataset)
    indices = np.random.permutation(data_size)
    partition_size = data_size // num_clients
    partitions = []
    for i in range(num_clients):
        start = i * partition_size
        end = (i + 1) * partition_size if i != num_clients - 1 else data_size
        partitions.append(Subset(train_dataset, indices[start:end]))
        print(f"Client partition {i} size: {end - start}")
    
    # Initialize clients and buffers
    clients = [Client(i, DiabetesNet(input_dim=X.shape[1])) for i in range(num_clients)]
    buffers = [Buffer(i) for i in range(num_clients)]

    # Local Training: Train each client on its partition
    for i, client in enumerate(clients):
        print(f"Training Client {client.client_id}")
        train_loader = DataLoader(
            partitions[i], batch_size=8, shuffle=True)
        train_model(client.model, train_loader, epochs=20, lr=0.01)
    
    # Clients share their models with each other
    for client in clients:
        time.sleep(delay_constant)
        client.share_model(clients)
    
    # Clients upload their shares to buffers
    for client in clients:
        client.upload_to_buffer(buffers)
    
    # Verify shares in buffers
    # for buffer in buffers:
    #     if buffer.buffer_shares:
    #         print(f"Buffer {buffer.buffer_id} has received shares from {len(buffer.buffer_shares)} clients.")
    #         client_ids = list(buffer.buffer_shares.keys())
    #         print(f"  - Shares from clients: {client_ids}")
            
    #         # Count how many tampered clients contributed to this buffer
    #         tampered_count = sum(1 for cid in client_ids if cid in attacked_clients)
    #         if tampered_count > 0:
    #             print(f"  - {tampered_count} compromised clients contributed shares")
    #     else:
    #         print(f"Buffer {buffer.buffer_id} didn't receive any shares.")
    
    try:
        # Reconstruct weights for each buffer using voting-based resilience
        print("\n Starting resilient reconstruction with majority voting...")
        reconstructed_model = reconstruct_weights_with_voting(buffers, threshold)
        
        # Print or process the reconstructed weights
        for buffer_id, layers in reconstructed_model.items():
            print(f"Reconstructed Weights for Buffer/Client {buffer_id}:")
            for i, layer in enumerate(layers):
                print(f"Layer {i} shape: {layer.shape}")
                print(f"Sample weights: {layer.flatten()[:5]}...\n")
    except ValueError as e:
        print("Error:", e)
        return  # Exit early if reconstruction fails
    
    print("Now transferring the reconstructed model updates to the server for aggregation...")
    
    # Create a server instance
    server = Server()
    
    # Transfer the reconstructed model updates to the server and perform FedAvg
    aggregated_state = send_reconstructed_weights_to_server(reconstructed_model, server)
    global_model = DiabetesNet(input_dim=X.shape[1])
    global_model.load_state_dict(aggregated_state)
    print("\n Aggregated Global Model Weights (Layer-wise):\n")
    for name, param in aggregated_state.items():
        print(f"Layer: {name}")
        print(f"Shape: {param.shape}")
        print(f"Sample weights: {param.flatten()[:5].tolist()}...\n")
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=16, shuffle=False)
    
    mse = server.evaluate_rmse(global_model, test_loader)
    print(f"Global Model RMSE: {mse:.4f}")
if __name__ == "__main__":
    main()


Starting attack simulation with 8 clients
 Clients [4, 6, 9, 1] will have their shares tampered
 Using threshold 3 for secret sharing
Original training samples: 353
Augmented training samples: 1059
Client partition 0 size: 44
Client partition 1 size: 44
Client partition 2 size: 44
Client partition 3 size: 44
Client partition 4 size: 44
Client partition 5 size: 44
Client partition 6 size: 44
Client partition 7 size: 45
Training Client 0
Epoch 20/20  Loss: 0.0693
Training Client 1
Epoch 20/20  Loss: 0.0207
Training Client 2
Epoch 20/20  Loss: 0.0763
Training Client 3
Epoch 20/20  Loss: 0.1011
Training Client 4
Epoch 20/20  Loss: 0.0556
Training Client 5
Epoch 20/20  Loss: 0.0406
Training Client 6
Epoch 20/20  Loss: 0.0269
Training Client 7
Epoch 20/20  Loss: 0.0136
ATTACK: Tampering witth shares from Client 1
ATTACK: Client 1 has tampered with 92416/92416 shares (100.00%)
ATTACK: Tampering witth shares from Client 4
ATTACK: Client 4 has tampered with 92416/92416 shares (100.00%)
ATTACK: 