In [1]:
!pip install torch torchvision web3

Collecting web3
  Downloading web3-7.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting eth-abi>=5.0.1 (from web3)
  Downloading eth_abi-5.1.0-py3-none-any.whl.metadata (5.1 kB)
Collecting eth-account>=0.13.1 (from web3)
  Downloading eth_account-0.13.4-py3-none-any.whl.metadata (5.3 kB)
Collecting eth-hash>=0.5.1 (from eth-hash[pycryptodome]>=0.5.1->web3)
  Downloading eth_hash-0.7.0-py3-none-any.whl.metadata (5.4 kB)
Collecting eth-typing>=5.0.0 (from web3)
  Downloading eth_typing-5.0.0-py3-none-any.whl.metadata (5.1 kB)
Collecting eth-utils>=5.0.0 (from web3)
  Downloading eth_utils-5.0.0-py3-none-any.whl.metadata (5.4 kB)
Collecting hexbytes>=1.2.0 (from web3)
  Downloading hexbytes-1.2.1-py3-none-any.whl.metadata (3.7 kB)
Collecting types-requests>=2.0.0 (from web3)
  Downloading types_requests-2.32.0.20240914-py3-none-any.whl.metadata (1.9 kB)
Collecting websockets>=10.0.0 (from web3)
  Downloading websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from web3 import Web3
import threading

# Set up Ethereum
infura_url = 'https://sepolia.infura.io/v3/ba0531e055a744aba6851b78bdda21be'
web3 = Web3(Web3.HTTPProvider(infura_url))
contract_address = web3.to_checksum_address('0x4bd3de79afc02629a23d85dcc2692f4c0671eadd')

# ABI of the contract
contract_abi = [
    {
        "inputs": [],
        "name": "getLogs",
        "outputs": [
            {
                "components": [
                    {
                        "internalType": "address",
                        "name": "client",
                        "type": "address"
                    },
                    {
                        "internalType": "uint256",
                        "name": "timestamp",
                        "type": "uint256"
                    },
                    {
                        "internalType": "uint256",
                        "name": "accuracy",
                        "type": "uint256"
                    }
                ],
                "internalType": "struct AccuracyLogger.AccuracyLog[]",
                "name": "",
                "type": "tuple[]"
            }
        ],
        "stateMutability": "view",
        "type": "function"
    },
    {
        "inputs": [],
        "name": "getLogsLength",
        "outputs": [
            {
                "internalType": "uint256",
                "name": "",
                "type": "uint256"
            }
        ],
        "stateMutability": "view",
        "type": "function"
    },
    {
        "inputs": [
            {
                "internalType": "uint256",
                "name": "_accuracy",
                "type": "uint256"
            }
        ],
        "name": "logAccuracy",
        "outputs": [],
        "stateMutability": "nonpayable",
        "type": "function"
    },
    {
        "inputs": [
            {
                "internalType": "uint256",
                "name": "",
                "type": "uint256"
            }
        ],
        "name": "logs",
        "outputs": [
            {
                "internalType": "address",
                "name": "client",
                "type": "address"
            },
            {
                "internalType": "uint256",
                "name": "timestamp",
                "type": "uint256"
            },
            {
                "internalType": "uint256",
                "name": "accuracy",
                "type": "uint256"
            }
        ],
        "stateMutability": "view",
        "type": "function"
    }
]

# Load the contract
contract = web3.eth.contract(address=contract_address, abi=contract_abi)

# Ethereum accounts and private keys for each client
client_accounts = [
    {'account': '0xeC85984aB1f737979Ae3a640c66F49AB71aba490', 'private_key': 'a564e210c37c8569f90ec5c11d08a3733f68581ac333af3fbc1b48262fbf31a0'},
    {'account': '0x129FF84CdB84EC51ef6bC22c4fffEf63DfbCFDE5', 'private_key': '7dbab9c80d19cfd651ab20f32e72061c91b35a1dd5c56428a92d6108a0e669e4'},
    {'account': '0xD10f4E2DDd6072ae47444022AF6b1736A98ADE58', 'private_key': 'c8d0692b3692efc9185424f58f418ffe6b71c3e70e67cadcd30dd615e7aba4c0'},
    {'account': '0xbBc37445850f11907ECB5B9296991C4E8a0b950C', 'private_key': '49fcba5cc4db9bb763a830a93fc54592c28ee788183c810e74d918e9ca6fcd77'}
]

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_size = len(trainset)
num_clients = 4
client_size = train_size // num_clients

