In [1]:
%pip install zmq web3

Note: you may need to restart the kernel to use updated packages.


In [2]:
import zmq
import pickle
from web3 import Web3
from eth_account import Account

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.backends import default_backend

In [3]:
from web3 import Web3
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))
address = "0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9"
abi = '[{"inputs": [{"internalType": "address","name": "_admin","type": "address"},{"internalType": "address","name": "_server","type": "address"}],"stateMutability": "nonpayable","type": "constructor"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "oldAdmin","type": "address"},{"indexed": true,"internalType": "address","name": "newAdmin","type": "address"}],"name": "AdminTransferred","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"}],"name": "ClientDeregistered","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"}],"name": "ClientRegistered","type": "event"},{"inputs": [{"internalType": "address","name": "_client","type": "address"}],"name": "deregisterClient","outputs": [],"stateMutability": "nonpayable","type": "function"},{"anonymous": false,"inputs": [{"indexed": false,"internalType": "string","name": "modelHash","type": "string"},{"indexed": false,"internalType": "uint256","name": "round","type": "uint256"}],"name": "GlobalModelSubmitted","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"},{"indexed": false,"internalType": "string","name": "modelHash","type": "string"},{"indexed": false,"internalType": "uint256","name": "round","type": "uint256"}],"name": "LocalModelSubmitted","type": "event"},{"inputs": [{"internalType": "address","name": "_client","type": "address"}],"name": "registerClient","outputs": [],"stateMutability": "nonpayable","type": "function"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "newServer","type": "address"}],"name": "ServerUpdated","type": "event"},{"inputs": [{"internalType": "string","name": "_modelHash","type": "string"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "submitGlobalModelHash","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "string","name": "_modelHash","type": "string"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "submitLocalModelHash","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "address","name": "_newAdmin","type": "address"}],"name": "transferAdmin","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "address","name": "_newServer","type": "address"}],"name": "updateServer","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [],"name": "admin","outputs": [{"internalType": "address","name": "","type": "address"}],"stateMutability": "view","type": "function"},{"inputs": [{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "getGlobalModelHash","outputs": [{"internalType": "string","name": "","type": "string"}],"stateMutability": "view","type": "function"},{"inputs": [{"internalType": "address","name": "_client","type": "address"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "getLocalModelHashAtRound","outputs": [{"internalType": "string","name": "","type": "string"}],"stateMutability": "view","type": "function"},{"inputs": [],"name": "server","outputs": [{"internalType": "address","name": "","type": "address"}],"stateMutability": "view","type": "function"}]'

contract_instance = w3.eth.contract(address=address, abi=abi)

In [4]:
with open("keys/server", "r") as file:
    private_key = file.read()

account = Account.from_key(private_key)
sender_address = account.address

def submitAggregateModelHash(modelHash: str, round: int):
    transaction = contract_instance.functions.submitGlobalModelHash(modelHash, round).build_transaction({
        'from': sender_address,
        'nonce': w3.eth.get_transaction_count(sender_address)
    })

    signed_transaction = w3.eth.account.sign_transaction(transaction, private_key)

    tx_hash = w3.eth.send_raw_transaction(signed_transaction.raw_transaction)

    tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)

    print(tx_receipt)

In [5]:
def encrypt(message):
    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    encrypted_message = public_key.encrypt(
        message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return encrypted_message

In [6]:
def decrypt(encrypted_message): 
    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    decrypted_message = private_key.decrypt(
        encrypted_message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return decrypted_message

In [None]:


class FederatedServer:
    def __init__(self, address="tcp://*:9003"):
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.REP)  # REPLY socket
        self.socket.bind(address)

    def aggregate_models(self, model_list):
        """Aggregate model weights by averaging."""
        avg_model = model_list[0]
        for key in avg_model.keys():
            for model in model_list[1:]:
                avg_model[key] += model[key]
            avg_model[key] /= len(model_list)
        return avg_model

    def run(self):
        print("Server is running...")
        model_updates = []

        while True:
            message = self.socket.recv()
            client_update = pickle.loads(message)
            print("Received model update from TEE.")

            model_updates.append(client_update)

            if len(model_updates) >= 2:  # Assume 2 clients for simplicity
                print("Aggregating models...")
                aggregated_model = self.aggregate_models(model_updates)

                aggregated_model_bytes = pickle.dumps(aggregated_model)
                aggregated_model_hash = Web3.keccak(aggregated_model_bytes).hex()
                submitAggregateModelHash(aggregated_model_hash , 1)
                self.socket.send(aggregated_model_bytes)
                
                print("Aggregated model sent to TEE.")
                model_updates = []  # Reset after aggregation
            else:
                self.socket.send(b"ACK")  # Acknowledge reception

if __name__ == "__main__":
    server = FederatedServer()
    server.run()


Server is running...


  return torch.load(io.BytesIO(b))


Received model update from TEE.
Received model update from TEE.
Aggregating models...
AttributeDict({'transactionHash': HexBytes('0x7dd50fe291a76873183053b4286fbbe9c39508136e8665f7e7ef108fbd9b4160'), 'transactionIndex': 0, 'blockHash': HexBytes('0x8cabbf627c4ba8a76423db93c8394fd137c1b0b8f0593305bcd98223b40990f6'), 'from': '0xa0BBC1A2b77d499102e8836d7EBc12A8D4B351D4', 'to': '0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9', 'blockNumber': 3481, 'cumulativeGasUsed': 103474, 'gasUsed': 103474, 'contractAddress': None, 'logs': [AttributeDict({'address': '0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9', 'topics': [HexBytes('0x11c5701841d606d3ead07e128db0b6d300a6204eaf51cee07dcd443ec5811d76')], 'data': HexBytes('0x000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000040393964383330633536343064376136323132363035306163376637643462343834646665366264353934333861613866653