In [1]:
%pip install torch torchvision zmq web3 cryptography

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import zmq
import pickle
from web3 import Web3
from eth_account import Account

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.backends import default_backend

In [3]:
def encrypt(message):
    with open("keys/public.pem", "rb") as f:
        public_key = serialization.load_pem_public_key(
            f.read(),
            backend=default_backend()
        )

    encrypted_message = public_key.encrypt(
        message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return encrypted_message

In [4]:
def decrypt(encrypted_message): 
    with open("keys/private.pem", "rb") as f:
        private_key = serialization.load_pem_private_key(
            f.read(),
            password=None,
            backend=default_backend()
        )

    decrypted_message = private_key.decrypt(
        encrypted_message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=SHA256()),
            algorithm=SHA256(),
            label=None
        )
    )

    return decrypted_message

In [5]:
w3 = Web3(Web3.HTTPProvider("http://127.0.0.1:9944"))
address = "0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9"
abi = '[{"inputs": [{"internalType": "address","name": "_admin","type": "address"},{"internalType": "address","name": "_server","type": "address"}],"stateMutability": "nonpayable","type": "constructor"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "oldAdmin","type": "address"},{"indexed": true,"internalType": "address","name": "newAdmin","type": "address"}],"name": "AdminTransferred","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"}],"name": "ClientDeregistered","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"}],"name": "ClientRegistered","type": "event"},{"inputs": [{"internalType": "address","name": "_client","type": "address"}],"name": "deregisterClient","outputs": [],"stateMutability": "nonpayable","type": "function"},{"anonymous": false,"inputs": [{"indexed": false,"internalType": "string","name": "modelHash","type": "string"},{"indexed": false,"internalType": "uint256","name": "round","type": "uint256"}],"name": "GlobalModelSubmitted","type": "event"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "client","type": "address"},{"indexed": false,"internalType": "string","name": "modelHash","type": "string"},{"indexed": false,"internalType": "uint256","name": "round","type": "uint256"}],"name": "LocalModelSubmitted","type": "event"},{"inputs": [{"internalType": "address","name": "_client","type": "address"}],"name": "registerClient","outputs": [],"stateMutability": "nonpayable","type": "function"},{"anonymous": false,"inputs": [{"indexed": true,"internalType": "address","name": "newServer","type": "address"}],"name": "ServerUpdated","type": "event"},{"inputs": [{"internalType": "string","name": "_modelHash","type": "string"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "submitGlobalModelHash","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "string","name": "_modelHash","type": "string"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "submitLocalModelHash","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "address","name": "_newAdmin","type": "address"}],"name": "transferAdmin","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [{"internalType": "address","name": "_newServer","type": "address"}],"name": "updateServer","outputs": [],"stateMutability": "nonpayable","type": "function"},{"inputs": [],"name": "admin","outputs": [{"internalType": "address","name": "","type": "address"}],"stateMutability": "view","type": "function"},{"inputs": [{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "getGlobalModelHash","outputs": [{"internalType": "string","name": "","type": "string"}],"stateMutability": "view","type": "function"},{"inputs": [{"internalType": "address","name": "_client","type": "address"},{"internalType": "uint256","name": "_round","type": "uint256"}],"name": "getLocalModelHashAtRound","outputs": [{"internalType": "string","name": "","type": "string"}],"stateMutability": "view","type": "function"},{"inputs": [],"name": "server","outputs": [{"internalType": "address","name": "","type": "address"}],"stateMutability": "view","type": "function"}]'
contract_instance = w3.eth.contract(address=address, abi=abi)

In [6]:
with open("keys/client1", "r") as file:
    private_key = file.read()

account = Account.from_key(private_key)
sender_address = account.address

