In [1]:
%pip install zmq web3 cryptography

Note: you may need to restart the kernel to use updated packages.


In [2]:
import zmq
from web3 import Web3

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.backends import default_backend

In [3]:
def encrypt(message):
    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    encrypted_message = public_key.encrypt(
        message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return encrypted_message

In [4]:
def decrypt(encrypted_message): 
    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    decrypted_message = private_key.decrypt(
        encrypted_message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return decrypted_message

In [5]:
from web3 import Web3
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))
address = "0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9"
abi = '[{"inputs": [{"internalType": "address","name": "_admin","type": "address"},{"internalType": "address","name": "_server","type": "address"}],"stateMutability": "nonpayable","type": "constructor"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "oldAdmin","type": "address"},{"indexed": true,"internalType": "address","name": "newAdmin","type": "address"}],"name": "AdminTransferred","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"}],"name": "ClientDeregistered","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"}],"name": "ClientRegistered","type": "event"},{"inputs": [{"internalType": "address","name": "_client","type": "address"}],"name": "deregisterClient","outputs": [],"stateMutability": "nonpayable","type": "function"},{"anonymous": false,"inputs": [{"indexed": false,"internalType": "string","name": "modelHash","type": "string"},{"indexed": false,"internalType": "uint256","name": "round","type": "uint256"}],"name": "GlobalModelSubmitted","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"},{"indexed": false,"internalType": "string","name": "modelHash","type": "string"},{"indexed": false,"internalType": "uint256","name": "round","type": "uint256"}],"name": "LocalModelSubmitted","type": "event"},{"inputs": [{"internalType": "address","name": "_client","type": "address"}],"name": "registerClient","outputs": [],"stateMutability": "nonpayable","type": "function"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "newServer","type": "address"}],"name": "ServerUpdated","type": "event"},{"inputs": [{"internalType": "string","name": "_modelHash","type": "string"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "submitGlobalModelHash","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "string","name": "_modelHash","type": "string"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "submitLocalModelHash","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "address","name": "_newAdmin","type": "address"}],"name": "transferAdmin","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "address","name": "_newServer","type": "address"}],"name": "updateServer","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [],"name": "admin","outputs": [{"internalType": "address","name": "","type": "address"}],"stateMutability": "view","type": "function"},{"inputs": [{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "getGlobalModelHash","outputs": [{"internalType": "string","name": "","type": "string"}],"stateMutability": "view","type": "function"},{"inputs": [{"internalType": "address","name": "_client","type": "address"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "getLocalModelHashAtRound","outputs": [{"internalType": "string","name": "","type": "string"}],"stateMutability": "view","type": "function"},{"inputs": [],"name": "server","outputs": [{"internalType": "address","name": "","type": "address"}],"stateMutability": "view","type": "function"}]'

contract_instance = w3.eth.contract(address=address, abi=abi)

In [None]:
class TEE:
    def __init__(self, client_address="tcp://*:9002", server_address="tcp://127.0.0.1:9003"):
        self.context = zmq.Context()

        # Client Communication
        self.client_socket = self.context.socket(zmq.REP)  # REPLY to client
        self.client_socket.bind(client_address)

        # Server Communication
        self.server_socket = self.context.socket(zmq.REQ)  # REQUEST to server
        self.server_socket.connect(server_address)

    def forward_to_server(self, data):
        """Forward the model received from the client to the server."""
        print("Forwarding model to server...")
        self.server_socket.send(data)
        return self.server_socket.recv()  # Receive response from server

    def forward_to_client(self, data):
        """Forward the aggregated model from the server to the client."""
        print("Forwarding aggregated model to client...")
        self.client_socket.send(data)

    def run(self):
        print("TEE is running...")
        while True:
            # Receive model from client
            client_message = self.client_socket.recv()
            print("Received model from client.")

            # Forward to server and get the response (aggregated model or ACK)
            server_response = self.forward_to_server(client_message)

            # Send the server response back to the client
            self.forward_to_client(server_response)

if __name__ == "__main__":
    tee_processor = TEE()
    tee_processor.run()


TEE is running...
Received model from client.
Forwarding model to server...
Forwarding aggregated model to client...
Received model from client.
Forwarding model to server...
Forwarding aggregated model to client...
Received model from client.
Forwarding model to server...
Forwarding aggregated model to client...
Received model from client.
Forwarding model to server...
Forwarding aggregated model to client...
Received model from client.
Forwarding model to server...
Forwarding aggregated model to client...
