In [None]:
%pip install zmq web3

In [2]:
import zmq
import pickle
from web3 import Web3
from eth_account import Account

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.backends import default_backend

In [3]:
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))

with open("contract/address", "r") as file:
    address = file.read()

with open("contract/abi", "r") as file:
    abi = file.read()

contract_instance = w3.eth.contract(address=address, abi=abi)

In [4]:
with open("keys/server", "r") as file:
    private_key = file.read()

account = Account.from_key(private_key)
sender_address = account.address

def submitAggregateModelHash(modelHash: str, round: int):
    transaction = contract_instance.functions.submitGlobalModelHash(modelHash, round).build_transaction({
        'from': sender_address,
        'nonce': w3.eth.get_transaction_count(sender_address)
    })

    signed_transaction = w3.eth.account.sign_transaction(transaction, private_key)

    tx_hash = w3.eth.send_raw_transaction(signed_transaction.raw_transaction)

    tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)

    print(tx_receipt)

In [5]:
def encrypt(message):
    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    encrypted_message = public_key.encrypt(
        message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return encrypted_message

In [6]:
def decrypt(encrypted_message): 
    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    decrypted_message = private_key.decrypt(
        encrypted_message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return decrypted_message

In [None]:


class FederatedServer:
    def __init__(self, address="tcp://*:9003"):
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REP)  # REPLY socket
        self.socket.bind(address)

    def aggregate_models(self, model_list):
        """Aggregate model weights by averaging."""
        avg_model = model_list[0]
        for key in avg_model.keys():
            for model in model_list[1:]:
                avg_model[key] += model[key]
            avg_model[key] /= len(model_list)
        return avg_model

    def run(self):
        print("Server is running...")
        model_updates = []

        while True:
            message = self.socket.recv()
            client_update = pickle.loads(message)
            print("Received model update from TEE.")

            model_updates.append(client_update)

            if len(model_updates) >= 2:  # Assume 2 clients for simplicity
                print("Aggregating models...")
                aggregated_model = self.aggregate_models(model_updates)

                aggregated_model_bytes = pickle.dumps(aggregated_model)
                aggregated_model_hash = Web3.keccak(aggregated_model_bytes).hex()
                submitAggregateModelHash(aggregated_model_hash , 1)
                self.socket.send(aggregated_model_bytes)
                
                print("Aggregated model sent to TEE.")
                model_updates = []  # Reset after aggregation
            else:
                self.socket.send(b"ACK")  # Acknowledge reception

if __name__ == "__main__":
    server = FederatedServer()
    server.run()