def submitLocalModelHash(modelHash: str, round: int):
    transaction = contract_instance.functions.submitLocalModelHash(modelHash, round).build_transaction({
        'from': sender_address,
        'nonce': w3.eth.get_transaction_count(sender_address)
    })

    signed_transaction = w3.eth.account.sign_transaction(transaction, private_key)

    tx_hash = w3.eth.send_raw_transaction(signed_transaction.raw_transaction)

    tx_receipt = w3.eth.wait_for_transaction_receipt(tx_hash)

    print(tx_receipt)

In [7]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def train(model, data_loader, optimizer, criterion, epochs=1):
    """Train the model locally on client data."""
    model.train()
    for epoch in range(epochs):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

def send_model_weights(socket, model, round_num):
    """Send model weights to the TEE."""
    model_weights = model.state_dict()
    model_bytes = pickle.dumps(model_weights)
    
    model_hash = Web3.keccak(model_bytes).hex()
    submitLocalModelHash(model_hash , round_num)
    socket.send(model_bytes)

def receive_aggregated_model(socket, model):
    """Receive aggregated model from the TEE and load it."""
    message = socket.recv()
    try:
        new_weights = pickle.loads(message)
        model.load_state_dict(new_weights)
        print("Received aggregated model from server.")
    except pickle.UnpicklingError:
        if message == b"ACK":
            print("Acknowledgment received from server.")
        else:
            print("Unexpected response from server.")

if __name__ == "__main__":
    context = zmq.Context()
    socket = context.socket(zmq.REQ)  # REQUEST socket
    socket.connect("tcp://127.0.0.1:9002")  # Connect to TEE

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    model = SimpleNN()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for round_num in range(5):
        print(f"Training round {round_num + 1}")
        train(model, train_loader, optimizer, criterion, epochs=1)

        print("Sending model to TEE...")
        send_model_weights(socket, model, round_num)

        print("Waiting for aggregated model from TEE...")
        receive_aggregated_model(socket, model)
        print("Received aggregated model from server.")


Training round 1
Sending model to TEE...
AttributeDict({'transactionHash': HexBytes('0xcb570c36f163d78aeb23f671e1ab8fef5fd084b69a05cbcf6e19319895ec76d0'), 'transactionIndex': 0, 'blockHash': HexBytes('0x1daeef9e6dc504753f354edf22d1835c795042402636d0922e52560dcf5e6029'), 'from': '0x5b90B3A89B4E973C7171174cBD2DB4288bF2c2B7', 'to': '0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9', 'blockNumber': 3479, 'cumulativeGasUsed': 103474, 'gasUsed': 103474, 'contractAddress': None, 'logs': [AttributeDict({'address': '0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9', 'topics': [HexBytes('0x6cc860100eb340319074c21d48e83cf486916acfbdef8d03e78360c1dfa019ff'), HexBytes('0x0000000000000000000000005b90b3a89b4e973c7171174cbd2db4288bf2c2b7')], 'data': HexBytes('0x0000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000406632383932633365383531376363356631376539616262613131383938396132

  return torch.load(io.BytesIO(b))


Received aggregated model from server.
Received aggregated model from server.
Training round 3
Sending model to TEE...
AttributeDict({'transactionHash': HexBytes('0xbedc481cc61e52563013037a5badca68b4d57de6cdea4c382f5e8030e81d3eeb'), 'transactionIndex': 0, 'blockHash': HexBytes('0x768fc98d968e8e883ae2cf8b75ddb08307be8a54546c15b5249c5ac55ac5e94e'), 'from': '0x5b90B3A89B4E973C7171174cBD2DB4288bF2c2B7', 'to': '0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9', 'blockNumber': 3482, 'cumulativeGasUsed': 103474, 'gasUsed': 103474, 'contractAddress': None, 'logs': [AttributeDict({'address': '0x59A939E2a21CC79073cDe58Bf80108a7401e6Bd9', 'topics': [HexBytes('0x6cc860100eb340319074c21d48e83cf486916acfbdef8d03e78360c1dfa019ff'), HexBytes('0x0000000000000000000000005b90b3a89b4e973c7171174cbd2db4288bf2c2b7')], 'data': HexBytes('0x0000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000