# Create client datasets
client_datasets = [Subset(trainset, range(i * client_size, (i + 1) * client_size)) for i in range(num_clients)]

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 100)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Function to log aggregated accuracy on Ethereum for each client with a different account
def log_aggregated_accuracy_to_ethereum(client_id, accuracy, account, private_key):
    # Get the nonce for the account
    nonce = web3.eth.get_transaction_count(account)

    # Convert accuracy to an integer format (e.g., scale it by 10000)
    accuracy_int = int(accuracy * 10000)

    # Build the transaction to call the smart contract's logAccuracy method
    transaction = contract.functions.logAccuracy(accuracy_int).build_transaction({
        'from': account,
        'gas': 4000000,
        'gasPrice': web3.to_wei('150', 'gwei'),
        'nonce': nonce
    })

    # Sign the transaction with the client's private key
    signed_txn = web3.eth.account.sign_transaction(transaction, private_key)

    # Send the transaction to the Ethereum network
    tx_hash = web3.eth.send_raw_transaction(signed_txn.raw_transaction)

    # Print the transaction hash
    print(f'Client {client_id + 1}: Transaction sent with hash: {tx_hash.hex()}')

# Define a function for client training
def train_client(client_id, dataset):
    client_data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    total_accuracy = 0
    total_samples = 0

    for epoch in range(30):  # Number of epochs
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in client_data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        # Calculate accuracy for the epoch
        accuracy = correct / total
        print(f'Client {client_id + 1}, Epoch {epoch + 1}, Loss: {running_loss/len(client_data_loader)}, Accuracy: {accuracy:.4f}')

        # Update aggregated accuracy and sample count
        total_accuracy += accuracy * total
        total_samples += total

    # Log final epoch accuracy to Ethereum after all epochs
    if total_samples > 0:  # Check to avoid division by zero
        final_accuracy = total_accuracy / total_samples
        log_aggregated_accuracy_to_ethereum(client_id, final_accuracy, client_accounts[client_id]['account'], client_accounts[client_id]['private_key'])

    return model.state_dict()  # Return the local model's weights

# Global model for aggregation
global_model = SimpleCNN()
global_weights = global_model.state_dict()

# Threads for client training
threads = []
local_models_weights = []

# Start training for each client in a separate thread
for client_id in range(num_clients):
    thread = threading.Thread(target=lambda id=client_id: local_models_weights.append(train_client(id, client_datasets[id])))
    threads.append(thread)
    thread.start()

# Wait for all threads to complete
for thread in threads:
    thread.join()

# Aggregate the weights from local models into the global model
for key in global_weights.keys():
    global_weights[key] = torch.mean(torch.stack([torch.tensor(local_models_weights[i][key], dtype=torch.float32) for i in range(num_clients)]), dim=0)

# Load the aggregated weights back into the global model
global_model.load_state_dict(global_weights)

print('Federated Learning Finished. Global Model is Ready.')


Files already downloaded and verified
Client 2, Epoch 1, Loss: 4.131866967891488, Accuracy: 0.0745
Client 1, Epoch 1, Loss: 4.11047167546304, Accuracy: 0.0746
Client 4, Epoch 1, Loss: 4.081154227561658, Accuracy: 0.0790
Client 3, Epoch 1, Loss: 4.080475634626111, Accuracy: 0.0806
Client 2, Epoch 2, Loss: 3.4776703555260777, Accuracy: 0.1758
Client 1, Epoch 2, Loss: 3.4633395385254375, Accuracy: 0.1770
Client 4, Epoch 2, Loss: 3.420239216226446, Accuracy: 0.1801
Client 3, Epoch 2, Loss: 3.4551418095903323, Accuracy: 0.1810
Client 2, Epoch 3, Loss: 3.0828543788636735, Accuracy: 0.2426
Client 4, Epoch 3, Loss: 3.0161626881650645, Accuracy: 0.2507
Client 1, Epoch 3, Loss: 3.0424661361957757, Accuracy: 0.2498
Client 3, Epoch 3, Loss: 3.0883688524251096, Accuracy: 0.2396
Client 2, Epoch 4, Loss: 2.7917135722192046, Accuracy: 0.2952
Client 4, Epoch 4, Loss: 2.667445073042379, Accuracy: 0.3246
Client 1, Epoch 4, Loss: 2.730954771456511, Accuracy: 0.3098
Client 3, Epoch 4, Loss: 2.7706125114884

  global_weights[key] = torch.mean(torch.stack([torch.tensor(local_models_weights[i][key], dtype=torch.float32) for i in range(num_clients)]), dim=0)
